MLIR  19.0.0git
LinalgInterfaces.cpp
Go to the documentation of this file.
1 //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "llvm/ADT/SetOperations.h"
22 #include "llvm/ADT/SmallBitVector.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include <algorithm>
25 
26 using namespace mlir;
27 using namespace mlir::linalg;
28 
29 /// Include the definitions of the copy operation interface.
30 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
31 
32 //===----------------------------------------------------------------------===//
33 // Interface utility functions
34 //===----------------------------------------------------------------------===//
35 
37  linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
38  SmallVector<AffineMap> indexingMaps;
39  for (auto &opOperand : linalgOp->getOpOperands()) {
40  if (llvm::is_contained(droppedOperands, &opOperand))
41  continue;
42  indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
43  }
44  if (indexingMaps.empty()) {
45  // If there are no indexing maps, the operand can only be dropped
46  // if the op has no loops.
47  return linalgOp.getNumLoops() == 0;
48  }
49  return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
50 }
51 
52 //===----------------------------------------------------------------------===//
53 // CopyOpInterface implementation
54 //===----------------------------------------------------------------------===//
55 
56 bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
57  // Structural.
58  if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
59  return false;
60 
61  // Operands and maps.
62  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
63  return false;
64  auto mapRange = linalgOp.getIndexingMapsArray();
65  if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
66  !mapRange.back().isIdentity()) {
67  return false;
68  }
69  // Region.
70  return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // FillOpInterface implementation
75 //===----------------------------------------------------------------------===//
76 std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
77  // Structural.
78  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
79  genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
80  return std::nullopt;
81 
82  // Input should be referenced and init should not.
83  if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
84  genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
85  return std::nullopt;
86 
87  OpOperand *value = genericOp.getDpsInputOperand(0);
88  if (!genericOp.isScalar(value))
89  return std::nullopt;
90 
91  Block *body = genericOp.getBody();
92  if (body->getOperations().size() != 1)
93  return std::nullopt;
94 
95  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
96  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
97  yieldOp->getOperand(0) != body->getArgument(0))
98  return std::nullopt;
99  return value->get();
100 }
101 
102 //===----------------------------------------------------------------------===//
103 // Elementwise Single Unary/Binary-OpInterface implementation
104 //===----------------------------------------------------------------------===//
105 static bool
106 isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
107  unsigned arity) {
108  // Check all loops are parallel.
109  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
110  genericOp.getNumLoops() < 1)
111  return false;
112 
113  // Check there are arity-inputs, 1-output and all are identity-maps.
114  if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
115  !llvm::all_of(genericOp.getIndexingMapsArray(),
116  [](AffineMap map) { return map.isIdentity(); }))
117  return false;
118 
119  // Init should not be referenced for elementwise operations.
120  if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
121  return false;
122 
123  // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
124  // as resulting from producer-consumer fusion. Here, we restrict to two ops in
125  // the body, where the first is the elementwise single op and the second a
126  // yield.
127  Block *body = genericOp.getBody();
128  if (body->getOperations().size() != 2)
129  return false;
130 
131  Operation *op = &body->front();
132  if (op->getNumOperands() != arity || op->getNumResults() != 1)
133  return false;
134 
135  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
136  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
137  yieldOp->getOperand(0).getDefiningOp() != op)
138  return false;
139  return true;
140 }
141 
142 bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp genericOp) {
143  // All basic elemwise checks.
145  return false;
146 
147  // Check input is actully used.
148  if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)))
149  return false;
150  return true;
151 }
152 
153 bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp genericOp) {
155  return false;
156 
157  // Check both inputs are used (elementwise).
158  OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0);
159  OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1);
160  if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
161  !genericOp.payloadUsesValueFromOperand(inputOpOperand1))
162  return false;
163  return true;
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // ContractionOpInterface implementation
168 //===----------------------------------------------------------------------===//
169 
170 /// If the value is defined by a chain of unary side effect-free, go up the
171 /// use-def chain until the first value that isn't defined by such an op.
172 // TODO: relax to multi-operands with constants, which are technically unary ops
173 // as needed (e.g. add5).
175  Operation *op = value.getDefiningOp();
176  while (op && op->getNumOperands() == 1) {
177  auto iface = dyn_cast<MemoryEffectOpInterface>(op);
178  if (!iface || !iface.hasNoEffect())
179  break;
180  value = op->getOperand(0);
181  op = value.getDefiningOp();
182  }
183  return value;
184 }
185 
187  Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
188  llvm::raw_ostream &errs) {
189  if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
190  errs << "no terminator in the block";
191  return false;
192  }
193 
194  if (block.getNumArguments() != 3) {
195  errs << "expected block with 3 arguments";
196  return false;
197  }
198 
199  Operation *terminator = block.getTerminator();
200  if (terminator->getNumOperands() != 1) {
201  errs << "expected terminator with 1 operand";
202  return false;
203  }
204 
205  Value yielded = getSourceSkipUnary(terminator->getOperand(0));
206  Operation *reductionOp = yielded.getDefiningOp();
207  if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
208  errs << "expected reduction op to be binary";
209  return false;
210  }
211 
212  Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
213  Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
214 
215  if (reductionLHS != block.getArgument(2) &&
216  reductionRHS != block.getArgument(2)) {
217  errs << "expected reduction to take block argument #2 as one of the "
218  "operands (modulo unary casts)";
219  return false;
220  }
221 
222  Value contributed = getSourceSkipUnary(
223  isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
224  Operation *elementwiseOp = contributed.getDefiningOp();
225  if (elementwiseOp->getNumResults() != 1 ||
226  elementwiseOp->getNumOperands() != 2) {
227  errs << "expected elementwise op to be binary";
228  return false;
229  }
230 
231  if (!isaPair(elementwiseOp, reductionOp)) {
232  errs << "expected reduction/elementwise op kind not satisfied";
233  return false;
234  }
235 
236  Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
237  Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
238  if ((elementwiseLHS == block.getArgument(0) &&
239  elementwiseRHS == block.getArgument(1)) ||
240  (elementwiseLHS == block.getArgument(1) &&
241  elementwiseRHS == block.getArgument(0))) {
242  return true;
243  }
244 
245  errs << "expected elementwise op to apply to block arguments (modulo unary "
246  "casts)";
247  return false;
248 }
249 
250 /// Returns true if the two operations are of the kinds specified by a pair of
251 /// consecutive template arguments.
252 template <typename AddOpTy, typename MulOpTy, typename... Args>
253 static bool isPairTemplateImpl(Operation *add, Operation *mul) {
254  static_assert(sizeof...(Args) % 2 == 0,
255  "expected an even number of template arguments");
256  if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
257  return true;
258 
259  if constexpr (sizeof...(Args) > 0)
260  return isPairTemplateImpl<Args...>(add, mul);
261  else
262  return false;
263 }
264 
265 /// Returns true if the block is a body of a contraction with the kinds of
266 /// operations given pairwise by template arguments.
267 template <typename... Args>
268 static bool isContractionBody(Block &block) {
269  return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
270 }
271 
272 /// Given an `indexingMap` and its corresponding `iterators`, returns
273 /// the positions of the iterators of type `iter` that are indexed by
274 /// the `indexingMap` as a permutation. This is useful to infer various
275 /// subcomputations on a `LinalgOp`. This is performed by looking up
276 /// each result in the `indexingMap` and determining whether:
277 /// - It is a single AffineDimExpr.
278 /// - It is the only result involving this AffineDimExpr.
279 static llvm::SmallDenseSet<int64_t>
282  utils::IteratorType iter) {
283  assert(iterators.size() == indexingMap.getNumDims());
284  llvm::SmallDenseSet<int64_t> res;
285  for (AffineExpr e : indexingMap.getResults()) {
286  if (auto d = dyn_cast<AffineDimExpr>(e)) {
287  if (iterators[d.getPosition()] == iter &&
288  llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
289  return e.isFunctionOfDim(d.getPosition());
290  }) == 1)
291  res.insert(d.getPosition());
292  }
293  }
294  return res;
295 }
296 
297 namespace {
298 auto par = utils::IteratorType::parallel;
299 auto red = utils::IteratorType::reduction;
300 } // namespace
301 
302 /// Infer the iterator types from the init affine map. This looks at which dims
303 /// are present in the map results, and returns an iterator types array with
304 /// parallel types for dims that are present, and reduction types for dims that
305 /// are not present.
306 static FailureOr<SmallVector<utils::IteratorType>>
308  if (!map.isProjectedPermutation())
309  return failure();
310  SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
311  for (auto expr : map.getResults())
312  if (auto dim = dyn_cast<AffineDimExpr>(expr))
313  iterators[dim.getPosition()] = par;
314  return iterators;
315 }
316 
317 /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
318 /// a matmul subcomputation within `linalgOp`. These dimensions are such that:
319 /// 1. The m dimension is involved in an outer-product along LHS
320 /// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
321 /// 2. The n dimension is involved in an outer-product along RHS
322 /// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
323 /// 3. The k dimension appears as a permutation on LHS and RHS.
324 /// 4. m, n and k appear only once in any given indexing.
325 /// 5. Optional batch dimensions that appear in all operands are captured.
326 /// This allows e.g. detecting that some contraction is embedded within
327 /// `linalgOp` with some orthogonal heuristic.
328 static FailureOr<ContractionDimensions>
330  ArrayRef<utils::IteratorType> iterators) {
331  llvm::SmallDenseSet<int64_t> a =
332  findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
333  llvm::SmallDenseSet<int64_t> b =
334  findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
335  llvm::SmallDenseSet<int64_t> c =
336  findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
337 
338  // A & C - B are the iterators involved in an outer-product along A (the LHS).
339  llvm::SmallDenseSet<int64_t> ac = a;
340  llvm::set_intersect(ac, c);
341  llvm::set_subtract(ac, b);
342  // B & C - A are the iterators involved in an outer-product along B (the RHS).
343  llvm::SmallDenseSet<int64_t> bc = b;
344  llvm::set_intersect(bc, c);
345  llvm::set_subtract(bc, a);
346  // A & B & C are the "batch" dimensions.
347  llvm::SmallDenseSet<int64_t> batches = a;
348  llvm::set_intersect(batches, b);
349  llvm::set_intersect(batches, c);
350 
351  // A & B red are the reduction dimensions.
352  llvm::SmallDenseSet<int64_t> ra =
353  findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
354  llvm::SmallDenseSet<int64_t> rb =
355  findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
356  llvm::set_intersect(ra, rb);
357 
358  // Return each set in sorted order.
359  ContractionDimensions dimensions{
360  SmallVector<unsigned, 2>(batches.begin(), batches.end()),
361  SmallVector<unsigned, 2>(ac.begin(), ac.end()),
362  SmallVector<unsigned, 2>(bc.begin(), bc.end()),
363  SmallVector<unsigned, 2>(ra.begin(), ra.end())};
364  llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
365  llvm::sort(dimensions.m.begin(), dimensions.m.end());
366  llvm::sort(dimensions.n.begin(), dimensions.n.end());
367  llvm::sort(dimensions.k.begin(), dimensions.k.end());
368  return dimensions;
369 }
370 
371 FailureOr<ContractionDimensions>
373  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
374  return failure();
375  return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
376  linalgOp.getIteratorTypesArray());
377 }
378 
379 FailureOr<ContractionDimensions>
381  if (indexingMaps.size() != 3)
382  return failure();
383  auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
384  if (failed(iterators))
385  return failure();
386  return inferContractionDimsImpl(indexingMaps, iterators.value());
387 }
388 
389 namespace mlir::linalg::detail {
391  Success = 0,
392  NotLinalgOp,
394  NoReduction,
396  NotAddMul
397 };
398 } // namespace mlir::linalg::detail
399 
403  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
404  if (!linalgOp)
405  return MatchContractionResult::NotLinalgOp;
406  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
407  return MatchContractionResult::WrongNumOperands;
408  auto mapRange = linalgOp.getIndexingMapsArray();
409  if (linalgOp.getNumReductionLoops() == 0)
410  return MatchContractionResult::NoReduction;
411  if (llvm::any_of(mapRange,
412  [](AffineMap m) { return !m.isProjectedPermutation(); }))
413  return MatchContractionResult::NotProjectedPermutations;
414  // TODO: more fields than add/mul.
415  // clang-format off
416  if (!::isContractionBody<
417  arith::MulFOp, arith::AddFOp,
418  arith::MulIOp, arith::AddIOp,
419  complex::MulOp, complex::AddOp,
420  arith::AndIOp, arith::OrIOp>(
421  *linalgOp.getBlock())) {
422  return MatchContractionResult::NotAddMul;
423  }
424  // clang-format on
425 
426  if (dimensions) {
427  FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);
428  assert(succeeded(res) && "unexpected failure to infer contraction dims");
429  *dimensions = *res;
430  }
431  return MatchContractionResult::Success;
432 }
433 
434 StringRef
436  switch (res) {
437  case MatchContractionResult::NotLinalgOp:
438  return "expected a LinalgOp";
439  case MatchContractionResult::WrongNumOperands:
440  return "expected op with 2 inputs and 1 output";
441  case MatchContractionResult::NoReduction:
442  return "expected at least 1 reduction";
443  case MatchContractionResult::NotProjectedPermutations:
444  return "expected indexing maps to be projected permutations";
445  case MatchContractionResult::NotAddMul:
446  return "expected add/mul op in the body";
447  case MatchContractionResult::Success:
448  return "";
449  }
450  llvm_unreachable("unhandled MatchContractionResult case");
451 }
452 
454  if (!linalgOp)
455  return false;
456  Operation *op = linalgOp.getOperation();
457  return isa<ContractionOpInterface>(op) ||
460 }
461 
462 /// Verify that a LinalgOp `op` is a contraction.
463 /// A Linalg contraction is defined in general terms:
464 /// 1. Has 2 input and 1 output shapes.
465 /// 2. Has at least one reduction dimension.
466 /// 3. Has only projected permutation indexing maps.
467 /// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
468 /// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
469 /// operations that may change the type (e.g. for mixed-precision).
470 /// As a consequence, when vectorization of such an op occurs, the only special
471 /// behavior is that the (unique) MulOpType is vectorized into a
472 /// `vector.contract`. All other ops are handled in a generic fashion.
473 /// In the future, we may wish to allow more input arguments and elementwise and
474 /// constant operations that do not involve the reduction dimension(s).
476  auto res = isContractionInterfaceImpl(op);
477  if (res != MatchContractionResult::Success)
478  return op->emitError(getMatchContractionMessage(res));
479  return success();
480 }
481 
482 //===----------------------------------------------------------------------===//
483 // ConvolutionOpInterface implementation
484 //===----------------------------------------------------------------------===//
485 
486 /// Of the given two expressions returns one that is of type T (`lhs` gets
487 /// preference over `rhs`)
488 template <typename T>
490  return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
491 }
492 
493 namespace {
494 /// Walk the indexing expressions for input of a convolution operation to verify
495 /// its of the right form, either
496 /// - AffineDimExpr
497 /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
498 /// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
499 ///
500 /// classifies the AffineDimExpr as convolved dimensions or unconvolved
501 /// dimensions and verifies each dimension occurs only once.
502 struct ConvAccessExprWalker
503  : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
504  // Stores dimensions used in expressions of the above form.
505  llvm::SmallDenseSet<int64_t> convolvedDims;
506  // Stores the dual mapping between LHS and RHS of convolution exprs.
507  llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
508  // Stores single use dimensions used by an AffineDimExpr.
509  llvm::SmallDenseSet<int64_t> unConvolvedDims;
510  // Stores a mapping from convolved dims to their coefficient.
511  llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
512 
513  // Removes dims with multiple uses in the source input map from dimension
514  // sets tracked by this walker.
515  void clearMultiUseDims(AffineMap map) {
516  for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
517  if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) {
518  return e.isFunctionOfDim(dimPos);
519  }) > 1) {
520  convolvedDims.erase(dimPos);
521  unConvolvedDims.erase(dimPos);
522  // If a duplicate dim is marked as convolved, the pair of the duplicate
523  // dim must be removed from the map as well.
524  if (convolvedDimMapping.contains(dimPos)) {
525  int64_t pairedDim = convolvedDimMapping[dimPos];
526  convolvedDims.erase(pairedDim);
527  unConvolvedDims.erase(pairedDim);
528  strideAndDilationMapping.erase(pairedDim);
529  convolvedDimMapping.erase(dimPos);
530  convolvedDimMapping.erase(pairedDim);
531  }
532  }
533  }
534  }
535 
536  LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
537  unsigned position = dimExpr.getPosition();
538  if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
539  return failure();
540  }
541  unConvolvedDims.insert(position);
542  return success();
543  }
544 
545  LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
546 
547  LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
548 
549  LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
550  // In pre-order visit, top level op has to be an add op.
551  if (binaryExpr.getKind() != AffineExprKind::Add)
552  return failure();
553  auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());
554  auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());
555  if (failed(lhsDimPos) || failed(rhsDimPos))
556  return failure();
557  convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
558  convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
559  return success();
560  }
561 
562  FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
563  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
564  int64_t dim = dimExpr.getPosition();
565  if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
566  return failure();
567  // Stride/dilation for this dim is implicitly 1.
568  strideAndDilationMapping[dim] =
570  convolvedDims.insert(dim);
571  return dim;
572  }
573  if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
574  if (symbolMulExpr.getKind() != AffineExprKind::Mul)
575  return failure();
576  auto lhsExpr = symbolMulExpr.getLHS();
577  auto rhsExpr = symbolMulExpr.getRHS();
578  // Check for symbol expression.
579  AffineExpr mulExpr =
580  getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
581  // If there was no symbol expr, check for constant expression.
582  if (!mulExpr) {
583  mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
584  }
585  auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
586  if (!mulExpr || !dimExpr)
587  return failure();
588  int64_t dim = dimExpr.getPosition();
589  if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
590  return failure();
591  strideAndDilationMapping[dim] = mulExpr;
592  convolvedDims.insert(dim);
593  return dim;
594  }
595  return failure();
596  }
597 };
598 } // namespace
599 
600 static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
601  assert(map.isProjectedPermutation() &&
602  "expected map to have projected permutations");
603  llvm::SmallDenseSet<int64_t> preservedDims;
604  for (auto expr : map.getResults())
605  preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
606  return preservedDims;
607 }
608 
612  for (auto e : exprs) {
613  auto constantExpr = dyn_cast<AffineConstantExpr>(e);
614  assert(constantExpr && "Found non-constant stride/dilation");
615  vals.push_back(constantExpr.getValue());
616  }
617  return vals;
618 }
619 
620 /// Classifies dimensions in the `linalgOp` used by a convolution
621 /// subcomputation, as captured by `inputExprWalker`. If
622 /// `allowEmptyConvolvedDims` is not set this this will fail if there is not
623 /// at least convolved dimension pair (output image + filter loop). Convolution
624 /// dimensions are specified in sorted order, and strides match the order of
625 /// the filter loop dimensions, while the dilations match the order of the
626 /// output image dimensions.
627 static FailureOr<ConvolutionDimensions>
628 inferConvolutionDimsImpl(LinalgOp linalgOp,
629  ConvAccessExprWalker &inputExprWalker,
630  bool allowEmptyConvolvedDims) {
631  auto filterMap =
632  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
633  auto outputMap =
634  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
635  llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
636  filterMap, linalgOp.getIteratorTypesArray(), par);
637  llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
638  outputMap, linalgOp.getIteratorTypesArray(), par);
639 
640  // unConvolvedDims & outputDims - filterDims are the batch iterators.
641  llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
642  llvm::set_intersect(batch, outputDims);
643  llvm::set_subtract(batch, filterDims);
644 
645  // convolvedDims & outputDims are the output image iterators.
646  llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
647  llvm::set_intersect(oi, outputDims);
648 
649  // filterDims & outputDims - unConvolvedDims are the output channel iterators.
650  llvm::SmallDenseSet<int64_t> oc = filterDims;
651  llvm::set_intersect(oc, outputDims);
652  llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
653 
654  // filterDims & outputDims & unConvolvedDims are the depth iterators.
655  llvm::SmallDenseSet<int64_t> depth = filterDims;
656  llvm::set_intersect(depth, outputDims);
657  llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
658 
659  llvm::SmallDenseSet<int64_t> filterReducedDims =
661  linalgOp.getIteratorTypesArray(), red);
662 
663  // convolvedDims & filterReducedDims are the filter loop iterators.
664  llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
665  llvm::set_intersect(fl, filterReducedDims);
666 
667  // unConvolvedDims & filterReducedDims are the input channel iterators.
668  llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
669  llvm::set_intersect(ic, filterReducedDims);
670 
671  if (oi.empty() && !allowEmptyConvolvedDims)
672  return failure();
673 
674  // Return each set in sorted order.
675  ConvolutionDimensions dimensions{
676  SmallVector<unsigned, 2>(batch.begin(), batch.end()),
677  SmallVector<unsigned, 2>(oi.begin(), oi.end()),
678  SmallVector<unsigned, 2>(oc.begin(), oc.end()),
679  SmallVector<unsigned, 2>(fl.begin(), fl.end()),
680  SmallVector<unsigned, 2>(ic.begin(), ic.end()),
681  SmallVector<unsigned, 2>(depth.begin(), depth.end()),
682  /*strides=*/SmallVector<int64_t, 2>{},
683  /*dilations=*/SmallVector<int64_t, 2>{}};
684  llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
685  llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
686  llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
687  llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
688  llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
689  llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
690 
691  // Use the op carried strides/dilations attribute if present.
692  auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
693  if (!nativeStrides) {
694  SmallVector<AffineExpr, 2> strideExprs;
695  for (unsigned oiDim : dimensions.outputImage)
696  strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
697  dimensions.strides = getConstantsFromExprList(strideExprs);
698  } else {
699  dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
700  }
701  auto nativeDilations =
702  linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
703  if (!nativeDilations) {
704  SmallVector<AffineExpr, 2> dilationExprs;
705  for (unsigned flDim : dimensions.filterLoop)
706  dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
707  dimensions.dilations = getConstantsFromExprList(dilationExprs);
708  } else {
709  dimensions.dilations =
710  llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
711  }
712  return dimensions;
713 }
714 
715 /// Find at least 1 parallel (output_image) and reduction (filter_loop)
716 /// dimension candidates that form a convolution subcomputation within
717 /// `linalgOp`. The LHS is assumed to be the convolution input while the
718 /// RHS is assumed as the filter.
719 /// These dimensions are such that:
720 /// 1. Optional batch dimensions that appear in the input and filter.
721 /// 2. The output_image dimension is involved in a cross-correlation along LHS
722 /// (i.e. it is a permutation on RES and LHS and has an associated
723 /// filter_loop in RHS).
724 /// 3. Optional output_channel dimension is involved in an outer-product along
725 /// RHS (i.e. it is a permutation on RES and RHS and does not appear in
726 /// LHS).
727 /// 4. Optional input_channel dimension appears as a permutation on LHS and
728 /// RHS.
729 /// 5. The filter_loop dimension appears as a permutation on the RHS and
730 /// represents the shape of the kernel cross-correlated along a
731 /// corresponding output_image dim.
732 /// 6. The input_channel dimension appears as a permutation on LHS and RHS.
733 /// 7. All dimensions appear only once in any given indexing map.
734 /// This allows e.g. detecting that some convolution is embedded within
735 /// `linalgOp` with some orthogonal heuristic.
736 /// When multiple dimension occurrences exist that match any classification
737 /// indices are returned in sorted order.
738 /// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
739 FailureOr<ConvolutionDimensions>
741  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
742  return failure();
743 
744  auto indexingMaps = linalgOp.getIndexingMapsArray();
745 
746  // Check the input indexing map has the right form.
747  ConvAccessExprWalker inputExprWalker;
748  for (AffineExpr expr : indexingMaps[0].getResults())
749  (void)inputExprWalker.visit(expr);
750  inputExprWalker.clearMultiUseDims(indexingMaps[0]);
751 
752  return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
753  /*allowEmptyConvolvedDims=*/false);
754 }
755 
756 namespace mlir::linalg::detail {
758  Success = 0,
759  NotLinalgOp,
766 };
767 } // namespace mlir::linalg::detail
768 
771  Operation *op, ConvolutionDimensions *dimensions) {
772  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
773  if (!linalgOp)
774  return MatchConvolutionResult::NotLinalgOp;
775  if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
776  return MatchConvolutionResult::WrongNumOperands;
777 
778  auto indexingMaps = linalgOp.getIndexingMapsArray();
779 
780  // Check the input indexing map has the right form.
781  ConvAccessExprWalker inputExprWalker;
782  if (llvm::any_of(indexingMaps[0].getResults(),
783  [&inputExprWalker](AffineExpr expr) {
784  return failed(inputExprWalker.visit(expr));
785  })) {
786  return MatchConvolutionResult::WrongInputIndexingMap;
787  }
788 
789  // Filter and output maps must be projected permutation.
790  if (!indexingMaps[1].isProjectedPermutation() ||
791  !indexingMaps.back().isProjectedPermutation())
792  return MatchConvolutionResult::NotProjectedPermutations;
793 
794  auto iteratorTypes = linalgOp.getIteratorTypesArray();
795 
796  llvm::SmallDenseSet<int64_t> outputDims =
797  getPreservedDims(indexingMaps.back());
798  llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
799  // Make sure all loops are characterized as one of:
800  // - Batch loop : present in output, as non-convolved in input, not present in
801  // filter.
802  // - Output image dimension : present in output, convolved dims in input, not
803  // present in filter.
804  // - Output channel dimension : present in output, not present in input,
805  // present in filter.
806  // - Filter loop dimension : present in filter, convolved in input, not
807  // present in output.
808  // - Input channel dimension : unconvolved in input, not present in output,
809  // present in filter.
810  // - Depth multiplier : unconvolved in input, present in output, present in
811  // filter.
812  llvm::SmallDenseSet<int64_t> allLoopDims;
813  for (auto outputExpr : indexingMaps.back().getResults()) {
814  int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
815  if (inputExprWalker.unConvolvedDims.count(outputDim) &&
816  !filterDims.count(outputDim)) {
817  // Batch dimension.
818  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
819  return MatchConvolutionResult::OutputDimsNotParallel;
820  allLoopDims.insert(outputDim);
821  continue;
822  }
823  if (inputExprWalker.convolvedDims.count(outputDim) &&
824  !filterDims.count(outputDim)) {
825  // Output image Loop dimension.
826  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
827  return MatchConvolutionResult::OutputDimsNotParallel;
828  allLoopDims.insert(outputDim);
829  continue;
830  }
831  if (!inputExprWalker.convolvedDims.count(outputDim) &&
832  !inputExprWalker.unConvolvedDims.count(outputDim) &&
833  filterDims.count(outputDim)) {
834  // Output channel dimension.
835  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
836  return MatchConvolutionResult::OutputDimsNotParallel;
837  allLoopDims.insert(outputDim);
838  continue;
839  }
840  if (inputExprWalker.unConvolvedDims.count(outputDim) &&
841  filterDims.count(outputDim)) {
842  // Depth multiplier.
843  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
844  return MatchConvolutionResult::OutputDimsNotParallel;
845  allLoopDims.insert(outputDim);
846  continue;
847  }
848  return MatchConvolutionResult::NonConvolutionLoop;
849  }
850  for (auto filterExpr : indexingMaps[1].getResults()) {
851  int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
852  if (outputDims.count(filterDim) &&
853  !inputExprWalker.unConvolvedDims.count(filterDim) &&
854  !inputExprWalker.convolvedDims.count(filterDim)) {
855  // Output channel dimension. This is already seen, continue;
856  continue;
857  }
858  if (inputExprWalker.convolvedDims.count(filterDim) &&
859  !outputDims.count(filterDim)) {
860  // Filter loop dimension.
861  if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
862  return MatchConvolutionResult::NonOutputDimNotReduction;
863  if (allLoopDims.count(filterDim))
864  return MatchConvolutionResult::NonConvolutionLoop;
865  allLoopDims.insert(filterDim);
866  continue;
867  }
868  if (inputExprWalker.unConvolvedDims.count(filterDim) &&
869  !outputDims.count(filterDim)) {
870  // Input channel dimension.
871  if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
872  return MatchConvolutionResult::NonOutputDimNotReduction;
873  if (allLoopDims.count(filterDim))
874  return MatchConvolutionResult::NonConvolutionLoop;
875  allLoopDims.insert(filterDim);
876  continue;
877  }
878  if (inputExprWalker.unConvolvedDims.count(filterDim) &&
879  outputDims.count(filterDim)) {
880  // Depthwise loop. Already seen.
881  continue;
882  }
883  return MatchConvolutionResult::NonConvolutionLoop;
884  }
885  // All loops must be covered now.
886  if (allLoopDims.size() != linalgOp.getNumLoops())
887  return MatchConvolutionResult::NonConvolutionLoop;
888 
889  if (dimensions) {
890  FailureOr<ConvolutionDimensions> res =
891  inferConvolutionDimsImpl(linalgOp, inputExprWalker,
892  /*allowEmptyConvolvedDims=*/true);
893  assert(succeeded(res) && "unexpected failure to infer convolution dims");
894  *dimensions = *res;
895  }
896 
897  return MatchConvolutionResult::Success;
898 }
899 
900 StringRef
902  switch (res) {
903  case MatchConvolutionResult::NotLinalgOp:
904  return "expected a LinalgOp";
905  case MatchConvolutionResult::WrongNumOperands:
906  return "expected op with 2 inputs and 1 output";
907  case MatchConvolutionResult::WrongInputIndexingMap:
908  return "unexpected input index map for convolutions";
909  case MatchConvolutionResult::NotProjectedPermutations:
910  return "expected output/filter indexing maps to be projected permutations";
911  case MatchConvolutionResult::NonConvolutionLoop:
912  return "unexpected loop dimension for convolution op";
913  case MatchConvolutionResult::OutputDimsNotParallel:
914  return "expected all iterators used to access outputs to be parallel";
915  case MatchConvolutionResult::NonOutputDimNotReduction:
916  return "expected all iterators not used to access outputs to be reduction";
917  case MatchConvolutionResult::Success:
918  return "";
919  }
920  llvm_unreachable("unhandled MatchConvolutionResult case");
921 }
922 
924  return linalg::detail::isConvolutionInterfaceImpl(linalgOp.getOperation()) ==
926 }
927 
930  if (res != MatchConvolutionResult::Success)
931  return op->emitError(getMatchConvolutionMessage(res));
932  return success();
933 }
934 
935 //===----------------------------------------------------------------------===//
936 // FillOpInterface implementation
937 //===----------------------------------------------------------------------===//
938 
939 enum class MatchFillResult {
940  Success = 0,
941  NotLinalgOp,
942  WrongNumOperands,
944 };
945 
947  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
948  if (!linalgOp)
950  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
952 
953  OpOperand *value = linalgOp.getDpsInputOperand(0);
954  if (!linalgOp.isScalar(value))
956 
958 }
959 
961  auto res = isFillInterfaceImpl(op);
962  if (res == MatchFillResult::NotLinalgOp)
963  return op->emitError("expected a LinalgOp");
965  return op->emitError("expected op with 1 input and 1 output");
967  return op->emitError("expected op with scalar input");
968 
969  return success();
970 }
971 
972 //===----------------------------------------------------------------------===//
973 // StructuredOpInterface implementation
974 //===----------------------------------------------------------------------===//
975 
976 SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
977  Location loc) {
979  for (OpOperand &opOperand : getOperation()->getOpOperands()) {
980  for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
981  res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
982  }
983  return res;
984 }
985 
986 SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
988  assert(!hasDynamicShape() && "expected operands to have static shapes");
989  for (OpOperand &opOperand : getOperation()->getOpOperands())
990  llvm::append_range(res, getShape(&opOperand));
991  return res;
992 }
993 
994 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
995  AffineMap map = getLoopsToShapesMap();
996  unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
997  auto viewSizes = createFlatListOfOperandDims(b, loc);
998  SmallVector<Range, 4> res(numDims);
999  for (unsigned idx = 0; idx < numRes; ++idx) {
1000  auto result = map.getResult(idx);
1001  if (auto d = dyn_cast<AffineDimExpr>(result)) {
1002  if (res[d.getPosition()].offset)
1003  continue;
1004  res[d.getPosition()] =
1005  Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
1006  }
1007  }
1008  return res;
1009 }
1010 
1011 SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
1012  AffineMap map = getLoopsToShapesMap();
1013  unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
1014  SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims();
1015  SmallVector<int64_t, 4> res(numDims, 0);
1016  for (unsigned idx = 0; idx < numRes; ++idx) {
1017  auto result = map.getResult(idx);
1018  if (auto d = dyn_cast<AffineDimExpr>(result))
1019  res[d.getPosition()] = allShapeSizes[idx];
1020  }
1021  return res;
1022 }
1023 
1024 /// Visitor to check if any of the given set of positions from AffineDimExprs
1025 /// are used within an AffineExpr.
1027  : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
1028  HasAffineDimExprVisitor(llvm::SmallBitVector positions)
1029  : positions(std::move(positions)) {}
1030 
1032  return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
1033  }
1034 
1035  bool visitDimExpr(AffineDimExpr dimExpr) {
1036  return positions.test(dimExpr.getPosition());
1037  }
1038 
1039  bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
1040 
1041  bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
1042 
1043 private:
1044  llvm::SmallBitVector positions;
1045 };
1046 
1047 static std::pair<int64_t, int64_t>
1049  int64_t inputRankSum = 0;
1050  int64_t outputRankSum = 0;
1051  for (OpOperand *input : op.getDpsInputOperands())
1052  inputRankSum += op.getRank(input);
1053  for (OpOperand &output : op.getDpsInitsMutable())
1054  outputRankSum += op.getRank(&output);
1055  return {inputRankSum, inputRankSum + outputRankSum};
1056 }
1057 
1058 LogicalResult
1060  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1061  // An example that helps understand the logic below.
1062  // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
1063  // We want to express the shape of dim 0 of O in terms of shape of the inputs.
1064  // This is achieved as follows.
1065  // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
1066  // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
1067  // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
1068  // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
1069  // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
1070  AffineMap loopsToShapesMap = getLoopsToShapesMap();
1071 
1072  // Find the position in the above map that represents the shape of the
1073  // result:dim being inferred.
1074  auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
1075 
1076  /// From loopsToShapesMap extract the submap that represents the shape of the
1077  /// (resultIdx, dim) needed.
1078  AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
1079  resultShapesSubMapPos.first,
1080  resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1081  AffineMap resultShapesFromInputShapesMap =
1082  loopToResultsShapeMap.compose(getShapesToLoopsMap());
1083 
1084  // Check that the result dim map does not contain the positions corresponding
1085  // to the outputs.
1086  llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
1087  outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1088  HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1089  Location loc = getOperation()->getLoc();
1090  IRRewriter rewriter(b);
1091  SmallVector<OpFoldResult> allResultDimValues =
1093  rewriter, loc, resultShapesFromInputShapesMap,
1094  createFlatListOfOperandDims(b, loc));
1095  int64_t pos = 0;
1096  ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
1097  for (OpOperand &opOperand : getDpsInitsMutable()) {
1099  for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1100  auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1101  if (!shapedType.isDynamicDim(dim)) {
1102  // Static dim: Return IntegerAttr.
1103  shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
1104  } else {
1105  // Dynamic dim: Return Value.
1106  OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1107  ? createOrFoldDimOp(b, loc, opOperand.get(), dim)
1108  : allResultDimValues[pos];
1109  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
1110  }
1111  pos++;
1112  }
1113  reifiedReturnShapes.emplace_back(std::move(shapes));
1114  }
1115  return success();
1116 }
1117 
1118 /// Return the index in the indexingMaps vector that corresponds to this
1119 /// `opOperand`.
1120 int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
1121  auto operandNumber = opOperand->getOperandNumber();
1122  auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1123  if (!dpsIface.isDpsInput(opOperand))
1124  return operandNumber;
1125  unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1126  assert(!dpsIface.isDpsInit(opOperand));
1127  // Account for potential inputs that are not DPS and may not appear in
1128  // `indexingMaps`.
1129  return cast<DestinationStyleOpInterface>(*this->getOperation())
1130  .getNumDpsInputs() +
1131  operandNumber - start;
1132 }
1133 
1135  LinalgOp linalgOp = cast<LinalgOp>(op);
1136 
1137  // Mixed tensor/buffer operands are not allowed.
1138  if (!linalgOp.hasPureTensorSemantics() &&
1139  !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1140  return op->emitOpError("expected to have pure tensor or buffer semantics");
1141 
1142  // Before checking indexing maps, we need to make sure the attributes
1143  // referenced by it are valid.
1144  if (linalgOp.hasDynamicIndexingMaps())
1145  if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1146  return failure();
1147 
1148  // All input/output operands must be indexed.
1149  if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
1150  linalgOp->getNumOperands())
1151  return op->emitOpError("expected the number of indexing_map (")
1152  << linalgOp.getIndexingMapsArray().size()
1153  << ") to be equal to the number of input/output operands ("
1154  << linalgOp->getNumOperands() << ")";
1155 
1156  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1157  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1158 
1159  // Symbols disallowed.
1160  if (indexingMap.getNumSymbols() != 0)
1161  return op->emitOpError("unexpected symbols in indexing_map #")
1162  << opOperand.getOperandNumber();
1163 
1164  // Domain must be consistent.
1165  unsigned numLoops = linalgOp.getNumLoops();
1166  if (indexingMap.getNumDims() != numLoops)
1167  return op->emitOpError("expected indexing_map #")
1168  << opOperand.getOperandNumber() << " to have " << numLoops
1169  << " dim(s) to match the number of loops";
1170 
1171  int64_t rank = linalgOp.getRank(&opOperand);
1172  if (indexingMap.getNumResults() != rank)
1173  return op->emitOpError("expected operand rank (")
1174  << rank << ") to match the result rank of indexing_map #"
1175  << opOperand.getOperandNumber() << " ("
1176  << indexingMap.getNumResults() << ")";
1177  }
1178 
1179  SmallVector<unsigned> redDims;
1180  linalgOp.getReductionDims(redDims);
1181 
1182  if (!linalgOp.getShapesToLoopsMap())
1183  return op->emitOpError("expected the shape-to-loops map to be non-null");
1184 
1185  // Check if given shapes match to inferred shapes.
1186  SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
1187  SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
1188 
1189  // Verify only static cases since we can't get exact dimension sizes and loop
1190  // ranges for dynamic cases in this stage.
1191  if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
1192  for (int64_t &range : endLoopRangeValues)
1193  range -= 1;
1194  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1195  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1196  SmallVector<int64_t, 4> startIndices =
1197  indexingMap.compose(startLoopRangeValues);
1198  SmallVector<int64_t, 4> endIndices =
1199  indexingMap.compose(endLoopRangeValues);
1200  ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand);
1201  for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
1202  // Ignore dynamic dimension or the case that the dimension size is 0
1203  if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
1204  continue;
1205 
1206  // The first index or last index should be the maximum or the minimum in
1207  // the inferred index ranges since the range is increasing or
1208  // decreasing. The size of dimensions of input/output operands and the
1209  // maximum value + 1 in the inferred range should be the same. But, for
1210  // now we check if the inferred ranges are in boundary of input/output
1211  // operands' size or not in case that Affine Expressions are complicated
1212  // such as d0 * 3
1213  // + d1 since it is not easy to handle the issues.
1214  // Found the case that this solution can't check, for example, (d0, d1)
1215  // -> (d1 - d0)
1216  int64_t inferredDimSize =
1217  std::max(startIndices[dim], endIndices[dim]) + 1;
1218  if (std::min(startIndices[dim], endIndices[dim]) < 0) {
1219  std::string mapStr;
1220  {
1221  llvm::raw_string_ostream os(mapStr);
1222  os << indexingMap;
1223  }
1224  return op->emitOpError(
1225  "unexpected result less than 0 at expression #")
1226  << dim << " in " << mapStr;
1227  }
1228  if (dyn_cast<AffineDimExpr>(indexingMap.getResult(dim))) {
1229  if (inferredDimSize != shape[dim]) {
1230  return op->emitOpError("inferred input/output operand #")
1231  << opOperand.getOperandNumber() << " has shape's dimension #"
1232  << dim << " to be " << inferredDimSize << ", but found "
1233  << shape[dim];
1234  }
1235  } else {
1236  if (inferredDimSize > shape[dim]) {
1237  return op->emitOpError("inferred input/output operand #")
1238  << opOperand.getOperandNumber() << " has shape's dimension #"
1239  << dim << " to be greater than or equal to "
1240  << inferredDimSize << ", but found " << shape[dim];
1241  }
1242  }
1243  }
1244  }
1245  }
1246 
1247  // Check the region has exactly one block.
1248  if (linalgOp->getNumRegions() != 1 ||
1249  !llvm::hasSingleElement(linalgOp->getRegion(0)))
1250  return op->emitOpError("expects to have 1 region with 1 block");
1251 
1252  // Simplifying assumption: bbargs match 1-1 with shape operands elemental
1253  // types.
1254  // TODO: once ranked shape types are plugged in, we may want to drop the
1255  // corresponding bbargs, that can never be read from. This will be subject to
1256  // consistency discussions (i.e. what to do with output tensors whose bbarg is
1257  // not used).
1258  Block &block = linalgOp->getRegion(0).front();
1259 
1260  if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
1261  return op->emitOpError("expected as many non-induction variable region "
1262  "arguments as the number of input/output operands");
1263 
1264  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1265  Type elementType = opOperand->get().getType();
1266  if (isa<MemRefType, RankedTensorType>(elementType))
1267  elementType = getElementTypeOrSelf(opOperand->get().getType());
1268  Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
1269  if (elementType != argType)
1270  return op->emitOpError("expected type of bb argument #")
1271  << opOperand->getOperandNumber() << " (" << argType << ")"
1272  << " to match element or self type of the corresponding operand ("
1273  << elementType << ")";
1274  }
1275 
1276  return success();
1277 }
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:63
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp, unsigned arity)
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static MatchFillResult isFillInterfaceImpl(Operation *op)
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
MatchFillResult
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Affine binary operation expression.
Definition: AffineExpr.h:227
AffineExpr getLHS() const
Definition: AffineExpr.cpp:339
AffineExpr getRHS() const
Definition: AffineExpr.cpp:342
An integer constant appearing in affine expression.
Definition: AffineExpr.h:252
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:236
unsigned getPosition() const
Definition: AffineExpr.cpp:347
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:34
MLIRContext * getContext() const
Definition: AffineExpr.cpp:32
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:626
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:582
unsigned getNumSymbols() const
Definition: AffineMap.cpp:385
unsigned getNumDims() const
Definition: AffineMap.cpp:381
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:394
unsigned getNumResults() const
Definition: AffineMap.cpp:389
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:398
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:543
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:244
Block represents an ordered list of Operations.
Definition: Block.h:31
bool empty()
Definition: Block.h:146
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation & back()
Definition: Block.h:150
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
OpListType & getOperations()
Definition: Block.h:135
Operation & front()
Definition: Block.h:151
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition: Operation.h:753
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1239
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:98
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
bool isaConvolutionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ConvolutionOpInterface.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:89
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:755
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:800
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:630
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.