MLIR  18.0.0git
LinalgInterfaces.cpp
Go to the documentation of this file.
1 //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "llvm/ADT/SetOperations.h"
22 #include "llvm/ADT/SmallBitVector.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include <algorithm>
25 
26 using namespace mlir;
27 using namespace mlir::linalg;
28 
29 /// Include the definitions of the copy operation interface.
30 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
31 
32 //===----------------------------------------------------------------------===//
33 // Interface utility functions
34 //===----------------------------------------------------------------------===//
35 
37  linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
38  SmallVector<AffineMap> indexingMaps;
39  for (auto &opOperand : linalgOp->getOpOperands()) {
40  if (llvm::is_contained(droppedOperands, &opOperand))
41  continue;
42  indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
43  }
44  if (indexingMaps.empty()) {
45  // If there are no indexing maps, the operand can only be dropped
46  // if the op has no loops.
47  return linalgOp.getNumLoops() == 0;
48  }
49  return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
50 }
51 
52 //===----------------------------------------------------------------------===//
53 // CopyOpInterface implementation
54 //===----------------------------------------------------------------------===//
55 
56 bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
57  // Structural.
58  if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
59  return false;
60 
61  // Operands and maps.
62  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
63  return false;
64  auto mapRange = linalgOp.getIndexingMapsArray();
65  if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
66  !mapRange.back().isIdentity()) {
67  return false;
68  }
69  // Region.
70  return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // ContractionOpInterface implementation
75 //===----------------------------------------------------------------------===//
76 
77 /// If the value is defined by a chain of unary side effect-free, go up the
78 /// use-def chain until the first value that isn't defined by such an op.
79 // TODO: relax to multi-operands with constants, which are technically unary ops
80 // as needed (e.g. add5).
82  Operation *op = value.getDefiningOp();
83  while (op && op->getNumOperands() == 1) {
84  auto iface = dyn_cast<MemoryEffectOpInterface>(op);
85  if (!iface || !iface.hasNoEffect())
86  break;
87  value = op->getOperand(0);
88  op = value.getDefiningOp();
89  }
90  return value;
91 }
92 
94  Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
95  llvm::raw_ostream &errs) {
96  if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
97  errs << "no terminator in the block";
98  return false;
99  }
100 
101  if (block.getNumArguments() != 3) {
102  errs << "expected block with 3 arguments";
103  return false;
104  }
105 
106  Operation *terminator = block.getTerminator();
107  if (terminator->getNumOperands() != 1) {
108  errs << "expected terminator with 1 operand";
109  return false;
110  }
111 
112  Value yielded = getSourceSkipUnary(terminator->getOperand(0));
113  Operation *reductionOp = yielded.getDefiningOp();
114  if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
115  errs << "expected reduction op to be binary";
116  return false;
117  }
118 
119  Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
120  Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
121 
122  if (reductionLHS != block.getArgument(2) &&
123  reductionRHS != block.getArgument(2)) {
124  errs << "expected reduction to take block argument #2 as one of the "
125  "operands (modulo unary casts)";
126  return false;
127  }
128 
129  Value contributed = getSourceSkipUnary(
130  isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
131  Operation *elementwiseOp = contributed.getDefiningOp();
132  if (elementwiseOp->getNumResults() != 1 ||
133  elementwiseOp->getNumOperands() != 2) {
134  errs << "expected elementwise op to be binary";
135  return false;
136  }
137 
138  if (!isaPair(elementwiseOp, reductionOp)) {
139  errs << "expected reduction/elementwise op kind not satisfied";
140  return false;
141  }
142 
143  Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
144  Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
145  if ((elementwiseLHS == block.getArgument(0) &&
146  elementwiseRHS == block.getArgument(1)) ||
147  (elementwiseLHS == block.getArgument(1) &&
148  elementwiseRHS == block.getArgument(0))) {
149  return true;
150  }
151 
152  errs << "expected elementwise op to apply to block arguments (modulo unary "
153  "casts)";
154  return false;
155 }
156 
157 /// Returns true if the two operations are of the kinds specified by a pair of
158 /// consecutive template arguments.
159 template <typename AddOpTy, typename MulOpTy, typename... Args>
160 static bool isPairTemplateImpl(Operation *add, Operation *mul) {
161  static_assert(sizeof...(Args) % 2 == 0,
162  "expected an even number of template arguments");
163  if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
164  return true;
165 
166  if constexpr (sizeof...(Args) > 0)
167  return isPairTemplateImpl<Args...>(add, mul);
168  else
169  return false;
170 }
171 
172 /// Returns true if the block is a body of a contraction with the kinds of
173 /// operations given pairwise by template arguments.
174 template <typename... Args>
175 static bool isContractionBody(Block &block) {
176  return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
177 }
178 
179 /// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
180 /// iterators of type `iter` that index the `opOperand` as a permutation.
181 /// This is useful to infer various subcomputations on a given `linalgOp`.
182 /// This is performed by looking up each result in the matching indexing map and
183 /// determining whether:
184 /// - It is a single AffineDimExpr.
185 /// - It is the only result involving this AffineDimExpr.
186 static llvm::SmallDenseSet<int64_t>
187 findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
188  utils::IteratorType iter) {
189  llvm::SmallDenseSet<int64_t> res;
190  assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
191  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
192  for (AffineExpr e : indexingMap.getResults()) {
193  if (auto d = dyn_cast<AffineDimExpr>(e)) {
194  if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
195  llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
196  return e.isFunctionOfDim(d.getPosition());
197  }) == 1)
198  res.insert(d.getPosition());
199  }
200  }
201  return res;
202 }
203 
204 namespace {
205 auto par = utils::IteratorType::parallel;
206 auto red = utils::IteratorType::reduction;
207 } // namespace
208 
209 /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
210 /// a matmul subcomputation within `linalgOp`. These dimensions are such that:
211 /// 1. The m dimension is involved in an outer-product along LHS
212 /// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
213 /// 2. The n dimension is involved in an outer-product along RHS
214 /// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
215 /// 3. The k dimension appears as a permutation on LHS and RHS.
216 /// 4. m, n and k appear only once in any given indexing.
217 /// 5. Optional batch dimensions that appear in all operands are captured.
218 /// This allows e.g. detecting that some contraction is embedded within
219 /// `linalgOp` with some orthogonal heuristic.
222  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
223  return failure();
224 
225  llvm::SmallDenseSet<int64_t> a = findPermutationsIndexingOperand(
226  linalgOp, linalgOp.getDpsInputOperand(0), par);
227  llvm::SmallDenseSet<int64_t> b = findPermutationsIndexingOperand(
228  linalgOp, linalgOp.getDpsInputOperand(1), par);
229  llvm::SmallDenseSet<int64_t> c = findPermutationsIndexingOperand(
230  linalgOp, linalgOp.getDpsInitOperand(0), par);
231 
232  // A & C - B are the iterators involved in an outer-product along A (the LHS).
233  llvm::SmallDenseSet<int64_t> ac = a;
234  llvm::set_intersect(ac, c);
235  llvm::set_subtract(ac, b);
236  // B & C - A are the iterators involved in an outer-product along B (the RHS).
237  llvm::SmallDenseSet<int64_t> bc = b;
238  llvm::set_intersect(bc, c);
239  llvm::set_subtract(bc, a);
240  // A & B & C are the "batch" dimensions.
241  llvm::SmallDenseSet<int64_t> batches = a;
242  llvm::set_intersect(batches, b);
243  llvm::set_intersect(batches, c);
244 
245  // A & B red are the reduction dimensions.
246  llvm::SmallDenseSet<int64_t> ra = findPermutationsIndexingOperand(
247  linalgOp, linalgOp.getDpsInputOperand(0), red);
248  llvm::SmallDenseSet<int64_t> rb = findPermutationsIndexingOperand(
249  linalgOp, linalgOp.getDpsInputOperand(1), red);
250  llvm::set_intersect(ra, rb);
251 
252  // Return each set in sorted order.
253  ContractionDimensions dimensions{
254  SmallVector<unsigned, 2>(batches.begin(), batches.end()),
255  SmallVector<unsigned, 2>(ac.begin(), ac.end()),
256  SmallVector<unsigned, 2>(bc.begin(), bc.end()),
257  SmallVector<unsigned, 2>(ra.begin(), ra.end())};
258  llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
259  llvm::sort(dimensions.m.begin(), dimensions.m.end());
260  llvm::sort(dimensions.n.begin(), dimensions.n.end());
261  llvm::sort(dimensions.k.begin(), dimensions.k.end());
262  return dimensions;
263 }
264 
265 namespace mlir::linalg::detail {
267  Success = 0,
268  NotLinalgOp,
270  NoReduction,
272  NotAddMul
273 };
274 } // namespace mlir::linalg::detail
275 
279  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
280  if (!linalgOp)
281  return MatchContractionResult::NotLinalgOp;
282  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
283  return MatchContractionResult::WrongNumOperands;
284  auto mapRange = linalgOp.getIndexingMapsArray();
285  if (linalgOp.getNumReductionLoops() == 0)
286  return MatchContractionResult::NoReduction;
287  if (llvm::any_of(mapRange,
288  [](AffineMap m) { return !m.isProjectedPermutation(); }))
289  return MatchContractionResult::NotProjectedPermutations;
290  // TODO: more fields than add/mul.
291  // clang-format off
292  if (!::isContractionBody<
293  arith::MulFOp, arith::AddFOp,
294  arith::MulIOp, arith::AddIOp,
295  complex::MulOp, complex::AddOp,
296  arith::AndIOp, arith::OrIOp>(
297  *linalgOp.getBlock())) {
298  return MatchContractionResult::NotAddMul;
299  }
300  // clang-format on
301 
302  if (dimensions) {
304  assert(succeeded(res) && "unexpected failure to infer contraction dims");
305  *dimensions = *res;
306  }
307  return MatchContractionResult::Success;
308 }
309 
310 StringRef
312  switch (res) {
313  case MatchContractionResult::NotLinalgOp:
314  return "expected a LinalgOp";
315  case MatchContractionResult::WrongNumOperands:
316  return "expected op with 2 inputs and 1 output";
317  case MatchContractionResult::NoReduction:
318  return "expected at least 1 reduction";
319  case MatchContractionResult::NotProjectedPermutations:
320  return "expected indexing maps to be projected permutations";
321  case MatchContractionResult::NotAddMul:
322  return "expected add/mul op in the body";
323  case MatchContractionResult::Success:
324  return "";
325  }
326  llvm_unreachable("unhandled MatchContractionResult case");
327 }
328 
330  if (!linalgOp)
331  return false;
332  Operation *op = linalgOp.getOperation();
333  return isa<ContractionOpInterface>(op) ||
336 }
337 
338 /// Verify that a LinalgOp `op` is a contraction.
339 /// A Linalg contraction is defined in general terms:
340 /// 1. Has 2 input and 1 output shapes.
341 /// 2. Has at least one reduction dimension.
342 /// 3. Has only projected permutation indexing maps.
343 /// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
344 /// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
345 /// operations that may change the type (e.g. for mixed-precision).
346 /// As a consequence, when vectorization of such an op occurs, the only special
347 /// behavior is that the (unique) MulOpType is vectorized into a
348 /// `vector.contract`. All other ops are handled in a generic fashion.
349 /// In the future, we may wish to allow more input arguments and elementwise and
350 /// constant operations that do not involve the reduction dimension(s).
352  auto res = isContractionInterfaceImpl(op);
353  if (res != MatchContractionResult::Success)
354  return op->emitError(getMatchContractionMessage(res));
355  return success();
356 }
357 
358 //===----------------------------------------------------------------------===//
359 // ConvolutionOpInterface implementation
360 //===----------------------------------------------------------------------===//
361 
362 /// Of the given two expressions returns one that is of type T (`lhs` gets
363 /// preference over `rhs`)
364 template <typename T>
366  return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
367 }
368 
369 namespace {
370 /// Walk the indexing expressions for input of a convolution operation to verify
371 /// its of the right form, either
372 /// - AffineDimExpr
373 /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
374 /// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
375 ///
376 /// classifies the AffineDimExpr as convolved dimensions or unconvolved
377 /// dimensions and verifies each dimension occurs only once.
378 struct ConvAccessExprWalker
379  : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
380  // Stores dimensions used in expressions of the above form.
381  llvm::SmallDenseSet<int64_t> convolvedDims;
382  // Stores the dual mapping between LHS and RHS of convolution exprs.
383  llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
384  // Stores single use dimensions used by an AffineDimExpr.
385  llvm::SmallDenseSet<int64_t> unConvolvedDims;
386  // Stores a mapping from convolved dims to their coefficient.
387  llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
388 
389  // Removes dims with multiple uses in the source input map from dimension
390  // sets tracked by this walker.
391  void clearMultiUseDims(AffineMap map) {
392  for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
393  if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) {
394  return e.isFunctionOfDim(dimPos);
395  }) > 1) {
396  convolvedDims.erase(dimPos);
397  unConvolvedDims.erase(dimPos);
398  // If a duplicate dim is marked as convolved, the pair of the duplicate
399  // dim must be removed from the map as well.
400  if (convolvedDimMapping.contains(dimPos)) {
401  int64_t pairedDim = convolvedDimMapping[dimPos];
402  convolvedDims.erase(pairedDim);
403  unConvolvedDims.erase(pairedDim);
404  strideAndDilationMapping.erase(pairedDim);
405  convolvedDimMapping.erase(dimPos);
406  convolvedDimMapping.erase(pairedDim);
407  }
408  }
409  }
410  }
411 
412  LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
413  unsigned position = dimExpr.getPosition();
414  if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
415  return failure();
416  }
417  unConvolvedDims.insert(position);
418  return success();
419  }
420 
421  LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
422 
423  LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
424 
425  LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
426  // In pre-order visit, top level op has to be an add op.
427  if (binaryExpr.getKind() != AffineExprKind::Add)
428  return failure();
429  auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());
430  auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());
431  if (failed(lhsDimPos) || failed(rhsDimPos))
432  return failure();
433  convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
434  convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
435  return success();
436  }
437 
438  FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
439  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
440  int64_t dim = dimExpr.getPosition();
441  if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
442  return failure();
443  // Stride/dilation for this dim is implicitly 1.
444  strideAndDilationMapping[dim] =
446  convolvedDims.insert(dim);
447  return dim;
448  }
449  if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
450  if (symbolMulExpr.getKind() != AffineExprKind::Mul)
451  return failure();
452  auto lhsExpr = symbolMulExpr.getLHS();
453  auto rhsExpr = symbolMulExpr.getRHS();
454  // Check for symbol expression.
455  AffineExpr mulExpr =
456  getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
457  // If there was no symbol expr, check for constant expression.
458  if (!mulExpr) {
459  mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
460  }
461  auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
462  if (!mulExpr || !dimExpr)
463  return failure();
464  int64_t dim = dimExpr.getPosition();
465  if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
466  return failure();
467  strideAndDilationMapping[dim] = mulExpr;
468  convolvedDims.insert(dim);
469  return dim;
470  }
471  return failure();
472  }
473 };
474 } // namespace
475 
476 static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
477  assert(map.isProjectedPermutation() &&
478  "expected map to have projected permutations");
479  llvm::SmallDenseSet<int64_t> preservedDims;
480  for (auto expr : map.getResults())
481  preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
482  return preservedDims;
483 }
484 
488  for (auto e : exprs) {
489  auto constantExpr = dyn_cast<AffineConstantExpr>(e);
490  assert(constantExpr && "Found non-constant stride/dilation");
491  vals.push_back(constantExpr.getValue());
492  }
493  return vals;
494 }
495 
496 /// Classifies dimensions in the `linalgOp` used by a convolution
497 /// subcomputation, as captured by `inputExprWalker`. If
498 /// `allowEmptyConvolvedDims` is not set this this will fail if there is not
499 /// at least convolved dimension pair (output image + filter loop). Convolution
500 /// dimensions are specified in sorted order, and strides match the order of
501 /// the filter loop dimensions, while the dilations match the order of the
502 /// output image dimensions.
504 inferConvolutionDimsImpl(LinalgOp linalgOp,
505  ConvAccessExprWalker &inputExprWalker,
506  bool allowEmptyConvolvedDims) {
507  llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
508  linalgOp, linalgOp.getDpsInputOperand(1), par);
509  llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
510  linalgOp, linalgOp.getDpsInitOperand(0), par);
511 
512  // unConvolvedDims & outputDims - filterDims are the batch iterators.
513  llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
514  llvm::set_intersect(batch, outputDims);
515  llvm::set_subtract(batch, filterDims);
516 
517  // convolvedDims & outputDims are the output image iterators.
518  llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
519  llvm::set_intersect(oi, outputDims);
520 
521  // filterDims & outputDims - unConvolvedDims are the output channel iterators.
522  llvm::SmallDenseSet<int64_t> oc = filterDims;
523  llvm::set_intersect(oc, outputDims);
524  llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
525 
526  // filterDims & outputDims & unConvolvedDims are the depth iterators.
527  llvm::SmallDenseSet<int64_t> depth = filterDims;
528  llvm::set_intersect(depth, outputDims);
529  llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
530 
531  llvm::SmallDenseSet<int64_t> filterReducedDims =
532  findPermutationsIndexingOperand(linalgOp, linalgOp.getDpsInputOperand(1),
533  red);
534 
535  // convolvedDims & filterReducedDims are the filter loop iterators.
536  llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
537  llvm::set_intersect(fl, filterReducedDims);
538 
539  // unConvolvedDims & filterReducedDims are the input channel iterators.
540  llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
541  llvm::set_intersect(ic, filterReducedDims);
542 
543  if (oi.empty() && !allowEmptyConvolvedDims)
544  return failure();
545 
546  // Return each set in sorted order.
547  ConvolutionDimensions dimensions{
548  SmallVector<unsigned, 2>(batch.begin(), batch.end()),
549  SmallVector<unsigned, 2>(oi.begin(), oi.end()),
550  SmallVector<unsigned, 2>(oc.begin(), oc.end()),
551  SmallVector<unsigned, 2>(fl.begin(), fl.end()),
552  SmallVector<unsigned, 2>(ic.begin(), ic.end()),
553  SmallVector<unsigned, 2>(depth.begin(), depth.end()),
554  /*strides=*/SmallVector<int64_t, 2>{},
555  /*dilations=*/SmallVector<int64_t, 2>{}};
556  llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
557  llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
558  llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
559  llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
560  llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
561  llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
562 
563  // Use the op carried strides/dilations attribute if present.
564  auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
565  if (!nativeStrides) {
566  SmallVector<AffineExpr, 2> strideExprs;
567  for (unsigned oiDim : dimensions.outputImage)
568  strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
569  dimensions.strides = getConstantsFromExprList(strideExprs);
570  } else {
571  dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
572  }
573  auto nativeDilations =
574  linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
575  if (!nativeDilations) {
576  SmallVector<AffineExpr, 2> dilationExprs;
577  for (unsigned flDim : dimensions.filterLoop)
578  dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
579  dimensions.dilations = getConstantsFromExprList(dilationExprs);
580  } else {
581  dimensions.dilations =
582  llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
583  }
584  return dimensions;
585 }
586 
587 /// Find at least 1 parallel (output_image) and reduction (filter_loop)
588 /// dimension candidates that form a convolution subcomputation within
589 /// `linalgOp`. The LHS is assumed to be the convolution input while the
590 /// RHS is assumed as the filter.
591 /// These dimensions are such that:
592 /// 1. Optional batch dimensions that appear in the input and filter.
593 /// 2. The output_image dimension is involved in a cross-correlation along LHS
594 /// (i.e. it is a permutation on RES and LHS and has an associated
595 /// filter_loop in RHS).
596 /// 3. Optional output_channel dimension is involved in an outer-product along
597 /// RHS (i.e. it is a permutation on RES and RHS and does not appear in
598 /// LHS).
599 /// 4. Optional input_channel dimension appears as a permutation on LHS and
600 /// RHS.
601 /// 5. The filter_loop dimension appears as a permutation on the RHS and
602 /// represents the shape of the kernel cross-correlated along a
603 /// corresponding output_image dim.
604 /// 6. The input_channel dimension appears as a permutation on LHS and RHS.
605 /// 7. All dimensions appear only once in any given indexing map.
606 /// This allows e.g. detecting that some convolution is embedded within
607 /// `linalgOp` with some orthogonal heuristic.
608 /// When multiple dimension occurrences exist that match any classification
609 /// indices are returned in sorted order.
610 /// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
613  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
614  return failure();
615 
616  auto indexingMaps = linalgOp.getIndexingMapsArray();
617 
618  // Check the input indexing map has the right form.
619  ConvAccessExprWalker inputExprWalker;
620  for (AffineExpr expr : indexingMaps[0].getResults())
621  (void)inputExprWalker.visit(expr);
622  inputExprWalker.clearMultiUseDims(indexingMaps[0]);
623 
624  return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
625  /*allowEmptyConvolvedDims=*/false);
626 }
627 
628 namespace mlir::linalg::detail {
630  Success = 0,
631  NotLinalgOp,
638 };
639 } // namespace mlir::linalg::detail
640 
643  Operation *op, ConvolutionDimensions *dimensions) {
644  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
645  if (!linalgOp)
646  return MatchConvolutionResult::NotLinalgOp;
647  if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
648  return MatchConvolutionResult::WrongNumOperands;
649 
650  auto indexingMaps = linalgOp.getIndexingMapsArray();
651 
652  // Check the input indexing map has the right form.
653  ConvAccessExprWalker inputExprWalker;
654  if (llvm::any_of(indexingMaps[0].getResults(),
655  [&inputExprWalker](AffineExpr expr) {
656  return failed(inputExprWalker.visit(expr));
657  })) {
658  return MatchConvolutionResult::WrongInputIndexingMap;
659  }
660 
661  // Filter and output maps must be projected permutation.
662  if (!indexingMaps[1].isProjectedPermutation() ||
663  !indexingMaps.back().isProjectedPermutation())
664  return MatchConvolutionResult::NotProjectedPermutations;
665 
666  auto iteratorTypes = linalgOp.getIteratorTypesArray();
667 
668  llvm::SmallDenseSet<int64_t> outputDims =
669  getPreservedDims(indexingMaps.back());
670  llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
671  // Make sure all loops are characterized as one of:
672  // - Batch loop : present in output, as non-convolved in input, not present in
673  // filter.
674  // - Output image dimension : present in output, convolved dims in input, not
675  // present in filter.
676  // - Output channel dimension : present in output, not present in input,
677  // present in filter.
678  // - Filter loop dimension : present in filter, convolved in input, not
679  // present in output.
680  // - Input channel dimension : unconvolved in input, not present in output,
681  // present in filter.
682  // - Depth multiplier : unconvolved in input, present in output, present in
683  // filter.
684  llvm::SmallDenseSet<int64_t> allLoopDims;
685  for (auto outputExpr : indexingMaps.back().getResults()) {
686  int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
687  if (inputExprWalker.unConvolvedDims.count(outputDim) &&
688  !filterDims.count(outputDim)) {
689  // Batch dimension.
690  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
691  return MatchConvolutionResult::OutputDimsNotParallel;
692  allLoopDims.insert(outputDim);
693  continue;
694  }
695  if (inputExprWalker.convolvedDims.count(outputDim) &&
696  !filterDims.count(outputDim)) {
697  // Output image Loop dimension.
698  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
699  return MatchConvolutionResult::OutputDimsNotParallel;
700  allLoopDims.insert(outputDim);
701  continue;
702  }
703  if (!inputExprWalker.convolvedDims.count(outputDim) &&
704  !inputExprWalker.unConvolvedDims.count(outputDim) &&
705  filterDims.count(outputDim)) {
706  // Output channel dimension.
707  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
708  return MatchConvolutionResult::OutputDimsNotParallel;
709  allLoopDims.insert(outputDim);
710  continue;
711  }
712  if (inputExprWalker.unConvolvedDims.count(outputDim) &&
713  filterDims.count(outputDim)) {
714  // Depth multiplier.
715  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
716  return MatchConvolutionResult::OutputDimsNotParallel;
717  allLoopDims.insert(outputDim);
718  continue;
719  }
720  return MatchConvolutionResult::NonConvolutionLoop;
721  }
722  for (auto filterExpr : indexingMaps[1].getResults()) {
723  int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
724  if (outputDims.count(filterDim) &&
725  !inputExprWalker.unConvolvedDims.count(filterDim) &&
726  !inputExprWalker.convolvedDims.count(filterDim)) {
727  // Output channel dimension. This is already seen, continue;
728  continue;
729  }
730  if (inputExprWalker.convolvedDims.count(filterDim) &&
731  !outputDims.count(filterDim)) {
732  // Filter loop dimension.
733  if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
734  return MatchConvolutionResult::NonOutputDimNotReduction;
735  if (allLoopDims.count(filterDim))
736  return MatchConvolutionResult::NonConvolutionLoop;
737  allLoopDims.insert(filterDim);
738  continue;
739  }
740  if (inputExprWalker.unConvolvedDims.count(filterDim) &&
741  !outputDims.count(filterDim)) {
742  // Input channel dimension.
743  if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
744  return MatchConvolutionResult::NonOutputDimNotReduction;
745  if (allLoopDims.count(filterDim))
746  return MatchConvolutionResult::NonConvolutionLoop;
747  allLoopDims.insert(filterDim);
748  continue;
749  }
750  if (inputExprWalker.unConvolvedDims.count(filterDim) &&
751  outputDims.count(filterDim)) {
752  // Depthwise loop. Already seen.
753  continue;
754  }
755  return MatchConvolutionResult::NonConvolutionLoop;
756  }
757  // All loops must be covered now.
758  if (allLoopDims.size() != linalgOp.getNumLoops())
759  return MatchConvolutionResult::NonConvolutionLoop;
760 
761  if (dimensions) {
763  inferConvolutionDimsImpl(linalgOp, inputExprWalker,
764  /*allowEmptyConvolvedDims=*/true);
765  assert(succeeded(res) && "unexpected failure to infer convolution dims");
766  *dimensions = *res;
767  }
768 
769  return MatchConvolutionResult::Success;
770 }
771 
772 StringRef
774  switch (res) {
775  case MatchConvolutionResult::NotLinalgOp:
776  return "expected a LinalgOp";
777  case MatchConvolutionResult::WrongNumOperands:
778  return "expected op with 2 inputs and 1 output";
779  case MatchConvolutionResult::WrongInputIndexingMap:
780  return "unexpected input index map for convolutions";
781  case MatchConvolutionResult::NotProjectedPermutations:
782  return "expected output/filter indexing maps to be projected permutations";
783  case MatchConvolutionResult::NonConvolutionLoop:
784  return "unexpected loop dimension for convolution op";
785  case MatchConvolutionResult::OutputDimsNotParallel:
786  return "expected all iterators used to access outputs to be parallel";
787  case MatchConvolutionResult::NonOutputDimNotReduction:
788  return "expected all iterators not used to access outputs to be reduction";
789  case MatchConvolutionResult::Success:
790  return "";
791  }
792  llvm_unreachable("unhandled MatchConvolutionResult case");
793 }
794 
796  return linalg::detail::isConvolutionInterfaceImpl(linalgOp.getOperation()) ==
798 }
799 
802  if (res != MatchConvolutionResult::Success)
803  return op->emitError(getMatchConvolutionMessage(res));
804  return success();
805 }
806 
807 //===----------------------------------------------------------------------===//
808 // FillOpInterface implementation
809 //===----------------------------------------------------------------------===//
810 
811 enum class MatchFillResult {
812  Success = 0,
813  NotLinalgOp,
814  WrongNumOperands,
816 };
817 
819  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
820  if (!linalgOp)
822  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
824 
825  OpOperand *value = linalgOp.getDpsInputOperand(0);
826  if (!linalgOp.isScalar(value))
828 
830 }
831 
833  auto res = isFillInterfaceImpl(op);
834  if (res == MatchFillResult::NotLinalgOp)
835  return op->emitError("expected a LinalgOp");
837  return op->emitError("expected op with 1 input and 1 output");
839  return op->emitError("expected op with scalar input");
840 
841  return success();
842 }
843 
844 //===----------------------------------------------------------------------===//
845 // StructuredOpInterface implementation
846 //===----------------------------------------------------------------------===//
847 
848 SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
849  Location loc) {
851  for (OpOperand &opOperand : getOperation()->getOpOperands()) {
852  for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
853  res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
854  }
855  return res;
856 }
857 
858 SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
860  assert(!hasDynamicShape() && "expected operands to have static shapes");
861  for (OpOperand &opOperand : getOperation()->getOpOperands())
862  llvm::append_range(res, getShape(&opOperand));
863  return res;
864 }
865 
866 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
867  AffineMap map = getLoopsToShapesMap();
868  unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
869  auto viewSizes = createFlatListOfOperandDims(b, loc);
870  SmallVector<Range, 4> res(numDims);
871  for (unsigned idx = 0; idx < numRes; ++idx) {
872  auto result = map.getResult(idx);
873  if (auto d = dyn_cast<AffineDimExpr>(result)) {
874  if (res[d.getPosition()].offset)
875  continue;
876  res[d.getPosition()] =
877  Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
878  }
879  }
880  return res;
881 }
882 
883 SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
884  AffineMap map = getLoopsToShapesMap();
885  unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
886  SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims();
887  SmallVector<int64_t, 4> res(numDims, 0);
888  for (unsigned idx = 0; idx < numRes; ++idx) {
889  auto result = map.getResult(idx);
890  if (auto d = dyn_cast<AffineDimExpr>(result))
891  res[d.getPosition()] = allShapeSizes[idx];
892  }
893  return res;
894 }
895 
896 /// Visitor to check if any of the given set of positions from AffineDimExprs
897 /// are used within an AffineExpr.
899  : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
900  HasAffineDimExprVisitor(llvm::SmallBitVector positions)
901  : positions(std::move(positions)) {}
902 
904  return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
905  }
906 
907  bool visitDimExpr(AffineDimExpr dimExpr) {
908  return positions.test(dimExpr.getPosition());
909  }
910 
911  bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
912 
913  bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
914 
915 private:
916  llvm::SmallBitVector positions;
917 };
918 
919 static std::pair<int64_t, int64_t>
921  int64_t inputRankSum = 0;
922  int64_t outputRankSum = 0;
923  for (OpOperand *input : op.getDpsInputOperands())
924  inputRankSum += op.getRank(input);
925  for (OpOperand &output : op.getDpsInitsMutable())
926  outputRankSum += op.getRank(&output);
927  return {inputRankSum, inputRankSum + outputRankSum};
928 }
929 
932  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
933  // An example that helps understand the logic below.
934  // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
935  // We want to express the shape of dim 0 of O in terms of shape of the inputs.
936  // This is achieved as follows.
937  // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
938  // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
939  // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
940  // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
941  // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
942  AffineMap loopsToShapesMap = getLoopsToShapesMap();
943 
944  // Find the position in the above map that represents the shape of the
945  // result:dim being inferred.
946  auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
947 
948  /// From loopsToShapesMap extract the submap that represents the shape of the
949  /// (resultIdx, dim) needed.
950  AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
951  resultShapesSubMapPos.first,
952  resultShapesSubMapPos.second - resultShapesSubMapPos.first);
953  AffineMap resultShapesFromInputShapesMap =
954  loopToResultsShapeMap.compose(getShapesToLoopsMap());
955 
956  // Check that the result dim map does not contain the positions corresponding
957  // to the outputs.
958  llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
959  outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
960  HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
961  Location loc = getOperation()->getLoc();
962  IRRewriter rewriter(b);
963  SmallVector<OpFoldResult> allResultDimValues =
965  rewriter, loc, resultShapesFromInputShapesMap,
966  createFlatListOfOperandDims(b, loc));
967  int64_t pos = 0;
968  ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
969  for (OpOperand &opOperand : getDpsInitsMutable()) {
971  for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
972  auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
973  if (!shapedType.isDynamicDim(dim)) {
974  // Static dim: Return IntegerAttr.
975  shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
976  } else {
977  // Dynamic dim: Return Value.
978  OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
979  ? createOrFoldDimOp(b, loc, opOperand.get(), dim)
980  : allResultDimValues[pos];
981  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
982  }
983  pos++;
984  }
985  reifiedReturnShapes.emplace_back(std::move(shapes));
986  }
987  return success();
988 }
989 
990 /// Return the index in the indexingMaps vector that corresponds to this
991 /// `opOperand`.
992 int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
993  auto operandNumber = opOperand->getOperandNumber();
994  auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
995  if (!dpsIface.isDpsInput(opOperand))
996  return operandNumber;
997  unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
998  assert(!dpsIface.isDpsInit(opOperand));
999  // Account for potential inputs that are not DPS and may not appear in
1000  // `indexingMaps`.
1001  return cast<DestinationStyleOpInterface>(*this->getOperation())
1002  .getNumDpsInputs() +
1003  operandNumber - start;
1004 }
1005 
1007  LinalgOp linalgOp = cast<LinalgOp>(op);
1008 
1009  // Before checking indexing maps, we need to make sure the attributes
1010  // referenced by it are valid.
1011  if (linalgOp.hasDynamicIndexingMaps())
1012  if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1013  return failure();
1014 
1015  // All input/output operands must be indexed.
1016  if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
1017  linalgOp->getNumOperands())
1018  return op->emitOpError("expected the number of indexing_map (")
1019  << linalgOp.getIndexingMapsArray().size()
1020  << ") to be equal to the number of input/output operands ("
1021  << linalgOp->getNumOperands() << ")";
1022 
1023  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1024  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1025 
1026  // Symbols disallowed.
1027  if (indexingMap.getNumSymbols() != 0)
1028  return op->emitOpError("unexpected symbols in indexing_map #")
1029  << opOperand.getOperandNumber();
1030 
1031  // Domain must be consistent.
1032  unsigned numLoops = linalgOp.getNumLoops();
1033  if (indexingMap.getNumDims() != numLoops)
1034  return op->emitOpError("expected indexing_map #")
1035  << opOperand.getOperandNumber() << " to have " << numLoops
1036  << " dim(s) to match the number of loops";
1037 
1038  int64_t rank = linalgOp.getRank(&opOperand);
1039  if (indexingMap.getNumResults() != rank)
1040  return op->emitOpError("expected operand rank (")
1041  << rank << ") to match the result rank of indexing_map #"
1042  << opOperand.getOperandNumber() << " ("
1043  << indexingMap.getNumResults() << ")";
1044  }
1045 
1046  SmallVector<unsigned> redDims;
1047  linalgOp.getReductionDims(redDims);
1048 
1049  if (!linalgOp.getShapesToLoopsMap())
1050  return op->emitOpError("expected the shape-to-loops map to be non-null");
1051 
1052  // Check if given shapes match to inferred shapes.
1053  SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
1054  SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
1055 
1056  // Verify only static cases since we can't get exact dimension sizes and loop
1057  // ranges for dynamic cases in this stage.
1058  if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
1059  for (int64_t &range : endLoopRangeValues)
1060  range -= 1;
1061  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1062  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1063  SmallVector<int64_t, 4> startIndices =
1064  indexingMap.compose(startLoopRangeValues);
1065  SmallVector<int64_t, 4> endIndices =
1066  indexingMap.compose(endLoopRangeValues);
1067  ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand);
1068  for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
1069  // Ignore dynamic dimension or the case that the dimension size is 0
1070  if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
1071  continue;
1072 
1073  // The first index or last index should be the maximum or the minimum in
1074  // the inferred index ranges since the range is increasing or
1075  // decreasing. The size of dimensions of input/output operands and the
1076  // maximum value + 1 in the inferred range should be the same. But, for
1077  // now we check if the inferred ranges are in boundary of input/output
1078  // operands' size or not in case that Affine Expressions are complicated
1079  // such as d0 * 3
1080  // + d1 since it is not easy to handle the issues.
1081  // Found the case that this solution can't check, for example, (d0, d1)
1082  // -> (d1 - d0)
1083  int64_t inferredDimSize =
1084  std::max(startIndices[dim], endIndices[dim]) + 1;
1085  if (std::min(startIndices[dim], endIndices[dim]) < 0) {
1086  std::string mapStr;
1087  {
1088  llvm::raw_string_ostream os(mapStr);
1089  os << indexingMap;
1090  }
1091  return op->emitOpError(
1092  "unexpected result less than 0 at expression #")
1093  << dim << " in " << mapStr;
1094  }
1095  if (dyn_cast<AffineDimExpr>(indexingMap.getResult(dim))) {
1096  if (inferredDimSize != shape[dim]) {
1097  return op->emitOpError("inferred input/output operand #")
1098  << opOperand.getOperandNumber() << " has shape's dimension #"
1099  << dim << " to be " << inferredDimSize << ", but found "
1100  << shape[dim];
1101  }
1102  } else {
1103  if (inferredDimSize > shape[dim]) {
1104  return op->emitOpError("inferred input/output operand #")
1105  << opOperand.getOperandNumber() << " has shape's dimension #"
1106  << dim << " to be greater than or equal to "
1107  << inferredDimSize << ", but found " << shape[dim];
1108  }
1109  }
1110  }
1111  }
1112  }
1113 
1114  // Check the region has exactly one block.
1115  if (linalgOp->getNumRegions() != 1 ||
1116  !llvm::hasSingleElement(linalgOp->getRegion(0)))
1117  return op->emitOpError("expects to have 1 region with 1 block");
1118 
1119  // Simplifying assumption: bbargs match 1-1 with shape operands elemental
1120  // types.
1121  // TODO: once ranked shape types are plugged in, we may want to drop the
1122  // corresponding bbargs, that can never be read from. This will be subject to
1123  // consistency discussions (i.e. what to do with output tensors whose bbarg is
1124  // not used).
1125  Block &block = linalgOp->getRegion(0).front();
1126 
1127  if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
1128  return op->emitOpError("expected as many non-induction variable region "
1129  "arguments as the number of input/output operands");
1130 
1131  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1132  Type elementType = opOperand->get().getType();
1133  if (isa<MemRefType, RankedTensorType>(elementType))
1134  elementType = getElementTypeOrSelf(opOperand->get().getType());
1135  Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
1136  if (elementType != argType)
1137  return op->emitOpError("expected type of bb argument #")
1138  << opOperand->getOperandNumber() << " (" << argType << ")"
1139  << " to match element or self type of the corresponding operand ("
1140  << elementType << ")";
1141  }
1142 
1143  return success();
1144 }
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:63
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static SmallVector< int64_t, 2 > getConstantsFromExprList(SmallVector< AffineExpr, 2 > exprs)
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand, utils::IteratorType iter)
Given a linalgOp and one of its opOperand, returns the positions of the iterators of type iter that i...
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static MatchFillResult isFillInterfaceImpl(Operation *op)
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
MatchFillResult
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Affine binary operation expression.
Definition: AffineExpr.h:213
AffineExpr getLHS() const
Definition: AffineExpr.cpp:317
AffineExpr getRHS() const
Definition: AffineExpr.cpp:320
An integer constant appearing in affine expression.
Definition: AffineExpr.h:238
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:222
unsigned getPosition() const
Definition: AffineExpr.cpp:325
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:27
MLIRContext * getContext() const
Definition: AffineExpr.cpp:25
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:619
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:575
unsigned getNumSymbols() const
Definition: AffineMap.cpp:378
unsigned getNumDims() const
Definition: AffineMap.cpp:374
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:387
unsigned getNumResults() const
Definition: AffineMap.cpp:382
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:391
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:536
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:230
Block represents an ordered list of Operations.
Definition: Block.h:30
bool empty()
Definition: Block.h:141
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
unsigned getNumArguments()
Definition: Block.h:121
Operation & back()
Definition: Block.h:145
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
Operation & front()
Definition: Block.h:146
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
An attribute that represents a reference to a dense integer vector or tensor object.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:710
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:206
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:762
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition: Operation.h:736
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1217
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:97
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaConvolutionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ConvolutionOpInterface.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:88
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:749
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:794
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:40
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:608
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.