MLIR 23.0.0git
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
29#include "mlir/IR/AffineExpr.h"
31#include "mlir/IR/AffineMap.h"
32#include "mlir/IR/Matchers.h"
33#include "llvm/ADT/SmallVectorExtras.h"
34#include "llvm/ADT/TypeSwitch.h"
35#include "llvm/Support/Debug.h"
36
37#include <optional>
38
39#define DEBUG_TYPE "linalg-utils"
40
41using namespace mlir;
42using namespace presburger;
43using namespace mlir::affine;
44using namespace mlir::linalg;
45using namespace mlir::scf;
46
47namespace {
48
49// Helper visitor to determine whether an AffineExpr is tiled.
50// This is achieved by traversing every AffineDimExpr with position `pos` and
51// checking whether the corresponding `tileSizes[pos]` is non-zero.
52// This also enforces only positive coefficients occur in multiplications.
53//
54// Example:
55// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
56//
57struct TileCheck : public AffineExprVisitor<TileCheck> {
58 TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
59
60 void visitDimExpr(AffineDimExpr expr) {
61 isTiled |= !isZeroInteger(tileSizes[expr.getPosition()]);
62 }
63 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
64 visit(expr.getLHS());
65 visit(expr.getRHS());
67 assert(cast<AffineConstantExpr>(expr.getRHS()).getValue() > 0 &&
68 "nonpositive multiplying coefficient");
69 }
70 bool isTiled = false;
71 ArrayRef<OpFoldResult> tileSizes;
72};
73
74} // namespace
75
76static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
77 if (!expr)
78 return false;
79 TileCheck t(tileSizes);
80 t.visit(expr);
81 return t.isTiled;
82}
83
84// Checks whether the `map varies with respect to a non-zero `tileSize`.
85static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
86 if (!map)
87 return false;
88 for (unsigned r = 0; r < map.getNumResults(); ++r)
89 if (isTiled(map.getResult(r), tileSizes))
90 return true;
91 return false;
92}
93
94std::optional<RegionMatcher::BinaryOpKind>
96 auto &region = op.getRegion();
97 if (!region.hasOneBlock())
98 return std::nullopt;
99
100 Block &block = region.front();
101 if (block.getNumArguments() != 2 ||
104 return std::nullopt;
105
106 auto &ops = block.getOperations();
107 if (!llvm::hasSingleElement(block.without_terminator()))
108 return std::nullopt;
109
111 auto a = m_Val(block.getArgument(0));
112 auto b = m_Val(block.getArgument(1));
113
114 auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b));
115 if (addPattern.match(&ops.back()))
116 return BinaryOpKind::IAdd;
117
118 return std::nullopt;
119}
120
121/// Explicit instantiation of loop nest generator for different loop types.
125
126/// Given a list of subview ranges, extract individual values for lower, upper
127/// bounds and steps and put them into the corresponding vectors.
128static void unpackRanges(OpBuilder &builder, Location loc,
131 SmallVectorImpl<Value> &steps) {
132 for (Range range : ranges) {
133 lbs.emplace_back(
134 getValueOrCreateConstantIndexOp(builder, loc, range.offset));
135 ubs.emplace_back(getValueOrCreateConstantIndexOp(builder, loc, range.size));
136 steps.emplace_back(
137 getValueOrCreateConstantIndexOp(builder, loc, range.stride));
138 }
139}
140
141//===----------------------------------------------------------------------===//
142// General utilities
143//===----------------------------------------------------------------------===//
144//
145/// The permutation can be obtained from two permutations:
146/// a) Compute the permutation vector to move the last `numPackedDims` into
147/// the `innerPosDims` of a shape of rank `rank`.
148/// b) Compute the permutation vector to move outer dims if the
149/// `outerPerm` parameter is not empty.
150/// Apply (b) permutation on (a) permutation to get the final permutation.
151static SmallVector<int64_t>
153 ArrayRef<int64_t> &outerPerm,
154 PackingMetadata &packingMetadata) {
155 int64_t numPackedDims = innerDimsPos.size();
156 auto lastDims =
157 llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank));
158 packingMetadata = computePackingMetadata(rank, innerDimsPos);
159 SmallVector<int64_t> innerPositionsPerm =
160 computePermutationVector(rank, lastDims, packingMetadata.insertPositions);
161
162 SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
163 if (!outerPerm.empty())
164 applyPermutationToVector(outerPos, outerPerm);
165 SmallVector<int64_t> outerPositionPerm =
166 computePermutationVector(rank, packingMetadata.outerPositions, outerPos);
167
168 SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
169 applyPermutationToVector(packInverseDestPermutation, outerPositionPerm);
170 return packInverseDestPermutation;
171}
172
173namespace mlir {
174namespace linalg {
175
177 PackingMetadata &metadata) {
178
179 int64_t packedRank = packOp.getDestType().getRank();
180 ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
181 ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
182 SmallVector<int64_t> packInvDestPerm =
183 computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
184 return packInvDestPerm;
185}
186
188 PackingMetadata &metadata) {
189 int64_t packedRank = unpackOp.getSourceType().getRank();
190 ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
191 ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
192 SmallVector<int64_t> unpackInvSrcPerm =
193 computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
194 return unpackInvSrcPerm;
195}
196
198 return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) {
199 return m.isProjectedPermutation(/*allowZeroInResults=*/true);
200 });
201}
202
204 if (!r.hasOneBlock())
205 return false;
206 for (Operation &op : r.front()) {
207 if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
208 linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) ||
210 llvm::any_of(op.getResultTypes(),
211 [](Type type) { return !type.isIntOrIndexOrFloat(); }))
212 return false;
213 }
214 return true;
215}
216
217bool isElementwise(LinalgOp op) {
218 if (op.getNumLoops() != op.getNumParallelLoops())
219 return false;
220
222 return false;
223
224 // TODO: relax the restrictions on indexing map.
225 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
226 if (!op.getMatchingIndexingMap(&opOperand).isPermutation())
227 return false;
228 }
229 return hasOnlyScalarElementwiseOp(op->getRegion(0));
230}
231
232bool isParallelIterator(utils::IteratorType iteratorType) {
233 return iteratorType == utils::IteratorType::parallel;
234}
235
236bool isReductionIterator(utils::IteratorType iteratorType) {
237 return iteratorType == utils::IteratorType::reduction;
238}
239
240//===----------------------------------------------------------------------===//
241// Convolution matcher utilities
242//===----------------------------------------------------------------------===//
243
244/// Returns the BlockArgument that leads to `val`, if any. Traverses optional
245/// ext*/sitofp ops.
247 BlockArgument blockArg = dyn_cast<BlockArgument>(val);
248 if ((blockArg))
249 return blockArg;
250
251 Operation *defOp = val.getDefiningOp();
252 if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
253 !dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
254 !dyn_cast_if_present<arith::ExtUIOp>(defOp) &&
255 !dyn_cast_if_present<arith::SIToFPOp>(defOp)) {
256 return nullptr;
257 }
258 return dyn_cast<BlockArgument>(defOp->getOperand(0));
259}
260
261/// Utility function to match the zero point offset body of quantized
262/// convolution ops.
263///
264/// Quantized convolutions have a body of the form:
265/// %out + ((%input - %inputZp) * (%filter - %filterZp))
266/// where:
267/// - %input is the input tensor element (block arg 0)
268/// - %filter is the filter tensor element (block arg 1)
269/// - %inputZp is the input zero-point scalar (block arg 2)
270/// - %filterZp is the filter zero-point scalar (block arg 3)
271/// - %out is the output accumulator (block arg 4)
272///
273/// This function verifies that the multiplication operands are subtraction
274/// operations matching this pattern.
276 Block *body) {
277 // The multiplication should have two subtraction operands:
278 // one for (input - inputZp) and one for (filter - filterZp).
279 Operation *inputSubOp = mulOp->getOperand(0).getDefiningOp();
280 if (!isa_and_present<arith::SubIOp, arith::SubFOp>(inputSubOp))
281 return false;
282
283 Operation *filterSubOp = mulOp->getOperand(1).getDefiningOp();
284 if (!isa_and_present<arith::SubIOp, arith::SubFOp>(filterSubOp))
285 return false;
286
287 // Extract block arguments from subtraction operands.
288 BlockArgument inputBlockArg =
290 BlockArgument inputZpBlockArg =
292 BlockArgument filterBlockArg =
294 BlockArgument filterZpBlockArg =
296 BlockArgument outBlockArg =
298
299 // Verify all block arguments are valid.
300 if (!inputBlockArg || !inputZpBlockArg || !filterBlockArg ||
301 !filterZpBlockArg || !outBlockArg)
302 return false;
303
304 // Verify all block arguments belong to the convolution body.
305 if (inputBlockArg.getOwner() != body || inputZpBlockArg.getOwner() != body ||
306 filterBlockArg.getOwner() != body ||
307 filterZpBlockArg.getOwner() != body || outBlockArg.getOwner() != body)
308 return false;
309
310 // Verify block arguments have expected indices:
311 // arg0: input, arg1: filter, arg2: inputZp, arg3: filterZp, arg4: output
312 if (inputBlockArg.getArgNumber() != 0 || filterBlockArg.getArgNumber() != 1 ||
313 inputZpBlockArg.getArgNumber() != 2 ||
314 filterZpBlockArg.getArgNumber() != 3 || outBlockArg.getArgNumber() != 4)
315 return false;
316
317 return true;
318}
319
320/// Utility to match block body for convolution ops.
321/// The body is thus expected to yield :-
322/// %out + (%lhs * %rhs)
323/// where: %lhs, %rhs and %out are block arguments and
324/// %lhs and %rhs can have optional upcast operation.
325/// For i1 element types, the pattern matches:
326/// %out | (%lhs & %rhs)
327/// using arith.ori for accumulation and arith.andi for multiplication.
328/// NOTE: In case of zero point offset convolution ops %lhs and %rhs would be :-
329/// %input - %input_scalar
330/// where, %input_scalar can have optional upcast operation.
331static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body,
332 bool containsZeroPointOffset = false) {
333 bool isOrOp = false;
334 Operation *accOp = yieldVal.getDefiningOp();
335 if (!isa_and_present<arith::AddIOp, arith::AddFOp>(accOp)) {
336 if (!isa_and_present<arith::OrIOp>(accOp))
337 return false;
338 isOrOp = true;
339 }
340
341 Operation *mulOp = accOp->getOperand(1).getDefiningOp();
342 if (!isOrOp && !isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
343 return false;
344 if (isOrOp && !isa_and_present<arith::AndIOp>(mulOp))
345 return false;
346
347 if (containsZeroPointOffset) {
348 return bodyMatcherForZeroPointOffsets(accOp, mulOp, body);
349 }
350 BlockArgument lhsBlockArg =
352 BlockArgument rhsBlockArg =
354 BlockArgument outBlockArg =
356 if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
357 lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body ||
358 outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 ||
359 rhsBlockArg.getArgNumber() != 1 || outBlockArg.getArgNumber() != 2)
360 return false;
361 return true;
362}
363
364/// Utility to match block body for linalg.pool* ops.
365template <typename... OpTypes>
366static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
367 Operation *defOp = yieldVal.getDefiningOp();
368 if (!(isa_and_present<OpTypes>(defOp) || ...))
369 return false;
370
371 BlockArgument lhsArg =
373 BlockArgument rhsArg =
375 if (!lhsArg || !rhsArg || lhsArg.getOwner() != body ||
376 rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 ||
377 rhsArg.getArgNumber() != 0)
378 return false;
379 return true;
380}
381
382static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
384 body);
385}
386
387static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
388 return bodyMatcherForPoolOps<arith::MaxUIOp>(yieldVal, body);
389}
390
391static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
393 body);
394}
395
396static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
397 return bodyMatcherForPoolOps<arith::MinUIOp>(yieldVal, body);
398}
399
400/// Matches sum pooling body pattern. For i1 element types, arith.ori is used
401/// instead of arith.addi/arith.addf for accumulation.
402static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
404 yieldVal, body);
405}
406
407static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex,
408 uint32_t dimIndex) {
409 auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
410 if (dimIndex < affineMap.getNumResults())
411 return affineMap.getResult(dimIndex);
412 return nullptr;
413}
414
415/// Check if `expr` is either:
416/// - a dimension expr alone (implying multiplication by 1), or
417/// - a multiplication of dimension expr by any positive constant != 1
418/// In both cases we will capture the dimension expression into `dim` and
419/// return the constant multiplier. Returns -1 in case of a match failure.
421 if ((dim = dyn_cast<AffineDimExpr>(expr)))
422 return 1;
423
424 auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
425 if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
426 return -1;
427
428 AffineExpr lhs = mulExpr.getLHS();
429 AffineExpr rhs = mulExpr.getRHS();
430
431 AffineConstantExpr cst = nullptr;
432 if (((dim = dyn_cast<AffineDimExpr>(lhs)) &&
433 (cst = dyn_cast<AffineConstantExpr>(rhs))) ||
434 ((dim = dyn_cast<AffineDimExpr>(rhs)) &&
435 (cst = dyn_cast<AffineConstantExpr>(lhs))))
436 return cst.getValue();
437 return -1;
438}
439
440/// Given an array of AffineMaps `indexingMaps` verify the following
441/// commutatively:-
442/// indexingMaps[0].getResult(iDim) ==
443/// indexingMaps[1].getResult(fDim) * <c0> +
444/// indexingMaps[n-1].getResult(oDim) * <c1>
445/// where,
446/// - c0 and c1 can be any constant,
447/// - n is the size of the indexingMaps' array,
448/// - 0, 1 and n-1 are input, filter and output map indices respectively,
449/// - iDim, fDim and oDim are the input, filter and output dimension
450/// indices in their respective indexing maps
451/// Example:
452/// #inputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6)
453/// -> (d0, d1 * 2 + d4 * 3, d2 + d5, d6)>
454/// #filterMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
455/// #outputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
456///
457/// Here,
458/// #inputMap[1] = #outputMap[1] * 2 + #filterMap[0] * 3
459/// Therefore,
460/// matchConvDimAddExprPattern(indexingMaps, 1, 0, 1, dilation, stride)
461/// would return true and update dilation = 3 and stride = 2
462static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
463 unsigned fDim, unsigned oDim,
464 int64_t &dilation, int64_t &stride) {
465 unsigned inputMapIdx = 0, filterMapIdx = 1,
466 outputMapIdx = indexingMaps.size() - 1;
467 AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim);
468 auto addExpr = dyn_cast_or_null<AffineBinaryOpExpr>(inpExpr);
469 if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
470 return false;
471
472 AffineExpr dim0, dim1;
473 int64_t c0 = isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0);
474 int64_t c1 = isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1);
475
476 if (c0 == -1 || c1 == -1)
477 return false;
478 // Pattern matched with dims and constants extracted.
479 AffineExpr fExpr = getAffineMapDim(indexingMaps, filterMapIdx, fDim);
480 AffineExpr oExpr = getAffineMapDim(indexingMaps, outputMapIdx, oDim);
481 if (dim0 == fExpr && dim1 == oExpr) {
482 dilation = c0;
483 stride = c1;
484 return true;
485 }
486 if (dim1 == fExpr && dim0 == oExpr) {
487 dilation = c1;
488 stride = c0;
489 return true;
490 }
491 return false;
492}
493
494/// Returns true if the given indexing maps matches with the expected indexing
495/// maps.
497 ArrayAttr indexingMaps, MLIRContext *context) {
498 SmallVector<AffineMap, 4> expectedIndexingMaps =
499 AffineMap::inferFromExprList(mapListExpected, context);
500 return indexingMaps ==
501 ArrayAttr::get(context,
502 llvm::map_to_vector<4>(expectedIndexingMaps,
503 [&](AffineMap m) -> Attribute {
504 return AffineMapAttr::get(m);
505 }));
506}
507
508/// Enum representing pooling operation types used by ConvMatcherBuilder.
517
518/// Helper class for building convolution op matchers with minimal boilerplate.
519/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well
520/// as Pooling ops.
521///
522/// Usage: Create an instance with the op, spatial rank, and output pointers for
523/// extracted dilations/strides. Then chain matchStride() calls for each spatial
524/// dimension, followed by matchMaps() to verify indexing maps, and finally
525/// matchBody() to verify the operation body pattern.
526///
527/// The `matched` flag starts as `true` and is set to `false` if any match step
528/// fails. This allows chaining multiple match calls; once any match fails, all
529/// subsequent calls become no-ops and the final result is `false`.
530///
531/// The `dilations` and `strides` pointers are output parameters that get
532/// populated with the extracted dilation and stride values from the operation's
533/// indexing maps during matchStride() calls. These values are initially set to
534/// 1 for each spatial dimension and updated as patterns are matched.
536 LinalgOp op;
537 MLIRContext *ctx;
538 SmallVector<int64_t> *dilations, *strides;
539 ArrayAttr indexingMaps;
540 PoolingType poolingType;
541 bool matched = true;
542
543public:
544 ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector<int64_t> *d,
546 PoolingType poolingType = PoolingType::None)
547 : op(op), ctx(op->getContext()), dilations(d), strides(s),
548 indexingMaps(op.getIndexingMaps()), poolingType(poolingType) {
549 *dilations = SmallVector<int64_t>(spatialRank, 1);
550 *strides = SmallVector<int64_t>(spatialRank, 1);
551 }
552
553 /// Get affine dimension expression for dimension `i`.
554 AffineExpr dim(unsigned i) { return getAffineDimExpr(i, ctx); }
555
556 /// Build strided expression: base * stride[idx] + kernel * dilation[idx].
557 AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx) {
558 return base * (*strides)[idx] + kernel * (*dilations)[idx];
559 }
560
561 /// Match stride/dilation pattern for a spatial dimension.
562 /// Returns *this for method chaining.
563 ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim,
564 unsigned idx) {
565 if (matched) {
566 matched &= matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
567 (*dilations)[idx], (*strides)[idx]);
568 }
569 return *this;
570 }
571
572 /// Match expected indexing maps layout. Returns *this for method chaining.
574 if (matched)
575 matched &= convLayoutMatches(maps, indexingMaps, ctx);
576 return *this;
577 }
578
579 /// Match body pattern. This should be called last.
580 bool matchBody(bool containsZeroPointOffset = false) {
581 if (!matched)
582 return false;
583 Block *body = op.getBlock();
584 auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
585 switch (poolingType) {
587 return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body,
588 containsZeroPointOffset);
590 return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body);
592 return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body);
594 return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body);
596 return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body);
597 case PoolingType::Sum:
598 return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body);
599 }
600 return false;
601 }
602};
603
604//===----------------------------------------------------------------------===//
605// Matchers for specific convolution operation.
606//===----------------------------------------------------------------------===//
607
608template <>
609std::optional<DilationsAndStrides>
612 if (isa<linalg::Conv1DOp>(op)) {
613 // Conv1DOp has no strides/dilations attributes, default to 1.
614 result.dilations = SmallVector<int64_t>(1, 1);
615 result.strides = SmallVector<int64_t>(1, 1);
616 return result;
617 }
618
620 return std::nullopt;
621
622 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
623 &result.strides);
624 AffineExpr W = m.dim(0);
625 AffineExpr w = m.dim(1);
626
627 if (m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
628 .matchMaps({/*inputMap=*/{m.strided(W, w, 0)},
629 /*filterMap=*/{w},
630 /*outputMap=*/{W}})
631 .matchBody())
632 return result;
633 return std::nullopt;
634}
635
636template <>
637std::optional<DilationsAndStrides>
640 if (auto convOp = dyn_cast<linalg::Conv1DNwcWcfOp>(op.getOperation())) {
641 result.dilations =
642 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
643 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
644 return result;
645 }
646
648 return std::nullopt;
649
650 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
651 &result.strides);
652 AffineExpr N = m.dim(0);
653 AffineExpr W = m.dim(1);
654 AffineExpr F = m.dim(2);
655 AffineExpr w = m.dim(3);
656 AffineExpr c = m.dim(4);
657
658 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
659 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
660 /*filterMap=*/{w, c, F},
661 /*outputMap=*/{N, W, F}})
662 .matchBody())
663 return result;
664 return std::nullopt;
665}
666
667template <>
668std::optional<DilationsAndStrides>
671 if (auto convOp = dyn_cast<linalg::Conv1DNcwFcwOp>(op.getOperation())) {
672 result.dilations =
673 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
674 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
675 return result;
676 }
677
679 return std::nullopt;
680
681 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
682 &result.strides);
683 AffineExpr N = m.dim(0);
684 AffineExpr F = m.dim(1);
685 AffineExpr W = m.dim(2);
686 AffineExpr c = m.dim(3);
687 AffineExpr w = m.dim(4);
688
689 if (m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
690 .matchMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
691 /*filterMap=*/{F, c, w},
692 /*outputMap=*/{N, F, W}})
693 .matchBody())
694 return result;
695 return std::nullopt;
696}
697
698template <>
699std::optional<DilationsAndStrides>
702 if (isa<linalg::Conv2DOp>(op)) {
703 // Conv2DOp has no strides/dilations attributes, default to 1.
704 result.dilations = SmallVector<int64_t>(2, 1);
705 result.strides = SmallVector<int64_t>(2, 1);
706 return result;
707 }
708
710 return std::nullopt;
711
712 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
713 &result.strides);
714 AffineExpr H = m.dim(0);
715 AffineExpr W = m.dim(1);
716 AffineExpr h = m.dim(2);
717 AffineExpr w = m.dim(3);
718
719 if (m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
720 .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
721 .matchMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
722 /*filterMap=*/{h, w},
723 /*outputMap=*/{H, W}})
724 .matchBody())
725 return result;
726 return std::nullopt;
727}
728
729template <>
730std::optional<DilationsAndStrides>
733 if (auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfOp>(op.getOperation())) {
734 result.dilations =
735 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
736 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
737 return result;
738 }
739
741 return std::nullopt;
742
743 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
744 &result.strides);
745 AffineExpr N = m.dim(0);
746 AffineExpr H = m.dim(1);
747 AffineExpr W = m.dim(2);
748 AffineExpr F = m.dim(3);
749 AffineExpr h = m.dim(4);
750 AffineExpr w = m.dim(5);
751 AffineExpr c = m.dim(6);
752
753 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
754 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
755 .matchMaps(
756 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
757 /*filterMap=*/{h, w, c, F},
758 /*outputMap=*/{N, H, W, F}})
759 .matchBody())
760 return result;
761 return std::nullopt;
762}
763
764template <>
765std::optional<DilationsAndStrides>
768 if (auto convOp = dyn_cast<linalg::Conv2DNhwcHwcfQOp>(op.getOperation())) {
769 result.dilations =
770 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
771 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
772 return result;
773 }
774
776 return std::nullopt;
777
778 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
779 &result.strides);
780 AffineExpr N = m.dim(0);
781 AffineExpr H = m.dim(1);
782 AffineExpr W = m.dim(2);
783 AffineExpr F = m.dim(3);
784 AffineExpr h = m.dim(4);
785 AffineExpr w = m.dim(5);
786 AffineExpr c = m.dim(6);
787
788 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
789 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
790 .matchMaps(
791 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
792 /*filterMap=*/{h, w, c, F},
793 /*scalarMap=*/{},
794 /*scalarMap=*/{},
795 /*outputMap=*/{N, H, W, F}})
796 .matchBody(/*containsZeroPointOffset=*/true))
797 return result;
798 return std::nullopt;
799}
800
801template <>
802std::optional<DilationsAndStrides>
805 if (auto convOp = dyn_cast<linalg::Conv2DNhwcFhwcOp>(op.getOperation())) {
806 result.dilations =
807 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
808 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
809 return result;
810 }
811
813 return std::nullopt;
814
815 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
816 &result.strides);
817 AffineExpr N = m.dim(0);
818 AffineExpr H = m.dim(1);
819 AffineExpr W = m.dim(2);
820 AffineExpr F = m.dim(3);
821 AffineExpr h = m.dim(4);
822 AffineExpr w = m.dim(5);
823 AffineExpr c = m.dim(6);
824
825 if (m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0)
826 .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1)
827 .matchMaps(
828 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
829 /*filterMap=*/{F, h, w, c},
830 /*outputMap=*/{N, H, W, F}})
831 .matchBody())
832 return result;
833 return std::nullopt;
834}
835
836template <>
837std::optional<DilationsAndStrides>
840 if (auto convOp = dyn_cast<linalg::Conv2DNhwcFhwcQOp>(op.getOperation())) {
841 result.dilations =
842 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
843 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
844 return result;
845 }
846
848 return std::nullopt;
849
850 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
851 &result.strides);
852 AffineExpr N = m.dim(0);
853 AffineExpr H = m.dim(1);
854 AffineExpr W = m.dim(2);
855 AffineExpr F = m.dim(3);
856 AffineExpr h = m.dim(4);
857 AffineExpr w = m.dim(5);
858 AffineExpr c = m.dim(6);
859
860 if (m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0)
861 .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1)
862 .matchMaps(
863 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
864 /*filterMap=*/{F, h, w, c},
865 /*scalarMap=*/{},
866 /*scalarMap=*/{},
867 /*outputMap=*/{N, H, W, F}})
868 .matchBody(/*containsZeroPointOffset=*/true))
869 return result;
870 return std::nullopt;
871}
872
873template <>
874std::optional<DilationsAndStrides>
877 if (auto convOp = dyn_cast<linalg::Conv2DNchwFchwOp>(op.getOperation())) {
878 result.dilations =
879 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
880 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
881 return result;
882 }
883
885 return std::nullopt;
886
887 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
888 &result.strides);
889 AffineExpr N = m.dim(0);
890 AffineExpr F = m.dim(1);
891 AffineExpr H = m.dim(2);
892 AffineExpr W = m.dim(3);
893 AffineExpr c = m.dim(4);
894 AffineExpr h = m.dim(5);
895 AffineExpr w = m.dim(6);
896
897 if (m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
898 .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
899 .matchMaps(
900 {/*inputMap=*/{N, c, m.strided(H, h, 0), m.strided(W, w, 1)},
901 /*filterMap=*/{F, c, h, w},
902 /*outputMap=*/{N, F, H, W}})
903 .matchBody())
904 return result;
905 return std::nullopt;
906}
907
908template <>
909std::optional<DilationsAndStrides>
912 if (auto convOp = dyn_cast<linalg::Conv2DNchwFchwQOp>(op.getOperation())) {
913 result.dilations =
914 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
915 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
916 return result;
917 }
918
920 return std::nullopt;
921
922 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
923 &result.strides);
924 AffineExpr N = m.dim(0);
925 AffineExpr F = m.dim(1);
926 AffineExpr H = m.dim(2);
927 AffineExpr W = m.dim(3);
928 AffineExpr c = m.dim(4);
929 AffineExpr h = m.dim(5);
930 AffineExpr w = m.dim(6);
931
932 if (m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
933 .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
934 .matchMaps(
935 {/*inputMap=*/{N, c, m.strided(H, h, 0), m.strided(W, w, 1)},
936 /*filterMap=*/{F, c, h, w},
937 /*scalarMap=*/{},
938 /*scalarMap=*/{},
939 /*outputMap=*/{N, F, H, W}})
940 .matchBody(/*containsZeroPointOffset=*/true))
941 return result;
942 return std::nullopt;
943}
944
945template <>
946std::optional<DilationsAndStrides>
949 if (auto convOp = dyn_cast<linalg::Conv2DNgchwFgchwOp>(op.getOperation())) {
950 result.dilations =
951 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
952 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
953 return result;
954 }
955
957 return std::nullopt;
958
959 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
960 &result.strides);
961 AffineExpr N = m.dim(0);
962 AffineExpr G = m.dim(1);
963 AffineExpr F = m.dim(2);
964 AffineExpr H = m.dim(3);
965 AffineExpr W = m.dim(4);
966 AffineExpr c = m.dim(5);
967 AffineExpr h = m.dim(6);
968 AffineExpr w = m.dim(7);
969
970 if (m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
971 .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
972 .matchMaps(
973 {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
974 /*filterMap=*/{F, G, c, h, w},
975 /*outputMap=*/{N, G, F, H, W}})
976 .matchBody())
977 return result;
978 return std::nullopt;
979}
980
981template <>
982std::optional<DilationsAndStrides>
985 if (auto convOp = dyn_cast<linalg::Conv2DNgchwGfchwOp>(op.getOperation())) {
986 result.dilations =
987 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
988 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
989 return result;
990 }
991
993 return std::nullopt;
994
995 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
996 &result.strides);
997 AffineExpr N = m.dim(0);
998 AffineExpr G = m.dim(1);
999 AffineExpr F = m.dim(2);
1000 AffineExpr H = m.dim(3);
1001 AffineExpr W = m.dim(4);
1002 AffineExpr c = m.dim(5);
1003 AffineExpr h = m.dim(6);
1004 AffineExpr w = m.dim(7);
1005
1006 if (m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
1007 .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
1008 .matchMaps(
1009 {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
1010 /*filterMap=*/{G, F, c, h, w},
1011 /*outputMap=*/{N, G, F, H, W}})
1012 .matchBody())
1013 return result;
1014 return std::nullopt;
1015}
1016
1017template <>
1018std::optional<DilationsAndStrides>
1021 if (auto convOp = dyn_cast<linalg::Conv2DNgchwGfchwQOp>(op.getOperation())) {
1022 result.dilations =
1023 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1024 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1025 return result;
1026 }
1027
1029 return std::nullopt;
1030
1031 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1032 &result.strides);
1033 AffineExpr N = m.dim(0);
1034 AffineExpr G = m.dim(1);
1035 AffineExpr F = m.dim(2);
1036 AffineExpr H = m.dim(3);
1037 AffineExpr W = m.dim(4);
1038 AffineExpr c = m.dim(5);
1039 AffineExpr h = m.dim(6);
1040 AffineExpr w = m.dim(7);
1041
1042 if (m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
1043 .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
1044 .matchMaps(
1045 {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
1046 /*filterMap=*/{G, F, c, h, w},
1047 /*scalarMap=*/{},
1048 /*scalarMap=*/{},
1049 /*outputMap=*/{N, G, F, H, W}})
1050 .matchBody(/*containsZeroPointOffset=*/true))
1051 return result;
1052 return std::nullopt;
1053}
1054
1055template <>
1056std::optional<DilationsAndStrides>
1059 if (auto convOp = dyn_cast<linalg::Conv2DNhwgcGfhwcOp>(op.getOperation())) {
1060 result.dilations =
1061 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1062 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1063 return result;
1064 }
1065
1067 return std::nullopt;
1068
1069 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1070 &result.strides);
1071 AffineExpr N = m.dim(0);
1072 AffineExpr H = m.dim(1);
1073 AffineExpr W = m.dim(2);
1074 AffineExpr G = m.dim(3);
1075 AffineExpr F = m.dim(4);
1076 AffineExpr h = m.dim(5);
1077 AffineExpr w = m.dim(6);
1078 AffineExpr c = m.dim(7);
1079
1080 if (m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0)
1081 .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1)
1082 .matchMaps(
1083 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
1084 /*filterMap=*/{G, F, h, w, c},
1085 /*outputMap=*/{N, H, W, G, F}})
1086 .matchBody())
1087 return result;
1088 return std::nullopt;
1089}
1090
1091template <>
1092std::optional<DilationsAndStrides>
1095 if (auto convOp = dyn_cast<linalg::Conv2DNhwgcGfhwcQOp>(op.getOperation())) {
1096 result.dilations =
1097 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1098 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1099 return result;
1100 }
1101
1103 return std::nullopt;
1104
1105 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1106 &result.strides);
1107 AffineExpr N = m.dim(0);
1108 AffineExpr H = m.dim(1);
1109 AffineExpr W = m.dim(2);
1110 AffineExpr G = m.dim(3);
1111 AffineExpr F = m.dim(4);
1112 AffineExpr h = m.dim(5);
1113 AffineExpr w = m.dim(6);
1114 AffineExpr c = m.dim(7);
1115
1116 if (m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0)
1117 .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1)
1118 .matchMaps(
1119 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
1120 /*filterMap=*/{G, F, h, w, c},
1121 /*scalarMap=*/{},
1122 /*scalarMap=*/{},
1123 /*outputMap=*/{N, H, W, G, F}})
1124 .matchBody(/*containsZeroPointOffset=*/true))
1125 return result;
1126 return std::nullopt;
1127}
1128
1129template <>
1130std::optional<DilationsAndStrides>
1133 if (isa<linalg::Conv3DOp>(op)) {
1134 // Conv3DOp has no strides/dilations attributes, default to 1.
1135 result.dilations = SmallVector<int64_t>(3, 1);
1136 result.strides = SmallVector<int64_t>(3, 1);
1137 return result;
1138 }
1139
1141 return std::nullopt;
1142
1143 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
1144 &result.strides);
1145 AffineExpr D = m.dim(0);
1146 AffineExpr H = m.dim(1);
1147 AffineExpr W = m.dim(2);
1148 AffineExpr d = m.dim(3);
1149 AffineExpr h = m.dim(4);
1150 AffineExpr w = m.dim(5);
1151
1152 if (m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
1153 .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
1154 .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2)
1155 .matchMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
1156 m.strided(W, w, 2)},
1157 /*filterMap=*/{d, h, w},
1158 /*outputMap=*/{D, H, W}})
1159 .matchBody())
1160 return result;
1161 return std::nullopt;
1162}
1163
1164template <>
1165std::optional<DilationsAndStrides>
1168 if (auto convOp = dyn_cast<linalg::Conv3DNdhwcDhwcfOp>(op.getOperation())) {
1169 result.dilations =
1170 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1171 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1172 return result;
1173 }
1174
1176 return std::nullopt;
1177
1178 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
1179 &result.strides);
1180 AffineExpr N = m.dim(0);
1181 AffineExpr D = m.dim(1);
1182 AffineExpr H = m.dim(2);
1183 AffineExpr W = m.dim(3);
1184 AffineExpr F = m.dim(4);
1185 AffineExpr d = m.dim(5);
1186 AffineExpr h = m.dim(6);
1187 AffineExpr w = m.dim(7);
1188 AffineExpr c = m.dim(8);
1189
1190 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1191 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1192 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1193 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1194 m.strided(W, w, 2), c},
1195 /*filterMap=*/{d, h, w, c, F},
1196 /*outputMap=*/{N, D, H, W, F}})
1197 .matchBody())
1198 return result;
1199 return std::nullopt;
1200}
1201
1202template <>
1203std::optional<DilationsAndStrides>
1206 if (auto convOp = dyn_cast<linalg::Conv3DNdhwcDhwcfQOp>(op.getOperation())) {
1207 result.dilations =
1208 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1209 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1210 return result;
1211 }
1212
1214 return std::nullopt;
1215
1216 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
1217 &result.strides);
1218 AffineExpr N = m.dim(0);
1219 AffineExpr D = m.dim(1);
1220 AffineExpr H = m.dim(2);
1221 AffineExpr W = m.dim(3);
1222 AffineExpr F = m.dim(4);
1223 AffineExpr d = m.dim(5);
1224 AffineExpr h = m.dim(6);
1225 AffineExpr w = m.dim(7);
1226 AffineExpr c = m.dim(8);
1227
1228 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1229 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1230 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1231 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1232 m.strided(W, w, 2), c},
1233 /*filterMap=*/{d, h, w, c, F},
1234 /*scalarMap=*/{},
1235 /*scalarMap=*/{},
1236 /*outputMap=*/{N, D, H, W, F}})
1237 .matchBody(/*containsZeroPointOffset=*/true))
1238 return result;
1239 return std::nullopt;
1240}
1241
1242template <>
1243std::optional<DilationsAndStrides>
1246 if (auto convOp = dyn_cast<linalg::Conv3DNcdhwFcdhwOp>(op.getOperation())) {
1247 result.dilations =
1248 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1249 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1250 return result;
1251 }
1252
1254 return std::nullopt;
1255
1256 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
1257 &result.strides);
1258 AffineExpr N = m.dim(0);
1259 AffineExpr F = m.dim(1);
1260 AffineExpr D = m.dim(2);
1261 AffineExpr H = m.dim(3);
1262 AffineExpr W = m.dim(4);
1263 AffineExpr c = m.dim(5);
1264 AffineExpr d = m.dim(6);
1265 AffineExpr h = m.dim(7);
1266 AffineExpr w = m.dim(8);
1267
1268 if (m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
1269 .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
1270 .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/2)
1271 .matchMaps({/*inputMap=*/{N, c, m.strided(D, d, 0),
1272 m.strided(H, h, 1), m.strided(W, w, 2)},
1273 /*filterMap=*/{F, c, d, h, w},
1274 /*outputMap=*/{N, F, D, H, W}})
1275 .matchBody())
1276 return result;
1277 return std::nullopt;
1278}
1279
1280template <>
1281std::optional<DilationsAndStrides>
1284 if (auto convOp =
1285 dyn_cast<linalg::DepthwiseConv1DNcwCwOp>(op.getOperation())) {
1286 result.dilations =
1287 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1288 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1289 return result;
1290 }
1291
1293 return std::nullopt;
1294
1295 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
1296 &result.strides);
1297 AffineExpr N = m.dim(0);
1298 AffineExpr W = m.dim(1);
1299 AffineExpr C = m.dim(2);
1300 AffineExpr w = m.dim(3);
1301
1302 if (m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
1303 .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
1304 /*filterMap=*/{C, w},
1305 /*outputMap=*/{N, C, W}})
1306 .matchBody())
1307 return result;
1308 return std::nullopt;
1309}
1310
1311template <>
1312std::optional<DilationsAndStrides>
1315 if (auto convOp =
1316 dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
1317 result.dilations =
1318 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1319 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1320 return result;
1321 }
1322
1324 return std::nullopt;
1325
1326 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
1327 &result.strides);
1328 AffineExpr N = m.dim(0);
1329 AffineExpr W = m.dim(1);
1330 AffineExpr C = m.dim(2);
1331 AffineExpr w = m.dim(3);
1332
1333 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1334 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
1335 /*filterMap=*/{w, C},
1336 /*outputMap=*/{N, W, C}})
1337 .matchBody())
1338 return result;
1339 return std::nullopt;
1340}
1341
1342template <>
1343std::optional<DilationsAndStrides>
1346 if (auto convOp =
1347 dyn_cast<linalg::DepthwiseConv1DNwcWcmOp>(op.getOperation())) {
1348 result.dilations =
1349 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1350 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1351 return result;
1352 }
1353
1355 return std::nullopt;
1356
1357 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
1358 &result.strides);
1359 AffineExpr N = m.dim(0);
1360 AffineExpr W = m.dim(1);
1361 AffineExpr C = m.dim(2);
1362 AffineExpr CM = m.dim(3);
1363 AffineExpr w = m.dim(4);
1364
1365 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1366 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
1367 /*filterMap=*/{w, C, CM},
1368 /*outputMap=*/{N, W, C, CM}})
1369 .matchBody())
1370 return result;
1371 return std::nullopt;
1372}
1373
1374template <>
1375std::optional<DilationsAndStrides>
1378 if (auto convOp =
1379 dyn_cast<linalg::DepthwiseConv2DNchwChwOp>(op.getOperation())) {
1380 result.dilations =
1381 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1382 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1383 return result;
1384 }
1385
1387 return std::nullopt;
1388
1389 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1390 &result.strides);
1391 AffineExpr N = m.dim(0);
1392 AffineExpr H = m.dim(1);
1393 AffineExpr W = m.dim(2);
1394 AffineExpr C = m.dim(3);
1395 AffineExpr h = m.dim(4);
1396 AffineExpr w = m.dim(5);
1397
1398 if (m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
1399 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
1400 .matchMaps(
1401 {/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
1402 /*filterMap=*/{C, h, w},
1403 /*outputMap=*/{N, C, H, W}})
1404 .matchBody())
1405 return result;
1406 return std::nullopt;
1407}
1408
1409template <>
1410std::optional<DilationsAndStrides>
1413 if (auto convOp =
1414 dyn_cast<linalg::DepthwiseConv2DNhwcHwcOp>(op.getOperation())) {
1415 result.dilations =
1416 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1417 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1418 return result;
1419 }
1420
1422 return std::nullopt;
1423
1424 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1425 &result.strides);
1426 AffineExpr N = m.dim(0);
1427 AffineExpr H = m.dim(1);
1428 AffineExpr W = m.dim(2);
1429 AffineExpr C = m.dim(3);
1430 AffineExpr h = m.dim(4);
1431 AffineExpr w = m.dim(5);
1432
1433 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1434 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1435 .matchMaps(
1436 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1437 /*filterMap=*/{h, w, C},
1438 /*outputMap=*/{N, H, W, C}})
1439 .matchBody())
1440 return result;
1441 return std::nullopt;
1442}
1443
1444template <>
1445std::optional<DilationsAndStrides>
1448 if (auto convOp =
1449 dyn_cast<linalg::DepthwiseConv2DNhwcHwcQOp>(op.getOperation())) {
1450 result.dilations =
1451 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1452 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1453 return result;
1454 }
1455
1457 return std::nullopt;
1458
1459 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1460 &result.strides);
1461 AffineExpr N = m.dim(0);
1462 AffineExpr H = m.dim(1);
1463 AffineExpr W = m.dim(2);
1464 AffineExpr C = m.dim(3);
1465 AffineExpr h = m.dim(4);
1466 AffineExpr w = m.dim(5);
1467
1468 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1469 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1470 .matchMaps(
1471 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1472 /*filterMap=*/{h, w, C},
1473 /*scalarMap=*/{},
1474 /*scalarMap=*/{},
1475 /*outputMap=*/{N, H, W, C}})
1476 .matchBody(/*containsZeroPointOffset=*/true))
1477 return result;
1478 return std::nullopt;
1479}
1480
1481template <>
1482std::optional<DilationsAndStrides>
1485 if (auto convOp =
1486 dyn_cast<linalg::DepthwiseConv2DNhwcHwcmOp>(op.getOperation())) {
1487 result.dilations =
1488 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1489 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1490 return result;
1491 }
1492
1494 return std::nullopt;
1495
1496 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1497 &result.strides);
1498 AffineExpr N = m.dim(0);
1499 AffineExpr H = m.dim(1);
1500 AffineExpr W = m.dim(2);
1501 AffineExpr C = m.dim(3);
1502 AffineExpr CM = m.dim(4);
1503 AffineExpr h = m.dim(5);
1504 AffineExpr w = m.dim(6);
1505
1506 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1507 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1508 .matchMaps(
1509 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1510 /*filterMap=*/{h, w, C, CM},
1511 /*outputMap=*/{N, H, W, C, CM}})
1512 .matchBody())
1513 return result;
1514 return std::nullopt;
1515}
1516
1517template <>
1518std::optional<DilationsAndStrides>
1521 if (auto convOp =
1522 dyn_cast<linalg::DepthwiseConv2DNhwcHwcmQOp>(op.getOperation())) {
1523 result.dilations =
1524 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1525 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1526 return result;
1527 }
1528
1530 return std::nullopt;
1531
1532 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1533 &result.strides);
1534 AffineExpr N = m.dim(0);
1535 AffineExpr H = m.dim(1);
1536 AffineExpr W = m.dim(2);
1537 AffineExpr C = m.dim(3);
1538 AffineExpr CM = m.dim(4);
1539 AffineExpr h = m.dim(5);
1540 AffineExpr w = m.dim(6);
1541
1542 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1543 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1544 .matchMaps(
1545 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1546 /*filterMap=*/{h, w, C, CM},
1547 /*scalarMap=*/{},
1548 /*scalarMap=*/{},
1549 /*outputMap=*/{N, H, W, C, CM}})
1550 .matchBody(/*containsZeroPointOffset=*/true))
1551 return result;
1552 return std::nullopt;
1553}
1554
1555template <>
1556std::optional<DilationsAndStrides>
1559 if (auto convOp =
1560 dyn_cast<linalg::DepthwiseConv3DNdhwcDhwcOp>(op.getOperation())) {
1561 result.dilations =
1562 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1563 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1564 return result;
1565 }
1566
1568 return std::nullopt;
1569
1570 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
1571 &result.strides);
1572 AffineExpr N = m.dim(0);
1573 AffineExpr D = m.dim(1);
1574 AffineExpr H = m.dim(2);
1575 AffineExpr W = m.dim(3);
1576 AffineExpr d = m.dim(4);
1577 AffineExpr h = m.dim(5);
1578 AffineExpr w = m.dim(6);
1579 AffineExpr C = m.dim(7);
1580
1581 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1582 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1583 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1584 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1585 m.strided(W, w, 2), C},
1586 /*filterMap=*/{d, h, w, C},
1587 /*outputMap=*/{N, D, H, W, C}})
1588 .matchBody())
1589 return result;
1590 return std::nullopt;
1591}
1592
1593template <>
1594std::optional<DilationsAndStrides>
1597 if (auto convOp =
1598 dyn_cast<linalg::DepthwiseConv3DNcdhwCdhwOp>(op.getOperation())) {
1599 result.dilations =
1600 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1601 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1602 return result;
1603 }
1604
1606 return std::nullopt;
1607
1608 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
1609 &result.strides);
1610 AffineExpr N = m.dim(0);
1611 AffineExpr D = m.dim(1);
1612 AffineExpr H = m.dim(2);
1613 AffineExpr W = m.dim(3);
1614 AffineExpr d = m.dim(4);
1615 AffineExpr h = m.dim(5);
1616 AffineExpr w = m.dim(6);
1617 AffineExpr C = m.dim(7);
1618
1619 if (m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
1620 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
1621 .matchStride(/*iDim=*/4, /*fDim=*/3, /*oDim=*/4, /*idx=*/2)
1622 .matchMaps({/*inputMap=*/{N, C, m.strided(D, d, 0),
1623 m.strided(H, h, 1), m.strided(W, w, 2)},
1624 /*filterMap=*/{C, d, h, w},
1625 /*outputMap=*/{N, C, D, H, W}})
1626 .matchBody())
1627 return result;
1628 return std::nullopt;
1629}
1630
1631template <>
1632std::optional<DilationsAndStrides>
1635 if (auto convOp =
1636 dyn_cast<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op.getOperation())) {
1637 result.dilations =
1638 llvm::to_vector(convOp.getDilations().getValues<int64_t>());
1639 result.strides = llvm::to_vector(convOp.getStrides().getValues<int64_t>());
1640 return result;
1641 }
1642
1644 return std::nullopt;
1645
1646 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
1647 &result.strides);
1648 AffineExpr N = m.dim(0);
1649 AffineExpr D = m.dim(1);
1650 AffineExpr H = m.dim(2);
1651 AffineExpr W = m.dim(3);
1652 AffineExpr CM = m.dim(4);
1653 AffineExpr d = m.dim(5);
1654 AffineExpr h = m.dim(6);
1655 AffineExpr w = m.dim(7);
1656 AffineExpr C = m.dim(8);
1657
1658 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1659 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1660 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1661 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1662 m.strided(W, w, 2), C},
1663 /*filterMap=*/{d, h, w, C, CM},
1664 /*outputMap=*/{N, D, H, W, C, CM}})
1665 .matchBody())
1666 return result;
1667 return std::nullopt;
1668}
1669
1670template <>
1671std::optional<DilationsAndStrides>
1674 if (auto poolOp = dyn_cast<linalg::PoolingNhwcMaxOp>(op.getOperation())) {
1675 result.dilations =
1676 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1677 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1678 return result;
1679 }
1680
1682 return std::nullopt;
1683
1684 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1685 &result.strides, PoolingType::MaxSigned);
1686 AffineExpr N = m.dim(0);
1687 AffineExpr H = m.dim(1);
1688 AffineExpr W = m.dim(2);
1689 AffineExpr C = m.dim(3);
1690 AffineExpr h = m.dim(4);
1691 AffineExpr w = m.dim(5);
1692
1693 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1694 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1695 .matchMaps(
1696 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1697 /*filterMap=*/{h, w},
1698 /*outputMap=*/{N, H, W, C}})
1699 .matchBody())
1700 return result;
1701 return std::nullopt;
1702}
1703
1704template <>
1705std::optional<DilationsAndStrides>
1708 if (auto poolOp = dyn_cast<linalg::PoolingNhwcMinOp>(op.getOperation())) {
1709 result.dilations =
1710 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1711 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1712 return result;
1713 }
1714
1716 return std::nullopt;
1717
1718 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1719 &result.strides, PoolingType::MinSigned);
1720 AffineExpr N = m.dim(0);
1721 AffineExpr H = m.dim(1);
1722 AffineExpr W = m.dim(2);
1723 AffineExpr C = m.dim(3);
1724 AffineExpr h = m.dim(4);
1725 AffineExpr w = m.dim(5);
1726
1727 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1728 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1729 .matchMaps(
1730 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1731 /*filterMap=*/{h, w},
1732 /*outputMap=*/{N, H, W, C}})
1733 .matchBody())
1734 return result;
1735 return std::nullopt;
1736}
1737
1738template <>
1739std::optional<DilationsAndStrides>
1742 if (auto poolOp = dyn_cast<linalg::PoolingNhwcSumOp>(op.getOperation())) {
1743 result.dilations =
1744 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1745 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1746 return result;
1747 }
1748
1750 return std::nullopt;
1751
1752 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1753 &result.strides, PoolingType::Sum);
1754 AffineExpr N = m.dim(0);
1755 AffineExpr H = m.dim(1);
1756 AffineExpr W = m.dim(2);
1757 AffineExpr C = m.dim(3);
1758 AffineExpr h = m.dim(4);
1759 AffineExpr w = m.dim(5);
1760
1761 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1762 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1763 .matchMaps(
1764 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1765 /*filterMap=*/{h, w},
1766 /*outputMap=*/{N, H, W, C}})
1767 .matchBody())
1768 return result;
1769 return std::nullopt;
1770}
1771
1772template <>
1773std::optional<DilationsAndStrides>
1776 if (auto poolOp =
1777 dyn_cast<linalg::PoolingNhwcMaxUnsignedOp>(op.getOperation())) {
1778 result.dilations =
1779 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1780 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1781 return result;
1782 }
1783
1785 return std::nullopt;
1786
1787 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1789 AffineExpr N = m.dim(0);
1790 AffineExpr H = m.dim(1);
1791 AffineExpr W = m.dim(2);
1792 AffineExpr C = m.dim(3);
1793 AffineExpr h = m.dim(4);
1794 AffineExpr w = m.dim(5);
1795
1796 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1797 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1798 .matchMaps(
1799 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1800 /*filterMap=*/{h, w},
1801 /*outputMap=*/{N, H, W, C}})
1802 .matchBody())
1803 return result;
1804 return std::nullopt;
1805}
1806
1807template <>
1808std::optional<DilationsAndStrides>
1811 if (auto poolOp =
1812 dyn_cast<linalg::PoolingNhwcMinUnsignedOp>(op.getOperation())) {
1813 result.dilations =
1814 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1815 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1816 return result;
1817 }
1818
1820 return std::nullopt;
1821
1822 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1824 AffineExpr N = m.dim(0);
1825 AffineExpr H = m.dim(1);
1826 AffineExpr W = m.dim(2);
1827 AffineExpr C = m.dim(3);
1828 AffineExpr h = m.dim(4);
1829 AffineExpr w = m.dim(5);
1830
1831 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1832 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1833 .matchMaps(
1834 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1835 /*filterMap=*/{h, w},
1836 /*outputMap=*/{N, H, W, C}})
1837 .matchBody())
1838 return result;
1839 return std::nullopt;
1840}
1841
1842template <>
1843std::optional<DilationsAndStrides>
1846 if (auto poolOp = dyn_cast<linalg::PoolingNchwSumOp>(op.getOperation())) {
1847 result.dilations =
1848 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1849 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1850 return result;
1851 }
1852
1854 return std::nullopt;
1855
1856 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1857 &result.strides, PoolingType::Sum);
1858 AffineExpr N = m.dim(0);
1859 AffineExpr C = m.dim(1);
1860 AffineExpr H = m.dim(2);
1861 AffineExpr W = m.dim(3);
1862 AffineExpr h = m.dim(4);
1863 AffineExpr w = m.dim(5);
1864
1865 if (m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
1866 .matchStride(/*iDim=*/3, /*fDim=*/1, /*oDim=*/3, /*idx=*/1)
1867 .matchMaps(
1868 {/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
1869 /*filterMap=*/{h, w},
1870 /*outputMap=*/{N, C, H, W}})
1871 .matchBody())
1872 return result;
1873 return std::nullopt;
1874}
1875
1876template <>
1877std::optional<DilationsAndStrides>
1880 if (auto poolOp = dyn_cast<linalg::PoolingNchwMaxOp>(op.getOperation())) {
1881 result.dilations =
1882 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1883 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1884 return result;
1885 }
1886
1888 return std::nullopt;
1889
1890 ConvMatcherBuilder m(op, /*spatialRank=*/2, &result.dilations,
1891 &result.strides, PoolingType::MaxSigned);
1892 AffineExpr N = m.dim(0);
1893 AffineExpr C = m.dim(1);
1894 AffineExpr H = m.dim(2);
1895 AffineExpr W = m.dim(3);
1896 AffineExpr h = m.dim(4);
1897 AffineExpr w = m.dim(5);
1898
1899 if (m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
1900 .matchStride(/*iDim=*/3, /*fDim=*/1, /*oDim=*/3, /*idx=*/1)
1901 .matchMaps(
1902 {/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
1903 /*filterMap=*/{h, w},
1904 /*outputMap=*/{N, C, H, W}})
1905 .matchBody())
1906 return result;
1907 return std::nullopt;
1908}
1909
1910template <>
1911std::optional<DilationsAndStrides>
1914 if (auto poolOp = dyn_cast<linalg::PoolingNwcSumOp>(op.getOperation())) {
1915 result.dilations =
1916 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1917 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1918 return result;
1919 }
1920
1922 return std::nullopt;
1923
1924 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
1925 &result.strides, PoolingType::Sum);
1926 AffineExpr N = m.dim(0);
1927 AffineExpr W = m.dim(1);
1928 AffineExpr C = m.dim(2);
1929 AffineExpr w = m.dim(3);
1930
1931 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1932 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
1933 /*filterMap=*/{w},
1934 /*outputMap=*/{N, W, C}})
1935 .matchBody())
1936 return result;
1937 return std::nullopt;
1938}
1939
1940template <>
1941std::optional<DilationsAndStrides>
1944 if (auto poolOp = dyn_cast<linalg::PoolingNcwSumOp>(op.getOperation())) {
1945 result.dilations =
1946 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1947 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1948 return result;
1949 }
1950
1952 return std::nullopt;
1953
1954 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
1955 &result.strides, PoolingType::Sum);
1956 AffineExpr N = m.dim(0);
1957 AffineExpr C = m.dim(1);
1958 AffineExpr W = m.dim(2);
1959 AffineExpr w = m.dim(3);
1960
1961 if (m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
1962 .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
1963 /*filterMap=*/{w},
1964 /*outputMap=*/{N, C, W}})
1965 .matchBody())
1966 return result;
1967 return std::nullopt;
1968}
1969
1970template <>
1971std::optional<DilationsAndStrides>
1974 if (auto poolOp = dyn_cast<linalg::PoolingNwcMaxOp>(op.getOperation())) {
1975 result.dilations =
1976 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
1977 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
1978 return result;
1979 }
1980
1982 return std::nullopt;
1983
1984 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
1985 &result.strides, PoolingType::MaxSigned);
1986 AffineExpr N = m.dim(0);
1987 AffineExpr W = m.dim(1);
1988 AffineExpr C = m.dim(2);
1989 AffineExpr w = m.dim(3);
1990
1991 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1992 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
1993 /*filterMap=*/{w},
1994 /*outputMap=*/{N, W, C}})
1995 .matchBody())
1996 return result;
1997 return std::nullopt;
1998}
1999
2000template <>
2001std::optional<DilationsAndStrides>
2004 if (auto poolOp =
2005 dyn_cast<linalg::PoolingNwcMaxUnsignedOp>(op.getOperation())) {
2006 result.dilations =
2007 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2008 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2009 return result;
2010 }
2011
2013 return std::nullopt;
2014
2015 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
2017 AffineExpr N = m.dim(0);
2018 AffineExpr W = m.dim(1);
2019 AffineExpr C = m.dim(2);
2020 AffineExpr w = m.dim(3);
2021
2022 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
2023 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
2024 /*filterMap=*/{w},
2025 /*outputMap=*/{N, W, C}})
2026 .matchBody())
2027 return result;
2028 return std::nullopt;
2029}
2030
2031template <>
2032std::optional<DilationsAndStrides>
2035 if (auto poolOp = dyn_cast<linalg::PoolingNcwMaxOp>(op.getOperation())) {
2036 result.dilations =
2037 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2038 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2039 return result;
2040 }
2041
2043 return std::nullopt;
2044
2045 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
2046 &result.strides, PoolingType::MaxSigned);
2047 AffineExpr N = m.dim(0);
2048 AffineExpr C = m.dim(1);
2049 AffineExpr W = m.dim(2);
2050 AffineExpr w = m.dim(3);
2051
2052 if (m.matchStride(/*iDim=*/2, /*fDim=*/0, /*oDim=*/2, /*idx=*/0)
2053 .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
2054 /*filterMap=*/{w},
2055 /*outputMap=*/{N, C, W}})
2056 .matchBody())
2057 return result;
2058 return std::nullopt;
2059}
2060
2061template <>
2062std::optional<DilationsAndStrides>
2065 if (auto poolOp = dyn_cast<linalg::PoolingNwcMinOp>(op.getOperation())) {
2066 result.dilations =
2067 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2068 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2069 return result;
2070 }
2071
2073 return std::nullopt;
2074
2075 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
2076 &result.strides, PoolingType::MinSigned);
2077 AffineExpr N = m.dim(0);
2078 AffineExpr W = m.dim(1);
2079 AffineExpr C = m.dim(2);
2080 AffineExpr w = m.dim(3);
2081
2082 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
2083 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
2084 /*filterMap=*/{w},
2085 /*outputMap=*/{N, W, C}})
2086 .matchBody())
2087 return result;
2088 return std::nullopt;
2089}
2090
2091template <>
2092std::optional<DilationsAndStrides>
2095 if (auto poolOp =
2096 dyn_cast<linalg::PoolingNwcMinUnsignedOp>(op.getOperation())) {
2097 result.dilations =
2098 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2099 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2100 return result;
2101 }
2102
2104 return std::nullopt;
2105
2106 ConvMatcherBuilder m(op, /*spatialRank=*/1, &result.dilations,
2108 AffineExpr N = m.dim(0);
2109 AffineExpr W = m.dim(1);
2110 AffineExpr C = m.dim(2);
2111 AffineExpr w = m.dim(3);
2112
2113 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
2114 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
2115 /*filterMap=*/{w},
2116 /*outputMap=*/{N, W, C}})
2117 .matchBody())
2118 return result;
2119 return std::nullopt;
2120}
2121
2122template <>
2123std::optional<DilationsAndStrides>
2126 if (auto poolOp = dyn_cast<linalg::PoolingNdhwcSumOp>(op.getOperation())) {
2127 result.dilations =
2128 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2129 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2130 return result;
2131 }
2132
2134 return std::nullopt;
2135
2136 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
2137 &result.strides, PoolingType::Sum);
2138 AffineExpr N = m.dim(0);
2139 AffineExpr D = m.dim(1);
2140 AffineExpr H = m.dim(2);
2141 AffineExpr W = m.dim(3);
2142 AffineExpr C = m.dim(4);
2143 AffineExpr d = m.dim(5);
2144 AffineExpr h = m.dim(6);
2145 AffineExpr w = m.dim(7);
2146
2147 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
2148 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
2149 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
2150 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
2151 m.strided(W, w, 2), C},
2152 /*filterMap=*/{d, h, w},
2153 /*outputMap=*/{N, D, H, W, C}})
2154 .matchBody())
2155 return result;
2156 return std::nullopt;
2157}
2158
2159template <>
2160std::optional<DilationsAndStrides>
2163 if (auto poolOp = dyn_cast<linalg::PoolingNdhwcMaxOp>(op.getOperation())) {
2164 result.dilations =
2165 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2166 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2167 return result;
2168 }
2169
2171 return std::nullopt;
2172
2173 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
2174 &result.strides, PoolingType::MaxSigned);
2175 AffineExpr N = m.dim(0);
2176 AffineExpr D = m.dim(1);
2177 AffineExpr H = m.dim(2);
2178 AffineExpr W = m.dim(3);
2179 AffineExpr C = m.dim(4);
2180 AffineExpr d = m.dim(5);
2181 AffineExpr h = m.dim(6);
2182 AffineExpr w = m.dim(7);
2183
2184 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
2185 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
2186 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
2187 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
2188 m.strided(W, w, 2), C},
2189 /*filterMap=*/{d, h, w},
2190 /*outputMap=*/{N, D, H, W, C}})
2191 .matchBody())
2192 return result;
2193 return std::nullopt;
2194}
2195
2196template <>
2197std::optional<DilationsAndStrides>
2200 if (auto poolOp = dyn_cast<linalg::PoolingNdhwcMinOp>(op.getOperation())) {
2201 result.dilations =
2202 llvm::to_vector(poolOp.getDilations().getValues<int64_t>());
2203 result.strides = llvm::to_vector(poolOp.getStrides().getValues<int64_t>());
2204 return result;
2205 }
2206
2208 return std::nullopt;
2209
2210 ConvMatcherBuilder m(op, /*spatialRank=*/3, &result.dilations,
2211 &result.strides, PoolingType::MinSigned);
2212 AffineExpr N = m.dim(0);
2213 AffineExpr D = m.dim(1);
2214 AffineExpr H = m.dim(2);
2215 AffineExpr W = m.dim(3);
2216 AffineExpr C = m.dim(4);
2217 AffineExpr d = m.dim(5);
2218 AffineExpr h = m.dim(6);
2219 AffineExpr w = m.dim(7);
2220
2221 if (m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
2222 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
2223 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
2224 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
2225 m.strided(W, w, 2), C},
2226 /*filterMap=*/{d, h, w},
2227 /*outputMap=*/{N, D, H, W, C}})
2228 .matchBody())
2229 return result;
2230 return std::nullopt;
2231}
2232
2233Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
2234 Value source, Value pad, bool nofold,
2235 ValueRange typeDynDims) {
2236 // Exit if `source` is not defined by an ExtractSliceOp.
2237 auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
2238 if (!sliceOp)
2239 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
2240 typeDynDims);
2241
2242 // Search the `source` use-def chain for padded LinalgOps.
2243 Value current = sliceOp.getSource();
2244 while (current) {
2245 auto linalgOp = current.getDefiningOp<LinalgOp>();
2246 if (!linalgOp)
2247 break;
2248 OpResult opResult = cast<OpResult>(current);
2249 current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get();
2250 }
2251 auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr;
2252
2253 // Exit if the search fails to match a tensor::PadOp at the end of the matched
2254 // LinalgOp sequence.
2255 if (!padOp)
2256 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
2257 typeDynDims);
2258
2259 // Exit if the padded result type does not match.
2260 if (sliceOp.getSource().getType() != type)
2261 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
2262 typeDynDims);
2263
2264 // Exit if the LinalgOps are not high padded.
2265 if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) {
2266 return getConstantIntValue(ofr) != static_cast<int64_t>(0);
2267 }))
2268 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
2269 typeDynDims);
2270
2271 // Exit if `padOpSliceOp`, which defines the slice used by
2272 // `padOp`, is rank-reducing.
2273 auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
2274 if (!padOpSliceOp ||
2275 sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size())
2276 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
2277 typeDynDims);
2278
2279 // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size
2280 // of the slice padded by `padOp`.
2281 if (llvm::any_of(
2282 llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()),
2283 [](std::tuple<OpFoldResult, OpFoldResult> it) {
2284 return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it));
2285 }))
2286 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
2287 typeDynDims);
2288
2289 // Exit if the padding values do not match.
2290 Attribute padOpPadAttr, padAttr;
2291 Value padOpPad = padOp.getConstantPaddingValue();
2292 if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) ||
2293 !matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr)
2294 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
2295 typeDynDims);
2296
2297 // Return the padded result if the padding values and sizes match.
2298 return sliceOp.getSource();
2299}
2300
2301GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
2302 auto memrefTypeTo = cast<MemRefType>(to.getType());
2303#ifndef NDEBUG
2304 auto memrefTypeFrom = cast<MemRefType>(from.getType());
2305 assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
2306 "`from` and `to` memref must have the same rank");
2307#endif // NDEBUG
2308
2309 AffineMap id =
2310 AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
2311 SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(),
2312 utils::IteratorType::parallel);
2313 return linalg::GenericOp::create(
2314 b, loc,
2315 /*inputs=*/from,
2316 /*outputs=*/to,
2317 /*indexingMaps=*/llvm::ArrayRef({id, id}),
2318 /*iteratorTypes=*/iteratorTypes,
2319 [](OpBuilder &b, Location loc, ValueRange args) {
2320 linalg::YieldOp::create(b, loc, args.front());
2321 });
2322}
2323
2324/// Specialization to build an scf "for" nest.
2325template <>
2327 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
2328 ArrayRef<utils::IteratorType> iteratorTypes,
2330 ValueRange)>
2331 bodyBuilderFn,
2332 ArrayRef<linalg::ProcInfo> procInfo) {
2333 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
2334 "expected as many entries for proc info as number of loops, even if "
2335 "they are null entries");
2336 SmallVector<Value> iterArgInitValues;
2337 if (!linalgOp.hasPureBufferSemantics())
2338 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
2339 SmallVector<Value, 4> lbs, ubs, steps;
2340 unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
2342 b, loc, lbs, ubs, steps, iterArgInitValues,
2343 [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
2344 assert(iterArgs.size() == iterArgInitValues.size() &&
2345 "expect the number of output tensors and iter args to match");
2346 SmallVector<Value> operandValuesToUse = linalgOp->getOperands();
2347 if (!iterArgs.empty()) {
2348 operandValuesToUse = linalgOp.getDpsInputs();
2349 operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
2350 }
2351 return bodyBuilderFn(b, loc, ivs, operandValuesToUse);
2352 });
2353
2354 if (loopNest.loops.empty() || procInfo.empty())
2355 return;
2356
2357 // Filter out scf.for loops that were created out of parallel dimensions.
2358 for (const auto &loop : llvm::enumerate(loopNest.loops)) {
2359 if (procInfo[loop.index()].distributionMethod ==
2361 mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId,
2362 procInfo[loop.index()].nprocs);
2363 }
2364 }
2365}
2366
2367/// Specialization to build affine "for" nest.
2368template <>
2370 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
2371 ArrayRef<utils::IteratorType> iteratorTypes,
2373 ValueRange)>
2374 bodyBuilderFn,
2375 ArrayRef<linalg::ProcInfo> /*procInfo*/) {
2376 SmallVector<Value> iterArgInitValues;
2377 if (!linalgOp.hasPureBufferSemantics())
2378 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
2379 assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
2380 SmallVector<Value, 4> lbs, ubs, steps;
2381 unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
2382
2383 // Affine loops require constant steps.
2384 SmallVector<int64_t, 4> constantSteps;
2385 constantSteps.reserve(steps.size());
2386 for (Value v : steps) {
2387 auto constVal = getConstantIntValue(v);
2388 assert(constVal.has_value() && "Affine loops require constant steps");
2389 constantSteps.push_back(constVal.value());
2390 }
2391
2392 affine::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
2393 [&](OpBuilder &b, Location loc, ValueRange ivs) {
2394 bodyBuilderFn(b, loc, ivs,
2395 linalgOp->getOperands());
2396 });
2397}
2398
2399/// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
2401 Value nprocs, Value &lb, Value &ub,
2402 Value &step) {
2403 AffineExpr d0, d1;
2404 bindDims(b.getContext(), d0, d1);
2405 AffineExpr s0 = getAffineSymbolExpr(0, b.getContext());
2406 lb =
2407 affine::makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step});
2408 step = affine::makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step});
2409}
2410
2411/// Generates a loop nest consisting of scf.parallel and scf.for, depending
2412/// on the `iteratorTypes.` Consecutive parallel loops create a single
2413/// scf.parallel operation; each sequential loop creates a new scf.for
2414/// operation. The body of the innermost loop is populated by
2415/// `bodyBuilderFn` that accepts a range of induction variables for all
2416/// loops. `ivStorage` is used to store the partial list of induction
2417/// variables.
2418// TODO: this function can be made iterative instead. However, it
2419// will have at most as many recursive calls as nested loops, which rarely
2420// exceeds 10.
2422 OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
2423 ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes,
2425 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2426 SmallVectorImpl<Value> &ivStorage) {
2427 assert(lbs.size() == ubs.size());
2428 assert(lbs.size() == steps.size());
2429 assert(lbs.size() == iteratorTypes.size());
2430 assert(procInfo.empty() || (lbs.size() == procInfo.size()));
2431
2432 // If there are no (more) loops to be generated, generate the body and be
2433 // done with it.
2434 if (iteratorTypes.empty()) {
2435 bodyBuilderFn(b, loc, ivStorage);
2436 return;
2437 }
2438
2439 // If there are no outer parallel loops, generate one sequential loop and
2440 // recurse.
2441 if (!isParallelIterator(iteratorTypes.front())) {
2442 LoopNest singleLoop = buildLoopNest(
2443 b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(),
2444 [&](OpBuilder &b, Location loc, ValueRange ivs) {
2445 ivStorage.append(ivs.begin(), ivs.end());
2446 generateParallelLoopNest(
2447 b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(),
2448 iteratorTypes.drop_front(),
2449 procInfo.empty() ? procInfo : procInfo.drop_front(),
2450 bodyBuilderFn, ivStorage);
2451 });
2452 return;
2453 }
2454
2455 unsigned nLoops = iteratorTypes.size();
2456 unsigned numProcessed = 0;
2457 DistributionMethod distributionMethod = DistributionMethod::None;
2458 if (procInfo.empty()) {
2459 numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size();
2460 } else {
2461 distributionMethod = procInfo.front().distributionMethod;
2462 numProcessed =
2463 nLoops - procInfo
2464 .drop_while([&](linalg::ProcInfo p) {
2465 return p.distributionMethod == distributionMethod;
2466 })
2467 .size();
2468 }
2469
2470 auto remainderProcInfo =
2471 procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed);
2472 switch (distributionMethod) {
2474 // Generate a single parallel loop-nest operation for all outermost
2475 // parallel loops and recurse.
2476 scf::ParallelOp::create(
2477 b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
2478 steps.take_front(numProcessed),
2479 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
2480 ivStorage.append(localIvs.begin(), localIvs.end());
2481 generateParallelLoopNest(
2482 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
2483 ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
2484 iteratorTypes.drop_front(numProcessed), remainderProcInfo,
2485 bodyBuilderFn, ivStorage);
2486 });
2487 return;
2488 }
2490 // Generate a single parallel loop-nest operation for all outermost
2491 // parallel loops and recurse.
2492 scf::ParallelOp::create(
2493 b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
2494 steps.take_front(numProcessed),
2495 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
2496 ivStorage.append(localIvs.begin(), localIvs.end());
2497 generateParallelLoopNest(
2498 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
2499 ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
2500 iteratorTypes.drop_front(numProcessed), remainderProcInfo,
2501 bodyBuilderFn, ivStorage);
2502 });
2503 return;
2504 }
2506 // Check (for the processed loops) that the iteration is in-bounds.
2507 ArithBuilder ab(b, loc);
2508 Value cond = ab.slt(lbs[0], ubs[0]);
2509 for (unsigned i = 1; i < numProcessed; ++i)
2510 cond = ab._and(cond, ab.slt(lbs[i], ubs[i]));
2511 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
2512 scf::IfOp::create(b, loc, cond, [&](OpBuilder &b, Location loc) {
2513 generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed),
2514 ubs.drop_front(numProcessed),
2515 steps.drop_front(numProcessed),
2516 iteratorTypes.drop_front(numProcessed),
2517 remainderProcInfo, bodyBuilderFn, ivStorage);
2518 scf::YieldOp::create(b, loc, ValueRange{});
2519 });
2520 return;
2521 }
2523 // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed
2524 // with inner loop generation.
2525 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
2527 b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
2528 steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
2529 remainderProcInfo, bodyBuilderFn, ivStorage);
2530 return;
2531 }
2532}
2533
2534/// Specialization for generating a mix of parallel and sequential scf loops.
2535template <>
2537 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
2538 ArrayRef<utils::IteratorType> iteratorTypes,
2540 ValueRange)>
2541 bodyBuilderFn,
2542 ArrayRef<linalg::ProcInfo> procInfo) {
2543 SmallVector<Value> iterArgInitValues;
2544 if (!linalgOp.hasPureBufferSemantics())
2545 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
2546 assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
2547 // This function may be passed more iterator types than ranges.
2548 assert(iteratorTypes.size() >= loopRanges.size() &&
2549 "expected iterator type for all ranges");
2550 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
2551 "expected proc information for all loops when present");
2552 iteratorTypes = iteratorTypes.take_front(loopRanges.size());
2553 SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs;
2554 unsigned numLoops = iteratorTypes.size();
2555 ivs.reserve(numLoops);
2556 lbsStorage.reserve(numLoops);
2557 ubsStorage.reserve(numLoops);
2558 stepsStorage.reserve(numLoops);
2559
2560 // Get the loop lb, ub, and step.
2561 unpackRanges(b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage);
2562
2563 // Modify the lb, ub, and step based on the distribution options.
2564 for (const auto &it : llvm::enumerate(procInfo)) {
2565 if (it.value().distributionMethod != linalg::DistributionMethod::None) {
2567 b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()],
2568 ubsStorage[it.index()], stepsStorage[it.index()]);
2569 }
2570 }
2571 ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
2573 b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
2574 [&](OpBuilder &b, Location loc, ValueRange ivs) {
2575 bodyBuilderFn(b, loc, ivs, linalgOp->getOperands());
2576 },
2577 ivs);
2578
2579 assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
2580}
2581
2583 Value valueToTile,
2584 const SliceParameters &sliceParams) {
2585 auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
2586 auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
2587 .Case([&](MemRefType) {
2588 return memref::SubViewOp::create(
2589 builder, loc, valueToTile, sliceParams.offsets,
2590 sliceParams.sizes, sliceParams.strides);
2591 })
2592 .Case([&](RankedTensorType) {
2593 return tensor::ExtractSliceOp::create(
2594 builder, loc, valueToTile, sliceParams.offsets,
2595 sliceParams.sizes, sliceParams.strides);
2596 })
2597 .DefaultUnreachable("Unexpected shaped type");
2598 return sliceOp;
2599}
2600
2602 ArrayRef<OpFoldResult> tileSizes, AffineMap map,
2605 ArrayRef<OpFoldResult> subShapeSizes,
2606 bool omitPartialTileCheck) {
2607 SliceParameters sliceParams =
2608 computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
2609 ubs, subShapeSizes, omitPartialTileCheck);
2610 return materializeTiledShape(builder, loc, valueToTile, sliceParams);
2611}
2612
2615 ArrayRef<OpFoldResult> tileSizes, AffineMap map,
2617 ArrayRef<OpFoldResult> subShapeSizes,
2618 bool omitPartialTileCheck) {
2619 auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
2620 assert(shapedType && "only shaped types can be tiled");
2621 ArrayRef<int64_t> shape = shapedType.getShape();
2622 int64_t rank = shapedType.getRank();
2623
2624 // Compute offsets/sizes/strides for the tile.
2625 SliceParameters sliceParams;
2626 sliceParams.offsets.reserve(rank);
2627 sliceParams.sizes.reserve(rank);
2628 sliceParams.strides.reserve(rank);
2629 for (unsigned r = 0; r < rank; ++r) {
2630 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r);
2631 if (!isTiled(map.getSubMap({r}), tileSizes)) {
2632 sliceParams.offsets.push_back(builder.getIndexAttr(0));
2633 OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
2634 sliceParams.sizes.push_back(dim);
2635 sliceParams.strides.push_back(builder.getIndexAttr(1));
2636 LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
2637 continue;
2638 }
2639 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
2640
2641 // Tiling creates a new slice at the proper index, the slice step is 1
2642 // (i.e. the op does not subsample, stepping occurs in the loop).
2643 auto m = map.getSubMap({r});
2644 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n");
2645 IRRewriter rewriter(builder);
2646 // The offset of the slice is m(lbs) - m(0).
2647 SmallVector<Attribute> zeros(lbs.size(), rewriter.getIndexAttr(0));
2648 SmallVector<Attribute> mAtZero;
2649 [[maybe_unused]] auto res = m.constantFold(zeros, mAtZero);
2650 assert(succeeded(res) && "affine_map must be evaluatable (not symbols)");
2651 int64_t mAtZeroInt =
2652 cast<IntegerAttr>(mAtZero[0]).getValue().getSExtValue();
2654 rewriter, loc, m.getResult(0) - mAtZeroInt, lbs);
2655 sliceParams.offsets.push_back(offset);
2656
2657 OpFoldResult closedIntSize =
2658 makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes);
2659 // Resulting size needs to be made half open interval again.
2660 AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
2661 OpFoldResult size =
2662 makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize);
2663 LLVM_DEBUG(llvm::dbgs()
2664 << "computeSliceParameters: raw size: " << size << "\n");
2665 LLVM_DEBUG(llvm::dbgs()
2666 << "computeSliceParameters: new offset: " << offset << "\n");
2667 sliceParams.strides.push_back(builder.getIndexAttr(1));
2668
2669 if (omitPartialTileCheck) {
2670 // We statically know that the partial/boundary tile condition is
2671 // unnecessary.
2672 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
2673 sliceParams.sizes.push_back(size);
2674 continue;
2675 }
2676
2677 // The size of the subview / extract_slice should be trimmed to avoid
2678 // out-of-bounds accesses, unless:
2679 // a. We statically know the subshape size divides the shape size evenly.
2680 // b. The subshape size is 1. According to the way the loops are set up,
2681 // tensors with "0" dimensions would never be constructed.
2682 int64_t shapeSize = shape[r];
2683 std::optional<int64_t> sizeCst = getConstantIntValue(size);
2684 auto hasTileSizeOne = sizeCst == 1;
2685 auto dividesEvenly = sizeCst && ShapedType::isStatic(shapeSize) &&
2686 ((shapeSize % *sizeCst) == 0);
2687 if (!hasTileSizeOne && !dividesEvenly) {
2688 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize
2689 << ", size: " << size
2690 << ": make sure in bound with affine.min\n");
2691
2692 AffineExpr dim0, dim1, dim2;
2693 MLIRContext *context = builder.getContext();
2694 bindDims(context, dim0, dim1, dim2);
2695
2696 // Get the dimension size for this dimension. We need to first calculate
2697 // the max index and then plus one. This is important because for
2698 // convolution ops, we have its input window dimension's affine map of the
2699 // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window
2700 // dimension and `s0` is stride. Directly use the dimension size of
2701 // output/filer window dimensions will cause incorrect calculation.
2703 {ArrayRef<AffineExpr>{dim0 - 1}}, context)
2704 .front();
2706 {ArrayRef<AffineExpr>{dim0 + 1}}, context)
2707 .front();
2708 SmallVector<OpFoldResult> maxIndices =
2709 llvm::map_to_vector(ubs, [&](OpFoldResult ub) {
2710 return makeComposedFoldedAffineApply(rewriter, loc, minusOneMap,
2711 {ub});
2712 });
2713 OpFoldResult maxIndex =
2714 makeComposedFoldedAffineApply(rewriter, loc, m, maxIndices);
2715 OpFoldResult d =
2716 makeComposedFoldedAffineApply(rewriter, loc, plusOneMap, {maxIndex});
2717
2718 // Compute min(dim - offset, size) to avoid out-of-bounds accesses.
2720 {ArrayRef<AffineExpr>{dim1 - dim2, dim0}}, context)
2721 .front();
2722 size =
2723 makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset});
2724 }
2725 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
2726 sliceParams.sizes.push_back(size);
2727 }
2728 return sliceParams;
2729}
2730
2733 ArrayRef<OpFoldResult> tileSizes) {
2735 for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
2736 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
2737 bool isTiled = !isZeroInteger(tileSizes[idx]);
2738 offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0));
2739 LLVM_DEBUG(llvm::dbgs()
2740 << "computeTileOffsets: " << offsets.back() << "\n");
2741 }
2742 return offsets;
2743}
2744
2746 ArrayRef<OpFoldResult> tileSizes,
2747 ArrayRef<OpFoldResult> sizeBounds) {
2749 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
2750 bool isTiled = !isZeroInteger(tileSizes[idx]);
2751 // Before composing, we need to make range a closed interval.
2752 OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx];
2753 AffineExpr d0 = getAffineDimExpr(0, b.getContext());
2754 IRRewriter rewriter(b);
2755 sizes.push_back(makeComposedFoldedAffineApply(rewriter, loc, d0 - 1, size));
2756 LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n");
2757 }
2758 return sizes;
2759}
2760
2762 if (op.hasPureBufferSemantics())
2763 return {};
2764 return llvm::map_to_vector(
2765 op.getDpsInitsMutable(), [&](OpOperand &opOperand) {
2766 return operands[opOperand.getOperandNumber()].getType();
2767 });
2768}
2769
2771 LinalgOp op, ValueRange operands,
2772 ValueRange results) {
2773 if (op.hasPureBufferSemantics())
2774 return {};
2775 SmallVector<Value> tensorResults;
2776 tensorResults.reserve(results.size());
2777 // Insert a insert_slice for each output tensor.
2778 unsigned resultIdx = 0;
2779 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2780 // TODO: use an interface/adaptor to avoid leaking position in
2781 // `tiledOperands`.
2782 Value outputTensor = operands[opOperand.getOperandNumber()];
2783 if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
2784 Value inserted = tensor::InsertSliceOp::create(
2785 builder, loc, sliceOp.getSource().getType(), results[resultIdx],
2786 sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(),
2787 sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2788 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2789 tensorResults.push_back(inserted);
2790 } else {
2791 tensorResults.push_back(results[resultIdx]);
2792 }
2793 ++resultIdx;
2794 }
2795 return tensorResults;
2796}
2797
2799computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
2800 ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
2801 ArrayRef<OpFoldResult> tileSizes,
2802 ArrayRef<OpFoldResult> sizeBounds,
2803 bool omitPartialTileCheck) {
2804 assert(ivs.size() == static_cast<size_t>(llvm::count_if(
2805 llvm::make_range(tileSizes.begin(), tileSizes.end()),
2806 [](OpFoldResult v) { return !isZeroInteger(v); })) &&
2807 "expected as many ivs as non-zero sizes");
2808
2809 // Construct (potentially temporary) mins and maxes on which to apply maps
2810 // that define tile subshapes.
2812 computeTileOffsets(builder, loc, ivs, tileSizes);
2813 SmallVector<OpFoldResult> subShapeSizes =
2814 computeTileSizes(builder, loc, tileSizes, sizeBounds);
2815
2816 assert(static_cast<int64_t>(valuesToTile.size()) <=
2817 linalgOp->getNumOperands() &&
2818 "more value to tile than operands.");
2820 allSliceParams.reserve(valuesToTile.size());
2821 for (auto [opOperand, val] :
2822 llvm::zip(linalgOp->getOpOperands(), valuesToTile)) {
2823 Value shapedOp = val;
2824 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
2825 AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
2826 // Use `opOperand` as is if it is not tiled and not an output tensor. Having
2827 // an extract/insert slice pair for all output tensors simplifies follow up
2828 // transformations such as padding and bufferization since the
2829 // extract/insert slice pairs make the accessed iteration argument
2830 // subdomains explicit.
2831
2832 Type operandType = opOperand.get().getType();
2833 if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
2834 linalgOp.isDpsInit(&opOperand))) {
2835 allSliceParams.push_back(std::nullopt);
2836 LLVM_DEBUG(llvm::dbgs()
2837 << ": not tiled: use shape: " << operandType << "\n");
2838 continue;
2839 }
2840 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
2841
2842 allSliceParams.push_back(computeSliceParameters(
2843 builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
2844 omitPartialTileCheck));
2845 }
2846
2847 return allSliceParams;
2848}
2849
2851 LinalgOp linalgOp, ValueRange valuesToTile,
2853 ArrayRef<OpFoldResult> tileSizes,
2854 ArrayRef<OpFoldResult> sizeBounds,
2855 bool omitPartialTileCheck) {
2857 computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs,
2858 tileSizes, sizeBounds, omitPartialTileCheck);
2859 SmallVector<Value> tiledShapes;
2860 for (auto item : llvm::zip(valuesToTile, allSliceParameter)) {
2861 Value valueToTile = std::get<0>(item);
2862 std::optional<SliceParameters> sliceParams = std::get<1>(item);
2863 tiledShapes.push_back(
2864 sliceParams.has_value()
2865 ? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
2866 ->getResult(0)
2867 : valueToTile);
2868 }
2869 return tiledShapes;
2870}
2871
2872void offsetIndices(OpBuilder &b, LinalgOp linalgOp,
2873 ArrayRef<OpFoldResult> offsets) {
2874 IRRewriter rewriter(b);
2875 offsetIndices(rewriter, linalgOp, offsets);
2876}
2877
2878void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
2879 ArrayRef<OpFoldResult> offsets) {
2880 if (!linalgOp.hasIndexSemantics())
2881 return;
2882
2883 for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
2884 if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()])
2885 continue;
2887 b.setInsertionPointAfter(indexOp);
2888 AffineExpr index, offset;
2889 bindDims(b.getContext(), index, offset);
2891 b, indexOp.getLoc(), index + offset,
2892 {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
2893 Value materialized =
2894 getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
2895 b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) {
2896 return use.getOwner() != materialized.getDefiningOp();
2897 });
2898 }
2899}
2900
2901/// Get the reassociation maps to fold the result of a extract_slice (or source
2902/// of a insert_slice) operation with given offsets, and sizes to its
2903/// rank-reduced version. This is only done for the cases where the size is 1
2904/// and offset is 0. Strictly speaking the offset 0 is not required in general,
2905/// but non-zero offsets are not handled by SPIR-V backend at this point (and
2906/// potentially cannot be handled).
2907std::optional<SmallVector<ReassociationIndices>>
2911 for (const auto &it : llvm::enumerate(mixedSizes)) {
2912 auto dim = it.index();
2913 auto size = it.value();
2914 curr.push_back(dim);
2915 auto attr = llvm::dyn_cast_if_present<Attribute>(size);
2916 if (attr && cast<IntegerAttr>(attr).getInt() == 1)
2917 continue;
2918 reassociation.emplace_back(ReassociationIndices{});
2919 std::swap(reassociation.back(), curr);
2920 }
2921 // When the reassociations are not empty, then fold the remaining
2922 // unit-dimensions into the last dimension. If the reassociations so far is
2923 // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
2924 if (!curr.empty() && !reassociation.empty())
2925 reassociation.back().append(curr.begin(), curr.end());
2926 return reassociation;
2927}
2928
2929} // namespace linalg
2930} // namespace mlir
static SmallVector< int64_t > computePackUnPackPerm(int64_t rank, ArrayRef< int64_t > &innerDimsPos, ArrayRef< int64_t > &outerPerm, PackingMetadata &packingMetadata)
The permutation can be obtained from two permutations: a) Compute the permutation vector to move the ...
Definition Utils.cpp:152
static bool isTiled(AffineExpr expr, ArrayRef< OpFoldResult > tileSizes)
Definition Utils.cpp:76
static void unpackRanges(OpBuilder &builder, Location loc, ArrayRef< Range > ranges, SmallVectorImpl< Value > &lbs, SmallVectorImpl< Value > &ubs, SmallVectorImpl< Value > &steps)
Given a list of subview ranges, extract individual values for lower, upper bounds and steps and put t...
Definition Utils.cpp:128
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition PDL.cpp:62
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Affine binary operation expression.
Definition AffineExpr.h:214
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
Definition AffineExpr.h:239
int64_t getValue() const
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:222
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
MLIRContext * getContext() const
Definition Builders.h:56
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isSignlessIntOrFloat() const
Return true of this is a signless integer or a float type.
Definition Types.cpp:110
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
Helper class for building convolution op matchers with minimal boilerplate.
Definition Utils.cpp:535
ConvMatcherBuilder & matchStride(unsigned iDim, unsigned fDim, unsigned oDim, unsigned idx)
Match stride/dilation pattern for a spatial dimension.
Definition Utils.cpp:563
bool matchBody(bool containsZeroPointOffset=false)
Match body pattern. This should be called last.
Definition Utils.cpp:580
AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx)
Build strided expression: base * stride[idx] + kernel * dilation[idx].
Definition Utils.cpp:557
AffineExpr dim(unsigned i)
Get affine dimension expression for dimension i.
Definition Utils.cpp:554
ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector< int64_t > *d, SmallVector< int64_t > *s, PoolingType poolingType=PoolingType::None)
Definition Utils.cpp:544
ConvMatcherBuilder & matchMaps(ArrayRef< ArrayRef< AffineExpr > > maps)
Match expected indexing maps layout. Returns *this for method chaining.
Definition Utils.cpp:573
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcmQOp >(LinalgOp op)
Definition Utils.cpp:1519
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNchwChwOp >(LinalgOp op)
Definition Utils.cpp:1376
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcOp >(LinalgOp op)
Definition Utils.cpp:1411
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv3DNdhwcDhwcmOp >(LinalgOp op)
Definition Utils.cpp:1633
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNcwMaxOp >(LinalgOp op)
Definition Utils.cpp:2033
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
Definition Utils.cpp:2850
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcMinOp >(LinalgOp op)
Definition Utils.cpp:2063
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition Utils.cpp:197
static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body, bool containsZeroPointOffset=false)
Utility to match block body for convolution ops.
Definition Utils.cpp:331
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv1DNwcWcmOp >(LinalgOp op)
Definition Utils.cpp:1344
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition Utils.cpp:232
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv1DNcwFcwOp >(LinalgOp op)
Definition Utils.cpp:669
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNchwSumOp >(LinalgOp op)
Definition Utils.cpp:1844
SmallVector< OpFoldResult > computeTileSizes(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds)
Computes tile sizes, given a list of tileSizes and dimension sizes (sizeBounds).
Definition Utils.cpp:2745
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcMaxUnsignedOp >(LinalgOp op)
Definition Utils.cpp:1774
GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to)
Returns GenericOp that copies an n-D memref.
Definition Utils.cpp:2301
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwgcGfhwcOp >(LinalgOp op)
Definition Utils.cpp:1057
static void generateParallelLoopNest(OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ArrayRef< utils::IteratorType > iteratorTypes, ArrayRef< linalg::ProcInfo > procInfo, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, SmallVectorImpl< Value > &ivStorage)
Generates a loop nest consisting of scf.parallel and scf.for, depending on the iteratorTypes.
Definition Utils.cpp:2421
SmallVector< OpFoldResult > computeTileOffsets(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes)
Computes tile offsets, given a list of loop ivs and tileSizes.
Definition Utils.cpp:2731
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNgchwGfchwQOp >(LinalgOp op)
Definition Utils.cpp:1019
PoolingType
Enum representing pooling operation types used by ConvMatcherBuilder.
Definition Utils.cpp:509
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv1DNcwCwOp >(LinalgOp op)
Definition Utils.cpp:1282
static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body)
Definition Utils.cpp:396
static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp, Block *body)
Utility function to match the zero point offset body of quantized convolution ops.
Definition Utils.cpp:275
static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body)
Definition Utils.cpp:382
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcmOp >(LinalgOp op)
Definition Utils.cpp:1483
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv1DNwcWcOp >(LinalgOp op)
Definition Utils.cpp:1313
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition Utils.cpp:236
bool hasOnlyScalarElementwiseOp(Region &r)
Detect whether r has only ConstantOp, ElementwiseMappable and YieldOp.
Definition Utils.cpp:203
static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex, uint32_t dimIndex)
Definition Utils.cpp:407
static BlockArgument getBlockArgumentWithOptionalCastOps(Value val)
Returns the BlockArgument that leads to val, if any.
Definition Utils.cpp:246
static bool bodyMatcherForPoolOps(Value yieldVal, Block *body)
Utility to match block body for linalg.pool* ops.
Definition Utils.cpp:366
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNchwFchwQOp >(LinalgOp op)
Definition Utils.cpp:910
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv1DNwcWcfOp >(LinalgOp op)
Definition Utils.cpp:638
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
Definition Utils.cpp:2908
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcMinUnsignedOp >(LinalgOp op)
Definition Utils.cpp:2093
DistributionMethod
Scheme used to distribute loops to processors.
Definition Utils.h:276
@ None
No Distribution.
Definition Utils.h:321
@ CyclicNumProcsGeNumIters
Cyclic distribution where the number of processors can be assumed to be more than or equal to the num...
Definition Utils.h:306
@ Cyclic
Cyclic distribution where no assumption is made about the dynamic relationship between number of proc...
Definition Utils.h:288
@ CyclicNumProcsEqNumIters
Cyclic distribution where the number of processors can be assumed to be equal to the number of iterat...
Definition Utils.h:318
static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body)
Definition Utils.cpp:387
SmallVector< Value > insertSlicesBack(OpBuilder &builder, Location loc, LinalgOp op, ValueRange operands, ValueRange results)
Creates insert_slice ops that insert results back into larger tensors they were originally extracted ...
Definition Utils.cpp:2770
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv3DNcdhwCdhwOp >(LinalgOp op)
Definition Utils.cpp:1595
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition Utils.cpp:217
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv3DNdhwcDhwcOp >(LinalgOp op)
Definition Utils.cpp:1557
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcMaxUnsignedOp >(LinalgOp op)
Definition Utils.cpp:2002
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
Definition Utils.cpp:2872
static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body)
Matches sum pooling body pattern.
Definition Utils.cpp:402
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwcHwcfOp >(LinalgOp op)
Definition Utils.cpp:731
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv3DNcdhwFcdhwOp >(LinalgOp op)
Definition Utils.cpp:1244
SmallVector< std::optional< SliceParameters > > computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Computes SliceParamaters for all valuesToTile of the given linalgOp, assuming linalgOp is being fused...
Definition Utils.cpp:2799
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DOp >(LinalgOp op)
Definition Utils.cpp:700
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwcHwcfQOp >(LinalgOp op)
Definition Utils.cpp:766
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNchwFchwOp >(LinalgOp op)
Definition Utils.cpp:875
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwcFhwcQOp >(LinalgOp op)
Definition Utils.cpp:838
Operation * makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Creates an extract_slice/subview op for a single valueToTile with builder.
Definition Utils.cpp:2601
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv3DOp >(LinalgOp op)
Definition Utils.cpp:1131
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNchwMaxOp >(LinalgOp op)
Definition Utils.cpp:1878
static bool convLayoutMatches(ArrayRef< ArrayRef< AffineExpr > > mapListExpected, ArrayAttr indexingMaps, MLIRContext *context)
Returns true if the given indexing maps matches with the expected indexing maps.
Definition Utils.cpp:496
static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body)
Definition Utils.cpp:391
static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim, int64_t &dilation, int64_t &stride)
Given an array of AffineMaps indexingMaps verify the following commutatively:- indexingMaps[0]....
Definition Utils.cpp:462
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcSumOp >(LinalgOp op)
Definition Utils.cpp:1740
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv3DNdhwcDhwcfQOp >(LinalgOp op)
Definition Utils.cpp:1204
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwgcGfhwcQOp >(LinalgOp op)
Definition Utils.cpp:1093
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcMinOp >(LinalgOp op)
Definition Utils.cpp:1706
static Operation * materializeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, const SliceParameters &sliceParams)
Definition Utils.cpp:2582
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNgchwFgchwOp >(LinalgOp op)
Definition Utils.cpp:947
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNgchwGfchwOp >(LinalgOp op)
Definition Utils.cpp:983
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcMaxOp >(LinalgOp op)
Definition Utils.cpp:1672
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNdhwcMinOp >(LinalgOp op)
Definition Utils.cpp:2198
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNhwcMinUnsignedOp >(LinalgOp op)
Definition Utils.cpp:1809
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value padding, bool nofold, ValueRange typeDynDims={})
Create a tensor::PadOp that pads source to the shape of type whose sizes are assumed to be greater th...
Definition Utils.cpp:2233
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv3DNdhwcDhwcfOp >(LinalgOp op)
Definition Utils.cpp:1166
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcSumOp >(LinalgOp op)
Definition Utils.cpp:1912
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv1DOp >(LinalgOp op)
Definition Utils.cpp:610
static int64_t isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim)
Check if expr is either:
Definition Utils.cpp:420
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNdhwcSumOp >(LinalgOp op)
Definition Utils.cpp:2124
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcQOp >(LinalgOp op)
Definition Utils.cpp:1446
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::Conv2DNhwcFhwcOp >(LinalgOp op)
Definition Utils.cpp:803
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNwcMaxOp >(LinalgOp op)
Definition Utils.cpp:1972
void updateBoundsForCyclicDistribution(OpBuilder &builder, Location loc, Value procId, Value nprocs, Value &lb, Value &ub, Value &step)
Update the lb, ub and step to get per processor lb, ub and step.
Definition Utils.cpp:2400
SmallVector< Type > getTensorOutputTypes(LinalgOp op, ValueRange operands)
Returns the list of tensor output types produced when the given structured operation op is applied to...
Definition Utils.cpp:2761
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNcwSumOp >(LinalgOp op)
Definition Utils.cpp:1942
SliceParameters computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Computes SliceParameters for a single valueToTile assuming that its user is being tiled with the give...
Definition Utils.cpp:2614
std::optional< DilationsAndStrides > matchConvolutionOpOfType< linalg::PoolingNdhwcMaxOp >(LinalgOp op)
Definition Utils.cpp:2161
auto m_Val(Value v)
Definition Matchers.h:539
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition SCF.cpp:777
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Definition Utils.cpp:23
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
detail::NameOpMatcher m_Op(StringRef opName)
Matches a named operation.
Definition Matchers.h:379
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition Matchers.h:484
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Helper struct to build simple arithmetic quantities with minimal type inference support.
Definition Utils.h:103
Value _and(Value lhs, Value rhs)
Definition Utils.cpp:312
Value slt(Value lhs, Value rhs)
Definition Utils.cpp:335
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
A struct containing dilations and strides inferred from convolution ops.
Definition Utils.h:110
Utility class used to generate nested loops with ranges described by loopRanges and loop type describ...
Definition Utils.h:390
static void doit(OpBuilder &b, Location loc, ArrayRef< Range > loopRanges, LinalgOp linalgOp, ArrayRef< utils::IteratorType > iteratorTypes, function_ref< scf::ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilderFn, ArrayRef< linalg::ProcInfo > procInfo={})
Callback function type used to get processor ID, and number of processors used for distribution for a...
Definition Utils.h:326
DistributionMethod distributionMethod
Definition Utils.h:329
static std::optional< BinaryOpKind > matchAsScalarBinaryOp(GenericOp op)
Matches the given linalg op if its body is performing binary operation on int or float scalar values ...
Definition Utils.cpp:95
A struct containg offsets-sizes-strides arguments of the tiled shape.
Definition Utils.h:172
SmallVector< OpFoldResult > strides
Definition Utils.h:175
SmallVector< OpFoldResult > sizes
Definition Utils.h:174
SmallVector< OpFoldResult > offsets
Definition Utils.h:173
LoopVector loops
Definition SCF.h:67