MLIR  19.0.0git
LinalgInterfaces.cpp
Go to the documentation of this file.
1 //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "llvm/ADT/SetOperations.h"
22 #include "llvm/ADT/SmallBitVector.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include <algorithm>
25 
26 using namespace mlir;
27 using namespace mlir::linalg;
28 
29 /// Include the definitions of the copy operation interface.
30 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
31 
32 //===----------------------------------------------------------------------===//
33 // Interface utility functions
34 //===----------------------------------------------------------------------===//
35 
37  linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
38  SmallVector<AffineMap> indexingMaps;
39  for (auto &opOperand : linalgOp->getOpOperands()) {
40  if (llvm::is_contained(droppedOperands, &opOperand))
41  continue;
42  indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
43  }
44  if (indexingMaps.empty()) {
45  // If there are no indexing maps, the operand can only be dropped
46  // if the op has no loops.
47  return linalgOp.getNumLoops() == 0;
48  }
49  return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
50 }
51 
52 //===----------------------------------------------------------------------===//
53 // CopyOpInterface implementation
54 //===----------------------------------------------------------------------===//
55 
56 bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
57  // Structural.
58  if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
59  return false;
60 
61  // Operands and maps.
62  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
63  return false;
64  auto mapRange = linalgOp.getIndexingMapsArray();
65  if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
66  !mapRange.back().isIdentity()) {
67  return false;
68  }
69  // Region.
70  return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // ContractionOpInterface implementation
75 //===----------------------------------------------------------------------===//
76 
77 /// If the value is defined by a chain of unary side effect-free, go up the
78 /// use-def chain until the first value that isn't defined by such an op.
79 // TODO: relax to multi-operands with constants, which are technically unary ops
80 // as needed (e.g. add5).
82  Operation *op = value.getDefiningOp();
83  while (op && op->getNumOperands() == 1) {
84  auto iface = dyn_cast<MemoryEffectOpInterface>(op);
85  if (!iface || !iface.hasNoEffect())
86  break;
87  value = op->getOperand(0);
88  op = value.getDefiningOp();
89  }
90  return value;
91 }
92 
94  Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
95  llvm::raw_ostream &errs) {
96  if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
97  errs << "no terminator in the block";
98  return false;
99  }
100 
101  if (block.getNumArguments() != 3) {
102  errs << "expected block with 3 arguments";
103  return false;
104  }
105 
106  Operation *terminator = block.getTerminator();
107  if (terminator->getNumOperands() != 1) {
108  errs << "expected terminator with 1 operand";
109  return false;
110  }
111 
112  Value yielded = getSourceSkipUnary(terminator->getOperand(0));
113  Operation *reductionOp = yielded.getDefiningOp();
114  if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
115  errs << "expected reduction op to be binary";
116  return false;
117  }
118 
119  Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
120  Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
121 
122  if (reductionLHS != block.getArgument(2) &&
123  reductionRHS != block.getArgument(2)) {
124  errs << "expected reduction to take block argument #2 as one of the "
125  "operands (modulo unary casts)";
126  return false;
127  }
128 
129  Value contributed = getSourceSkipUnary(
130  isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
131  Operation *elementwiseOp = contributed.getDefiningOp();
132  if (elementwiseOp->getNumResults() != 1 ||
133  elementwiseOp->getNumOperands() != 2) {
134  errs << "expected elementwise op to be binary";
135  return false;
136  }
137 
138  if (!isaPair(elementwiseOp, reductionOp)) {
139  errs << "expected reduction/elementwise op kind not satisfied";
140  return false;
141  }
142 
143  Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
144  Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
145  if ((elementwiseLHS == block.getArgument(0) &&
146  elementwiseRHS == block.getArgument(1)) ||
147  (elementwiseLHS == block.getArgument(1) &&
148  elementwiseRHS == block.getArgument(0))) {
149  return true;
150  }
151 
152  errs << "expected elementwise op to apply to block arguments (modulo unary "
153  "casts)";
154  return false;
155 }
156 
157 /// Returns true if the two operations are of the kinds specified by a pair of
158 /// consecutive template arguments.
159 template <typename AddOpTy, typename MulOpTy, typename... Args>
160 static bool isPairTemplateImpl(Operation *add, Operation *mul) {
161  static_assert(sizeof...(Args) % 2 == 0,
162  "expected an even number of template arguments");
163  if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
164  return true;
165 
166  if constexpr (sizeof...(Args) > 0)
167  return isPairTemplateImpl<Args...>(add, mul);
168  else
169  return false;
170 }
171 
172 /// Returns true if the block is a body of a contraction with the kinds of
173 /// operations given pairwise by template arguments.
174 template <typename... Args>
175 static bool isContractionBody(Block &block) {
176  return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
177 }
178 
179 /// Given an `indexingMap` and its corresponding `iterators`, returns
180 /// the positions of the iterators of type `iter` that are indexed by
181 /// the `indexingMap` as a permutation. This is useful to infer various
182 /// subcomputations on a `LinalgOp`. This is performed by looking up
183 /// each result in the `indexingMap` and determining whether:
184 /// - It is a single AffineDimExpr.
185 /// - It is the only result involving this AffineDimExpr.
186 static llvm::SmallDenseSet<int64_t>
189  utils::IteratorType iter) {
190  assert(iterators.size() == indexingMap.getNumDims());
191  llvm::SmallDenseSet<int64_t> res;
192  for (AffineExpr e : indexingMap.getResults()) {
193  if (auto d = dyn_cast<AffineDimExpr>(e)) {
194  if (iterators[d.getPosition()] == iter &&
195  llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
196  return e.isFunctionOfDim(d.getPosition());
197  }) == 1)
198  res.insert(d.getPosition());
199  }
200  }
201  return res;
202 }
203 
204 namespace {
205 auto par = utils::IteratorType::parallel;
206 auto red = utils::IteratorType::reduction;
207 } // namespace
208 
209 /// Infer the iterator types from the init affine map. This looks at which dims
210 /// are present in the map results, and returns an iterator types array with
211 /// parallel types for dims that are present, and reduction types for dims that
212 /// are not present.
215  if (!map.isProjectedPermutation())
216  return failure();
217  SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
218  for (auto expr : map.getResults())
219  if (auto dim = dyn_cast<AffineDimExpr>(expr))
220  iterators[dim.getPosition()] = par;
221  return iterators;
222 }
223 
224 /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
225 /// a matmul subcomputation within `linalgOp`. These dimensions are such that:
226 /// 1. The m dimension is involved in an outer-product along LHS
227 /// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
228 /// 2. The n dimension is involved in an outer-product along RHS
229 /// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
230 /// 3. The k dimension appears as a permutation on LHS and RHS.
231 /// 4. m, n and k appear only once in any given indexing.
232 /// 5. Optional batch dimensions that appear in all operands are captured.
233 /// This allows e.g. detecting that some contraction is embedded within
234 /// `linalgOp` with some orthogonal heuristic.
237  ArrayRef<utils::IteratorType> iterators) {
238  llvm::SmallDenseSet<int64_t> a =
239  findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
240  llvm::SmallDenseSet<int64_t> b =
241  findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
242  llvm::SmallDenseSet<int64_t> c =
243  findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
244 
245  // A & C - B are the iterators involved in an outer-product along A (the LHS).
246  llvm::SmallDenseSet<int64_t> ac = a;
247  llvm::set_intersect(ac, c);
248  llvm::set_subtract(ac, b);
249  // B & C - A are the iterators involved in an outer-product along B (the RHS).
250  llvm::SmallDenseSet<int64_t> bc = b;
251  llvm::set_intersect(bc, c);
252  llvm::set_subtract(bc, a);
253  // A & B & C are the "batch" dimensions.
254  llvm::SmallDenseSet<int64_t> batches = a;
255  llvm::set_intersect(batches, b);
256  llvm::set_intersect(batches, c);
257 
258  // A & B red are the reduction dimensions.
259  llvm::SmallDenseSet<int64_t> ra =
260  findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
261  llvm::SmallDenseSet<int64_t> rb =
262  findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
263  llvm::set_intersect(ra, rb);
264 
265  // Return each set in sorted order.
266  ContractionDimensions dimensions{
267  SmallVector<unsigned, 2>(batches.begin(), batches.end()),
268  SmallVector<unsigned, 2>(ac.begin(), ac.end()),
269  SmallVector<unsigned, 2>(bc.begin(), bc.end()),
270  SmallVector<unsigned, 2>(ra.begin(), ra.end())};
271  llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
272  llvm::sort(dimensions.m.begin(), dimensions.m.end());
273  llvm::sort(dimensions.n.begin(), dimensions.n.end());
274  llvm::sort(dimensions.k.begin(), dimensions.k.end());
275  return dimensions;
276 }
277 
280  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
281  return failure();
282  return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
283  linalgOp.getIteratorTypesArray());
284 }
285 
288  if (indexingMaps.size() != 3)
289  return failure();
290  auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
291  if (failed(iterators))
292  return failure();
293  return inferContractionDimsImpl(indexingMaps, iterators.value());
294 }
295 
296 namespace mlir::linalg::detail {
298  Success = 0,
299  NotLinalgOp,
301  NoReduction,
303  NotAddMul
304 };
305 } // namespace mlir::linalg::detail
306 
310  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
311  if (!linalgOp)
312  return MatchContractionResult::NotLinalgOp;
313  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
314  return MatchContractionResult::WrongNumOperands;
315  auto mapRange = linalgOp.getIndexingMapsArray();
316  if (linalgOp.getNumReductionLoops() == 0)
317  return MatchContractionResult::NoReduction;
318  if (llvm::any_of(mapRange,
319  [](AffineMap m) { return !m.isProjectedPermutation(); }))
320  return MatchContractionResult::NotProjectedPermutations;
321  // TODO: more fields than add/mul.
322  // clang-format off
323  if (!::isContractionBody<
324  arith::MulFOp, arith::AddFOp,
325  arith::MulIOp, arith::AddIOp,
326  complex::MulOp, complex::AddOp,
327  arith::AndIOp, arith::OrIOp>(
328  *linalgOp.getBlock())) {
329  return MatchContractionResult::NotAddMul;
330  }
331  // clang-format on
332 
333  if (dimensions) {
335  assert(succeeded(res) && "unexpected failure to infer contraction dims");
336  *dimensions = *res;
337  }
338  return MatchContractionResult::Success;
339 }
340 
341 StringRef
343  switch (res) {
344  case MatchContractionResult::NotLinalgOp:
345  return "expected a LinalgOp";
346  case MatchContractionResult::WrongNumOperands:
347  return "expected op with 2 inputs and 1 output";
348  case MatchContractionResult::NoReduction:
349  return "expected at least 1 reduction";
350  case MatchContractionResult::NotProjectedPermutations:
351  return "expected indexing maps to be projected permutations";
352  case MatchContractionResult::NotAddMul:
353  return "expected add/mul op in the body";
354  case MatchContractionResult::Success:
355  return "";
356  }
357  llvm_unreachable("unhandled MatchContractionResult case");
358 }
359 
361  if (!linalgOp)
362  return false;
363  Operation *op = linalgOp.getOperation();
364  return isa<ContractionOpInterface>(op) ||
367 }
368 
369 /// Verify that a LinalgOp `op` is a contraction.
370 /// A Linalg contraction is defined in general terms:
371 /// 1. Has 2 input and 1 output shapes.
372 /// 2. Has at least one reduction dimension.
373 /// 3. Has only projected permutation indexing maps.
374 /// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
375 /// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
376 /// operations that may change the type (e.g. for mixed-precision).
377 /// As a consequence, when vectorization of such an op occurs, the only special
378 /// behavior is that the (unique) MulOpType is vectorized into a
379 /// `vector.contract`. All other ops are handled in a generic fashion.
380 /// In the future, we may wish to allow more input arguments and elementwise and
381 /// constant operations that do not involve the reduction dimension(s).
383  auto res = isContractionInterfaceImpl(op);
384  if (res != MatchContractionResult::Success)
385  return op->emitError(getMatchContractionMessage(res));
386  return success();
387 }
388 
389 //===----------------------------------------------------------------------===//
390 // ConvolutionOpInterface implementation
391 //===----------------------------------------------------------------------===//
392 
393 /// Of the given two expressions returns one that is of type T (`lhs` gets
394 /// preference over `rhs`)
395 template <typename T>
397  return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
398 }
399 
400 namespace {
401 /// Walk the indexing expressions for input of a convolution operation to verify
402 /// its of the right form, either
403 /// - AffineDimExpr
404 /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
405 /// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
406 ///
407 /// classifies the AffineDimExpr as convolved dimensions or unconvolved
408 /// dimensions and verifies each dimension occurs only once.
409 struct ConvAccessExprWalker
410  : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
411  // Stores dimensions used in expressions of the above form.
412  llvm::SmallDenseSet<int64_t> convolvedDims;
413  // Stores the dual mapping between LHS and RHS of convolution exprs.
414  llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
415  // Stores single use dimensions used by an AffineDimExpr.
416  llvm::SmallDenseSet<int64_t> unConvolvedDims;
417  // Stores a mapping from convolved dims to their coefficient.
418  llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
419 
420  // Removes dims with multiple uses in the source input map from dimension
421  // sets tracked by this walker.
422  void clearMultiUseDims(AffineMap map) {
423  for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
424  if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) {
425  return e.isFunctionOfDim(dimPos);
426  }) > 1) {
427  convolvedDims.erase(dimPos);
428  unConvolvedDims.erase(dimPos);
429  // If a duplicate dim is marked as convolved, the pair of the duplicate
430  // dim must be removed from the map as well.
431  if (convolvedDimMapping.contains(dimPos)) {
432  int64_t pairedDim = convolvedDimMapping[dimPos];
433  convolvedDims.erase(pairedDim);
434  unConvolvedDims.erase(pairedDim);
435  strideAndDilationMapping.erase(pairedDim);
436  convolvedDimMapping.erase(dimPos);
437  convolvedDimMapping.erase(pairedDim);
438  }
439  }
440  }
441  }
442 
443  LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
444  unsigned position = dimExpr.getPosition();
445  if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
446  return failure();
447  }
448  unConvolvedDims.insert(position);
449  return success();
450  }
451 
452  LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
453 
454  LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
455 
456  LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
457  // In pre-order visit, top level op has to be an add op.
458  if (binaryExpr.getKind() != AffineExprKind::Add)
459  return failure();
460  auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());
461  auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());
462  if (failed(lhsDimPos) || failed(rhsDimPos))
463  return failure();
464  convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
465  convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
466  return success();
467  }
468 
469  FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
470  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
471  int64_t dim = dimExpr.getPosition();
472  if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
473  return failure();
474  // Stride/dilation for this dim is implicitly 1.
475  strideAndDilationMapping[dim] =
477  convolvedDims.insert(dim);
478  return dim;
479  }
480  if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
481  if (symbolMulExpr.getKind() != AffineExprKind::Mul)
482  return failure();
483  auto lhsExpr = symbolMulExpr.getLHS();
484  auto rhsExpr = symbolMulExpr.getRHS();
485  // Check for symbol expression.
486  AffineExpr mulExpr =
487  getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
488  // If there was no symbol expr, check for constant expression.
489  if (!mulExpr) {
490  mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
491  }
492  auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
493  if (!mulExpr || !dimExpr)
494  return failure();
495  int64_t dim = dimExpr.getPosition();
496  if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
497  return failure();
498  strideAndDilationMapping[dim] = mulExpr;
499  convolvedDims.insert(dim);
500  return dim;
501  }
502  return failure();
503  }
504 };
505 } // namespace
506 
507 static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
508  assert(map.isProjectedPermutation() &&
509  "expected map to have projected permutations");
510  llvm::SmallDenseSet<int64_t> preservedDims;
511  for (auto expr : map.getResults())
512  preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
513  return preservedDims;
514 }
515 
519  for (auto e : exprs) {
520  auto constantExpr = dyn_cast<AffineConstantExpr>(e);
521  assert(constantExpr && "Found non-constant stride/dilation");
522  vals.push_back(constantExpr.getValue());
523  }
524  return vals;
525 }
526 
527 /// Classifies dimensions in the `linalgOp` used by a convolution
528 /// subcomputation, as captured by `inputExprWalker`. If
529 /// `allowEmptyConvolvedDims` is not set this this will fail if there is not
530 /// at least convolved dimension pair (output image + filter loop). Convolution
531 /// dimensions are specified in sorted order, and strides match the order of
532 /// the filter loop dimensions, while the dilations match the order of the
533 /// output image dimensions.
535 inferConvolutionDimsImpl(LinalgOp linalgOp,
536  ConvAccessExprWalker &inputExprWalker,
537  bool allowEmptyConvolvedDims) {
538  auto filterMap =
539  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
540  auto outputMap =
541  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
542  llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
543  filterMap, linalgOp.getIteratorTypesArray(), par);
544  llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
545  outputMap, linalgOp.getIteratorTypesArray(), par);
546 
547  // unConvolvedDims & outputDims - filterDims are the batch iterators.
548  llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
549  llvm::set_intersect(batch, outputDims);
550  llvm::set_subtract(batch, filterDims);
551 
552  // convolvedDims & outputDims are the output image iterators.
553  llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
554  llvm::set_intersect(oi, outputDims);
555 
556  // filterDims & outputDims - unConvolvedDims are the output channel iterators.
557  llvm::SmallDenseSet<int64_t> oc = filterDims;
558  llvm::set_intersect(oc, outputDims);
559  llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
560 
561  // filterDims & outputDims & unConvolvedDims are the depth iterators.
562  llvm::SmallDenseSet<int64_t> depth = filterDims;
563  llvm::set_intersect(depth, outputDims);
564  llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
565 
566  llvm::SmallDenseSet<int64_t> filterReducedDims =
568  linalgOp.getIteratorTypesArray(), red);
569 
570  // convolvedDims & filterReducedDims are the filter loop iterators.
571  llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
572  llvm::set_intersect(fl, filterReducedDims);
573 
574  // unConvolvedDims & filterReducedDims are the input channel iterators.
575  llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
576  llvm::set_intersect(ic, filterReducedDims);
577 
578  if (oi.empty() && !allowEmptyConvolvedDims)
579  return failure();
580 
581  // Return each set in sorted order.
582  ConvolutionDimensions dimensions{
583  SmallVector<unsigned, 2>(batch.begin(), batch.end()),
584  SmallVector<unsigned, 2>(oi.begin(), oi.end()),
585  SmallVector<unsigned, 2>(oc.begin(), oc.end()),
586  SmallVector<unsigned, 2>(fl.begin(), fl.end()),
587  SmallVector<unsigned, 2>(ic.begin(), ic.end()),
588  SmallVector<unsigned, 2>(depth.begin(), depth.end()),
589  /*strides=*/SmallVector<int64_t, 2>{},
590  /*dilations=*/SmallVector<int64_t, 2>{}};
591  llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
592  llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
593  llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
594  llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
595  llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
596  llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
597 
598  // Use the op carried strides/dilations attribute if present.
599  auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
600  if (!nativeStrides) {
601  SmallVector<AffineExpr, 2> strideExprs;
602  for (unsigned oiDim : dimensions.outputImage)
603  strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
604  dimensions.strides = getConstantsFromExprList(strideExprs);
605  } else {
606  dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
607  }
608  auto nativeDilations =
609  linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
610  if (!nativeDilations) {
611  SmallVector<AffineExpr, 2> dilationExprs;
612  for (unsigned flDim : dimensions.filterLoop)
613  dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
614  dimensions.dilations = getConstantsFromExprList(dilationExprs);
615  } else {
616  dimensions.dilations =
617  llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
618  }
619  return dimensions;
620 }
621 
622 /// Find at least 1 parallel (output_image) and reduction (filter_loop)
623 /// dimension candidates that form a convolution subcomputation within
624 /// `linalgOp`. The LHS is assumed to be the convolution input while the
625 /// RHS is assumed as the filter.
626 /// These dimensions are such that:
627 /// 1. Optional batch dimensions that appear in the input and filter.
628 /// 2. The output_image dimension is involved in a cross-correlation along LHS
629 /// (i.e. it is a permutation on RES and LHS and has an associated
630 /// filter_loop in RHS).
631 /// 3. Optional output_channel dimension is involved in an outer-product along
632 /// RHS (i.e. it is a permutation on RES and RHS and does not appear in
633 /// LHS).
634 /// 4. Optional input_channel dimension appears as a permutation on LHS and
635 /// RHS.
636 /// 5. The filter_loop dimension appears as a permutation on the RHS and
637 /// represents the shape of the kernel cross-correlated along a
638 /// corresponding output_image dim.
639 /// 6. The input_channel dimension appears as a permutation on LHS and RHS.
640 /// 7. All dimensions appear only once in any given indexing map.
641 /// This allows e.g. detecting that some convolution is embedded within
642 /// `linalgOp` with some orthogonal heuristic.
643 /// When multiple dimension occurrences exist that match any classification
644 /// indices are returned in sorted order.
645 /// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
648  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
649  return failure();
650 
651  auto indexingMaps = linalgOp.getIndexingMapsArray();
652 
653  // Check the input indexing map has the right form.
654  ConvAccessExprWalker inputExprWalker;
655  for (AffineExpr expr : indexingMaps[0].getResults())
656  (void)inputExprWalker.visit(expr);
657  inputExprWalker.clearMultiUseDims(indexingMaps[0]);
658 
659  return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
660  /*allowEmptyConvolvedDims=*/false);
661 }
662 
663 namespace mlir::linalg::detail {
665  Success = 0,
666  NotLinalgOp,
673 };
674 } // namespace mlir::linalg::detail
675 
678  Operation *op, ConvolutionDimensions *dimensions) {
679  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
680  if (!linalgOp)
681  return MatchConvolutionResult::NotLinalgOp;
682  if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
683  return MatchConvolutionResult::WrongNumOperands;
684 
685  auto indexingMaps = linalgOp.getIndexingMapsArray();
686 
687  // Check the input indexing map has the right form.
688  ConvAccessExprWalker inputExprWalker;
689  if (llvm::any_of(indexingMaps[0].getResults(),
690  [&inputExprWalker](AffineExpr expr) {
691  return failed(inputExprWalker.visit(expr));
692  })) {
693  return MatchConvolutionResult::WrongInputIndexingMap;
694  }
695 
696  // Filter and output maps must be projected permutation.
697  if (!indexingMaps[1].isProjectedPermutation() ||
698  !indexingMaps.back().isProjectedPermutation())
699  return MatchConvolutionResult::NotProjectedPermutations;
700 
701  auto iteratorTypes = linalgOp.getIteratorTypesArray();
702 
703  llvm::SmallDenseSet<int64_t> outputDims =
704  getPreservedDims(indexingMaps.back());
705  llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
706  // Make sure all loops are characterized as one of:
707  // - Batch loop : present in output, as non-convolved in input, not present in
708  // filter.
709  // - Output image dimension : present in output, convolved dims in input, not
710  // present in filter.
711  // - Output channel dimension : present in output, not present in input,
712  // present in filter.
713  // - Filter loop dimension : present in filter, convolved in input, not
714  // present in output.
715  // - Input channel dimension : unconvolved in input, not present in output,
716  // present in filter.
717  // - Depth multiplier : unconvolved in input, present in output, present in
718  // filter.
719  llvm::SmallDenseSet<int64_t> allLoopDims;
720  for (auto outputExpr : indexingMaps.back().getResults()) {
721  int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
722  if (inputExprWalker.unConvolvedDims.count(outputDim) &&
723  !filterDims.count(outputDim)) {
724  // Batch dimension.
725  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
726  return MatchConvolutionResult::OutputDimsNotParallel;
727  allLoopDims.insert(outputDim);
728  continue;
729  }
730  if (inputExprWalker.convolvedDims.count(outputDim) &&
731  !filterDims.count(outputDim)) {
732  // Output image Loop dimension.
733  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
734  return MatchConvolutionResult::OutputDimsNotParallel;
735  allLoopDims.insert(outputDim);
736  continue;
737  }
738  if (!inputExprWalker.convolvedDims.count(outputDim) &&
739  !inputExprWalker.unConvolvedDims.count(outputDim) &&
740  filterDims.count(outputDim)) {
741  // Output channel dimension.
742  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
743  return MatchConvolutionResult::OutputDimsNotParallel;
744  allLoopDims.insert(outputDim);
745  continue;
746  }
747  if (inputExprWalker.unConvolvedDims.count(outputDim) &&
748  filterDims.count(outputDim)) {
749  // Depth multiplier.
750  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
751  return MatchConvolutionResult::OutputDimsNotParallel;
752  allLoopDims.insert(outputDim);
753  continue;
754  }
755  return MatchConvolutionResult::NonConvolutionLoop;
756  }
757  for (auto filterExpr : indexingMaps[1].getResults()) {
758  int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
759  if (outputDims.count(filterDim) &&
760  !inputExprWalker.unConvolvedDims.count(filterDim) &&
761  !inputExprWalker.convolvedDims.count(filterDim)) {
762  // Output channel dimension. This is already seen, continue;
763  continue;
764  }
765  if (inputExprWalker.convolvedDims.count(filterDim) &&
766  !outputDims.count(filterDim)) {
767  // Filter loop dimension.
768  if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
769  return MatchConvolutionResult::NonOutputDimNotReduction;
770  if (allLoopDims.count(filterDim))
771  return MatchConvolutionResult::NonConvolutionLoop;
772  allLoopDims.insert(filterDim);
773  continue;
774  }
775  if (inputExprWalker.unConvolvedDims.count(filterDim) &&
776  !outputDims.count(filterDim)) {
777  // Input channel dimension.
778  if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
779  return MatchConvolutionResult::NonOutputDimNotReduction;
780  if (allLoopDims.count(filterDim))
781  return MatchConvolutionResult::NonConvolutionLoop;
782  allLoopDims.insert(filterDim);
783  continue;
784  }
785  if (inputExprWalker.unConvolvedDims.count(filterDim) &&
786  outputDims.count(filterDim)) {
787  // Depthwise loop. Already seen.
788  continue;
789  }
790  return MatchConvolutionResult::NonConvolutionLoop;
791  }
792  // All loops must be covered now.
793  if (allLoopDims.size() != linalgOp.getNumLoops())
794  return MatchConvolutionResult::NonConvolutionLoop;
795 
796  if (dimensions) {
798  inferConvolutionDimsImpl(linalgOp, inputExprWalker,
799  /*allowEmptyConvolvedDims=*/true);
800  assert(succeeded(res) && "unexpected failure to infer convolution dims");
801  *dimensions = *res;
802  }
803 
804  return MatchConvolutionResult::Success;
805 }
806 
807 StringRef
809  switch (res) {
810  case MatchConvolutionResult::NotLinalgOp:
811  return "expected a LinalgOp";
812  case MatchConvolutionResult::WrongNumOperands:
813  return "expected op with 2 inputs and 1 output";
814  case MatchConvolutionResult::WrongInputIndexingMap:
815  return "unexpected input index map for convolutions";
816  case MatchConvolutionResult::NotProjectedPermutations:
817  return "expected output/filter indexing maps to be projected permutations";
818  case MatchConvolutionResult::NonConvolutionLoop:
819  return "unexpected loop dimension for convolution op";
820  case MatchConvolutionResult::OutputDimsNotParallel:
821  return "expected all iterators used to access outputs to be parallel";
822  case MatchConvolutionResult::NonOutputDimNotReduction:
823  return "expected all iterators not used to access outputs to be reduction";
824  case MatchConvolutionResult::Success:
825  return "";
826  }
827  llvm_unreachable("unhandled MatchConvolutionResult case");
828 }
829 
831  return linalg::detail::isConvolutionInterfaceImpl(linalgOp.getOperation()) ==
833 }
834 
837  if (res != MatchConvolutionResult::Success)
838  return op->emitError(getMatchConvolutionMessage(res));
839  return success();
840 }
841 
842 //===----------------------------------------------------------------------===//
843 // FillOpInterface implementation
844 //===----------------------------------------------------------------------===//
845 
846 enum class MatchFillResult {
847  Success = 0,
848  NotLinalgOp,
849  WrongNumOperands,
851 };
852 
854  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
855  if (!linalgOp)
857  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
859 
860  OpOperand *value = linalgOp.getDpsInputOperand(0);
861  if (!linalgOp.isScalar(value))
863 
865 }
866 
868  auto res = isFillInterfaceImpl(op);
869  if (res == MatchFillResult::NotLinalgOp)
870  return op->emitError("expected a LinalgOp");
872  return op->emitError("expected op with 1 input and 1 output");
874  return op->emitError("expected op with scalar input");
875 
876  return success();
877 }
878 
879 //===----------------------------------------------------------------------===//
880 // StructuredOpInterface implementation
881 //===----------------------------------------------------------------------===//
882 
883 SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
884  Location loc) {
886  for (OpOperand &opOperand : getOperation()->getOpOperands()) {
887  for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
888  res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
889  }
890  return res;
891 }
892 
893 SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
895  assert(!hasDynamicShape() && "expected operands to have static shapes");
896  for (OpOperand &opOperand : getOperation()->getOpOperands())
897  llvm::append_range(res, getShape(&opOperand));
898  return res;
899 }
900 
901 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
902  AffineMap map = getLoopsToShapesMap();
903  unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
904  auto viewSizes = createFlatListOfOperandDims(b, loc);
905  SmallVector<Range, 4> res(numDims);
906  for (unsigned idx = 0; idx < numRes; ++idx) {
907  auto result = map.getResult(idx);
908  if (auto d = dyn_cast<AffineDimExpr>(result)) {
909  if (res[d.getPosition()].offset)
910  continue;
911  res[d.getPosition()] =
912  Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
913  }
914  }
915  return res;
916 }
917 
918 SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
919  AffineMap map = getLoopsToShapesMap();
920  unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
921  SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims();
922  SmallVector<int64_t, 4> res(numDims, 0);
923  for (unsigned idx = 0; idx < numRes; ++idx) {
924  auto result = map.getResult(idx);
925  if (auto d = dyn_cast<AffineDimExpr>(result))
926  res[d.getPosition()] = allShapeSizes[idx];
927  }
928  return res;
929 }
930 
931 /// Visitor to check if any of the given set of positions from AffineDimExprs
932 /// are used within an AffineExpr.
934  : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
935  HasAffineDimExprVisitor(llvm::SmallBitVector positions)
936  : positions(std::move(positions)) {}
937 
939  return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
940  }
941 
942  bool visitDimExpr(AffineDimExpr dimExpr) {
943  return positions.test(dimExpr.getPosition());
944  }
945 
946  bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
947 
948  bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
949 
950 private:
951  llvm::SmallBitVector positions;
952 };
953 
954 static std::pair<int64_t, int64_t>
956  int64_t inputRankSum = 0;
957  int64_t outputRankSum = 0;
958  for (OpOperand *input : op.getDpsInputOperands())
959  inputRankSum += op.getRank(input);
960  for (OpOperand &output : op.getDpsInitsMutable())
961  outputRankSum += op.getRank(&output);
962  return {inputRankSum, inputRankSum + outputRankSum};
963 }
964 
967  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
968  // An example that helps understand the logic below.
969  // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
970  // We want to express the shape of dim 0 of O in terms of shape of the inputs.
971  // This is achieved as follows.
972  // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
973  // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
974  // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
975  // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
976  // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
977  AffineMap loopsToShapesMap = getLoopsToShapesMap();
978 
979  // Find the position in the above map that represents the shape of the
980  // result:dim being inferred.
981  auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
982 
983  /// From loopsToShapesMap extract the submap that represents the shape of the
984  /// (resultIdx, dim) needed.
985  AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
986  resultShapesSubMapPos.first,
987  resultShapesSubMapPos.second - resultShapesSubMapPos.first);
988  AffineMap resultShapesFromInputShapesMap =
989  loopToResultsShapeMap.compose(getShapesToLoopsMap());
990 
991  // Check that the result dim map does not contain the positions corresponding
992  // to the outputs.
993  llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
994  outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
995  HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
996  Location loc = getOperation()->getLoc();
997  IRRewriter rewriter(b);
998  SmallVector<OpFoldResult> allResultDimValues =
1000  rewriter, loc, resultShapesFromInputShapesMap,
1001  createFlatListOfOperandDims(b, loc));
1002  int64_t pos = 0;
1003  ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
1004  for (OpOperand &opOperand : getDpsInitsMutable()) {
1006  for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1007  auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1008  if (!shapedType.isDynamicDim(dim)) {
1009  // Static dim: Return IntegerAttr.
1010  shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
1011  } else {
1012  // Dynamic dim: Return Value.
1013  OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1014  ? createOrFoldDimOp(b, loc, opOperand.get(), dim)
1015  : allResultDimValues[pos];
1016  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
1017  }
1018  pos++;
1019  }
1020  reifiedReturnShapes.emplace_back(std::move(shapes));
1021  }
1022  return success();
1023 }
1024 
1025 /// Return the index in the indexingMaps vector that corresponds to this
1026 /// `opOperand`.
1027 int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
1028  auto operandNumber = opOperand->getOperandNumber();
1029  auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1030  if (!dpsIface.isDpsInput(opOperand))
1031  return operandNumber;
1032  unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1033  assert(!dpsIface.isDpsInit(opOperand));
1034  // Account for potential inputs that are not DPS and may not appear in
1035  // `indexingMaps`.
1036  return cast<DestinationStyleOpInterface>(*this->getOperation())
1037  .getNumDpsInputs() +
1038  operandNumber - start;
1039 }
1040 
1042  LinalgOp linalgOp = cast<LinalgOp>(op);
1043 
1044  // Mixed tensor/buffer operands are not allowed.
1045  if (!linalgOp.hasPureTensorSemantics() &&
1046  !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1047  return op->emitOpError("expected to have pure tensor or buffer semantics");
1048 
1049  // Before checking indexing maps, we need to make sure the attributes
1050  // referenced by it are valid.
1051  if (linalgOp.hasDynamicIndexingMaps())
1052  if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1053  return failure();
1054 
1055  // All input/output operands must be indexed.
1056  if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
1057  linalgOp->getNumOperands())
1058  return op->emitOpError("expected the number of indexing_map (")
1059  << linalgOp.getIndexingMapsArray().size()
1060  << ") to be equal to the number of input/output operands ("
1061  << linalgOp->getNumOperands() << ")";
1062 
1063  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1064  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1065 
1066  // Symbols disallowed.
1067  if (indexingMap.getNumSymbols() != 0)
1068  return op->emitOpError("unexpected symbols in indexing_map #")
1069  << opOperand.getOperandNumber();
1070 
1071  // Domain must be consistent.
1072  unsigned numLoops = linalgOp.getNumLoops();
1073  if (indexingMap.getNumDims() != numLoops)
1074  return op->emitOpError("expected indexing_map #")
1075  << opOperand.getOperandNumber() << " to have " << numLoops
1076  << " dim(s) to match the number of loops";
1077 
1078  int64_t rank = linalgOp.getRank(&opOperand);
1079  if (indexingMap.getNumResults() != rank)
1080  return op->emitOpError("expected operand rank (")
1081  << rank << ") to match the result rank of indexing_map #"
1082  << opOperand.getOperandNumber() << " ("
1083  << indexingMap.getNumResults() << ")";
1084  }
1085 
1086  SmallVector<unsigned> redDims;
1087  linalgOp.getReductionDims(redDims);
1088 
1089  if (!linalgOp.getShapesToLoopsMap())
1090  return op->emitOpError("expected the shape-to-loops map to be non-null");
1091 
1092  // Check if given shapes match to inferred shapes.
1093  SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
1094  SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
1095 
1096  // Verify only static cases since we can't get exact dimension sizes and loop
1097  // ranges for dynamic cases in this stage.
1098  if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
1099  for (int64_t &range : endLoopRangeValues)
1100  range -= 1;
1101  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1102  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1103  SmallVector<int64_t, 4> startIndices =
1104  indexingMap.compose(startLoopRangeValues);
1105  SmallVector<int64_t, 4> endIndices =
1106  indexingMap.compose(endLoopRangeValues);
1107  ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand);
1108  for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
1109  // Ignore dynamic dimension or the case that the dimension size is 0
1110  if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
1111  continue;
1112 
1113  // The first index or last index should be the maximum or the minimum in
1114  // the inferred index ranges since the range is increasing or
1115  // decreasing. The size of dimensions of input/output operands and the
1116  // maximum value + 1 in the inferred range should be the same. But, for
1117  // now we check if the inferred ranges are in boundary of input/output
1118  // operands' size or not in case that Affine Expressions are complicated
1119  // such as d0 * 3
1120  // + d1 since it is not easy to handle the issues.
1121  // Found the case that this solution can't check, for example, (d0, d1)
1122  // -> (d1 - d0)
1123  int64_t inferredDimSize =
1124  std::max(startIndices[dim], endIndices[dim]) + 1;
1125  if (std::min(startIndices[dim], endIndices[dim]) < 0) {
1126  std::string mapStr;
1127  {
1128  llvm::raw_string_ostream os(mapStr);
1129  os << indexingMap;
1130  }
1131  return op->emitOpError(
1132  "unexpected result less than 0 at expression #")
1133  << dim << " in " << mapStr;
1134  }
1135  if (dyn_cast<AffineDimExpr>(indexingMap.getResult(dim))) {
1136  if (inferredDimSize != shape[dim]) {
1137  return op->emitOpError("inferred input/output operand #")
1138  << opOperand.getOperandNumber() << " has shape's dimension #"
1139  << dim << " to be " << inferredDimSize << ", but found "
1140  << shape[dim];
1141  }
1142  } else {
1143  if (inferredDimSize > shape[dim]) {
1144  return op->emitOpError("inferred input/output operand #")
1145  << opOperand.getOperandNumber() << " has shape's dimension #"
1146  << dim << " to be greater than or equal to "
1147  << inferredDimSize << ", but found " << shape[dim];
1148  }
1149  }
1150  }
1151  }
1152  }
1153 
1154  // Check the region has exactly one block.
1155  if (linalgOp->getNumRegions() != 1 ||
1156  !llvm::hasSingleElement(linalgOp->getRegion(0)))
1157  return op->emitOpError("expects to have 1 region with 1 block");
1158 
1159  // Simplifying assumption: bbargs match 1-1 with shape operands elemental
1160  // types.
1161  // TODO: once ranked shape types are plugged in, we may want to drop the
1162  // corresponding bbargs, that can never be read from. This will be subject to
1163  // consistency discussions (i.e. what to do with output tensors whose bbarg is
1164  // not used).
1165  Block &block = linalgOp->getRegion(0).front();
1166 
1167  if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
1168  return op->emitOpError("expected as many non-induction variable region "
1169  "arguments as the number of input/output operands");
1170 
1171  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1172  Type elementType = opOperand->get().getType();
1173  if (isa<MemRefType, RankedTensorType>(elementType))
1174  elementType = getElementTypeOrSelf(opOperand->get().getType());
1175  Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
1176  if (elementType != argType)
1177  return op->emitOpError("expected type of bb argument #")
1178  << opOperand->getOperandNumber() << " (" << argType << ")"
1179  << " to match element or self type of the corresponding operand ("
1180  << elementType << ")";
1181  }
1182 
1183  return success();
1184 }
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:63
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static MatchFillResult isFillInterfaceImpl(Operation *op)
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
MatchFillResult
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Affine binary operation expression.
Definition: AffineExpr.h:228
AffineExpr getLHS() const
Definition: AffineExpr.cpp:332
AffineExpr getRHS() const
Definition: AffineExpr.cpp:335
An integer constant appearing in affine expression.
Definition: AffineExpr.h:253
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:237
unsigned getPosition() const
Definition: AffineExpr.cpp:340
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition: AffineExpr.h:69
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:27
MLIRContext * getContext() const
Definition: AffineExpr.cpp:25
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:623
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:579
unsigned getNumSymbols() const
Definition: AffineMap.cpp:382
unsigned getNumDims() const
Definition: AffineMap.cpp:378
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:391
unsigned getNumResults() const
Definition: AffineMap.cpp:386
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:395
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:540
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:245
Block represents an ordered list of Operations.
Definition: Block.h:30
bool empty()
Definition: Block.h:145
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
Operation & back()
Definition: Block.h:149
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
Operation & front()
Definition: Block.h:150
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
An attribute that represents a reference to a dense integer vector or tensor object.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:263
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition: Operation.h:753
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1235
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:98
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaConvolutionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ConvolutionOpInterface.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:89
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:753
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:798
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:623
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.