MLIR  20.0.0git
LinalgInterfaces.cpp
Go to the documentation of this file.
1 //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "llvm/ADT/SetOperations.h"
22 #include "llvm/ADT/SmallBitVector.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include <algorithm>
25 
26 using namespace mlir;
27 using namespace mlir::linalg;
28 
29 /// Include the definitions of the copy operation interface.
30 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
31 
32 //===----------------------------------------------------------------------===//
33 // Interface utility functions
34 //===----------------------------------------------------------------------===//
35 
37  linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
38  SmallVector<AffineMap> indexingMaps;
39  for (auto &opOperand : linalgOp->getOpOperands()) {
40  if (llvm::is_contained(droppedOperands, &opOperand))
41  continue;
42  indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
43  }
44  if (indexingMaps.empty()) {
45  // If there are no indexing maps, the operand can only be dropped
46  // if the op has no loops.
47  return linalgOp.getNumLoops() == 0;
48  }
49  return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
50 }
51 
52 //===----------------------------------------------------------------------===//
53 // CopyOpInterface implementation
54 //===----------------------------------------------------------------------===//
55 
56 bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
57  // Structural.
58  if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
59  return false;
60 
61  // Operands and maps.
62  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
63  return false;
64  auto mapRange = linalgOp.getIndexingMapsArray();
65  if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
66  !mapRange.back().isIdentity()) {
67  return false;
68  }
69  // Region.
70  return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // FillOpInterface implementation
75 //===----------------------------------------------------------------------===//
76 std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
77  // Structural.
78  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
79  genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
80  return std::nullopt;
81 
82  // Input should be referenced and init should not.
83  if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
84  genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
85  return std::nullopt;
86 
87  OpOperand *value = genericOp.getDpsInputOperand(0);
88  if (!genericOp.isScalar(value))
89  return std::nullopt;
90 
91  Block *body = genericOp.getBody();
92  if (body->getOperations().size() != 1)
93  return std::nullopt;
94 
95  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
96  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
97  yieldOp->getOperand(0) != body->getArgument(0))
98  return std::nullopt;
99  return value->get();
100 }
101 
102 //===----------------------------------------------------------------------===//
103 // Elementwise Single Unary/Binary-OpInterface implementation
104 //===----------------------------------------------------------------------===//
105 static bool
106 isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
107  unsigned arity) {
108  // Check all loops are parallel.
109  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
110  genericOp.getNumLoops() < 1)
111  return false;
112 
113  // Check there are arity-inputs, 1-output and all are identity-maps.
114  if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
115  !llvm::all_of(genericOp.getIndexingMapsArray(),
116  [](AffineMap map) { return map.isIdentity(); }))
117  return false;
118 
119  // Init should not be referenced for elementwise operations.
120  if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
121  return false;
122 
123  // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
124  // as resulting from producer-consumer fusion. Here, we restrict to two ops in
125  // the body, where the first is the elementwise single op and the second a
126  // yield.
127  Block *body = genericOp.getBody();
128  if (body->getOperations().size() != 2)
129  return false;
130 
131  Operation *op = &body->front();
132  if (op->getNumOperands() != arity || op->getNumResults() != 1)
133  return false;
134 
135  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
136  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
137  yieldOp->getOperand(0).getDefiningOp() != op)
138  return false;
139  return true;
140 }
141 
142 bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp genericOp) {
143  // All basic elemwise checks.
145  return false;
146 
147  // Check input is actully used.
148  if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)))
149  return false;
150  return true;
151 }
152 
153 bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp genericOp) {
155  return false;
156 
157  // Check both inputs are used (elementwise).
158  OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0);
159  OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1);
160  if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
161  !genericOp.payloadUsesValueFromOperand(inputOpOperand1))
162  return false;
163  return true;
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // ContractionOpInterface implementation
168 //===----------------------------------------------------------------------===//
169 
170 /// If the value is defined by a chain of unary side effect-free, go up the
171 /// use-def chain until the first value that isn't defined by such an op.
172 // TODO: relax to multi-operands with constants, which are technically unary ops
173 // as needed (e.g. add5).
175  Operation *op = value.getDefiningOp();
176  while (op && op->getNumOperands() == 1) {
177  auto iface = dyn_cast<MemoryEffectOpInterface>(op);
178  if (!iface || !iface.hasNoEffect())
179  break;
180  value = op->getOperand(0);
181  op = value.getDefiningOp();
182  }
183  return value;
184 }
185 
187  Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
188  llvm::raw_ostream &errs) {
189  if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
190  errs << "no terminator in the block";
191  return false;
192  }
193 
194  if (block.getNumArguments() != 3) {
195  errs << "expected block with 3 arguments";
196  return false;
197  }
198 
199  Operation *terminator = block.getTerminator();
200  if (terminator->getNumOperands() != 1) {
201  errs << "expected terminator with 1 operand";
202  return false;
203  }
204 
205  Value yielded = getSourceSkipUnary(terminator->getOperand(0));
206  Operation *reductionOp = yielded.getDefiningOp();
207  if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
208  errs << "expected reduction op to be binary";
209  return false;
210  }
211 
212  Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
213  Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
214 
215  if (reductionLHS != block.getArgument(2) &&
216  reductionRHS != block.getArgument(2)) {
217  errs << "expected reduction to take block argument #2 as one of the "
218  "operands (modulo unary casts)";
219  return false;
220  }
221 
222  Value contributed = getSourceSkipUnary(
223  isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
224  Operation *elementwiseOp = contributed.getDefiningOp();
225  if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||
226  elementwiseOp->getNumOperands() != 2) {
227  errs << "expected elementwise op to be binary";
228  return false;
229  }
230 
231  if (!isaPair(elementwiseOp, reductionOp)) {
232  errs << "expected reduction/elementwise op kind not satisfied";
233  return false;
234  }
235 
236  Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
237  Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
238  if ((elementwiseLHS == block.getArgument(0) &&
239  elementwiseRHS == block.getArgument(1)) ||
240  (elementwiseLHS == block.getArgument(1) &&
241  elementwiseRHS == block.getArgument(0))) {
242  return true;
243  }
244 
245  errs << "expected elementwise op to apply to block arguments (modulo unary "
246  "casts)";
247  return false;
248 }
249 
250 /// Returns true if the two operations are of the kinds specified by a pair of
251 /// consecutive template arguments.
252 template <typename AddOpTy, typename MulOpTy, typename... Args>
253 static bool isPairTemplateImpl(Operation *add, Operation *mul) {
254  static_assert(sizeof...(Args) % 2 == 0,
255  "expected an even number of template arguments");
256  if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
257  return true;
258 
259  if constexpr (sizeof...(Args) > 0)
260  return isPairTemplateImpl<Args...>(add, mul);
261  else
262  return false;
263 }
264 
265 /// Returns true if the block is a body of a contraction with the kinds of
266 /// operations given pairwise by template arguments.
267 template <typename... Args>
268 static bool isContractionBody(Block &block) {
269  return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
270 }
271 
272 /// Given an `indexingMap` and its corresponding `iterators`, returns
273 /// the positions of the iterators of type `iter` that are indexed by
274 /// the `indexingMap` as a permutation. This is useful to infer various
275 /// subcomputations on a `LinalgOp`. This is performed by looking up
276 /// each result in the `indexingMap` and determining whether:
277 /// - It is a single AffineDimExpr.
278 /// - It is the only result involving this AffineDimExpr.
279 static llvm::SmallDenseSet<int64_t>
282  utils::IteratorType iter) {
283  assert(iterators.size() == indexingMap.getNumDims());
284  llvm::SmallDenseSet<int64_t> res;
285  for (AffineExpr e : indexingMap.getResults()) {
286  if (auto d = dyn_cast<AffineDimExpr>(e)) {
287  if (iterators[d.getPosition()] == iter &&
288  llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
289  return e.isFunctionOfDim(d.getPosition());
290  }) == 1)
291  res.insert(d.getPosition());
292  }
293  }
294  return res;
295 }
296 
297 namespace {
298 auto par = utils::IteratorType::parallel;
299 auto red = utils::IteratorType::reduction;
300 } // namespace
301 
302 /// Infer the iterator types from the init affine map. This looks at which dims
303 /// are present in the map results, and returns an iterator types array with
304 /// parallel types for dims that are present, and reduction types for dims that
305 /// are not present.
306 static FailureOr<SmallVector<utils::IteratorType>>
308  if (!map.isProjectedPermutation())
309  return failure();
310  SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
311  for (auto expr : map.getResults())
312  if (auto dim = dyn_cast<AffineDimExpr>(expr))
313  iterators[dim.getPosition()] = par;
314  return iterators;
315 }
316 
317 /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
318 /// a matmul subcomputation within `linalgOp`. These dimensions are such that:
319 /// 1. The m dimension is involved in an outer-product along LHS
320 /// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
321 /// 2. The n dimension is involved in an outer-product along RHS
322 /// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
323 /// 3. The k dimension appears as a permutation on LHS and RHS.
324 /// 4. m, n and k appear only once in any given indexing.
325 /// 5. Optional batch dimensions that appear in all operands are captured.
326 /// This allows e.g. detecting that some contraction is embedded within
327 /// `linalgOp` with some orthogonal heuristic.
328 static FailureOr<ContractionDimensions>
330  ArrayRef<utils::IteratorType> iterators) {
331  llvm::SmallDenseSet<int64_t> a =
332  findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
333  llvm::SmallDenseSet<int64_t> b =
334  findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
335  llvm::SmallDenseSet<int64_t> c =
336  findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
337 
338  // A & C - B are the iterators involved in an outer-product along A (the LHS).
339  llvm::SmallDenseSet<int64_t> ac = a;
340  llvm::set_intersect(ac, c);
341  llvm::set_subtract(ac, b);
342  // B & C - A are the iterators involved in an outer-product along B (the RHS).
343  llvm::SmallDenseSet<int64_t> bc = b;
344  llvm::set_intersect(bc, c);
345  llvm::set_subtract(bc, a);
346  // A & B & C are the "batch" dimensions.
347  llvm::SmallDenseSet<int64_t> batches = a;
348  llvm::set_intersect(batches, b);
349  llvm::set_intersect(batches, c);
350 
351  // A & B red are the reduction dimensions.
352  llvm::SmallDenseSet<int64_t> ra =
353  findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
354  llvm::SmallDenseSet<int64_t> rb =
355  findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
356  llvm::set_intersect(ra, rb);
357 
358  // Return each set in sorted order.
359  ContractionDimensions dimensions{
360  SmallVector<unsigned, 2>(batches.begin(), batches.end()),
361  SmallVector<unsigned, 2>(ac.begin(), ac.end()),
362  SmallVector<unsigned, 2>(bc.begin(), bc.end()),
363  SmallVector<unsigned, 2>(ra.begin(), ra.end())};
364  llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
365  llvm::sort(dimensions.m.begin(), dimensions.m.end());
366  llvm::sort(dimensions.n.begin(), dimensions.n.end());
367  llvm::sort(dimensions.k.begin(), dimensions.k.end());
368  return dimensions;
369 }
370 
371 FailureOr<ContractionDimensions>
373  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
374  return failure();
375  return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
376  linalgOp.getIteratorTypesArray());
377 }
378 
379 FailureOr<ContractionDimensions>
381  if (indexingMaps.size() != 3)
382  return failure();
383  auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
384  if (failed(iterators))
385  return failure();
386  return inferContractionDimsImpl(indexingMaps, iterators.value());
387 }
388 
389 namespace mlir::linalg::detail {
391  Success = 0,
392  NotLinalgOp,
394  NoReduction,
396  NotAddMul
397 };
398 } // namespace mlir::linalg::detail
399 
403  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
404  if (!linalgOp)
405  return MatchContractionResult::NotLinalgOp;
406  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
407  return MatchContractionResult::WrongNumOperands;
408  auto mapRange = linalgOp.getIndexingMapsArray();
409  if (linalgOp.getNumReductionLoops() == 0)
410  return MatchContractionResult::NoReduction;
411  if (llvm::any_of(mapRange,
412  [](AffineMap m) { return !m.isProjectedPermutation(); }))
413  return MatchContractionResult::NotProjectedPermutations;
414  // TODO: more fields than add/mul.
415  // clang-format off
416  if (!::isContractionBody<
417  arith::MulFOp, arith::AddFOp,
418  arith::MulIOp, arith::AddIOp,
419  complex::MulOp, complex::AddOp,
420  arith::AndIOp, arith::OrIOp>(
421  *linalgOp.getBlock())) {
422  return MatchContractionResult::NotAddMul;
423  }
424  // clang-format on
425 
426  if (dimensions) {
427  FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);
428  assert(succeeded(res) && "unexpected failure to infer contraction dims");
429  *dimensions = *res;
430  }
431  return MatchContractionResult::Success;
432 }
433 
434 StringRef
436  switch (res) {
437  case MatchContractionResult::NotLinalgOp:
438  return "expected a LinalgOp";
439  case MatchContractionResult::WrongNumOperands:
440  return "expected op with 2 inputs and 1 output";
441  case MatchContractionResult::NoReduction:
442  return "expected at least 1 reduction";
443  case MatchContractionResult::NotProjectedPermutations:
444  return "expected indexing maps to be projected permutations";
445  case MatchContractionResult::NotAddMul:
446  return "expected add/mul op in the body";
447  case MatchContractionResult::Success:
448  return "";
449  }
450  llvm_unreachable("unhandled MatchContractionResult case");
451 }
452 
454  if (!linalgOp)
455  return false;
456  Operation *op = linalgOp.getOperation();
457  return isa<ContractionOpInterface>(op) ||
460 }
461 
462 /// Verify that a LinalgOp `op` is a contraction.
463 /// A Linalg contraction is defined in general terms:
464 /// 1. Has 2 input and 1 output shapes.
465 /// 2. Has at least one reduction dimension.
466 /// 3. Has only projected permutation indexing maps.
467 /// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
468 /// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
469 /// operations that may change the type (e.g. for mixed-precision).
470 /// As a consequence, when vectorization of such an op occurs, the only special
471 /// behavior is that the (unique) MulOpType is vectorized into a
472 /// `vector.contract`. All other ops are handled in a generic fashion.
473 /// In the future, we may wish to allow more input arguments and elementwise and
474 /// constant operations that do not involve the reduction dimension(s).
476  auto res = isContractionInterfaceImpl(op);
477  if (res != MatchContractionResult::Success)
478  return op->emitError(getMatchContractionMessage(res));
479  return success();
480 }
481 
482 //===----------------------------------------------------------------------===//
483 // ConvolutionOpInterface implementation
484 //===----------------------------------------------------------------------===//
485 
486 /// Of the given two expressions returns one that is of type T (`lhs` gets
487 /// preference over `rhs`)
488 template <typename T>
490  return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
491 }
492 
493 namespace {
494 /// Walk the indexing expressions for input of a convolution operation to verify
495 /// its of the right form, either
496 /// - AffineDimExpr
497 /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
498 /// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
499 ///
500 /// classifies the AffineDimExpr as convolved dimensions or unconvolved
501 /// dimensions and verifies each dimension occurs only once.
502 struct ConvAccessExprWalker
503  : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
504  // Stores dimensions used in expressions of the above form.
505  llvm::SmallDenseSet<int64_t> convolvedDims;
506  // Stores the dual mapping between LHS and RHS of convolution exprs.
507  llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
508  // Stores single use dimensions used by an AffineDimExpr.
509  llvm::SmallDenseSet<int64_t> unConvolvedDims;
510  // Stores a mapping from convolved dims to their coefficient.
511  llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
512 
513  // Removes dims with multiple uses in the source input map from dimension
514  // sets tracked by this walker.
515  void clearMultiUseDims(AffineMap map) {
516  for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
517  if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) {
518  return e.isFunctionOfDim(dimPos);
519  }) > 1) {
520  convolvedDims.erase(dimPos);
521  unConvolvedDims.erase(dimPos);
522  // If a duplicate dim is marked as convolved, the pair of the duplicate
523  // dim must be removed from the map as well.
524  auto it = convolvedDimMapping.find(dimPos);
525  if (it != convolvedDimMapping.end()) {
526  int64_t pairedDim = it->second;
527  convolvedDims.erase(pairedDim);
528  unConvolvedDims.erase(pairedDim);
529  strideAndDilationMapping.erase(pairedDim);
530  convolvedDimMapping.erase(dimPos);
531  convolvedDimMapping.erase(pairedDim);
532  }
533  }
534  }
535  }
536 
537  LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
538  unsigned position = dimExpr.getPosition();
539  if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
540  return failure();
541  }
542  unConvolvedDims.insert(position);
543  return success();
544  }
545 
546  LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
547 
548  LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
549 
550  LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
551  // In pre-order visit, top level op has to be an add op.
552  if (binaryExpr.getKind() != AffineExprKind::Add)
553  return failure();
554  auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());
555  auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());
556  if (failed(lhsDimPos) || failed(rhsDimPos))
557  return failure();
558  convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
559  convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
560  return success();
561  }
562 
563  FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
564  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
565  int64_t dim = dimExpr.getPosition();
566  if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
567  return failure();
568  // Stride/dilation for this dim is implicitly 1.
569  strideAndDilationMapping[dim] =
571  convolvedDims.insert(dim);
572  return dim;
573  }
574  if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
575  if (symbolMulExpr.getKind() != AffineExprKind::Mul)
576  return failure();
577  auto lhsExpr = symbolMulExpr.getLHS();
578  auto rhsExpr = symbolMulExpr.getRHS();
579  // Check for symbol expression.
580  AffineExpr mulExpr =
581  getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
582  // If there was no symbol expr, check for constant expression.
583  if (!mulExpr) {
584  mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
585  }
586  auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
587  if (!mulExpr || !dimExpr)
588  return failure();
589  int64_t dim = dimExpr.getPosition();
590  if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
591  return failure();
592  strideAndDilationMapping[dim] = mulExpr;
593  convolvedDims.insert(dim);
594  return dim;
595  }
596  return failure();
597  }
598 };
599 } // namespace
600 
601 static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
602  assert(map.isProjectedPermutation() &&
603  "expected map to have projected permutations");
604  llvm::SmallDenseSet<int64_t> preservedDims;
605  for (auto expr : map.getResults())
606  preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
607  return preservedDims;
608 }
609 
613  for (auto e : exprs) {
614  auto constantExpr = dyn_cast<AffineConstantExpr>(e);
615  assert(constantExpr && "Found non-constant stride/dilation");
616  vals.push_back(constantExpr.getValue());
617  }
618  return vals;
619 }
620 
621 /// Classifies dimensions in the `linalgOp` used by a convolution
622 /// subcomputation, as captured by `inputExprWalker`. If
623 /// `allowEmptyConvolvedDims` is not set this this will fail if there is not
624 /// at least convolved dimension pair (output image + filter loop). Convolution
625 /// dimensions are specified in sorted order, and strides match the order of
626 /// the filter loop dimensions, while the dilations match the order of the
627 /// output image dimensions.
628 static FailureOr<ConvolutionDimensions>
629 inferConvolutionDimsImpl(LinalgOp linalgOp,
630  ConvAccessExprWalker &inputExprWalker,
631  bool allowEmptyConvolvedDims) {
632  auto filterMap =
633  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
634  auto outputMap =
635  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
636  llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
637  filterMap, linalgOp.getIteratorTypesArray(), par);
638  llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
639  outputMap, linalgOp.getIteratorTypesArray(), par);
640 
641  // unConvolvedDims & outputDims - filterDims are the batch iterators.
642  llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
643  llvm::set_intersect(batch, outputDims);
644  llvm::set_subtract(batch, filterDims);
645 
646  // convolvedDims & outputDims are the output image iterators.
647  llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
648  llvm::set_intersect(oi, outputDims);
649 
650  // filterDims & outputDims - unConvolvedDims are the output channel iterators.
651  llvm::SmallDenseSet<int64_t> oc = filterDims;
652  llvm::set_intersect(oc, outputDims);
653  llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
654 
655  // filterDims & outputDims & unConvolvedDims are the depth iterators.
656  llvm::SmallDenseSet<int64_t> depth = filterDims;
657  llvm::set_intersect(depth, outputDims);
658  llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
659 
660  llvm::SmallDenseSet<int64_t> filterReducedDims =
662  linalgOp.getIteratorTypesArray(), red);
663 
664  // convolvedDims & filterReducedDims are the filter loop iterators.
665  llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
666  llvm::set_intersect(fl, filterReducedDims);
667 
668  // unConvolvedDims & filterReducedDims are the input channel iterators.
669  llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
670  llvm::set_intersect(ic, filterReducedDims);
671 
672  if (oi.empty() && !allowEmptyConvolvedDims)
673  return failure();
674 
675  // Return each set in sorted order.
676  ConvolutionDimensions dimensions{
677  SmallVector<unsigned, 2>(batch.begin(), batch.end()),
678  SmallVector<unsigned, 2>(oi.begin(), oi.end()),
679  SmallVector<unsigned, 2>(oc.begin(), oc.end()),
680  SmallVector<unsigned, 2>(fl.begin(), fl.end()),
681  SmallVector<unsigned, 2>(ic.begin(), ic.end()),
682  SmallVector<unsigned, 2>(depth.begin(), depth.end()),
683  /*strides=*/SmallVector<int64_t, 2>{},
684  /*dilations=*/SmallVector<int64_t, 2>{}};
685  llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
686  llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
687  llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
688  llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
689  llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
690  llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
691 
692  // Use the op carried strides/dilations attribute if present.
693  auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
694  if (!nativeStrides) {
695  SmallVector<AffineExpr, 2> strideExprs;
696  for (unsigned oiDim : dimensions.outputImage)
697  strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
698  dimensions.strides = getConstantsFromExprList(strideExprs);
699  } else {
700  dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
701  }
702  auto nativeDilations =
703  linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
704  if (!nativeDilations) {
705  SmallVector<AffineExpr, 2> dilationExprs;
706  for (unsigned flDim : dimensions.filterLoop)
707  dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
708  dimensions.dilations = getConstantsFromExprList(dilationExprs);
709  } else {
710  dimensions.dilations =
711  llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
712  }
713  return dimensions;
714 }
715 
716 /// Find at least 1 parallel (output_image) and reduction (filter_loop)
717 /// dimension candidates that form a convolution subcomputation within
718 /// `linalgOp`. The LHS is assumed to be the convolution input while the
719 /// RHS is assumed as the filter.
720 /// These dimensions are such that:
721 /// 1. Optional batch dimensions that appear in the input and filter.
722 /// 2. The output_image dimension is involved in a cross-correlation along LHS
723 /// (i.e. it is a permutation on RES and LHS and has an associated
724 /// filter_loop in RHS).
725 /// 3. Optional output_channel dimension is involved in an outer-product along
726 /// RHS (i.e. it is a permutation on RES and RHS and does not appear in
727 /// LHS).
728 /// 4. Optional input_channel dimension appears as a permutation on LHS and
729 /// RHS.
730 /// 5. The filter_loop dimension appears as a permutation on the RHS and
731 /// represents the shape of the kernel cross-correlated along a
732 /// corresponding output_image dim.
733 /// 6. The input_channel dimension appears as a permutation on LHS and RHS.
734 /// 7. All dimensions appear only once in any given indexing map.
735 /// This allows e.g. detecting that some convolution is embedded within
736 /// `linalgOp` with some orthogonal heuristic.
737 /// When multiple dimension occurrences exist that match any classification
738 /// indices are returned in sorted order.
739 /// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
740 FailureOr<ConvolutionDimensions>
742  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
743  return failure();
744 
745  auto indexingMaps = linalgOp.getIndexingMapsArray();
746 
747  // Check the input indexing map has the right form.
748  ConvAccessExprWalker inputExprWalker;
749  for (AffineExpr expr : indexingMaps[0].getResults())
750  (void)inputExprWalker.visit(expr);
751  inputExprWalker.clearMultiUseDims(indexingMaps[0]);
752 
753  return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
754  /*allowEmptyConvolvedDims=*/false);
755 }
756 
757 namespace mlir::linalg::detail {
759  Success = 0,
760  NotLinalgOp,
768 };
769 } // namespace mlir::linalg::detail
770 
773  Operation *op, ConvolutionDimensions *dimensions,
774  bool allowEmptyConvolvedDims) {
775  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
776  if (!linalgOp)
777  return MatchConvolutionResult::NotLinalgOp;
778  if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
779  return MatchConvolutionResult::WrongNumOperands;
780 
781  auto indexingMaps = linalgOp.getIndexingMapsArray();
782 
783  // Check the input indexing map has the right form.
784  ConvAccessExprWalker inputExprWalker;
785  if (llvm::any_of(indexingMaps[0].getResults(),
786  [&inputExprWalker](AffineExpr expr) {
787  return failed(inputExprWalker.visit(expr));
788  })) {
789  return MatchConvolutionResult::WrongInputIndexingMap;
790  }
791 
792  // Filter and output maps must be projected permutation.
793  if (!indexingMaps[1].isProjectedPermutation() ||
794  !indexingMaps.back().isProjectedPermutation())
795  return MatchConvolutionResult::NotProjectedPermutations;
796 
797  auto iteratorTypes = linalgOp.getIteratorTypesArray();
798 
799  llvm::SmallDenseSet<int64_t> outputDims =
800  getPreservedDims(indexingMaps.back());
801  llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
802  // Make sure all loops are characterized as one of:
803  // - Batch loop : present in output, as non-convolved in input, not present in
804  // filter.
805  // - Output image dimension : present in output, convolved dims in input, not
806  // present in filter.
807  // - Output channel dimension : present in output, not present in input,
808  // present in filter.
809  // - Filter loop dimension : present in filter, convolved in input, not
810  // present in output.
811  // - Input channel dimension : unconvolved in input, not present in output,
812  // present in filter.
813  // - Depth multiplier : unconvolved in input, present in output, present in
814  // filter.
815  llvm::SmallDenseSet<int64_t> allLoopDims;
816  for (auto outputExpr : indexingMaps.back().getResults()) {
817  int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
818  if (inputExprWalker.unConvolvedDims.count(outputDim) &&
819  !filterDims.count(outputDim)) {
820  // Batch dimension.
821  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
822  return MatchConvolutionResult::OutputDimsNotParallel;
823  allLoopDims.insert(outputDim);
824  continue;
825  }
826  if (inputExprWalker.convolvedDims.count(outputDim) &&
827  !filterDims.count(outputDim)) {
828  // Output image Loop dimension.
829  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
830  return MatchConvolutionResult::OutputDimsNotParallel;
831  allLoopDims.insert(outputDim);
832  continue;
833  }
834  if (!inputExprWalker.convolvedDims.count(outputDim) &&
835  !inputExprWalker.unConvolvedDims.count(outputDim) &&
836  filterDims.count(outputDim)) {
837  // Output channel dimension.
838  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
839  return MatchConvolutionResult::OutputDimsNotParallel;
840  allLoopDims.insert(outputDim);
841  continue;
842  }
843  if (inputExprWalker.unConvolvedDims.count(outputDim) &&
844  filterDims.count(outputDim)) {
845  // Depth multiplier.
846  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
847  return MatchConvolutionResult::OutputDimsNotParallel;
848  allLoopDims.insert(outputDim);
849  continue;
850  }
851  return MatchConvolutionResult::NonConvolutionLoop;
852  }
853  for (auto filterExpr : indexingMaps[1].getResults()) {
854  int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
855  if (outputDims.count(filterDim) &&
856  !inputExprWalker.unConvolvedDims.count(filterDim) &&
857  !inputExprWalker.convolvedDims.count(filterDim)) {
858  // Output channel dimension. This is already seen, continue;
859  continue;
860  }
861  if (inputExprWalker.convolvedDims.count(filterDim) &&
862  !outputDims.count(filterDim)) {
863  // Filter loop dimension.
864  if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
865  return MatchConvolutionResult::NonOutputDimNotReduction;
866  if (allLoopDims.count(filterDim))
867  return MatchConvolutionResult::NonConvolutionLoop;
868  allLoopDims.insert(filterDim);
869  continue;
870  }
871  if (inputExprWalker.unConvolvedDims.count(filterDim) &&
872  !outputDims.count(filterDim)) {
873  // Input channel dimension.
874  if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
875  return MatchConvolutionResult::NonOutputDimNotReduction;
876  if (allLoopDims.count(filterDim))
877  return MatchConvolutionResult::NonConvolutionLoop;
878  allLoopDims.insert(filterDim);
879  continue;
880  }
881  if (inputExprWalker.unConvolvedDims.count(filterDim) &&
882  outputDims.count(filterDim)) {
883  // Depthwise loop. Already seen.
884  continue;
885  }
886  return MatchConvolutionResult::NonConvolutionLoop;
887  }
888  // All loops must be covered now.
889  if (allLoopDims.size() != linalgOp.getNumLoops())
890  return MatchConvolutionResult::NonConvolutionLoop;
891 
892  if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
893  return MatchConvolutionResult::EmptyConvolvedDims;
894 
895  if (dimensions) {
896  FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl(
897  linalgOp, inputExprWalker, allowEmptyConvolvedDims);
898  assert(succeeded(res) && "unexpected failure to infer convolution dims");
899  *dimensions = *res;
900  }
901 
902  return MatchConvolutionResult::Success;
903 }
904 
905 StringRef
907  switch (res) {
908  case MatchConvolutionResult::NotLinalgOp:
909  return "expected a LinalgOp";
910  case MatchConvolutionResult::WrongNumOperands:
911  return "expected op with 2 inputs and 1 output";
912  case MatchConvolutionResult::WrongInputIndexingMap:
913  return "unexpected input index map for convolutions";
914  case MatchConvolutionResult::NotProjectedPermutations:
915  return "expected output/filter indexing maps to be projected permutations";
916  case MatchConvolutionResult::NonConvolutionLoop:
917  return "unexpected loop dimension for convolution op";
918  case MatchConvolutionResult::OutputDimsNotParallel:
919  return "expected all iterators used to access outputs to be parallel";
920  case MatchConvolutionResult::NonOutputDimNotReduction:
921  return "expected all iterators not used to access outputs to be reduction";
922  case MatchConvolutionResult::EmptyConvolvedDims:
923  return "expected convolved dim to be non-empty";
924  case MatchConvolutionResult::Success:
925  return "";
926  }
927  llvm_unreachable("unhandled MatchConvolutionResult case");
928 }
929 
931  bool allowEmptyConvolvedDims) {
933  linalgOp.getOperation(), nullptr, allowEmptyConvolvedDims) ==
935 }
936 
939  if (res != MatchConvolutionResult::Success)
940  return op->emitError(getMatchConvolutionMessage(res));
941  return success();
942 }
943 
944 //===----------------------------------------------------------------------===//
945 // FillOpInterface implementation
946 //===----------------------------------------------------------------------===//
947 
948 enum class MatchFillResult {
949  Success = 0,
950  NotLinalgOp,
951  WrongNumOperands,
953 };
954 
956  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
957  if (!linalgOp)
959  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
961 
962  OpOperand *value = linalgOp.getDpsInputOperand(0);
963  if (!linalgOp.isScalar(value))
965 
967 }
968 
970  auto res = isFillInterfaceImpl(op);
971  if (res == MatchFillResult::NotLinalgOp)
972  return op->emitError("expected a LinalgOp");
974  return op->emitError("expected op with 1 input and 1 output");
976  return op->emitError("expected op with scalar input");
977 
978  return success();
979 }
980 
981 //===----------------------------------------------------------------------===//
982 // StructuredOpInterface implementation
983 //===----------------------------------------------------------------------===//
984 
985 SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
986  Location loc) {
988  for (OpOperand &opOperand : getOperation()->getOpOperands()) {
989  for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
990  res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
991  }
992  return res;
993 }
994 
995 SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
997  assert(!hasDynamicShape() && "expected operands to have static shapes");
998  for (OpOperand &opOperand : getOperation()->getOpOperands())
999  llvm::append_range(res, getShape(&opOperand));
1000  return res;
1001 }
1002 
1003 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
1004  AffineMap map = getLoopsToShapesMap();
1005  unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
1006  auto viewSizes = createFlatListOfOperandDims(b, loc);
1007  SmallVector<Range, 4> res(numDims);
1008  for (unsigned idx = 0; idx < numRes; ++idx) {
1009  auto result = map.getResult(idx);
1010  if (auto d = dyn_cast<AffineDimExpr>(result)) {
1011  if (res[d.getPosition()].offset)
1012  continue;
1013  res[d.getPosition()] =
1014  Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
1015  }
1016  }
1017  return res;
1018 }
1019 
1020 SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
1021  AffineMap map = getLoopsToShapesMap();
1022  unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
1023  SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims();
1024  SmallVector<int64_t, 4> res(numDims, 0);
1025  for (unsigned idx = 0; idx < numRes; ++idx) {
1026  auto result = map.getResult(idx);
1027  if (auto d = dyn_cast<AffineDimExpr>(result))
1028  res[d.getPosition()] = allShapeSizes[idx];
1029  }
1030  return res;
1031 }
1032 
1033 /// Visitor to check if any of the given set of positions from AffineDimExprs
1034 /// are used within an AffineExpr.
1036  : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
1037  HasAffineDimExprVisitor(llvm::SmallBitVector positions)
1038  : positions(std::move(positions)) {}
1039 
1041  return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
1042  }
1043 
1044  bool visitDimExpr(AffineDimExpr dimExpr) {
1045  return positions.test(dimExpr.getPosition());
1046  }
1047 
1048  bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
1049 
1050  bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
1051 
1052 private:
1053  llvm::SmallBitVector positions;
1054 };
1055 
1056 static std::pair<int64_t, int64_t>
1058  int64_t inputRankSum = 0;
1059  int64_t outputRankSum = 0;
1060  for (OpOperand *input : op.getDpsInputOperands())
1061  inputRankSum += op.getRank(input);
1062  for (OpOperand &output : op.getDpsInitsMutable())
1063  outputRankSum += op.getRank(&output);
1064  return {inputRankSum, inputRankSum + outputRankSum};
1065 }
1066 
1067 LogicalResult
1069  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1070  // An example that helps understand the logic below.
1071  // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
1072  // We want to express the shape of dim 0 of O in terms of shape of the inputs.
1073  // This is achieved as follows.
1074  // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
1075  // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
1076  // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
1077  // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
1078  // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
1079  AffineMap loopsToShapesMap = getLoopsToShapesMap();
1080 
1081  // Find the position in the above map that represents the shape of the
1082  // result:dim being inferred.
1083  auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
1084 
1085  /// From loopsToShapesMap extract the submap that represents the shape of the
1086  /// (resultIdx, dim) needed.
1087  AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
1088  resultShapesSubMapPos.first,
1089  resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1090  AffineMap resultShapesFromInputShapesMap =
1091  loopToResultsShapeMap.compose(getShapesToLoopsMap());
1092 
1093  // Check that the result dim map does not contain the positions corresponding
1094  // to the outputs.
1095  llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
1096  outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1097  HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1098  Location loc = getOperation()->getLoc();
1099  IRRewriter rewriter(b);
1100  SmallVector<OpFoldResult> allResultDimValues =
1102  rewriter, loc, resultShapesFromInputShapesMap,
1103  createFlatListOfOperandDims(b, loc));
1104  int64_t pos = 0;
1105  ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
1106  for (OpOperand &opOperand : getDpsInitsMutable()) {
1108  for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1109  auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1110  if (!shapedType.isDynamicDim(dim)) {
1111  // Static dim: Return IntegerAttr.
1112  shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
1113  } else {
1114  // Dynamic dim: Return Value.
1115  OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1116  ? createOrFoldDimOp(b, loc, opOperand.get(), dim)
1117  : allResultDimValues[pos];
1118  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
1119  }
1120  pos++;
1121  }
1122  reifiedReturnShapes.emplace_back(std::move(shapes));
1123  }
1124  return success();
1125 }
1126 
1127 /// Return the index in the indexingMaps vector that corresponds to this
1128 /// `opOperand`.
1129 int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
1130  auto operandNumber = opOperand->getOperandNumber();
1131  auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1132  if (!dpsIface.isDpsInput(opOperand))
1133  return operandNumber;
1134  unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1135  assert(!dpsIface.isDpsInit(opOperand));
1136  // Account for potential inputs that are not DPS and may not appear in
1137  // `indexingMaps`.
1138  return cast<DestinationStyleOpInterface>(*this->getOperation())
1139  .getNumDpsInputs() +
1140  operandNumber - start;
1141 }
1142 
1144  LinalgOp linalgOp = cast<LinalgOp>(op);
1145 
1146  // Mixed tensor/buffer operands are not allowed.
1147  if (!linalgOp.hasPureTensorSemantics() &&
1148  !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1149  return op->emitOpError("expected to have pure tensor or buffer semantics");
1150 
1151  // Before checking indexing maps, we need to make sure the attributes
1152  // referenced by it are valid.
1153  if (linalgOp.hasDynamicIndexingMaps())
1154  if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1155  return failure();
1156 
1157  // All input/output operands must be indexed.
1158  if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
1159  linalgOp->getNumOperands())
1160  return op->emitOpError("expected the number of indexing_map (")
1161  << linalgOp.getIndexingMapsArray().size()
1162  << ") to be equal to the number of input/output operands ("
1163  << linalgOp->getNumOperands() << ")";
1164 
1165  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1166  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1167 
1168  // Symbols disallowed.
1169  if (indexingMap.getNumSymbols() != 0)
1170  return op->emitOpError("unexpected symbols in indexing_map #")
1171  << opOperand.getOperandNumber();
1172 
1173  // Domain must be consistent.
1174  unsigned numLoops = linalgOp.getNumLoops();
1175  if (indexingMap.getNumDims() != numLoops)
1176  return op->emitOpError("expected indexing_map #")
1177  << opOperand.getOperandNumber() << " to have " << numLoops
1178  << " dim(s) to match the number of loops";
1179 
1180  int64_t rank = linalgOp.getRank(&opOperand);
1181  if (indexingMap.getNumResults() != rank)
1182  return op->emitOpError("expected operand rank (")
1183  << rank << ") to match the result rank of indexing_map #"
1184  << opOperand.getOperandNumber() << " ("
1185  << indexingMap.getNumResults() << ")";
1186  }
1187 
1188  SmallVector<unsigned> redDims;
1189  linalgOp.getReductionDims(redDims);
1190 
1191  if (!linalgOp.getShapesToLoopsMap())
1192  return op->emitOpError("expected the shape-to-loops map to be non-null");
1193 
1194  // Check if given shapes match to inferred shapes.
1195  SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
1196  SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
1197 
1198  // Verify only static cases since we can't get exact dimension sizes and loop
1199  // ranges for dynamic cases in this stage.
1200  if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
1201  for (int64_t &range : endLoopRangeValues)
1202  range -= 1;
1203  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1204  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1205  SmallVector<int64_t, 4> startIndices =
1206  indexingMap.compose(startLoopRangeValues);
1207  SmallVector<int64_t, 4> endIndices =
1208  indexingMap.compose(endLoopRangeValues);
1209  ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand);
1210  for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
1211  // Ignore dynamic dimension or the case that the dimension size is 0
1212  if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
1213  continue;
1214 
1215  // The first index or last index should be the maximum or the minimum in
1216  // the inferred index ranges since the range is increasing or
1217  // decreasing. The size of dimensions of input/output operands and the
1218  // maximum value + 1 in the inferred range should be the same. But, for
1219  // now we check if the inferred ranges are in boundary of input/output
1220  // operands' size or not in case that Affine Expressions are complicated
1221  // such as d0 * 3
1222  // + d1 since it is not easy to handle the issues.
1223  // Found the case that this solution can't check, for example, (d0, d1)
1224  // -> (d1 - d0)
1225  int64_t inferredDimSize =
1226  std::max(startIndices[dim], endIndices[dim]) + 1;
1227  if (std::min(startIndices[dim], endIndices[dim]) < 0) {
1228  std::string mapStr;
1229  {
1230  llvm::raw_string_ostream os(mapStr);
1231  os << indexingMap;
1232  }
1233  return op->emitOpError(
1234  "unexpected result less than 0 at expression #")
1235  << dim << " in " << mapStr;
1236  }
1237  if (dyn_cast<AffineDimExpr>(indexingMap.getResult(dim))) {
1238  if (inferredDimSize != shape[dim]) {
1239  return op->emitOpError("inferred input/output operand #")
1240  << opOperand.getOperandNumber() << " has shape's dimension #"
1241  << dim << " to be " << inferredDimSize << ", but found "
1242  << shape[dim];
1243  }
1244  } else {
1245  if (inferredDimSize > shape[dim]) {
1246  return op->emitOpError("inferred input/output operand #")
1247  << opOperand.getOperandNumber() << " has shape's dimension #"
1248  << dim << " to be greater than or equal to "
1249  << inferredDimSize << ", but found " << shape[dim];
1250  }
1251  }
1252  }
1253  }
1254  }
1255 
1256  // Check the region has exactly one block.
1257  if (linalgOp->getNumRegions() != 1 ||
1258  !llvm::hasSingleElement(linalgOp->getRegion(0)))
1259  return op->emitOpError("expects to have 1 region with 1 block");
1260 
1261  // Simplifying assumption: bbargs match 1-1 with shape operands elemental
1262  // types.
1263  // TODO: once ranked shape types are plugged in, we may want to drop the
1264  // corresponding bbargs, that can never be read from. This will be subject to
1265  // consistency discussions (i.e. what to do with output tensors whose bbarg is
1266  // not used).
1267  Block &block = linalgOp->getRegion(0).front();
1268 
1269  if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
1270  return op->emitOpError("expected as many non-induction variable region "
1271  "arguments as the number of input/output operands");
1272 
1273  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1274  Type elementType = opOperand->get().getType();
1275  if (isa<MemRefType, RankedTensorType>(elementType))
1276  elementType = getElementTypeOrSelf(opOperand->get().getType());
1277  Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
1278  if (elementType != argType)
1279  return op->emitOpError("expected type of bb argument #")
1280  << opOperand->getOperandNumber() << " (" << argType << ")"
1281  << " to match element or self type of the corresponding operand ("
1282  << elementType << ")";
1283  }
1284 
1285  return success();
1286 }
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:63
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp, unsigned arity)
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static MatchFillResult isFillInterfaceImpl(Operation *op)
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
MatchFillResult
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Affine binary operation expression.
Definition: AffineExpr.h:227
AffineExpr getLHS() const
Definition: AffineExpr.cpp:340
AffineExpr getRHS() const
Definition: AffineExpr.cpp:343
An integer constant appearing in affine expression.
Definition: AffineExpr.h:252
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:236
unsigned getPosition() const
Definition: AffineExpr.cpp:348
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:35
MLIRContext * getContext() const
Definition: AffineExpr.cpp:33
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:662
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:244
Block represents an ordered list of Operations.
Definition: Block.h:31
bool empty()
Definition: Block.h:146
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation & back()
Definition: Block.h:150
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
OpListType & getOperations()
Definition: Block.h:135
Operation & front()
Definition: Block.h:151
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:215
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition: Operation.h:753
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1239
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:99
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:90
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:836
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:631
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.