MLIR 22.0.0git
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
29#include "mlir/IR/AffineExpr.h"
31#include "mlir/IR/AffineMap.h"
32#include "mlir/IR/Matchers.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Debug.h"
35#include <optional>
36
37#define DEBUG_TYPE "linalg-utils"
38
39using namespace mlir;
40using namespace presburger;
41using namespace mlir::affine;
42using namespace mlir::linalg;
43using namespace mlir::scf;
44
45namespace {
46
47// Helper visitor to determine whether an AffineExpr is tiled.
48// This is achieved by traversing every AffineDimExpr with position `pos` and
49// checking whether the corresponding `tileSizes[pos]` is non-zero.
50// This also enforces only positive coefficients occur in multiplications.
51//
52// Example:
53// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
54//
55struct TileCheck : public AffineExprVisitor<TileCheck> {
56 TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
57
58 void visitDimExpr(AffineDimExpr expr) {
59 isTiled |= !isZeroInteger(tileSizes[expr.getPosition()]);
60 }
61 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
62 visit(expr.getLHS());
63 visit(expr.getRHS());
65 assert(cast<AffineConstantExpr>(expr.getRHS()).getValue() > 0 &&
66 "nonpositive multiplying coefficient");
67 }
68 bool isTiled = false;
69 ArrayRef<OpFoldResult> tileSizes;
70};
71
72} // namespace
73
74static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
75 if (!expr)
76 return false;
77 TileCheck t(tileSizes);
78 t.visit(expr);
79 return t.isTiled;
80}
81
82// Checks whether the `map varies with respect to a non-zero `tileSize`.
83static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
84 if (!map)
85 return false;
86 for (unsigned r = 0; r < map.getNumResults(); ++r)
87 if (isTiled(map.getResult(r), tileSizes))
88 return true;
89 return false;
90}
91
92std::optional<RegionMatcher::BinaryOpKind>
94 auto &region = op.getRegion();
95 if (!region.hasOneBlock())
96 return std::nullopt;
97
98 Block &block = region.front();
99 if (block.getNumArguments() != 2 ||
102 return std::nullopt;
103
104 auto &ops = block.getOperations();
105 if (!llvm::hasSingleElement(block.without_terminator()))
106 return std::nullopt;
107
109 auto a = m_Val(block.getArgument(0));
110 auto b = m_Val(block.getArgument(1));
111
112 auto addPattern = m_Op<linalg::YieldOp>(m_Op<arith::AddIOp>(a, b));
113 if (addPattern.match(&ops.back()))
114 return BinaryOpKind::IAdd;
115
116 return std::nullopt;
117}
118
119/// Explicit instantiation of loop nest generator for different loop types.
123
124/// Given a list of subview ranges, extract individual values for lower, upper
125/// bounds and steps and put them into the corresponding vectors.
126static void unpackRanges(OpBuilder &builder, Location loc,
129 SmallVectorImpl<Value> &steps) {
130 for (Range range : ranges) {
131 lbs.emplace_back(
132 getValueOrCreateConstantIndexOp(builder, loc, range.offset));
133 ubs.emplace_back(getValueOrCreateConstantIndexOp(builder, loc, range.size));
134 steps.emplace_back(
135 getValueOrCreateConstantIndexOp(builder, loc, range.stride));
136 }
137}
138
139//===----------------------------------------------------------------------===//
140// General utilities
141//===----------------------------------------------------------------------===//
142//
143/// The permutation can be obtained from two permutations:
144/// a) Compute the permutation vector to move the last `numPackedDims` into
145/// the `innerPosDims` of a shape of rank `rank`.
146/// b) Compute the permutation vector to move outer dims if the
147/// `outerPerm` parameter is not empty.
148/// Apply (b) permutation on (a) permutation to get the final permutation.
149static SmallVector<int64_t>
151 ArrayRef<int64_t> &outerPerm,
152 PackingMetadata &packingMetadata) {
153 int64_t numPackedDims = innerDimsPos.size();
154 auto lastDims =
155 llvm::to_vector(llvm::seq<int64_t>(rank - numPackedDims, rank));
156 packingMetadata = computePackingMetadata(rank, innerDimsPos);
157 SmallVector<int64_t> innerPositionsPerm =
158 computePermutationVector(rank, lastDims, packingMetadata.insertPositions);
159
160 SmallVector<int64_t> outerPos = packingMetadata.outerPositions;
161 if (!outerPerm.empty())
162 applyPermutationToVector(outerPos, outerPerm);
163 SmallVector<int64_t> outerPositionPerm =
164 computePermutationVector(rank, packingMetadata.outerPositions, outerPos);
165
166 SmallVector<int64_t> packInverseDestPermutation = innerPositionsPerm;
167 applyPermutationToVector(packInverseDestPermutation, outerPositionPerm);
168 return packInverseDestPermutation;
169}
170
171namespace mlir {
172namespace linalg {
173
175 PackingMetadata &metadata) {
176
177 int64_t packedRank = packOp.getDestType().getRank();
178 ArrayRef<int64_t> innerDimPos = packOp.getInnerDimsPos();
179 ArrayRef<int64_t> outerPerm = packOp.getOuterDimsPerm();
180 SmallVector<int64_t> packInvDestPerm =
181 computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
182 return packInvDestPerm;
183}
184
186 PackingMetadata &metadata) {
187 int64_t packedRank = unpackOp.getSourceType().getRank();
188 ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
189 ArrayRef<int64_t> outerPerm = unpackOp.getOuterDimsPerm();
190 SmallVector<int64_t> unpackInvSrcPerm =
191 computePackUnPackPerm(packedRank, innerDimPos, outerPerm, metadata);
192 return unpackInvSrcPerm;
193}
194
196 return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) {
197 return m.isProjectedPermutation(/*allowZeroInResults=*/true);
198 });
199}
200
202 if (!r.hasOneBlock())
203 return false;
204 for (Operation &op : r.front()) {
205 if (!(isa<arith::ConstantOp, func::ConstantOp, tensor::ExtractOp,
206 linalg::YieldOp, linalg::IndexOp, AffineApplyOp>(op) ||
208 llvm::any_of(op.getResultTypes(),
209 [](Type type) { return !type.isIntOrIndexOrFloat(); }))
210 return false;
211 }
212 return true;
213}
214
215bool isElementwise(LinalgOp op) {
216 if (op.getNumLoops() != op.getNumParallelLoops())
217 return false;
218
220 return false;
221
222 // TODO: relax the restrictions on indexing map.
223 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
224 if (!op.getMatchingIndexingMap(&opOperand).isPermutation())
225 return false;
226 }
227 return hasOnlyScalarElementwiseOp(op->getRegion(0));
228}
229
230bool isParallelIterator(utils::IteratorType iteratorType) {
231 return iteratorType == utils::IteratorType::parallel;
232}
233
234bool isReductionIterator(utils::IteratorType iteratorType) {
235 return iteratorType == utils::IteratorType::reduction;
236}
237
238//===----------------------------------------------------------------------===//
239// Convolution matcher utilities
240//===----------------------------------------------------------------------===//
241
242/// Returns the BlockArgument that leads to `val`, if any. Traverses optional
243/// ext*/sitofp ops.
245 BlockArgument blockArg = dyn_cast<BlockArgument>(val);
246 if ((blockArg))
247 return blockArg;
248
249 Operation *defOp = val.getDefiningOp();
250 if (!dyn_cast_if_present<arith::ExtFOp>(defOp) &&
251 !dyn_cast_if_present<arith::ExtSIOp>(defOp) &&
252 !dyn_cast_if_present<arith::ExtUIOp>(defOp) &&
253 !dyn_cast_if_present<arith::SIToFPOp>(defOp)) {
254 return nullptr;
255 }
256 return dyn_cast<BlockArgument>(defOp->getOperand(0));
257}
258
259/// Utility function to match the zero point offset body of quantized
260/// convolution ops.
261///
262/// Quantized convolutions have a body of the form:
263/// %out + ((%input - %inputZp) * (%filter - %filterZp))
264/// where:
265/// - %input is the input tensor element (block arg 0)
266/// - %filter is the filter tensor element (block arg 1)
267/// - %inputZp is the input zero-point scalar (block arg 2)
268/// - %filterZp is the filter zero-point scalar (block arg 3)
269/// - %out is the output accumulator (block arg 4)
270///
271/// This function verifies that the multiplication operands are subtraction
272/// operations matching this pattern.
274 Block *body) {
275 // The multiplication should have two subtraction operands:
276 // one for (input - inputZp) and one for (filter - filterZp).
277 Operation *inputSubOp = mulOp->getOperand(0).getDefiningOp();
278 if (!isa_and_present<arith::SubIOp, arith::SubFOp>(inputSubOp))
279 return false;
280
281 Operation *filterSubOp = mulOp->getOperand(1).getDefiningOp();
282 if (!isa_and_present<arith::SubIOp, arith::SubFOp>(filterSubOp))
283 return false;
284
285 // Extract block arguments from subtraction operands.
286 BlockArgument inputBlockArg =
288 BlockArgument inputZpBlockArg =
290 BlockArgument filterBlockArg =
292 BlockArgument filterZpBlockArg =
294 BlockArgument outBlockArg =
296
297 // Verify all block arguments are valid.
298 if (!inputBlockArg || !inputZpBlockArg || !filterBlockArg ||
299 !filterZpBlockArg || !outBlockArg)
300 return false;
301
302 // Verify all block arguments belong to the convolution body.
303 if (inputBlockArg.getOwner() != body || inputZpBlockArg.getOwner() != body ||
304 filterBlockArg.getOwner() != body ||
305 filterZpBlockArg.getOwner() != body || outBlockArg.getOwner() != body)
306 return false;
307
308 // Verify block arguments have expected indices:
309 // arg0: input, arg1: filter, arg2: inputZp, arg3: filterZp, arg4: output
310 if (inputBlockArg.getArgNumber() != 0 || filterBlockArg.getArgNumber() != 1 ||
311 inputZpBlockArg.getArgNumber() != 2 ||
312 filterZpBlockArg.getArgNumber() != 3 || outBlockArg.getArgNumber() != 4)
313 return false;
314
315 return true;
316}
317
318/// Utility to match block body for convolution ops.
319/// The body is thus expected to yield :-
320/// %out + (%lhs * %rhs)
321/// where: %lhs, %rhs and %out are block arguments and
322/// %lhs and %rhs can have optional upcast operation.
323/// NOTE: In case of zero point offset convolution ops %lhs and %rhs would be :-
324/// %input - %input_scalar
325/// where, %input_scalar can have optional upcast operation.
326static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body,
327 bool containsZeroPointOffset = false) {
328 Operation *addOp = yieldVal.getDefiningOp();
329 if (!isa_and_present<arith::AddIOp, arith::AddFOp>(addOp))
330 return false;
331
332 Operation *mulOp = addOp->getOperand(1).getDefiningOp();
333 if (!isa_and_present<arith::MulIOp, arith::MulFOp>(mulOp))
334 return false;
335
336 if (containsZeroPointOffset) {
337 return bodyMatcherForZeroPointOffsets(addOp, mulOp, body);
338 }
339 BlockArgument lhsBlockArg =
341 BlockArgument rhsBlockArg =
343 BlockArgument outBlockArg =
345 if (!lhsBlockArg || !rhsBlockArg || !outBlockArg ||
346 lhsBlockArg.getOwner() != body || rhsBlockArg.getOwner() != body ||
347 outBlockArg.getOwner() != body || lhsBlockArg.getArgNumber() != 0 ||
348 rhsBlockArg.getArgNumber() != 1 || outBlockArg.getArgNumber() != 2)
349 return false;
350 return true;
351}
352
353/// Utility to match block body for linalg.pool* ops.
354template <typename... OpTypes>
355static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
356 Operation *defOp = yieldVal.getDefiningOp();
357 if (!(isa_and_present<OpTypes>(defOp) || ...))
358 return false;
359
360 BlockArgument lhsArg =
362 BlockArgument rhsArg =
364 if (!lhsArg || !rhsArg || lhsArg.getOwner() != body ||
365 rhsArg.getOwner() != body || lhsArg.getArgNumber() != 2 ||
366 rhsArg.getArgNumber() != 0)
367 return false;
368 return true;
369}
370
371static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
373 body);
374}
375
376// max_unsigned ops should not allow float data type.
377// TODO(#164800): Retire OPDSL logic.
378static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
380 body);
381}
382
383static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
385 body);
386}
387
388// min_unsigned ops should not allow float data type.
389// TODO(#164800): Retire OPDSL logic.
390static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
392 body);
393}
394
395static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
397}
398
399static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex,
400 uint32_t dimIndex) {
401 auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
402 if (dimIndex < affineMap.getNumResults())
403 return affineMap.getResult(dimIndex);
404 return nullptr;
405}
406
407/// Check if `expr` is either:
408/// - a dimension expr alone (implying multiplication by 1), or
409/// - a multiplication of dimension expr by any positive constant != 1
410/// In both cases we will capture the dimension expression into `dim` and
411/// return the constant multiplier. Returns -1 in case of a match failure.
413 if ((dim = dyn_cast<AffineDimExpr>(expr)))
414 return 1;
415
416 auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
417 if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
418 return -1;
419
420 AffineExpr lhs = mulExpr.getLHS();
421 AffineExpr rhs = mulExpr.getRHS();
422
423 AffineConstantExpr cst = nullptr;
424 if (((dim = dyn_cast<AffineDimExpr>(lhs)) &&
425 (cst = dyn_cast<AffineConstantExpr>(rhs))) ||
426 ((dim = dyn_cast<AffineDimExpr>(rhs)) &&
427 (cst = dyn_cast<AffineConstantExpr>(lhs))))
428 return cst.getValue();
429 return -1;
430}
431
432/// Given an array of AffineMaps `indexingMaps` verify the following
433/// commutatively:-
434/// indexingMaps[0].getResult(iDim) ==
435/// indexingMaps[1].getResult(fDim) * <c0> +
436/// indexingMaps[n-1].getResult(oDim) * <c1>
437/// where,
438/// - c0 and c1 can be any constant,
439/// - n is the size of the indexingMaps' array,
440/// - 0, 1 and n-1 are input, filter and output map indices respectively,
441/// - iDim, fDim and oDim are the input, filter and output dimension
442/// indices in their respective indexing maps
443/// Example:
444/// #inputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6)
445/// -> (d0, d1 * 2 + d4 * 3, d2 + d5, d6)>
446/// #filterMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
447/// #outputMap = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
448///
449/// Here,
450/// #inputMap[1] = #outputMap[1] * 2 + #filterMap[0] * 3
451/// Therefore,
452/// matchConvDimAddExprPattern(indexingMaps, 1, 0, 1, dilation, stride)
453/// would return true and update dilation = 3 and stride = 2
454static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim,
455 unsigned fDim, unsigned oDim,
456 int64_t &dilation, int64_t &stride) {
457 unsigned inputMapIdx = 0, filterMapIdx = 1,
458 outputMapIdx = indexingMaps.size() - 1;
459 AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim);
460 auto addExpr = dyn_cast_or_null<AffineBinaryOpExpr>(inpExpr);
461 if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
462 return false;
463
464 AffineExpr dim0, dim1;
465 int64_t c0 = isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0);
466 int64_t c1 = isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1);
467
468 if (c0 == -1 || c1 == -1)
469 return false;
470 // Pattern matched with dims and constants extracted.
471 AffineExpr fExpr = getAffineMapDim(indexingMaps, filterMapIdx, fDim);
472 AffineExpr oExpr = getAffineMapDim(indexingMaps, outputMapIdx, oDim);
473 if (dim0 == fExpr && dim1 == oExpr) {
474 dilation = c0;
475 stride = c1;
476 return true;
477 }
478 if (dim1 == fExpr && dim0 == oExpr) {
479 dilation = c1;
480 stride = c0;
481 return true;
482 }
483 return false;
484}
485
486/// Returns true if the given indexing maps matches with the expected indexing
487/// maps.
489 ArrayAttr indexingMaps, MLIRContext *context) {
490 SmallVector<AffineMap, 4> expectedIndexingMaps =
491 AffineMap::inferFromExprList(mapListExpected, context);
492 return indexingMaps ==
493 ArrayAttr::get(
494 context, llvm::to_vector<4>(llvm::map_range(
495 expectedIndexingMaps, [&](AffineMap m) -> Attribute {
496 return AffineMapAttr::get(m);
497 })));
498}
499
500/// Enum representing pooling operation types used by ConvMatcherBuilder.
509
510/// Helper class for building convolution op matchers with minimal boilerplate.
511/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well
512/// as Pooling ops.
513///
514/// Usage: Create an instance with the op, spatial rank, and output pointers for
515/// extracted dilations/strides. Then chain matchStride() calls for each spatial
516/// dimension, followed by matchMaps() to verify indexing maps, and finally
517/// matchBody() to verify the operation body pattern.
518///
519/// The `matched` flag starts as `true` and is set to `false` if any match step
520/// fails. This allows chaining multiple match calls; once any match fails, all
521/// subsequent calls become no-ops and the final result is `false`.
522///
523/// The `dilations` and `strides` pointers are output parameters that get
524/// populated with the extracted dilation and stride values from the operation's
525/// indexing maps during matchStride() calls. These values are initially set to
526/// 1 for each spatial dimension and updated as patterns are matched.
528 LinalgOp op;
529 MLIRContext *ctx;
530 SmallVector<int64_t> *dilations, *strides;
531 ArrayAttr indexingMaps;
532 PoolingType poolingType;
533 bool matched = true;
534
535public:
536 ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector<int64_t> *d,
538 PoolingType poolingType = PoolingType::None)
539 : op(op), ctx(op->getContext()), dilations(d), strides(s),
540 indexingMaps(op.getIndexingMaps()), poolingType(poolingType) {
541 *dilations = SmallVector<int64_t>(spatialRank, 1);
542 *strides = SmallVector<int64_t>(spatialRank, 1);
543 }
544
545 /// Get affine dimension expression for dimension `i`.
546 AffineExpr dim(unsigned i) { return getAffineDimExpr(i, ctx); }
547
548 /// Build strided expression: base * stride[idx] + kernel * dilation[idx].
549 AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx) {
550 return base * (*strides)[idx] + kernel * (*dilations)[idx];
551 }
552
553 /// Match stride/dilation pattern for a spatial dimension.
554 /// Returns *this for method chaining.
555 ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim,
556 unsigned idx) {
557 if (matched) {
558 matched &= matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
559 (*dilations)[idx], (*strides)[idx]);
560 }
561 return *this;
562 }
563
564 /// Match expected indexing maps layout. Returns *this for method chaining.
566 if (matched)
567 matched &= convLayoutMatches(maps, indexingMaps, ctx);
568 return *this;
569 }
570
571 /// Match body pattern. This should be called last.
572 bool matchBody(bool containsZeroPointOffset = false) {
573 if (!matched)
574 return false;
575 Block *body = op.getBlock();
576 auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
577 switch (poolingType) {
579 return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body,
580 containsZeroPointOffset);
582 return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body);
584 return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body);
586 return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body);
588 return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body);
589 case PoolingType::Sum:
590 return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body);
591 }
592 return false;
593 }
594};
595
596//===----------------------------------------------------------------------===//
597// Matchers for specific convolution operation.
598//===----------------------------------------------------------------------===//
599
600template <>
602 SmallVector<int64_t> *dilations,
603 SmallVector<int64_t> *strides) {
604 if (isa<linalg::Conv1DOp>(op))
605 return true;
606
607 assert(isaConvolutionOpInterface(op) &&
608 "expected op to implement ConvolutionOpInterface");
609
610 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
611 AffineExpr W = m.dim(0);
612 AffineExpr w = m.dim(1);
613
614 return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
615 .matchMaps({/*inputMap=*/{m.strided(W, w, 0)},
616 /*filterMap=*/{w},
617 /*outputMap=*/{W}})
618 .matchBody();
619}
620
621template <>
623 LinalgOp op, SmallVector<int64_t> *dilations,
624 SmallVector<int64_t> *strides) {
625 if (isa<linalg::Conv1DNwcWcfOp>(op))
626 return true;
627
628 assert(isaConvolutionOpInterface(op) &&
629 "expected op to implement ConvolutionOpInterface");
630
631 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
632 AffineExpr N = m.dim(0);
633 AffineExpr W = m.dim(1);
634 AffineExpr F = m.dim(2);
635 AffineExpr w = m.dim(3);
636 AffineExpr c = m.dim(4);
637
638 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
639 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
640 /*filterMap=*/{w, c, F},
641 /*outputMap=*/{N, W, F}})
642 .matchBody();
643}
644
645template <>
647 LinalgOp op, SmallVector<int64_t> *dilations,
648 SmallVector<int64_t> *strides) {
649 if (isa<linalg::Conv1DNcwFcwOp>(op))
650 return true;
651
652 assert(isaConvolutionOpInterface(op) &&
653 "expected op to implement ConvolutionOpInterface");
654
655 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
656 AffineExpr N = m.dim(0);
657 AffineExpr F = m.dim(1);
658 AffineExpr W = m.dim(2);
659 AffineExpr c = m.dim(3);
660 AffineExpr w = m.dim(4);
661
662 return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
663 .matchMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
664 /*filterMap=*/{F, c, w},
665 /*outputMap=*/{N, F, W}})
666 .matchBody();
667}
668
669template <>
671 SmallVector<int64_t> *dilations,
672 SmallVector<int64_t> *strides) {
673 if (isa<linalg::Conv2DOp>(op))
674 return true;
675
676 assert(isaConvolutionOpInterface(op) &&
677 "expected op to implement ConvolutionOpInterface");
678
679 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
680 AffineExpr H = m.dim(0);
681 AffineExpr W = m.dim(1);
682 AffineExpr h = m.dim(2);
683 AffineExpr w = m.dim(3);
684
685 return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
686 .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
687 .matchMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
688 /*filterMap=*/{h, w},
689 /*outputMap=*/{H, W}})
690 .matchBody();
691}
692
693template <>
695 LinalgOp op, SmallVector<int64_t> *dilations,
696 SmallVector<int64_t> *strides) {
697 if (isa<linalg::Conv2DNhwcHwcfOp>(op))
698 return true;
699
700 assert(isaConvolutionOpInterface(op) &&
701 "expected op to implement ConvolutionOpInterface");
702
703 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
704 AffineExpr N = m.dim(0);
705 AffineExpr H = m.dim(1);
706 AffineExpr W = m.dim(2);
707 AffineExpr F = m.dim(3);
708 AffineExpr h = m.dim(4);
709 AffineExpr w = m.dim(5);
710 AffineExpr c = m.dim(6);
711
712 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
713 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
714 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
715 /*filterMap=*/{h, w, c, F},
716 /*outputMap=*/{N, H, W, F}})
717 .matchBody();
718}
719
720template <>
722 LinalgOp op, SmallVector<int64_t> *dilations,
723 SmallVector<int64_t> *strides) {
724 if (isa<linalg::Conv2DNhwcHwcfQOp>(op))
725 return true;
726
727 assert(isaConvolutionOpInterface(op) &&
728 "expected op to implement ConvolutionOpInterface");
729
730 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
731 AffineExpr N = m.dim(0);
732 AffineExpr H = m.dim(1);
733 AffineExpr W = m.dim(2);
734 AffineExpr F = m.dim(3);
735 AffineExpr h = m.dim(4);
736 AffineExpr w = m.dim(5);
737 AffineExpr c = m.dim(6);
738
739 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
740 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
741 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
742 /*filterMap=*/{h, w, c, F},
743 /*scalarMap=*/{},
744 /*scalarMap=*/{},
745 /*outputMap=*/{N, H, W, F}})
746 .matchBody(/*containsZeroPointOffset=*/true);
747}
748
749template <>
751 LinalgOp op, SmallVector<int64_t> *dilations,
752 SmallVector<int64_t> *strides) {
753 if (isa<linalg::Conv2DNhwcFhwcOp>(op))
754 return true;
755
756 assert(isaConvolutionOpInterface(op) &&
757 "expected op to implement ConvolutionOpInterface");
758
759 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
760 AffineExpr N = m.dim(0);
761 AffineExpr H = m.dim(1);
762 AffineExpr W = m.dim(2);
763 AffineExpr F = m.dim(3);
764 AffineExpr h = m.dim(4);
765 AffineExpr w = m.dim(5);
766 AffineExpr c = m.dim(6);
767
768 return m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0)
769 .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1)
770 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
771 /*filterMap=*/{F, h, w, c},
772 /*outputMap=*/{N, H, W, F}})
773 .matchBody();
774}
775
776template <>
778 LinalgOp op, SmallVector<int64_t> *dilations,
779 SmallVector<int64_t> *strides) {
780 if (isa<linalg::Conv2DNhwcFhwcQOp>(op))
781 return true;
782
783 assert(isaConvolutionOpInterface(op) &&
784 "expected op to implement ConvolutionOpInterface");
785
786 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
787 AffineExpr N = m.dim(0);
788 AffineExpr H = m.dim(1);
789 AffineExpr W = m.dim(2);
790 AffineExpr F = m.dim(3);
791 AffineExpr h = m.dim(4);
792 AffineExpr w = m.dim(5);
793 AffineExpr c = m.dim(6);
794
795 return m.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/0)
796 .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/1)
797 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), c},
798 /*filterMap=*/{F, h, w, c},
799 /*scalarMap=*/{},
800 /*scalarMap=*/{},
801 /*outputMap=*/{N, H, W, F}})
802 .matchBody(/*containsZeroPointOffset=*/true);
803}
804
805template <>
807 LinalgOp op, SmallVector<int64_t> *dilations,
808 SmallVector<int64_t> *strides) {
809 if (isa<linalg::Conv2DNchwFchwOp>(op))
810 return true;
811
812 assert(isaConvolutionOpInterface(op) &&
813 "expected op to implement ConvolutionOpInterface");
814
815 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
816 AffineExpr N = m.dim(0);
817 AffineExpr F = m.dim(1);
818 AffineExpr H = m.dim(2);
819 AffineExpr W = m.dim(3);
820 AffineExpr c = m.dim(4);
821 AffineExpr h = m.dim(5);
822 AffineExpr w = m.dim(6);
823
824 return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
825 .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
826 .matchMaps({/*inputMap=*/{N, c, m.strided(H, h, 0), m.strided(W, w, 1)},
827 /*filterMap=*/{F, c, h, w},
828 /*outputMap=*/{N, F, H, W}})
829 .matchBody();
830}
831
832template <>
834 LinalgOp op, SmallVector<int64_t> *dilations,
835 SmallVector<int64_t> *strides) {
836 if (isa<linalg::Conv2DNchwFchwQOp>(op))
837 return true;
838
839 assert(isaConvolutionOpInterface(op) &&
840 "expected op to implement ConvolutionOpInterface");
841
842 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
843 AffineExpr N = m.dim(0);
844 AffineExpr F = m.dim(1);
845 AffineExpr H = m.dim(2);
846 AffineExpr W = m.dim(3);
847 AffineExpr c = m.dim(4);
848 AffineExpr h = m.dim(5);
849 AffineExpr w = m.dim(6);
850
851 return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
852 .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
853 .matchMaps({/*inputMap=*/{N, c, m.strided(H, h, 0), m.strided(W, w, 1)},
854 /*filterMap=*/{F, c, h, w},
855 /*scalarMap=*/{},
856 /*scalarMap=*/{},
857 /*outputMap=*/{N, F, H, W}})
858 .matchBody(/*containsZeroPointOffset=*/true);
859}
860
861template <>
863 LinalgOp op, SmallVector<int64_t> *dilations,
864 SmallVector<int64_t> *strides) {
865 if (isa<linalg::Conv2DNgchwFgchwOp>(op))
866 return true;
867
868 assert(isaConvolutionOpInterface(op) &&
869 "expected op to implement ConvolutionOpInterface");
870
871 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
872 AffineExpr N = m.dim(0);
873 AffineExpr G = m.dim(1);
874 AffineExpr F = m.dim(2);
875 AffineExpr H = m.dim(3);
876 AffineExpr W = m.dim(4);
877 AffineExpr c = m.dim(5);
878 AffineExpr h = m.dim(6);
879 AffineExpr w = m.dim(7);
880
881 return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
882 .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
883 .matchMaps(
884 {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
885 /*filterMap=*/{F, G, c, h, w},
886 /*outputMap=*/{N, G, F, H, W}})
887 .matchBody();
888}
889
890template <>
892 LinalgOp op, SmallVector<int64_t> *dilations,
893 SmallVector<int64_t> *strides) {
894 if (isa<linalg::Conv2DNgchwGfchwOp>(op))
895 return true;
896
897 assert(isaConvolutionOpInterface(op) &&
898 "expected op to implement ConvolutionOpInterface");
899
900 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
901 AffineExpr N = m.dim(0);
902 AffineExpr G = m.dim(1);
903 AffineExpr F = m.dim(2);
904 AffineExpr H = m.dim(3);
905 AffineExpr W = m.dim(4);
906 AffineExpr c = m.dim(5);
907 AffineExpr h = m.dim(6);
908 AffineExpr w = m.dim(7);
909
910 return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
911 .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
912 .matchMaps(
913 {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
914 /*filterMap=*/{G, F, c, h, w},
915 /*outputMap=*/{N, G, F, H, W}})
916 .matchBody();
917}
918
919template <>
921 LinalgOp op, SmallVector<int64_t> *dilations,
922 SmallVector<int64_t> *strides) {
923 if (isa<linalg::Conv2DNgchwGfchwQOp>(op))
924 return true;
925
926 assert(isaConvolutionOpInterface(op) &&
927 "expected op to implement ConvolutionOpInterface");
928
929 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
930 AffineExpr N = m.dim(0);
931 AffineExpr G = m.dim(1);
932 AffineExpr F = m.dim(2);
933 AffineExpr H = m.dim(3);
934 AffineExpr W = m.dim(4);
935 AffineExpr c = m.dim(5);
936 AffineExpr h = m.dim(6);
937 AffineExpr w = m.dim(7);
938
939 return m.matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/0)
940 .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/1)
941 .matchMaps(
942 {/*inputMap=*/{N, G, c, m.strided(H, h, 0), m.strided(W, w, 1)},
943 /*filterMap=*/{G, F, c, h, w},
944 /*scalarMap=*/{},
945 /*scalarMap=*/{},
946 /*outputMap=*/{N, G, F, H, W}})
947 .matchBody(/*containsZeroPointOffset=*/true);
948}
949
950template <>
952 LinalgOp op, SmallVector<int64_t> *dilations,
953 SmallVector<int64_t> *strides) {
954 if (isa<linalg::Conv2DNhwgcGfhwcOp>(op))
955 return true;
956
957 assert(isaConvolutionOpInterface(op) &&
958 "expected op to implement ConvolutionOpInterface");
959
960 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
961 AffineExpr N = m.dim(0);
962 AffineExpr H = m.dim(1);
963 AffineExpr W = m.dim(2);
964 AffineExpr G = m.dim(3);
965 AffineExpr F = m.dim(4);
966 AffineExpr h = m.dim(5);
967 AffineExpr w = m.dim(6);
968 AffineExpr c = m.dim(7);
969
970 return m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0)
971 .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1)
972 .matchMaps(
973 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
974 /*filterMap=*/{G, F, h, w, c},
975 /*outputMap=*/{N, H, W, G, F}})
976 .matchBody();
977}
978
979template <>
981 LinalgOp op, SmallVector<int64_t> *dilations,
982 SmallVector<int64_t> *strides) {
983 if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op))
984 return true;
985
986 assert(isaConvolutionOpInterface(op) &&
987 "expected op to implement ConvolutionOpInterface");
988
989 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
990 AffineExpr N = m.dim(0);
991 AffineExpr H = m.dim(1);
992 AffineExpr W = m.dim(2);
993 AffineExpr G = m.dim(3);
994 AffineExpr F = m.dim(4);
995 AffineExpr h = m.dim(5);
996 AffineExpr w = m.dim(6);
997 AffineExpr c = m.dim(7);
998
999 return m.matchStride(/*iDim=*/1, /*fDim=*/2, /*oDim=*/1, /*idx=*/0)
1000 .matchStride(/*iDim=*/2, /*fDim=*/3, /*oDim=*/2, /*idx=*/1)
1001 .matchMaps(
1002 {/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), G, c},
1003 /*filterMap=*/{G, F, h, w, c},
1004 /*scalarMap=*/{},
1005 /*scalarMap=*/{},
1006 /*outputMap=*/{N, H, W, G, F}})
1007 .matchBody(/*containsZeroPointOffset=*/true);
1008}
1009
1010template <>
1012 SmallVector<int64_t> *dilations,
1013 SmallVector<int64_t> *strides) {
1014 if (isa<linalg::Conv3DOp>(op))
1015 return true;
1016
1017 assert(isaConvolutionOpInterface(op) &&
1018 "expected op to implement ConvolutionOpInterface");
1019
1020 ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
1021 AffineExpr D = m.dim(0);
1022 AffineExpr H = m.dim(1);
1023 AffineExpr W = m.dim(2);
1024 AffineExpr d = m.dim(3);
1025 AffineExpr h = m.dim(4);
1026 AffineExpr w = m.dim(5);
1027
1028 return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
1029 .matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
1030 .matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2)
1031 .matchMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
1032 m.strided(W, w, 2)},
1033 /*filterMap=*/{d, h, w},
1034 /*outputMap=*/{D, H, W}})
1035 .matchBody();
1036}
1037
1038template <>
1040 LinalgOp op, SmallVector<int64_t> *dilations,
1041 SmallVector<int64_t> *strides) {
1042 if (isa<linalg::Conv3DNdhwcDhwcfOp>(op))
1043 return true;
1044
1045 assert(isaConvolutionOpInterface(op) &&
1046 "expected op to implement ConvolutionOpInterface");
1047
1048 ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
1049 AffineExpr N = m.dim(0);
1050 AffineExpr D = m.dim(1);
1051 AffineExpr H = m.dim(2);
1052 AffineExpr W = m.dim(3);
1053 AffineExpr F = m.dim(4);
1054 AffineExpr d = m.dim(5);
1055 AffineExpr h = m.dim(6);
1056 AffineExpr w = m.dim(7);
1057 AffineExpr c = m.dim(8);
1058
1059 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1060 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1061 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1062 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1063 m.strided(W, w, 2), c},
1064 /*filterMap=*/{d, h, w, c, F},
1065 /*outputMap=*/{N, D, H, W, F}})
1066 .matchBody();
1067}
1068
1069template <>
1071 LinalgOp op, SmallVector<int64_t> *dilations,
1072 SmallVector<int64_t> *strides) {
1073 if (isa<linalg::Conv3DNdhwcDhwcfQOp>(op))
1074 return true;
1075
1076 assert(isaConvolutionOpInterface(op) &&
1077 "expected op to implement ConvolutionOpInterface");
1078
1079 ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
1080 AffineExpr N = m.dim(0);
1081 AffineExpr D = m.dim(1);
1082 AffineExpr H = m.dim(2);
1083 AffineExpr W = m.dim(3);
1084 AffineExpr F = m.dim(4);
1085 AffineExpr d = m.dim(5);
1086 AffineExpr h = m.dim(6);
1087 AffineExpr w = m.dim(7);
1088 AffineExpr c = m.dim(8);
1089
1090 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1091 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1092 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1093 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1094 m.strided(W, w, 2), c},
1095 /*filterMap=*/{d, h, w, c, F},
1096 /*scalarMap=*/{},
1097 /*scalarMap=*/{},
1098 /*outputMap=*/{N, D, H, W, F}})
1099 .matchBody(/*containsZeroPointOffset=*/true);
1100}
1101
1102template <>
1104 LinalgOp op, SmallVector<int64_t> *dilations,
1105 SmallVector<int64_t> *strides) {
1106 if (isa<linalg::Conv3DNcdhwFcdhwOp>(op))
1107 return true;
1108
1109 assert(isaConvolutionOpInterface(op) &&
1110 "expected op to implement ConvolutionOpInterface");
1111
1112 ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
1113 AffineExpr N = m.dim(0);
1114 AffineExpr F = m.dim(1);
1115 AffineExpr D = m.dim(2);
1116 AffineExpr H = m.dim(3);
1117 AffineExpr W = m.dim(4);
1118 AffineExpr c = m.dim(5);
1119 AffineExpr d = m.dim(6);
1120 AffineExpr h = m.dim(7);
1121 AffineExpr w = m.dim(8);
1122
1123 return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
1124 .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1)
1125 .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/2)
1126 .matchMaps({/*inputMap=*/{N, c, m.strided(D, d, 0), m.strided(H, h, 1),
1127 m.strided(W, w, 2)},
1128 /*filterMap=*/{F, c, d, h, w},
1129 /*outputMap=*/{N, F, D, H, W}})
1130 .matchBody();
1131}
1132
1133template <>
1135 LinalgOp op, SmallVector<int64_t> *dilations,
1136 SmallVector<int64_t> *strides) {
1137 if (isa<linalg::DepthwiseConv1DNcwCwOp>(op))
1138 return true;
1139
1140 assert(isaConvolutionOpInterface(op) &&
1141 "expected op to implement ConvolutionOpInterface");
1142
1143 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
1144 AffineExpr N = m.dim(0);
1145 AffineExpr W = m.dim(1);
1146 AffineExpr C = m.dim(2);
1147 AffineExpr w = m.dim(3);
1148
1149 return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
1150 .matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
1151 /*filterMap=*/{C, w},
1152 /*outputMap=*/{N, C, W}})
1153 .matchBody();
1154}
1155
1156template <>
1158 LinalgOp op, SmallVector<int64_t> *dilations,
1159 SmallVector<int64_t> *strides) {
1160 if (isa<linalg::DepthwiseConv1DNwcWcOp>(op))
1161 return true;
1162
1163 assert(isaConvolutionOpInterface(op) &&
1164 "expected op to implement ConvolutionOpInterface");
1165
1166 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
1167 AffineExpr N = m.dim(0);
1168 AffineExpr W = m.dim(1);
1169 AffineExpr C = m.dim(2);
1170 AffineExpr w = m.dim(3);
1171
1172 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1173 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
1174 /*filterMap=*/{w, C},
1175 /*outputMap=*/{N, W, C}})
1176 .matchBody();
1177}
1178
1179template <>
1181 LinalgOp op, SmallVector<int64_t> *dilations,
1182 SmallVector<int64_t> *strides) {
1183 if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op))
1184 return true;
1185
1186 assert(isaConvolutionOpInterface(op) &&
1187 "expected op to implement ConvolutionOpInterface");
1188
1189 ConvMatcherBuilder m(op, /*spatialRank=*/1, dilations, strides);
1190 AffineExpr N = m.dim(0);
1191 AffineExpr W = m.dim(1);
1192 AffineExpr C = m.dim(2);
1193 AffineExpr CM = m.dim(3);
1194 AffineExpr w = m.dim(4);
1195
1196 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1197 .matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
1198 /*filterMap=*/{w, C, CM},
1199 /*outputMap=*/{N, W, C, CM}})
1200 .matchBody();
1201}
1202
1203template <>
1205 LinalgOp op, SmallVector<int64_t> *dilations,
1206 SmallVector<int64_t> *strides) {
1207 if (isa<linalg::DepthwiseConv2DNchwChwOp>(op))
1208 return true;
1209
1210 assert(isaConvolutionOpInterface(op) &&
1211 "expected op to implement ConvolutionOpInterface");
1212
1213 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
1214 AffineExpr N = m.dim(0);
1215 AffineExpr H = m.dim(1);
1216 AffineExpr W = m.dim(2);
1217 AffineExpr C = m.dim(3);
1218 AffineExpr h = m.dim(4);
1219 AffineExpr w = m.dim(5);
1220
1221 return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
1222 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
1223 .matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
1224 /*filterMap=*/{C, h, w},
1225 /*outputMap=*/{N, C, H, W}})
1226 .matchBody();
1227}
1228
1229template <>
1231 LinalgOp op, SmallVector<int64_t> *dilations,
1232 SmallVector<int64_t> *strides) {
1233 if (isa<linalg::DepthwiseConv2DNhwcHwcOp>(op))
1234 return true;
1235
1236 assert(isaConvolutionOpInterface(op) &&
1237 "expected op to implement ConvolutionOpInterface");
1238
1239 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
1240 AffineExpr N = m.dim(0);
1241 AffineExpr H = m.dim(1);
1242 AffineExpr W = m.dim(2);
1243 AffineExpr C = m.dim(3);
1244 AffineExpr h = m.dim(4);
1245 AffineExpr w = m.dim(5);
1246
1247 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1248 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1249 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1250 /*filterMap=*/{h, w, C},
1251 /*outputMap=*/{N, H, W, C}})
1252 .matchBody();
1253}
1254
1255template <>
1257 LinalgOp op, SmallVector<int64_t> *dilations,
1258 SmallVector<int64_t> *strides) {
1259 if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op))
1260 return true;
1261
1262 assert(isaConvolutionOpInterface(op) &&
1263 "expected op to implement ConvolutionOpInterface");
1264
1265 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
1266 AffineExpr N = m.dim(0);
1267 AffineExpr H = m.dim(1);
1268 AffineExpr W = m.dim(2);
1269 AffineExpr C = m.dim(3);
1270 AffineExpr h = m.dim(4);
1271 AffineExpr w = m.dim(5);
1272
1273 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1274 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1275 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1276 /*filterMap=*/{h, w, C},
1277 /*scalarMap=*/{},
1278 /*scalarMap=*/{},
1279 /*outputMap=*/{N, H, W, C}})
1280 .matchBody(/*containsZeroPointOffset=*/true);
1281}
1282
1283template <>
1285 LinalgOp op, SmallVector<int64_t> *dilations,
1286 SmallVector<int64_t> *strides) {
1287 if (isa<linalg::DepthwiseConv2DNhwcHwcmOp>(op))
1288 return true;
1289
1290 assert(isaConvolutionOpInterface(op) &&
1291 "expected op to implement ConvolutionOpInterface");
1292
1293 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
1294 AffineExpr N = m.dim(0);
1295 AffineExpr H = m.dim(1);
1296 AffineExpr W = m.dim(2);
1297 AffineExpr C = m.dim(3);
1298 AffineExpr CM = m.dim(4);
1299 AffineExpr h = m.dim(5);
1300 AffineExpr w = m.dim(6);
1301
1302 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1303 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1304 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1305 /*filterMap=*/{h, w, C, CM},
1306 /*outputMap=*/{N, H, W, C, CM}})
1307 .matchBody();
1308}
1309
1310template <>
1312 LinalgOp op, SmallVector<int64_t> *dilations,
1313 SmallVector<int64_t> *strides) {
1314 if (isa<linalg::DepthwiseConv2DNhwcHwcmQOp>(op))
1315 return true;
1316
1317 assert(isaConvolutionOpInterface(op) &&
1318 "expected op to implement ConvolutionOpInterface");
1319
1320 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides);
1321 AffineExpr N = m.dim(0);
1322 AffineExpr H = m.dim(1);
1323 AffineExpr W = m.dim(2);
1324 AffineExpr C = m.dim(3);
1325 AffineExpr CM = m.dim(4);
1326 AffineExpr h = m.dim(5);
1327 AffineExpr w = m.dim(6);
1328
1329 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1330 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1331 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1332 /*filterMap=*/{h, w, C, CM},
1333 /*scalarMap=*/{},
1334 /*scalarMap=*/{},
1335 /*outputMap=*/{N, H, W, C, CM}})
1336 .matchBody(/*containsZeroPointOffset=*/true);
1337}
1338
1339template <>
1341 LinalgOp op, SmallVector<int64_t> *dilations,
1342 SmallVector<int64_t> *strides) {
1343 if (isa<linalg::DepthwiseConv3DNdhwcDhwcOp>(op))
1344 return true;
1345
1346 assert(isaConvolutionOpInterface(op) &&
1347 "expected op to implement ConvolutionOpInterface");
1348
1349 ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
1350 AffineExpr N = m.dim(0);
1351 AffineExpr D = m.dim(1);
1352 AffineExpr H = m.dim(2);
1353 AffineExpr W = m.dim(3);
1354 AffineExpr d = m.dim(4);
1355 AffineExpr h = m.dim(5);
1356 AffineExpr w = m.dim(6);
1357 AffineExpr C = m.dim(7);
1358
1359 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1360 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1361 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1362 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1363 m.strided(W, w, 2), C},
1364 /*filterMap=*/{d, h, w, C},
1365 /*outputMap=*/{N, D, H, W, C}})
1366 .matchBody();
1367}
1368
1369template <>
1371 LinalgOp op, SmallVector<int64_t> *dilations,
1372 SmallVector<int64_t> *strides) {
1373 if (isa<linalg::DepthwiseConv3DNcdhwCdhwOp>(op))
1374 return true;
1375
1376 assert(isaConvolutionOpInterface(op) &&
1377 "expected op to implement ConvolutionOpInterface");
1378
1379 ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
1380 AffineExpr N = m.dim(0);
1381 AffineExpr D = m.dim(1);
1382 AffineExpr H = m.dim(2);
1383 AffineExpr W = m.dim(3);
1384 AffineExpr d = m.dim(4);
1385 AffineExpr h = m.dim(5);
1386 AffineExpr w = m.dim(6);
1387 AffineExpr C = m.dim(7);
1388
1389 return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
1390 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
1391 .matchStride(/*iDim=*/4, /*fDim=*/3, /*oDim=*/4, /*idx=*/2)
1392 .matchMaps({/*inputMap=*/{N, C, m.strided(D, d, 0), m.strided(H, h, 1),
1393 m.strided(W, w, 2)},
1394 /*filterMap=*/{C, d, h, w},
1395 /*outputMap=*/{N, C, D, H, W}})
1396 .matchBody();
1397}
1398
1399template <>
1401 LinalgOp op, SmallVector<int64_t> *dilations,
1402 SmallVector<int64_t> *strides) {
1403 if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op))
1404 return true;
1405
1406 assert(isaConvolutionOpInterface(op) &&
1407 "expected op to implement ConvolutionOpInterface");
1408
1409 ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides);
1410 AffineExpr N = m.dim(0);
1411 AffineExpr D = m.dim(1);
1412 AffineExpr H = m.dim(2);
1413 AffineExpr W = m.dim(3);
1414 AffineExpr CM = m.dim(4);
1415 AffineExpr d = m.dim(5);
1416 AffineExpr h = m.dim(6);
1417 AffineExpr w = m.dim(7);
1418 AffineExpr C = m.dim(8);
1419
1420 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1421 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1422 .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
1423 .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
1424 m.strided(W, w, 2), C},
1425 /*filterMap=*/{d, h, w, C, CM},
1426 /*outputMap=*/{N, D, H, W, C, CM}})
1427 .matchBody();
1428}
1429
1430template <>
1432 LinalgOp op, SmallVector<int64_t> *dilations,
1433 SmallVector<int64_t> *strides) {
1434 if (isa<linalg::PoolingNhwcMaxOp>(op))
1435 return true;
1436
1437 assert(isaConvolutionOpInterface(op) &&
1438 "expected op to implement ConvolutionOpInterface");
1439
1440 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
1442 AffineExpr N = m.dim(0);
1443 AffineExpr H = m.dim(1);
1444 AffineExpr W = m.dim(2);
1445 AffineExpr C = m.dim(3);
1446 AffineExpr h = m.dim(4);
1447 AffineExpr w = m.dim(5);
1448
1449 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1450 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1451 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1452 /*filterMap=*/{h, w},
1453 /*outputMap=*/{N, H, W, C}})
1454 .matchBody();
1455}
1456
1457template <>
1459 LinalgOp op, SmallVector<int64_t> *dilations,
1460 SmallVector<int64_t> *strides) {
1461 if (isa<linalg::PoolingNhwcMinOp>(op))
1462 return true;
1463
1464 assert(isaConvolutionOpInterface(op) &&
1465 "expected op to implement ConvolutionOpInterface");
1466
1467 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
1469 AffineExpr N = m.dim(0);
1470 AffineExpr H = m.dim(1);
1471 AffineExpr W = m.dim(2);
1472 AffineExpr C = m.dim(3);
1473 AffineExpr h = m.dim(4);
1474 AffineExpr w = m.dim(5);
1475
1476 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1477 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1478 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1479 /*filterMap=*/{h, w},
1480 /*outputMap=*/{N, H, W, C}})
1481 .matchBody();
1482}
1483
1484template <>
1486 LinalgOp op, SmallVector<int64_t> *dilations,
1487 SmallVector<int64_t> *strides) {
1488 if (isa<linalg::PoolingNhwcSumOp>(op))
1489 return true;
1490
1491 assert(isaConvolutionOpInterface(op) &&
1492 "expected op to implement ConvolutionOpInterface");
1493
1494 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
1496 AffineExpr N = m.dim(0);
1497 AffineExpr H = m.dim(1);
1498 AffineExpr W = m.dim(2);
1499 AffineExpr C = m.dim(3);
1500 AffineExpr h = m.dim(4);
1501 AffineExpr w = m.dim(5);
1502
1503 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1504 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1505 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1506 /*filterMap=*/{h, w},
1507 /*outputMap=*/{N, H, W, C}})
1508 .matchBody();
1509}
1510
1511template <>
1513 LinalgOp op, SmallVector<int64_t> *dilations,
1514 SmallVector<int64_t> *strides) {
1515 if (isa<linalg::PoolingNhwcMaxUnsignedOp>(op))
1516 return true;
1517
1518 assert(isaConvolutionOpInterface(op) &&
1519 "expected op to implement ConvolutionOpInterface");
1520
1521 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
1523 AffineExpr N = m.dim(0);
1524 AffineExpr H = m.dim(1);
1525 AffineExpr W = m.dim(2);
1526 AffineExpr C = m.dim(3);
1527 AffineExpr h = m.dim(4);
1528 AffineExpr w = m.dim(5);
1529
1530 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1531 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1532 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1533 /*filterMap=*/{h, w},
1534 /*outputMap=*/{N, H, W, C}})
1535 .matchBody();
1536}
1537
1538template <>
1540 LinalgOp op, SmallVector<int64_t> *dilations,
1541 SmallVector<int64_t> *strides) {
1542 if (isa<linalg::PoolingNhwcMinUnsignedOp>(op))
1543 return true;
1544
1545 assert(isaConvolutionOpInterface(op) &&
1546 "expected op to implement ConvolutionOpInterface");
1547
1548 ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
1550 AffineExpr N = m.dim(0);
1551 AffineExpr H = m.dim(1);
1552 AffineExpr W = m.dim(2);
1553 AffineExpr C = m.dim(3);
1554 AffineExpr h = m.dim(4);
1555 AffineExpr w = m.dim(5);
1556
1557 return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
1558 .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
1559 .matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
1560 /*filterMap=*/{h, w},
1561 /*outputMap=*/{N, H, W, C}})
1562 .matchBody();
1563}
1564
1565Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
1566 Value source, Value pad, bool nofold,
1567 ValueRange typeDynDims) {
1568 // Exit if `source` is not defined by an ExtractSliceOp.
1569 auto sliceOp = source.getDefiningOp<tensor::ExtractSliceOp>();
1570 if (!sliceOp)
1571 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
1572 typeDynDims);
1573
1574 // Search the `source` use-def chain for padded LinalgOps.
1575 Value current = sliceOp.getSource();
1576 while (current) {
1577 auto linalgOp = current.getDefiningOp<LinalgOp>();
1578 if (!linalgOp)
1579 break;
1580 OpResult opResult = cast<OpResult>(current);
1581 current = linalgOp.getDpsInitOperand(opResult.getResultNumber())->get();
1582 }
1583 auto padOp = current ? current.getDefiningOp<tensor::PadOp>() : nullptr;
1584
1585 // Exit if the search fails to match a tensor::PadOp at the end of the matched
1586 // LinalgOp sequence.
1587 if (!padOp)
1588 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
1589 typeDynDims);
1590
1591 // Exit if the padded result type does not match.
1592 if (sliceOp.getSource().getType() != type)
1593 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
1594 typeDynDims);
1595
1596 // Exit if the LinalgOps are not high padded.
1597 if (llvm::any_of(padOp.getMixedLowPad(), [](OpFoldResult ofr) {
1598 return getConstantIntValue(ofr) != static_cast<int64_t>(0);
1599 }))
1600 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
1601 typeDynDims);
1602
1603 // Exit if `padOpSliceOp`, which defines the slice used by
1604 // `padOp`, is rank-reducing.
1605 auto padOpSliceOp = padOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
1606 if (!padOpSliceOp ||
1607 sliceOp.getMixedSizes().size() != padOpSliceOp.getMixedSizes().size())
1608 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
1609 typeDynDims);
1610
1611 // Exit if the sizes of the dynamic sizes of `sliceOp` do not match the size
1612 // of the slice padded by `padOp`.
1613 if (llvm::any_of(
1614 llvm::zip(sliceOp.getMixedSizes(), padOpSliceOp.getMixedSizes()),
1615 [](std::tuple<OpFoldResult, OpFoldResult> it) {
1616 return !isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it));
1617 }))
1618 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
1619 typeDynDims);
1620
1621 // Exit if the padding values do not match.
1622 Attribute padOpPadAttr, padAttr;
1623 Value padOpPad = padOp.getConstantPaddingValue();
1624 if (!padOpPad || !matchPattern(padOpPad, m_Constant(&padOpPadAttr)) ||
1625 !matchPattern(pad, m_Constant(&padAttr)) || padOpPadAttr != padAttr)
1626 return tensor::createPadHighOp(type, source, pad, nofold, loc, b,
1627 typeDynDims);
1628
1629 // Return the padded result if the padding values and sizes match.
1630 return sliceOp.getSource();
1631}
1632
1633GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) {
1634 auto memrefTypeTo = cast<MemRefType>(to.getType());
1635#ifndef NDEBUG
1636 auto memrefTypeFrom = cast<MemRefType>(from.getType());
1637 assert(memrefTypeFrom.getRank() == memrefTypeTo.getRank() &&
1638 "`from` and `to` memref must have the same rank");
1639#endif // NDEBUG
1640
1641 AffineMap id =
1642 AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext());
1643 SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(),
1644 utils::IteratorType::parallel);
1645 return linalg::GenericOp::create(
1646 b, loc,
1647 /*inputs=*/from,
1648 /*outputs=*/to,
1649 /*indexingMaps=*/llvm::ArrayRef({id, id}),
1650 /*iteratorTypes=*/iteratorTypes,
1651 [](OpBuilder &b, Location loc, ValueRange args) {
1652 linalg::YieldOp::create(b, loc, args.front());
1653 });
1654}
1655
1656/// Specialization to build an scf "for" nest.
1657template <>
1659 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
1660 ArrayRef<utils::IteratorType> iteratorTypes,
1662 ValueRange)>
1663 bodyBuilderFn,
1664 ArrayRef<linalg::ProcInfo> procInfo) {
1665 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
1666 "expected as many entries for proc info as number of loops, even if "
1667 "they are null entries");
1668 SmallVector<Value> iterArgInitValues;
1669 if (!linalgOp.hasPureBufferSemantics())
1670 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
1671 SmallVector<Value, 4> lbs, ubs, steps;
1672 unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
1674 b, loc, lbs, ubs, steps, iterArgInitValues,
1675 [&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
1676 assert(iterArgs.size() == iterArgInitValues.size() &&
1677 "expect the number of output tensors and iter args to match");
1678 SmallVector<Value> operandValuesToUse = linalgOp->getOperands();
1679 if (!iterArgs.empty()) {
1680 operandValuesToUse = linalgOp.getDpsInputs();
1681 operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
1682 }
1683 return bodyBuilderFn(b, loc, ivs, operandValuesToUse);
1684 });
1685
1686 if (loopNest.loops.empty() || procInfo.empty())
1687 return;
1688
1689 // Filter out scf.for loops that were created out of parallel dimensions.
1690 for (const auto &loop : llvm::enumerate(loopNest.loops)) {
1691 if (procInfo[loop.index()].distributionMethod ==
1693 mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId,
1694 procInfo[loop.index()].nprocs);
1695 }
1696 }
1697}
1698
1699/// Specialization to build affine "for" nest.
1700template <>
1702 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
1703 ArrayRef<utils::IteratorType> iteratorTypes,
1705 ValueRange)>
1706 bodyBuilderFn,
1707 ArrayRef<linalg::ProcInfo> /*procInfo*/) {
1708 SmallVector<Value> iterArgInitValues;
1709 if (!linalgOp.hasPureBufferSemantics())
1710 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
1711 assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
1712 SmallVector<Value, 4> lbs, ubs, steps;
1713 unpackRanges(b, loc, loopRanges, lbs, ubs, steps);
1714
1715 // Affine loops require constant steps.
1716 SmallVector<int64_t, 4> constantSteps;
1717 constantSteps.reserve(steps.size());
1718 for (Value v : steps) {
1719 auto constVal = getConstantIntValue(v);
1720 assert(constVal.has_value() && "Affine loops require constant steps");
1721 constantSteps.push_back(constVal.value());
1722 }
1723
1724 affine::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
1725 [&](OpBuilder &b, Location loc, ValueRange ivs) {
1726 bodyBuilderFn(b, loc, ivs,
1727 linalgOp->getOperands());
1728 });
1729}
1730
1731/// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
1733 Value nprocs, Value &lb, Value &ub,
1734 Value &step) {
1735 AffineExpr d0, d1;
1736 bindDims(b.getContext(), d0, d1);
1737 AffineExpr s0 = getAffineSymbolExpr(0, b.getContext());
1738 lb =
1739 affine::makeComposedAffineApply(b, loc, d0 + d1 * s0, {lb, procId, step});
1740 step = affine::makeComposedAffineApply(b, loc, d0 * s0, {nprocs, step});
1741}
1742
1743/// Generates a loop nest consisting of scf.parallel and scf.for, depending
1744/// on the `iteratorTypes.` Consecutive parallel loops create a single
1745/// scf.parallel operation; each sequential loop creates a new scf.for
1746/// operation. The body of the innermost loop is populated by
1747/// `bodyBuilderFn` that accepts a range of induction variables for all
1748/// loops. `ivStorage` is used to store the partial list of induction
1749/// variables.
1750// TODO: this function can be made iterative instead. However, it
1751// will have at most as many recursive calls as nested loops, which rarely
1752// exceeds 10.
1754 OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs,
1755 ValueRange steps, ArrayRef<utils::IteratorType> iteratorTypes,
1757 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
1758 SmallVectorImpl<Value> &ivStorage) {
1759 assert(lbs.size() == ubs.size());
1760 assert(lbs.size() == steps.size());
1761 assert(lbs.size() == iteratorTypes.size());
1762 assert(procInfo.empty() || (lbs.size() == procInfo.size()));
1763
1764 // If there are no (more) loops to be generated, generate the body and be
1765 // done with it.
1766 if (iteratorTypes.empty()) {
1767 bodyBuilderFn(b, loc, ivStorage);
1768 return;
1769 }
1770
1771 // If there are no outer parallel loops, generate one sequential loop and
1772 // recurse.
1773 if (!isParallelIterator(iteratorTypes.front())) {
1774 LoopNest singleLoop = buildLoopNest(
1775 b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(),
1776 [&](OpBuilder &b, Location loc, ValueRange ivs) {
1777 ivStorage.append(ivs.begin(), ivs.end());
1778 generateParallelLoopNest(
1779 b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(),
1780 iteratorTypes.drop_front(),
1781 procInfo.empty() ? procInfo : procInfo.drop_front(),
1782 bodyBuilderFn, ivStorage);
1783 });
1784 return;
1785 }
1786
1787 unsigned nLoops = iteratorTypes.size();
1788 unsigned numProcessed = 0;
1789 DistributionMethod distributionMethod = DistributionMethod::None;
1790 if (procInfo.empty()) {
1791 numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size();
1792 } else {
1793 distributionMethod = procInfo.front().distributionMethod;
1794 numProcessed =
1795 nLoops - procInfo
1796 .drop_while([&](linalg::ProcInfo p) {
1797 return p.distributionMethod == distributionMethod;
1798 })
1799 .size();
1800 }
1801
1802 auto remainderProcInfo =
1803 procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed);
1804 switch (distributionMethod) {
1806 // Generate a single parallel loop-nest operation for all outermost
1807 // parallel loops and recurse.
1808 scf::ParallelOp::create(
1809 b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
1810 steps.take_front(numProcessed),
1811 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
1812 ivStorage.append(localIvs.begin(), localIvs.end());
1813 generateParallelLoopNest(
1814 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
1815 ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
1816 iteratorTypes.drop_front(numProcessed), remainderProcInfo,
1817 bodyBuilderFn, ivStorage);
1818 });
1819 return;
1820 }
1822 // Generate a single parallel loop-nest operation for all outermost
1823 // parallel loops and recurse.
1824 scf::ParallelOp::create(
1825 b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed),
1826 steps.take_front(numProcessed),
1827 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
1828 ivStorage.append(localIvs.begin(), localIvs.end());
1829 generateParallelLoopNest(
1830 nestedBuilder, nestedLoc, lbs.drop_front(numProcessed),
1831 ubs.drop_front(numProcessed), steps.drop_front(numProcessed),
1832 iteratorTypes.drop_front(numProcessed), remainderProcInfo,
1833 bodyBuilderFn, ivStorage);
1834 });
1835 return;
1836 }
1838 // Check (for the processed loops) that the iteration is in-bounds.
1839 ArithBuilder ab(b, loc);
1840 Value cond = ab.slt(lbs[0], ubs[0]);
1841 for (unsigned i = 1; i < numProcessed; ++i)
1842 cond = ab._and(cond, ab.slt(lbs[i], ubs[i]));
1843 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
1844 scf::IfOp::create(b, loc, cond, [&](OpBuilder &b, Location loc) {
1845 generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed),
1846 ubs.drop_front(numProcessed),
1847 steps.drop_front(numProcessed),
1848 iteratorTypes.drop_front(numProcessed),
1849 remainderProcInfo, bodyBuilderFn, ivStorage);
1850 scf::YieldOp::create(b, loc, ValueRange{});
1851 });
1852 return;
1853 }
1855 // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed
1856 // with inner loop generation.
1857 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
1859 b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
1860 steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
1861 remainderProcInfo, bodyBuilderFn, ivStorage);
1862 return;
1863 }
1864}
1865
1866/// Specialization for generating a mix of parallel and sequential scf loops.
1867template <>
1869 OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, LinalgOp linalgOp,
1870 ArrayRef<utils::IteratorType> iteratorTypes,
1872 ValueRange)>
1873 bodyBuilderFn,
1874 ArrayRef<linalg::ProcInfo> procInfo) {
1875 SmallVector<Value> iterArgInitValues;
1876 if (!linalgOp.hasPureBufferSemantics())
1877 llvm::append_range(iterArgInitValues, linalgOp.getDpsInits());
1878 assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
1879 // This function may be passed more iterator types than ranges.
1880 assert(iteratorTypes.size() >= loopRanges.size() &&
1881 "expected iterator type for all ranges");
1882 assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) &&
1883 "expected proc information for all loops when present");
1884 iteratorTypes = iteratorTypes.take_front(loopRanges.size());
1885 SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs;
1886 unsigned numLoops = iteratorTypes.size();
1887 ivs.reserve(numLoops);
1888 lbsStorage.reserve(numLoops);
1889 ubsStorage.reserve(numLoops);
1890 stepsStorage.reserve(numLoops);
1891
1892 // Get the loop lb, ub, and step.
1893 unpackRanges(b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage);
1894
1895 // Modify the lb, ub, and step based on the distribution options.
1896 for (const auto &it : llvm::enumerate(procInfo)) {
1897 if (it.value().distributionMethod != linalg::DistributionMethod::None) {
1899 b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()],
1900 ubsStorage[it.index()], stepsStorage[it.index()]);
1901 }
1902 }
1903 ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
1905 b, loc, lbs, ubs, steps, iteratorTypes, procInfo,
1906 [&](OpBuilder &b, Location loc, ValueRange ivs) {
1907 bodyBuilderFn(b, loc, ivs, linalgOp->getOperands());
1908 },
1909 ivs);
1910
1911 assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
1912}
1913
1915 Value valueToTile,
1916 const SliceParameters &sliceParams) {
1917 auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
1918 auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType)
1919 .Case([&](MemRefType) {
1920 return memref::SubViewOp::create(
1921 builder, loc, valueToTile, sliceParams.offsets,
1922 sliceParams.sizes, sliceParams.strides);
1923 })
1924 .Case([&](RankedTensorType) {
1925 return tensor::ExtractSliceOp::create(
1926 builder, loc, valueToTile, sliceParams.offsets,
1927 sliceParams.sizes, sliceParams.strides);
1928 })
1929 .DefaultUnreachable("Unexpected shaped type");
1930 return sliceOp;
1931}
1932
1934 ArrayRef<OpFoldResult> tileSizes, AffineMap map,
1937 ArrayRef<OpFoldResult> subShapeSizes,
1938 bool omitPartialTileCheck) {
1939 SliceParameters sliceParams =
1940 computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
1941 ubs, subShapeSizes, omitPartialTileCheck);
1942 return materializeTiledShape(builder, loc, valueToTile, sliceParams);
1943}
1944
1947 ArrayRef<OpFoldResult> tileSizes, AffineMap map,
1949 ArrayRef<OpFoldResult> subShapeSizes,
1950 bool omitPartialTileCheck) {
1951 auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
1952 assert(shapedType && "only shaped types can be tiled");
1953 ArrayRef<int64_t> shape = shapedType.getShape();
1954 int64_t rank = shapedType.getRank();
1955
1956 // Compute offsets/sizes/strides for the tile.
1957 SliceParameters sliceParams;
1958 sliceParams.offsets.reserve(rank);
1959 sliceParams.sizes.reserve(rank);
1960 sliceParams.strides.reserve(rank);
1961 for (unsigned r = 0; r < rank; ++r) {
1962 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r);
1963 if (!isTiled(map.getSubMap({r}), tileSizes)) {
1964 sliceParams.offsets.push_back(builder.getIndexAttr(0));
1965 OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
1966 sliceParams.sizes.push_back(dim);
1967 sliceParams.strides.push_back(builder.getIndexAttr(1));
1968 LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
1969 continue;
1970 }
1971 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
1972
1973 // Tiling creates a new slice at the proper index, the slice step is 1
1974 // (i.e. the op does not subsample, stepping occurs in the loop).
1975 auto m = map.getSubMap({r});
1976 LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: submap: " << m << "\n");
1977 IRRewriter rewriter(builder);
1978 // The offset of the slice is m(lbs) - m(0).
1979 SmallVector<Attribute> zeros(lbs.size(), rewriter.getIndexAttr(0));
1980 SmallVector<Attribute> mAtZero;
1981 [[maybe_unused]] auto res = m.constantFold(zeros, mAtZero);
1982 assert(succeeded(res) && "affine_map must be evaluatable (not symbols)");
1983 int64_t mAtZeroInt =
1984 cast<IntegerAttr>(mAtZero[0]).getValue().getSExtValue();
1986 rewriter, loc, m.getResult(0) - mAtZeroInt, lbs);
1987 sliceParams.offsets.push_back(offset);
1988
1989 OpFoldResult closedIntSize =
1990 makeComposedFoldedAffineApply(rewriter, loc, m, subShapeSizes);
1991 // Resulting size needs to be made half open interval again.
1992 AffineExpr s0 = getAffineSymbolExpr(0, builder.getContext());
1993 OpFoldResult size =
1994 makeComposedFoldedAffineApply(rewriter, loc, s0 + 1, closedIntSize);
1995 LLVM_DEBUG(llvm::dbgs()
1996 << "computeSliceParameters: raw size: " << size << "\n");
1997 LLVM_DEBUG(llvm::dbgs()
1998 << "computeSliceParameters: new offset: " << offset << "\n");
1999 sliceParams.strides.push_back(builder.getIndexAttr(1));
2000
2001 if (omitPartialTileCheck) {
2002 // We statically know that the partial/boundary tile condition is
2003 // unnecessary.
2004 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
2005 sliceParams.sizes.push_back(size);
2006 continue;
2007 }
2008
2009 // The size of the subview / extract_slice should be trimmed to avoid
2010 // out-of-bounds accesses, unless:
2011 // a. We statically know the subshape size divides the shape size evenly.
2012 // b. The subshape size is 1. According to the way the loops are set up,
2013 // tensors with "0" dimensions would never be constructed.
2014 int64_t shapeSize = shape[r];
2015 std::optional<int64_t> sizeCst = getConstantIntValue(size);
2016 auto hasTileSizeOne = sizeCst == 1;
2017 auto dividesEvenly = sizeCst && ShapedType::isStatic(shapeSize) &&
2018 ((shapeSize % *sizeCst) == 0);
2019 if (!hasTileSizeOne && !dividesEvenly) {
2020 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: shapeSize=" << shapeSize
2021 << ", size: " << size
2022 << ": make sure in bound with affine.min\n");
2023
2024 AffineExpr dim0, dim1, dim2;
2025 MLIRContext *context = builder.getContext();
2026 bindDims(context, dim0, dim1, dim2);
2027
2028 // Get the dimension size for this dimension. We need to first calculate
2029 // the max index and then plus one. This is important because for
2030 // convolution ops, we have its input window dimension's affine map of the
2031 // form `(d0 * s0 + d1)`, where `d0`/`d1 is an output/filter window
2032 // dimension and `s0` is stride. Directly use the dimension size of
2033 // output/filer window dimensions will cause incorrect calculation.
2035 {ArrayRef<AffineExpr>{dim0 - 1}}, context)
2036 .front();
2038 {ArrayRef<AffineExpr>{dim0 + 1}}, context)
2039 .front();
2040 SmallVector<OpFoldResult> maxIndices =
2041 llvm::to_vector(llvm::map_range(ubs, [&](OpFoldResult ub) {
2042 return makeComposedFoldedAffineApply(rewriter, loc, minusOneMap,
2043 {ub});
2044 }));
2045 OpFoldResult maxIndex =
2046 makeComposedFoldedAffineApply(rewriter, loc, m, maxIndices);
2047 OpFoldResult d =
2048 makeComposedFoldedAffineApply(rewriter, loc, plusOneMap, {maxIndex});
2049
2050 // Compute min(dim - offset, size) to avoid out-of-bounds accesses.
2052 {ArrayRef<AffineExpr>{dim1 - dim2, dim0}}, context)
2053 .front();
2054 size =
2055 makeComposedFoldedAffineMin(rewriter, loc, minMap, {size, d, offset});
2056 }
2057 LLVM_DEBUG(llvm::dbgs() << "makeTiledShape: new size: " << size << "\n");
2058 sliceParams.sizes.push_back(size);
2059 }
2060 return sliceParams;
2061}
2062
2065 ArrayRef<OpFoldResult> tileSizes) {
2067 for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
2068 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
2069 bool isTiled = !isZeroInteger(tileSizes[idx]);
2070 offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0));
2071 LLVM_DEBUG(llvm::dbgs()
2072 << "computeTileOffsets: " << offsets.back() << "\n");
2073 }
2074 return offsets;
2075}
2076
2078 ArrayRef<OpFoldResult> tileSizes,
2079 ArrayRef<OpFoldResult> sizeBounds) {
2081 for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
2082 bool isTiled = !isZeroInteger(tileSizes[idx]);
2083 // Before composing, we need to make range a closed interval.
2084 OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx];
2085 AffineExpr d0 = getAffineDimExpr(0, b.getContext());
2086 IRRewriter rewriter(b);
2087 sizes.push_back(makeComposedFoldedAffineApply(rewriter, loc, d0 - 1, size));
2088 LLVM_DEBUG(llvm::dbgs() << "computeTileSizes: " << sizes.back() << "\n");
2089 }
2090 return sizes;
2091}
2092
2094 if (op.hasPureBufferSemantics())
2095 return {};
2096 return llvm::to_vector(
2097 llvm::map_range(op.getDpsInitsMutable(), [&](OpOperand &opOperand) {
2098 return operands[opOperand.getOperandNumber()].getType();
2099 }));
2100}
2101
2103 LinalgOp op, ValueRange operands,
2104 ValueRange results) {
2105 if (op.hasPureBufferSemantics())
2106 return {};
2107 SmallVector<Value> tensorResults;
2108 tensorResults.reserve(results.size());
2109 // Insert a insert_slice for each output tensor.
2110 unsigned resultIdx = 0;
2111 for (OpOperand &opOperand : op.getDpsInitsMutable()) {
2112 // TODO: use an interface/adaptor to avoid leaking position in
2113 // `tiledOperands`.
2114 Value outputTensor = operands[opOperand.getOperandNumber()];
2115 if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
2116 Value inserted = tensor::InsertSliceOp::create(
2117 builder, loc, sliceOp.getSource().getType(), results[resultIdx],
2118 sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(),
2119 sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2120 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2121 tensorResults.push_back(inserted);
2122 } else {
2123 tensorResults.push_back(results[resultIdx]);
2124 }
2125 ++resultIdx;
2126 }
2127 return tensorResults;
2128}
2129
2131computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
2132 ValueRange valuesToTile, ArrayRef<OpFoldResult> ivs,
2133 ArrayRef<OpFoldResult> tileSizes,
2134 ArrayRef<OpFoldResult> sizeBounds,
2135 bool omitPartialTileCheck) {
2136 assert(ivs.size() == static_cast<size_t>(llvm::count_if(
2137 llvm::make_range(tileSizes.begin(), tileSizes.end()),
2138 [](OpFoldResult v) { return !isZeroInteger(v); })) &&
2139 "expected as many ivs as non-zero sizes");
2140
2141 // Construct (potentially temporary) mins and maxes on which to apply maps
2142 // that define tile subshapes.
2144 computeTileOffsets(builder, loc, ivs, tileSizes);
2145 SmallVector<OpFoldResult> subShapeSizes =
2146 computeTileSizes(builder, loc, tileSizes, sizeBounds);
2147
2148 assert(static_cast<int64_t>(valuesToTile.size()) <=
2149 linalgOp->getNumOperands() &&
2150 "more value to tile than operands.");
2152 allSliceParams.reserve(valuesToTile.size());
2153 for (auto [opOperand, val] :
2154 llvm::zip(linalgOp->getOpOperands(), valuesToTile)) {
2155 Value shapedOp = val;
2156 LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
2157 AffineMap map = linalgOp.getMatchingIndexingMap(&opOperand);
2158 // Use `opOperand` as is if it is not tiled and not an output tensor. Having
2159 // an extract/insert slice pair for all output tensors simplifies follow up
2160 // transformations such as padding and bufferization since the
2161 // extract/insert slice pairs make the accessed iteration argument
2162 // subdomains explicit.
2163
2164 Type operandType = opOperand.get().getType();
2165 if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
2166 linalgOp.isDpsInit(&opOperand))) {
2167 allSliceParams.push_back(std::nullopt);
2168 LLVM_DEBUG(llvm::dbgs()
2169 << ": not tiled: use shape: " << operandType << "\n");
2170 continue;
2171 }
2172 LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
2173
2174 allSliceParams.push_back(computeSliceParameters(
2175 builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
2176 omitPartialTileCheck));
2177 }
2178
2179 return allSliceParams;
2180}
2181
2183 LinalgOp linalgOp, ValueRange valuesToTile,
2185 ArrayRef<OpFoldResult> tileSizes,
2186 ArrayRef<OpFoldResult> sizeBounds,
2187 bool omitPartialTileCheck) {
2189 computeAllSliceParameters(builder, loc, linalgOp, valuesToTile, ivs,
2190 tileSizes, sizeBounds, omitPartialTileCheck);
2191 SmallVector<Value> tiledShapes;
2192 for (auto item : llvm::zip(valuesToTile, allSliceParameter)) {
2193 Value valueToTile = std::get<0>(item);
2194 std::optional<SliceParameters> sliceParams = std::get<1>(item);
2195 tiledShapes.push_back(
2196 sliceParams.has_value()
2197 ? materializeTiledShape(builder, loc, valueToTile, *sliceParams)
2198 ->getResult(0)
2199 : valueToTile);
2200 }
2201 return tiledShapes;
2202}
2203
2204void offsetIndices(OpBuilder &b, LinalgOp linalgOp,
2205 ArrayRef<OpFoldResult> offsets) {
2206 IRRewriter rewriter(b);
2207 offsetIndices(rewriter, linalgOp, offsets);
2208}
2209
2210void offsetIndices(RewriterBase &b, LinalgOp linalgOp,
2211 ArrayRef<OpFoldResult> offsets) {
2212 if (!linalgOp.hasIndexSemantics())
2213 return;
2214
2215 for (IndexOp indexOp : linalgOp.getBlock()->getOps<IndexOp>()) {
2216 if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()])
2217 continue;
2219 b.setInsertionPointAfter(indexOp);
2220 AffineExpr index, offset;
2221 bindDims(b.getContext(), index, offset);
2223 b, indexOp.getLoc(), index + offset,
2224 {getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
2225 Value materialized =
2226 getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
2227 b.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) {
2228 return use.getOwner() != materialized.getDefiningOp();
2229 });
2230 }
2231}
2232
2233/// Get the reassociation maps to fold the result of a extract_slice (or source
2234/// of a insert_slice) operation with given offsets, and sizes to its
2235/// rank-reduced version. This is only done for the cases where the size is 1
2236/// and offset is 0. Strictly speaking the offset 0 is not required in general,
2237/// but non-zero offsets are not handled by SPIR-V backend at this point (and
2238/// potentially cannot be handled).
2239std::optional<SmallVector<ReassociationIndices>>
2243 for (const auto &it : llvm::enumerate(mixedSizes)) {
2244 auto dim = it.index();
2245 auto size = it.value();
2246 curr.push_back(dim);
2247 auto attr = llvm::dyn_cast_if_present<Attribute>(size);
2248 if (attr && cast<IntegerAttr>(attr).getInt() == 1)
2249 continue;
2250 reassociation.emplace_back(ReassociationIndices{});
2251 std::swap(reassociation.back(), curr);
2252 }
2253 // When the reassociations are not empty, then fold the remaining
2254 // unit-dimensions into the last dimension. If the reassociations so far is
2255 // empty, then leave it emtpy. This will fold everything to a rank-0 tensor.
2256 if (!curr.empty() && !reassociation.empty())
2257 reassociation.back().append(curr.begin(), curr.end());
2258 return reassociation;
2259}
2260
2261} // namespace linalg
2262} // namespace mlir
static SmallVector< int64_t > computePackUnPackPerm(int64_t rank, ArrayRef< int64_t > &innerDimsPos, ArrayRef< int64_t > &outerPerm, PackingMetadata &packingMetadata)
The permutation can be obtained from two permutations: a) Compute the permutation vector to move the ...
Definition Utils.cpp:150
static bool isTiled(AffineExpr expr, ArrayRef< OpFoldResult > tileSizes)
Definition Utils.cpp:74
static void unpackRanges(OpBuilder &builder, Location loc, ArrayRef< Range > ranges, SmallVectorImpl< Value > &lbs, SmallVectorImpl< Value > &ubs, SmallVectorImpl< Value > &steps)
Given a list of subview ranges, extract individual values for lower, upper bounds and steps and put t...
Definition Utils.cpp:126
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition PDL.cpp:62
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
Affine binary operation expression.
Definition AffineExpr.h:214
AffineExpr getLHS() const
AffineExpr getRHS() const
An integer constant appearing in affine expression.
Definition AffineExpr.h:239
int64_t getValue() const
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
unsigned getPosition() const
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineExpr getResult(unsigned idx) const
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
unsigned getArgNumber() const
Returns the number of this argument.
Definition Value.h:321
Block * getOwner() const
Returns the block that owns this argument.
Definition Value.h:318
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
OpListType & getOperations()
Definition Block.h:137
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
MLIRContext * getContext() const
Definition Builders.h:56
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
This is a value defined by a result of an operation.
Definition Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition Value.h:469
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isSignlessIntOrFloat() const
Return true of this is a signless integer or a float type.
Definition Types.cpp:108
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
Helper class for building convolution op matchers with minimal boilerplate.
Definition Utils.cpp:527
ConvMatcherBuilder & matchStride(unsigned iDim, unsigned fDim, unsigned oDim, unsigned idx)
Match stride/dilation pattern for a spatial dimension.
Definition Utils.cpp:555
bool matchBody(bool containsZeroPointOffset=false)
Match body pattern. This should be called last.
Definition Utils.cpp:572
AffineExpr strided(AffineExpr base, AffineExpr kernel, unsigned idx)
Build strided expression: base * stride[idx] + kernel * dilation[idx].
Definition Utils.cpp:549
AffineExpr dim(unsigned i)
Get affine dimension expression for dimension i.
Definition Utils.cpp:546
ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector< int64_t > *d, SmallVector< int64_t > *s, PoolingType poolingType=PoolingType::None)
Definition Utils.cpp:536
ConvMatcherBuilder & matchMaps(ArrayRef< ArrayRef< AffineExpr > > maps)
Match expected indexing maps layout. Returns *this for method chaining.
Definition Utils.cpp:565
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isaConvolutionOpOfType< linalg::DepthwiseConv3DNdhwcDhwcmOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1400
bool isaConvolutionOpOfType< linalg::PoolingNhwcSumOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1485
bool isaConvolutionOpOfType< linalg::Conv2DNhwcHwcfQOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:721
bool isaConvolutionOpOfType< linalg::DepthwiseConv2DNchwChwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1204
SmallVector< int64_t > getUnPackInverseSrcPerm(linalg::UnPackOp, PackingMetadata &metadata)
Compute inverse permutation for the source tensor (i.e.
SmallVector< Value > makeTiledShapes(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Creates extract_slice/subview ops for all valuesToTile of the given linalgOp with builder,...
Definition Utils.cpp:2182
bool allIndexingsAreProjectedPermutation(LinalgOp op)
Check if all indexing maps are projected permutations.
Definition Utils.cpp:195
static bool bodyMatcherForConvolutionOps(Value yieldVal, Block *body, bool containsZeroPointOffset=false)
Utility to match block body for convolution ops.
Definition Utils.cpp:326
bool isaConvolutionOpOfType< linalg::Conv1DOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:601
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition Utils.cpp:230
SmallVector< OpFoldResult > computeTileSizes(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds)
Computes tile sizes, given a list of tileSizes and dimension sizes (sizeBounds).
Definition Utils.cpp:2077
bool isaConvolutionOpOfType< linalg::Conv1DNcwFcwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:646
bool isaConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcmQOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1311
bool isaConvolutionOpOfType< linalg::Conv3DNdhwcDhwcfOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1039
GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to)
Returns GenericOp that copies an n-D memref.
Definition Utils.cpp:1633
bool isaConvolutionOpOfType< linalg::DepthwiseConv3DNdhwcDhwcOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1340
static void generateParallelLoopNest(OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ArrayRef< utils::IteratorType > iteratorTypes, ArrayRef< linalg::ProcInfo > procInfo, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, SmallVectorImpl< Value > &ivStorage)
Generates a loop nest consisting of scf.parallel and scf.for, depending on the iteratorTypes.
Definition Utils.cpp:1753
bool isaConvolutionOpOfType< linalg::Conv2DNhwgcGfhwcQOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:980
SmallVector< OpFoldResult > computeTileOffsets(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes)
Computes tile offsets, given a list of loop ivs and tileSizes.
Definition Utils.cpp:2063
bool isaConvolutionOpOfType< linalg::Conv2DOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:670
PoolingType
Enum representing pooling operation types used by ConvMatcherBuilder.
Definition Utils.cpp:501
bool isaConvolutionOpOfType< linalg::DepthwiseConv1DNwcWcOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1157
bool isaConvolutionOpOfType< linalg::Conv2DNgchwFgchwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:862
bool isaConvolutionOpOfType< linalg::Conv2DNchwFchwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:806
static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body)
Definition Utils.cpp:390
bool isaConvolutionOpOfType< linalg::Conv3DOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1011
bool isaConvolutionOpOfType< linalg::PoolingNhwcMaxOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1431
static bool bodyMatcherForZeroPointOffsets(Operation *addOp, Operation *mulOp, Block *body)
Utility function to match the zero point offset body of quantized convolution ops.
Definition Utils.cpp:273
static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body)
Definition Utils.cpp:371
bool isReductionIterator(utils::IteratorType iteratorType)
Check if iterator type has "reduction" semantics.
Definition Utils.cpp:234
bool hasOnlyScalarElementwiseOp(Region &r)
Detect whether r has only ConstantOp, ElementwiseMappable and YieldOp.
Definition Utils.cpp:201
bool isaConvolutionOpOfType< linalg::Conv2DNgchwGfchwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:891
static AffineExpr getAffineMapDim(ArrayAttr indexingMaps, uint32_t mapIndex, uint32_t dimIndex)
Definition Utils.cpp:399
static BlockArgument getBlockArgumentWithOptionalCastOps(Value val)
Returns the BlockArgument that leads to val, if any.
Definition Utils.cpp:244
bool isaConvolutionOpOfType< linalg::Conv3DNcdhwFcdhwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1103
static bool bodyMatcherForPoolOps(Value yieldVal, Block *body)
Utility to match block body for linalg.pool* ops.
Definition Utils.cpp:355
bool isaConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1230
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
Definition Utils.cpp:2240
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
DistributionMethod
Scheme used to distribute loops to processors.
Definition Utils.h:262
@ None
No Distribution.
Definition Utils.h:307
@ CyclicNumProcsGeNumIters
Cyclic distribution where the number of processors can be assumed to be more than or equal to the num...
Definition Utils.h:292
@ Cyclic
Cyclic distribution where no assumption is made about the dynamic relationship between number of proc...
Definition Utils.h:274
@ CyclicNumProcsEqNumIters
Cyclic distribution where the number of processors can be assumed to be equal to the number of iterat...
Definition Utils.h:304
static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body)
Definition Utils.cpp:378
SmallVector< Value > insertSlicesBack(OpBuilder &builder, Location loc, LinalgOp op, ValueRange operands, ValueRange results)
Creates insert_slice ops that insert results back into larger tensors they were originally extracted ...
Definition Utils.cpp:2102
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
bool isaConvolutionOpOfType< linalg::DepthwiseConv1DNcwCwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1134
bool isaConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcmOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1284
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition Utils.cpp:215
bool isaConvolutionOpOfType< linalg::Conv3DNdhwcDhwcfQOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1070
void offsetIndices(OpBuilder &b, LinalgOp linalgOp, ArrayRef< OpFoldResult > offests)
Add the specified offsets to any linalg.index ops contained in the given linalgOp.
Definition Utils.cpp:2204
static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body)
Definition Utils.cpp:395
SmallVector< int64_t > getPackInverseDestPerm(linalg::PackOp packOp, PackingMetadata &metadata)
Compute inverse permutation for the destination tensor (i.e.
bool isaConvolutionOpOfType< linalg::Conv2DNgchwGfchwQOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:920
bool isaConvolutionOpOfType< linalg::DepthwiseConv1DNwcWcmOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1180
bool isaConvolutionOpOfType< linalg::PoolingNhwcMinOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1458
bool isaConvolutionOpOfType< linalg::Conv2DNchwFchwQOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:833
SmallVector< std::optional< SliceParameters > > computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, ValueRange valuesToTile, ArrayRef< OpFoldResult > ivs, ArrayRef< OpFoldResult > tileSizes, ArrayRef< OpFoldResult > sizeBounds, bool omitPartialTileCheck)
Computes SliceParamaters for all valuesToTile of the given linalgOp, assuming linalgOp is being fused...
Definition Utils.cpp:2131
Operation * makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Creates an extract_slice/subview op for a single valueToTile with builder.
Definition Utils.cpp:1933
static bool convLayoutMatches(ArrayRef< ArrayRef< AffineExpr > > mapListExpected, ArrayAttr indexingMaps, MLIRContext *context)
Returns true if the given indexing maps matches with the expected indexing maps.
Definition Utils.cpp:488
bool isaConvolutionOpOfType< linalg::Conv2DNhwgcGfhwcOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:951
static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body)
Definition Utils.cpp:383
static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim, int64_t &dilation, int64_t &stride)
Given an array of AffineMaps indexingMaps verify the following commutatively:- indexingMaps[0]....
Definition Utils.cpp:454
bool isaConvolutionOpOfType< linalg::PoolingNhwcMinUnsignedOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1539
bool isaConvolutionOpOfType< linalg::DepthwiseConv2DNhwcHwcQOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1256
bool isaConvolutionOpOfType< linalg::Conv2DNhwcFhwcOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:750
bool isaConvolutionOpOfType< linalg::Conv2DNhwcHwcfOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:694
bool isaConvolutionOpOfType< linalg::Conv2DNhwcFhwcQOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:777
bool isaConvolutionOpOfType< linalg::DepthwiseConv3DNcdhwCdhwOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1370
static Operation * materializeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, const SliceParameters &sliceParams)
Definition Utils.cpp:1914
bool isaConvolutionOpOfType< linalg::Conv1DNwcWcfOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:622
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type, Value source, Value padding, bool nofold, ValueRange typeDynDims={})
Create a tensor::PadOp that pads source to the shape of type whose sizes are assumed to be greater th...
Definition Utils.cpp:1565
static int64_t isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim)
Check if expr is either:
Definition Utils.cpp:412
void updateBoundsForCyclicDistribution(OpBuilder &builder, Location loc, Value procId, Value nprocs, Value &lb, Value &ub, Value &step)
Update the lb, ub and step to get per processor lb, ub and step.
Definition Utils.cpp:1732
SmallVector< Type > getTensorOutputTypes(LinalgOp op, ValueRange operands)
Returns the list of tensor output types produced when the given structured operation op is applied to...
Definition Utils.cpp:2093
SliceParameters computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile, ArrayRef< OpFoldResult > tileSizes, AffineMap map, ArrayRef< OpFoldResult > lbs, ArrayRef< OpFoldResult > ubs, ArrayRef< OpFoldResult > subShapeSizes, bool omitPartialTileCheck)
Computes SliceParameters for a single valueToTile assuming that its user is being tiled with the give...
Definition Utils.cpp:1946
bool isaConvolutionOpOfType< linalg::PoolingNhwcMaxUnsignedOp >(LinalgOp op, SmallVector< int64_t > *dilations, SmallVector< int64_t > *strides)
Definition Utils.cpp:1512
auto m_Val(Value v)
Definition Matchers.h:539
LoopNest buildLoopNest(OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, function_ref< ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilder=nullptr)
Creates a perfect nest of "for" loops, i.e.
Definition SCF.cpp:838
SmallVector< Value > ValueVector
An owning vector of values, handy to return from functions.
Definition SCF.h:64
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder, ValueRange dynOutDims={})
Definition Utils.cpp:23
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
detail::NameOpMatcher m_Op(StringRef opName)
Matches a named operation.
Definition Matchers.h:379
@ Mul
RHS of mul is always a constant or a symbolic expression.
Definition AffineExpr.h:43
SmallVector< int64_t > computePermutationVector(int64_t permSize, ArrayRef< int64_t > positions, ArrayRef< int64_t > desiredPositions)
Return a permutation vector of size permSize that would result in moving positions into desiredPositi...
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
detail::op_matcher< OpClass > m_Op()
Matches the given OpClass.
Definition Matchers.h:484
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Helper struct to build simple arithmetic quantities with minimal type inference support.
Definition Utils.h:103
Value _and(Value lhs, Value rhs)
Definition Utils.cpp:311
Value slt(Value lhs, Value rhs)
Definition Utils.cpp:334
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Utility class used to generate nested loops with ranges described by loopRanges and loop type describ...
Definition Utils.h:376
static void doit(OpBuilder &b, Location loc, ArrayRef< Range > loopRanges, LinalgOp linalgOp, ArrayRef< utils::IteratorType > iteratorTypes, function_ref< scf::ValueVector(OpBuilder &, Location, ValueRange, ValueRange)> bodyBuilderFn, ArrayRef< linalg::ProcInfo > procInfo={})
Callback function type used to get processor ID, and number of processors used for distribution for a...
Definition Utils.h:312
DistributionMethod distributionMethod
Definition Utils.h:315
static std::optional< BinaryOpKind > matchAsScalarBinaryOp(GenericOp op)
Matches the given linalg op if its body is performing binary operation on int or float scalar values ...
Definition Utils.cpp:93
A struct containg offsets-sizes-strides arguments of the tiled shape.
Definition Utils.h:158
SmallVector< OpFoldResult > strides
Definition Utils.h:161
SmallVector< OpFoldResult > sizes
Definition Utils.h:160
SmallVector< OpFoldResult > offsets
Definition Utils.h:159
LoopVector loops
Definition SCF.h:67