MLIR  20.0.0git
LinalgInterfaces.cpp
Go to the documentation of this file.
1 //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
19 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "llvm/ADT/SetOperations.h"
22 #include "llvm/ADT/SmallBitVector.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include <algorithm>
25 #include <numeric>
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 
30 /// Include the definitions of the copy operation interface.
31 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
32 
33 //===----------------------------------------------------------------------===//
34 // Interface utility functions
35 //===----------------------------------------------------------------------===//
36 
38  linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
39  SmallVector<AffineMap> indexingMaps;
40  for (auto &opOperand : linalgOp->getOpOperands()) {
41  if (llvm::is_contained(droppedOperands, &opOperand))
42  continue;
43  indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
44  }
45  if (indexingMaps.empty()) {
46  // If there are no indexing maps, the operand can only be dropped
47  // if the op has no loops.
48  return linalgOp.getNumLoops() == 0;
49  }
50  return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
51 }
52 
53 //===----------------------------------------------------------------------===//
54 // CopyOpInterface implementation
55 //===----------------------------------------------------------------------===//
56 
57 bool linalg::isaCopyOpInterface(LinalgOp op) {
58  // Check all loops are parallel and linalgOp is single input and output.
59  if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
60  return false;
61 
62  auto mapRange = op.getIndexingMapsArray();
63  if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
64  !mapRange.back().isIdentity()) {
65  return false;
66  }
67  // Region.
68  return llvm::hasSingleElement(op.getBlock()->getOperations());
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // FillOpInterface implementation
73 //===----------------------------------------------------------------------===//
74 std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
75  // Structural.
76  if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
77  !op.isSingleYieldOp())
78  return std::nullopt;
79 
80  // Input should be referenced and init should not.
81  if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
82  op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
83  return std::nullopt;
84 
85  OpOperand *value = op.getDpsInputOperand(0);
86  if (!op.isScalar(value))
87  return std::nullopt;
88  return value->get();
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // BroadcastOpInterface implementation
93 //===----------------------------------------------------------------------===//
94 std::optional<SmallVector<int64_t>>
96  // Structural.
97  if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
98  !op.isSingleYieldOp())
99  return std::nullopt;
100 
101  auto srcTy = op.getDpsInputOperand(0)->get().getType();
102  auto dstTy = op.getDpsInitOperand(0)->get().getType();
103  if (!isa<MemRefType, RankedTensorType>(srcTy) ||
104  !isa<MemRefType, RankedTensorType>(dstTy))
105  return std::nullopt;
106 
107  // Check output is identity map. Broadcast could additionally be
108  // employing permutation of indices and that would be expressible
109  // in linalg.generic but is not expressible for named broadcast op.
110  auto dstMap = op.getIndexingMapsArray()[1];
111  if (!dstMap.isIdentity())
112  return std::nullopt;
113 
114  SmallVector<int64_t> position;
115  auto srcMap = op.getIndexingMapsArray()[0];
116 
117  if (srcMap.getResults().size() >= dstMap.getResults().size())
118  return std::nullopt;
119 
120  // Check input map is monotonically increasing DimIds.
121  for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
122  auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
123  if (!expr)
124  return std::nullopt;
125  int64_t pos = expr.getPosition();
126  if (i > 0 && pos <= position[i - 1])
127  return std::nullopt;
128  position.push_back(expr.getPosition());
129  }
130 
131  SmallVector<int64_t> broadcastedDims;
132  auto numDims = srcMap.getNumDims();
133  // This is quadratic but number of items is generally small.
134  for (auto dim : llvm::seq<int64_t>(0, numDims)) {
135  if (!llvm::is_contained(position, dim))
136  broadcastedDims.push_back(dim);
137  }
138  return broadcastedDims;
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // TranposeOpInterface implementation
143 //===----------------------------------------------------------------------===//
144 std::optional<SmallVector<int64_t>>
146  // To specialize as a transpose op, the genericOp must be
147  // all parallel loops, single input, single output, and its body
148  // should be just a yield op, yielding input as output as is (no compute).
149  if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
150  !op.isSingleYieldOp())
151  return std::nullopt;
152 
153  auto mapRange = op.getIndexingMapsArray();
154  if (mapRange.size() != 2)
155  return std::nullopt;
156 
157  auto mapOfInput = mapRange.front();
158  auto mapOfResult = mapRange.back();
159 
160  // linalg.transpose permutes the dimensions of input using this
161  // rule: dim(result, i) = dim(input, permutation[i])
162  if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
163  return std::nullopt;
164 
165  SmallVector<int64_t> permutation(mapOfInput.getNumDims());
166  for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
167  auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
168  permutation[expr.getPosition()] = i;
169  }
170  return permutation;
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // Elementwise Single Unary/Binary-OpInterface implementation
175 //===----------------------------------------------------------------------===//
176 static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
177  unsigned arity) {
178  // Check all loops are parallel.
179  if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
180  return false;
181 
182  // Check there are arity-inputs, 1-output and all are identity-maps.
183  if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
184  !llvm::all_of(op.getIndexingMapsArray(),
185  [](AffineMap map) { return map.isIdentity(); }))
186  return false;
187 
188  // Init should not be referenced for elementwise operations.
189  if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
190  return false;
191 
192  // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
193  // as resulting from producer-consumer fusion. Here, we restrict to two ops in
194  // the body, where the first is the elementwise single op and the second a
195  // yield.
196  Block *body = op.getBody();
197  if (body->getOperations().size() != 2)
198  return false;
199 
200  Operation *oper = &body->front();
201  if (oper->getNumOperands() != arity || oper->getNumResults() != 1)
202  return false;
203 
204  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
205  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
206  yieldOp->getOperand(0).getDefiningOp() != oper)
207  return false;
208  return true;
209 }
210 
211 bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) {
212  // All basic elemwise checks.
214  return false;
215 
216  // Check input is actully used.
217  if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
218  return false;
219  return true;
220 }
221 
222 bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
224  return false;
225 
226  // Check both inputs are used (elementwise).
227  OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
228  OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
229  if (!op.payloadUsesValueFromOperand(inputOpOperand0) ||
230  !op.payloadUsesValueFromOperand(inputOpOperand1))
231  return false;
232  return true;
233 }
234 
235 //===----------------------------------------------------------------------===//
236 // ContractionOpInterface implementation
237 //===----------------------------------------------------------------------===//
238 
239 /// If the value is defined by a chain of unary side effect-free, go up the
240 /// use-def chain until the first value that isn't defined by such an op.
241 // TODO: relax to multi-operands with constants, which are technically unary ops
242 // as needed (e.g. add5).
244  Operation *op = value.getDefiningOp();
245  while (op && op->getNumOperands() == 1) {
246  auto iface = dyn_cast<MemoryEffectOpInterface>(op);
247  if (!iface || !iface.hasNoEffect())
248  break;
249  value = op->getOperand(0);
250  op = value.getDefiningOp();
251  }
252  return value;
253 }
254 
256  Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
257  llvm::raw_ostream &errs) {
258  if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
259  errs << "no terminator in the block";
260  return false;
261  }
262 
263  if (block.getNumArguments() != 3) {
264  errs << "expected block with 3 arguments";
265  return false;
266  }
267 
268  Operation *terminator = block.getTerminator();
269  if (terminator->getNumOperands() != 1) {
270  errs << "expected terminator with 1 operand";
271  return false;
272  }
273 
274  Value yielded = getSourceSkipUnary(terminator->getOperand(0));
275  Operation *reductionOp = yielded.getDefiningOp();
276  if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
277  errs << "expected reduction op to be binary";
278  return false;
279  }
280 
281  Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
282  Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
283 
284  if (reductionLHS != block.getArgument(2) &&
285  reductionRHS != block.getArgument(2)) {
286  errs << "expected reduction to take block argument #2 as one of the "
287  "operands (modulo unary casts)";
288  return false;
289  }
290 
291  Value contributed = getSourceSkipUnary(
292  isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
293  Operation *elementwiseOp = contributed.getDefiningOp();
294  if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||
295  elementwiseOp->getNumOperands() != 2) {
296  errs << "expected elementwise op to be binary";
297  return false;
298  }
299 
300  if (!isaPair(elementwiseOp, reductionOp)) {
301  errs << "expected reduction/elementwise op kind not satisfied";
302  return false;
303  }
304 
305  Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
306  Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
307  if ((elementwiseLHS == block.getArgument(0) &&
308  elementwiseRHS == block.getArgument(1)) ||
309  (elementwiseLHS == block.getArgument(1) &&
310  elementwiseRHS == block.getArgument(0))) {
311  return true;
312  }
313 
314  errs << "expected elementwise op to apply to block arguments (modulo unary "
315  "casts)";
316  return false;
317 }
318 
319 /// Returns true if the two operations are of the kinds specified by a pair of
320 /// consecutive template arguments.
321 template <typename AddOpTy, typename MulOpTy, typename... Args>
322 static bool isPairTemplateImpl(Operation *add, Operation *mul) {
323  static_assert(sizeof...(Args) % 2 == 0,
324  "expected an even number of template arguments");
325  if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
326  return true;
327 
328  if constexpr (sizeof...(Args) > 0)
329  return isPairTemplateImpl<Args...>(add, mul);
330  else
331  return false;
332 }
333 
334 /// Returns true if the block is a body of a contraction with the kinds of
335 /// operations given pairwise by template arguments.
336 template <typename... Args>
337 static bool isContractionBody(Block &block) {
338  return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
339 }
340 
341 /// Given an `indexingMap` and its corresponding `iterators`, returns
342 /// the positions of the iterators of type `iter` that are indexed by
343 /// the `indexingMap` as a permutation. This is useful to infer various
344 /// subcomputations on a `LinalgOp`. This is performed by looking up
345 /// each result in the `indexingMap` and determining whether:
346 /// - It is a single AffineDimExpr.
347 /// - It is the only result involving this AffineDimExpr.
348 static llvm::SmallDenseSet<int64_t>
351  utils::IteratorType iter) {
352  assert(iterators.size() == indexingMap.getNumDims());
353  llvm::SmallDenseSet<int64_t> res;
354  for (AffineExpr e : indexingMap.getResults()) {
355  if (auto d = dyn_cast<AffineDimExpr>(e)) {
356  if (iterators[d.getPosition()] == iter &&
357  llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
358  return e.isFunctionOfDim(d.getPosition());
359  }) == 1)
360  res.insert(d.getPosition());
361  }
362  }
363  return res;
364 }
365 
366 namespace {
367 auto par = utils::IteratorType::parallel;
368 auto red = utils::IteratorType::reduction;
369 } // namespace
370 
371 /// Infer the iterator types from the init affine map. This looks at which dims
372 /// are present in the map results, and returns an iterator types array with
373 /// parallel types for dims that are present, and reduction types for dims that
374 /// are not present.
375 static FailureOr<SmallVector<utils::IteratorType>>
377  if (!map.isProjectedPermutation())
378  return failure();
379  SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
380  for (auto expr : map.getResults())
381  if (auto dim = dyn_cast<AffineDimExpr>(expr))
382  iterators[dim.getPosition()] = par;
383  return iterators;
384 }
385 
386 /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
387 /// a matmul subcomputation within `linalgOp`. These dimensions are such that:
388 /// 1. The m dimension is involved in an outer-product along LHS
389 /// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
390 /// 2. The n dimension is involved in an outer-product along RHS
391 /// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
392 /// 3. The k dimension appears as a permutation on LHS and RHS.
393 /// 4. m, n and k appear only once in any given indexing.
394 /// 5. Optional batch dimensions that appear in all operands are captured.
395 /// This allows e.g. detecting that some contraction is embedded within
396 /// `linalgOp` with some orthogonal heuristic.
397 static FailureOr<ContractionDimensions>
399  ArrayRef<utils::IteratorType> iterators) {
400  llvm::SmallDenseSet<int64_t> a =
401  findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
402  llvm::SmallDenseSet<int64_t> b =
403  findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
404  llvm::SmallDenseSet<int64_t> c =
405  findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
406 
407  // A & C - B are the iterators involved in an outer-product along A (the LHS).
408  llvm::SmallDenseSet<int64_t> ac = a;
409  llvm::set_intersect(ac, c);
410  llvm::set_subtract(ac, b);
411  // B & C - A are the iterators involved in an outer-product along B (the RHS).
412  llvm::SmallDenseSet<int64_t> bc = b;
413  llvm::set_intersect(bc, c);
414  llvm::set_subtract(bc, a);
415  // A & B & C are the "batch" dimensions.
416  llvm::SmallDenseSet<int64_t> batches = a;
417  llvm::set_intersect(batches, b);
418  llvm::set_intersect(batches, c);
419 
420  // A & B red are the reduction dimensions.
421  llvm::SmallDenseSet<int64_t> ra =
422  findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
423  llvm::SmallDenseSet<int64_t> rb =
424  findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
425  llvm::set_intersect(ra, rb);
426 
427  // Return each set in sorted order.
428  ContractionDimensions dimensions{
429  SmallVector<unsigned, 2>(batches.begin(), batches.end()),
430  SmallVector<unsigned, 2>(ac.begin(), ac.end()),
431  SmallVector<unsigned, 2>(bc.begin(), bc.end()),
432  SmallVector<unsigned, 2>(ra.begin(), ra.end())};
433  llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
434  llvm::sort(dimensions.m.begin(), dimensions.m.end());
435  llvm::sort(dimensions.n.begin(), dimensions.n.end());
436  llvm::sort(dimensions.k.begin(), dimensions.k.end());
437  return dimensions;
438 }
439 
440 FailureOr<ContractionDimensions>
442  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
443  return failure();
444  return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
445  linalgOp.getIteratorTypesArray());
446 }
447 
448 FailureOr<ContractionDimensions>
450  if (indexingMaps.size() != 3)
451  return failure();
452  auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
453  if (failed(iterators))
454  return failure();
455  return inferContractionDimsImpl(indexingMaps, iterators.value());
456 }
457 
458 namespace mlir::linalg::detail {
460  Success = 0,
461  NotLinalgOp,
463  NoReduction,
465  NotAddMul
466 };
467 } // namespace mlir::linalg::detail
468 
472  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
473  if (!linalgOp)
474  return MatchContractionResult::NotLinalgOp;
475  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
476  return MatchContractionResult::WrongNumOperands;
477  auto mapRange = linalgOp.getIndexingMapsArray();
478  if (linalgOp.getNumReductionLoops() == 0)
479  return MatchContractionResult::NoReduction;
480  if (llvm::any_of(mapRange,
481  [](AffineMap m) { return !m.isProjectedPermutation(); }))
482  return MatchContractionResult::NotProjectedPermutations;
483  // TODO: more fields than add/mul.
484  // clang-format off
485  if (!::isContractionBody<
486  arith::MulFOp, arith::AddFOp,
487  arith::MulIOp, arith::AddIOp,
488  complex::MulOp, complex::AddOp,
489  arith::AndIOp, arith::OrIOp>(
490  *linalgOp.getBlock())) {
491  return MatchContractionResult::NotAddMul;
492  }
493  // clang-format on
494 
495  if (dimensions) {
496  FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);
497  assert(succeeded(res) && "unexpected failure to infer contraction dims");
498  *dimensions = *res;
499  }
500  return MatchContractionResult::Success;
501 }
502 
503 StringRef
505  switch (res) {
506  case MatchContractionResult::NotLinalgOp:
507  return "expected a LinalgOp";
508  case MatchContractionResult::WrongNumOperands:
509  return "expected op with 2 inputs and 1 output";
510  case MatchContractionResult::NoReduction:
511  return "expected at least 1 reduction";
512  case MatchContractionResult::NotProjectedPermutations:
513  return "expected indexing maps to be projected permutations";
514  case MatchContractionResult::NotAddMul:
515  return "expected add/mul op in the body";
516  case MatchContractionResult::Success:
517  return "";
518  }
519  llvm_unreachable("unhandled MatchContractionResult case");
520 }
521 
523  if (!linalgOp)
524  return false;
525  Operation *op = linalgOp.getOperation();
526  return isa<ContractionOpInterface>(op) ||
529 }
530 
531 /// Verify that a LinalgOp `op` is a contraction.
532 /// A Linalg contraction is defined in general terms:
533 /// 1. Has 2 input and 1 output shapes.
534 /// 2. Has at least one reduction dimension.
535 /// 3. Has only projected permutation indexing maps.
536 /// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
537 /// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
538 /// operations that may change the type (e.g. for mixed-precision).
539 /// As a consequence, when vectorization of such an op occurs, the only special
540 /// behavior is that the (unique) MulOpType is vectorized into a
541 /// `vector.contract`. All other ops are handled in a generic fashion.
542 /// In the future, we may wish to allow more input arguments and elementwise and
543 /// constant operations that do not involve the reduction dimension(s).
545  auto res = isContractionInterfaceImpl(op);
546  if (res != MatchContractionResult::Success)
547  return op->emitError(getMatchContractionMessage(res));
548  return success();
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // ConvolutionOpInterface implementation
553 //===----------------------------------------------------------------------===//
554 
555 /// Of the given two expressions returns one that is of type T (`lhs` gets
556 /// preference over `rhs`)
557 template <typename T>
559  return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
560 }
561 
562 namespace {
563 /// Walk the indexing expressions for input of a convolution operation to verify
564 /// its of the right form, either
565 /// - AffineDimExpr
566 /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
567 /// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
568 ///
569 /// classifies the AffineDimExpr as convolved dimensions or unconvolved
570 /// dimensions and verifies each dimension occurs only once.
571 struct ConvAccessExprWalker
572  : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
573  // Stores dimensions used in expressions of the above form.
574  llvm::SmallDenseSet<int64_t> convolvedDims;
575  // Stores the dual mapping between LHS and RHS of convolution exprs.
576  llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
577  // Stores single use dimensions used by an AffineDimExpr.
578  llvm::SmallDenseSet<int64_t> unConvolvedDims;
579  // Stores a mapping from convolved dims to their coefficient.
580  llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
581 
582  // Removes dims with multiple uses in the source input map from dimension
583  // sets tracked by this walker.
584  void clearMultiUseDims(AffineMap map) {
585  for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
586  if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) {
587  return e.isFunctionOfDim(dimPos);
588  }) > 1) {
589  convolvedDims.erase(dimPos);
590  unConvolvedDims.erase(dimPos);
591  // If a duplicate dim is marked as convolved, the pair of the duplicate
592  // dim must be removed from the map as well.
593  auto it = convolvedDimMapping.find(dimPos);
594  if (it != convolvedDimMapping.end()) {
595  int64_t pairedDim = it->second;
596  convolvedDims.erase(pairedDim);
597  unConvolvedDims.erase(pairedDim);
598  strideAndDilationMapping.erase(pairedDim);
599  convolvedDimMapping.erase(dimPos);
600  convolvedDimMapping.erase(pairedDim);
601  }
602  }
603  }
604  }
605 
606  LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
607  unsigned position = dimExpr.getPosition();
608  if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
609  return failure();
610  }
611  unConvolvedDims.insert(position);
612  return success();
613  }
614 
615  LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
616 
617  LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
618 
619  LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
620  // In pre-order visit, top level op has to be an add op.
621  if (binaryExpr.getKind() != AffineExprKind::Add)
622  return failure();
623  auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());
624  auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());
625  if (failed(lhsDimPos) || failed(rhsDimPos))
626  return failure();
627  convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
628  convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
629  return success();
630  }
631 
632  FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
633  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
634  int64_t dim = dimExpr.getPosition();
635  if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
636  return failure();
637  // Stride/dilation for this dim is implicitly 1.
638  strideAndDilationMapping[dim] =
640  convolvedDims.insert(dim);
641  return dim;
642  }
643  if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
644  if (symbolMulExpr.getKind() != AffineExprKind::Mul)
645  return failure();
646  auto lhsExpr = symbolMulExpr.getLHS();
647  auto rhsExpr = symbolMulExpr.getRHS();
648  // Check for symbol expression.
649  AffineExpr mulExpr =
650  getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
651  // If there was no symbol expr, check for constant expression.
652  if (!mulExpr) {
653  mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
654  }
655  auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
656  if (!mulExpr || !dimExpr)
657  return failure();
658  int64_t dim = dimExpr.getPosition();
659  if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
660  return failure();
661  strideAndDilationMapping[dim] = mulExpr;
662  convolvedDims.insert(dim);
663  return dim;
664  }
665  return failure();
666  }
667 };
668 } // namespace
669 
670 static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
671  assert(map.isProjectedPermutation() &&
672  "expected map to have projected permutations");
673  llvm::SmallDenseSet<int64_t> preservedDims;
674  for (auto expr : map.getResults())
675  preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
676  return preservedDims;
677 }
678 
682  for (auto e : exprs) {
683  auto constantExpr = dyn_cast<AffineConstantExpr>(e);
684  assert(constantExpr && "Found non-constant stride/dilation");
685  vals.push_back(constantExpr.getValue());
686  }
687  return vals;
688 }
689 
690 /// Classifies dimensions in the `linalgOp` used by a convolution
691 /// subcomputation, as captured by `inputExprWalker`. If
692 /// `allowEmptyConvolvedDims` is not set this this will fail if there is not
693 /// at least convolved dimension pair (output image + filter loop). Convolution
694 /// dimensions are specified in sorted order, and strides match the order of
695 /// the filter loop dimensions, while the dilations match the order of the
696 /// output image dimensions.
697 static FailureOr<ConvolutionDimensions>
698 inferConvolutionDimsImpl(LinalgOp linalgOp,
699  ConvAccessExprWalker &inputExprWalker,
700  bool allowEmptyConvolvedDims) {
701  auto filterMap =
702  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
703  auto outputMap =
704  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
705  llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
706  filterMap, linalgOp.getIteratorTypesArray(), par);
707  llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
708  outputMap, linalgOp.getIteratorTypesArray(), par);
709 
710  // unConvolvedDims & outputDims - filterDims are the batch iterators.
711  llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
712  llvm::set_intersect(batch, outputDims);
713  llvm::set_subtract(batch, filterDims);
714 
715  // convolvedDims & outputDims are the output image iterators.
716  llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
717  llvm::set_intersect(oi, outputDims);
718 
719  // filterDims & outputDims - unConvolvedDims are the output channel iterators.
720  llvm::SmallDenseSet<int64_t> oc = filterDims;
721  llvm::set_intersect(oc, outputDims);
722  llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
723 
724  // filterDims & outputDims & unConvolvedDims are the depth iterators.
725  llvm::SmallDenseSet<int64_t> depth = filterDims;
726  llvm::set_intersect(depth, outputDims);
727  llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
728 
729  llvm::SmallDenseSet<int64_t> filterReducedDims =
731  linalgOp.getIteratorTypesArray(), red);
732 
733  // convolvedDims & filterReducedDims are the filter loop iterators.
734  llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
735  llvm::set_intersect(fl, filterReducedDims);
736 
737  // unConvolvedDims & filterReducedDims are the input channel iterators.
738  llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
739  llvm::set_intersect(ic, filterReducedDims);
740 
741  if (oi.empty() && !allowEmptyConvolvedDims)
742  return failure();
743 
744  // Return each set in sorted order.
745  ConvolutionDimensions dimensions{
746  SmallVector<unsigned, 2>(batch.begin(), batch.end()),
747  SmallVector<unsigned, 2>(oi.begin(), oi.end()),
748  SmallVector<unsigned, 2>(oc.begin(), oc.end()),
749  SmallVector<unsigned, 2>(fl.begin(), fl.end()),
750  SmallVector<unsigned, 2>(ic.begin(), ic.end()),
751  SmallVector<unsigned, 2>(depth.begin(), depth.end()),
752  /*strides=*/SmallVector<int64_t, 2>{},
753  /*dilations=*/SmallVector<int64_t, 2>{}};
754  llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
755  llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
756  llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
757  llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
758  llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
759  llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
760 
761  // Use the op carried strides/dilations attribute if present.
762  auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
763  if (!nativeStrides) {
764  SmallVector<AffineExpr, 2> strideExprs;
765  for (unsigned oiDim : dimensions.outputImage)
766  strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
767  dimensions.strides = getConstantsFromExprList(strideExprs);
768  } else {
769  dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
770  }
771  auto nativeDilations =
772  linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
773  if (!nativeDilations) {
774  SmallVector<AffineExpr, 2> dilationExprs;
775  for (unsigned flDim : dimensions.filterLoop)
776  dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
777  dimensions.dilations = getConstantsFromExprList(dilationExprs);
778  } else {
779  dimensions.dilations =
780  llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
781  }
782  return dimensions;
783 }
784 
785 /// Find at least 1 parallel (output_image) and reduction (filter_loop)
786 /// dimension candidates that form a convolution subcomputation within
787 /// `linalgOp`. The LHS is assumed to be the convolution input while the
788 /// RHS is assumed as the filter.
789 /// These dimensions are such that:
790 /// 1. Optional batch dimensions that appear in the input and filter.
791 /// 2. The output_image dimension is involved in a cross-correlation along LHS
792 /// (i.e. it is a permutation on RES and LHS and has an associated
793 /// filter_loop in RHS).
794 /// 3. Optional output_channel dimension is involved in an outer-product along
795 /// RHS (i.e. it is a permutation on RES and RHS and does not appear in
796 /// LHS).
797 /// 4. Optional input_channel dimension appears as a permutation on LHS and
798 /// RHS.
799 /// 5. The filter_loop dimension appears as a permutation on the RHS and
800 /// represents the shape of the kernel cross-correlated along a
801 /// corresponding output_image dim.
802 /// 6. The input_channel dimension appears as a permutation on LHS and RHS.
803 /// 7. All dimensions appear only once in any given indexing map.
804 /// This allows e.g. detecting that some convolution is embedded within
805 /// `linalgOp` with some orthogonal heuristic.
806 /// When multiple dimension occurrences exist that match any classification
807 /// indices are returned in sorted order.
808 /// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
809 FailureOr<ConvolutionDimensions>
811  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
812  return failure();
813 
814  auto indexingMaps = linalgOp.getIndexingMapsArray();
815 
816  // Check the input indexing map has the right form.
817  ConvAccessExprWalker inputExprWalker;
818  for (AffineExpr expr : indexingMaps[0].getResults())
819  (void)inputExprWalker.visit(expr);
820  inputExprWalker.clearMultiUseDims(indexingMaps[0]);
821 
822  return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
823  /*allowEmptyConvolvedDims=*/false);
824 }
825 
826 namespace mlir::linalg::detail {
828  Success = 0,
829  NotLinalgOp,
837 };
838 } // namespace mlir::linalg::detail
839 
842  Operation *op, ConvolutionDimensions *dimensions,
843  bool allowEmptyConvolvedDims) {
844  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
845  if (!linalgOp)
846  return MatchConvolutionResult::NotLinalgOp;
847  if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
848  return MatchConvolutionResult::WrongNumOperands;
849 
850  auto indexingMaps = linalgOp.getIndexingMapsArray();
851 
852  // Check the input indexing map has the right form.
853  ConvAccessExprWalker inputExprWalker;
854  if (llvm::any_of(indexingMaps[0].getResults(),
855  [&inputExprWalker](AffineExpr expr) {
856  return failed(inputExprWalker.visit(expr));
857  })) {
858  return MatchConvolutionResult::WrongInputIndexingMap;
859  }
860 
861  // Filter and output maps must be projected permutation.
862  if (!indexingMaps[1].isProjectedPermutation() ||
863  !indexingMaps.back().isProjectedPermutation())
864  return MatchConvolutionResult::NotProjectedPermutations;
865 
866  auto iteratorTypes = linalgOp.getIteratorTypesArray();
867 
868  llvm::SmallDenseSet<int64_t> outputDims =
869  getPreservedDims(indexingMaps.back());
870  llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
871  // Make sure all loops are characterized as one of:
872  // - Batch loop : present in output, as non-convolved in input, not present in
873  // filter.
874  // - Output image dimension : present in output, convolved dims in input, not
875  // present in filter.
876  // - Output channel dimension : present in output, not present in input,
877  // present in filter.
878  // - Filter loop dimension : present in filter, convolved in input, not
879  // present in output.
880  // - Input channel dimension : unconvolved in input, not present in output,
881  // present in filter.
882  // - Depth multiplier : unconvolved in input, present in output, present in
883  // filter.
884  llvm::SmallDenseSet<int64_t> allLoopDims;
885  for (auto outputExpr : indexingMaps.back().getResults()) {
886  int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
887  if (inputExprWalker.unConvolvedDims.count(outputDim) &&
888  !filterDims.count(outputDim)) {
889  // Batch dimension.
890  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
891  return MatchConvolutionResult::OutputDimsNotParallel;
892  allLoopDims.insert(outputDim);
893  continue;
894  }
895  if (inputExprWalker.convolvedDims.count(outputDim) &&
896  !filterDims.count(outputDim)) {
897  // Output image Loop dimension.
898  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
899  return MatchConvolutionResult::OutputDimsNotParallel;
900  allLoopDims.insert(outputDim);
901  continue;
902  }
903  if (!inputExprWalker.convolvedDims.count(outputDim) &&
904  !inputExprWalker.unConvolvedDims.count(outputDim) &&
905  filterDims.count(outputDim)) {
906  // Output channel dimension.
907  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
908  return MatchConvolutionResult::OutputDimsNotParallel;
909  allLoopDims.insert(outputDim);
910  continue;
911  }
912  if (inputExprWalker.unConvolvedDims.count(outputDim) &&
913  filterDims.count(outputDim)) {
914  // Depth multiplier.
915  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
916  return MatchConvolutionResult::OutputDimsNotParallel;
917  allLoopDims.insert(outputDim);
918  continue;
919  }
920  return MatchConvolutionResult::NonConvolutionLoop;
921  }
922  for (auto filterExpr : indexingMaps[1].getResults()) {
923  int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
924  if (outputDims.count(filterDim) &&
925  !inputExprWalker.unConvolvedDims.count(filterDim) &&
926  !inputExprWalker.convolvedDims.count(filterDim)) {
927  // Output channel dimension. This is already seen, continue;
928  continue;
929  }
930  if (inputExprWalker.convolvedDims.count(filterDim) &&
931  !outputDims.count(filterDim)) {
932  // Filter loop dimension.
933  if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
934  return MatchConvolutionResult::NonOutputDimNotReduction;
935  if (allLoopDims.count(filterDim))
936  return MatchConvolutionResult::NonConvolutionLoop;
937  allLoopDims.insert(filterDim);
938  continue;
939  }
940  if (inputExprWalker.unConvolvedDims.count(filterDim) &&
941  !outputDims.count(filterDim)) {
942  // Input channel dimension.
943  if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
944  return MatchConvolutionResult::NonOutputDimNotReduction;
945  if (allLoopDims.count(filterDim))
946  return MatchConvolutionResult::NonConvolutionLoop;
947  allLoopDims.insert(filterDim);
948  continue;
949  }
950  if (inputExprWalker.unConvolvedDims.count(filterDim) &&
951  outputDims.count(filterDim)) {
952  // Depthwise loop. Already seen.
953  continue;
954  }
955  return MatchConvolutionResult::NonConvolutionLoop;
956  }
957  // All loops must be covered now.
958  if (allLoopDims.size() != linalgOp.getNumLoops())
959  return MatchConvolutionResult::NonConvolutionLoop;
960 
961  if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
962  return MatchConvolutionResult::EmptyConvolvedDims;
963 
964  if (dimensions) {
965  FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl(
966  linalgOp, inputExprWalker, allowEmptyConvolvedDims);
967  assert(succeeded(res) && "unexpected failure to infer convolution dims");
968  *dimensions = *res;
969  }
970 
971  return MatchConvolutionResult::Success;
972 }
973 
974 StringRef
976  switch (res) {
977  case MatchConvolutionResult::NotLinalgOp:
978  return "expected a LinalgOp";
979  case MatchConvolutionResult::WrongNumOperands:
980  return "expected op with 2 inputs and 1 output";
981  case MatchConvolutionResult::WrongInputIndexingMap:
982  return "unexpected input index map for convolutions";
983  case MatchConvolutionResult::NotProjectedPermutations:
984  return "expected output/filter indexing maps to be projected permutations";
985  case MatchConvolutionResult::NonConvolutionLoop:
986  return "unexpected loop dimension for convolution op";
987  case MatchConvolutionResult::OutputDimsNotParallel:
988  return "expected all iterators used to access outputs to be parallel";
989  case MatchConvolutionResult::NonOutputDimNotReduction:
990  return "expected all iterators not used to access outputs to be reduction";
991  case MatchConvolutionResult::EmptyConvolvedDims:
992  return "expected convolved dim to be non-empty";
993  case MatchConvolutionResult::Success:
994  return "";
995  }
996  llvm_unreachable("unhandled MatchConvolutionResult case");
997 }
998 
1000  bool allowEmptyConvolvedDims) {
1002  linalgOp.getOperation(), nullptr, allowEmptyConvolvedDims) ==
1004 }
1005 
1008  if (res != MatchConvolutionResult::Success)
1009  return op->emitError(getMatchConvolutionMessage(res));
1010  return success();
1011 }
1012 
1013 //===----------------------------------------------------------------------===//
1014 // FillOpInterface implementation
1015 //===----------------------------------------------------------------------===//
1016 
1017 enum class MatchFillResult {
1018  Success = 0,
1019  NotLinalgOp,
1020  WrongNumOperands,
1022 };
1023 
1025  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1026  if (!linalgOp)
1028  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1030 
1031  OpOperand *value = linalgOp.getDpsInputOperand(0);
1032  if (!linalgOp.isScalar(value))
1034 
1035  return MatchFillResult::Success;
1036 }
1037 
1039  auto res = isFillInterfaceImpl(op);
1040  if (res == MatchFillResult::NotLinalgOp)
1041  return op->emitError("expected a LinalgOp");
1043  return op->emitError("expected op with 1 input and 1 output");
1045  return op->emitError("expected op with scalar input");
1046 
1047  return success();
1048 }
1049 
1050 //===----------------------------------------------------------------------===//
1051 // StructuredOpInterface implementation
1052 //===----------------------------------------------------------------------===//
1053 
1054 SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
1055  Location loc) {
1057  for (OpOperand &opOperand : getOperation()->getOpOperands()) {
1058  for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1059  res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
1060  }
1061  return res;
1062 }
1063 
1064 SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
1066  assert(!hasDynamicShape() && "expected operands to have static shapes");
1067  for (OpOperand &opOperand : getOperation()->getOpOperands())
1068  llvm::append_range(res, getShape(&opOperand));
1069  return res;
1070 }
1071 
1072 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
1073  AffineMap map = getLoopsToShapesMap();
1074  unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
1075  auto viewSizes = createFlatListOfOperandDims(b, loc);
1076  SmallVector<Range, 4> res(numDims);
1077  for (unsigned idx = 0; idx < numRes; ++idx) {
1078  auto result = map.getResult(idx);
1079  if (auto d = dyn_cast<AffineDimExpr>(result)) {
1080  if (res[d.getPosition()].offset)
1081  continue;
1082  res[d.getPosition()] =
1083  Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
1084  }
1085  }
1086  return res;
1087 }
1088 
1089 SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
1090  AffineMap map = getLoopsToShapesMap();
1091  unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
1092  SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims();
1093  SmallVector<int64_t, 4> res(numDims, 0);
1094  for (unsigned idx = 0; idx < numRes; ++idx) {
1095  auto result = map.getResult(idx);
1096  if (auto d = dyn_cast<AffineDimExpr>(result))
1097  res[d.getPosition()] = allShapeSizes[idx];
1098  }
1099  return res;
1100 }
1101 
1102 /// Visitor to check if any of the given set of positions from AffineDimExprs
1103 /// are used within an AffineExpr.
1105  : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
1106  HasAffineDimExprVisitor(llvm::SmallBitVector positions)
1107  : positions(std::move(positions)) {}
1108 
1110  return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
1111  }
1112 
1113  bool visitDimExpr(AffineDimExpr dimExpr) {
1114  return positions.test(dimExpr.getPosition());
1115  }
1116 
1117  bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
1118 
1119  bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
1120 
1121 private:
1122  llvm::SmallBitVector positions;
1123 };
1124 
1125 static std::pair<int64_t, int64_t>
1127  int64_t inputRankSum = 0;
1128  int64_t outputRankSum = 0;
1129  for (OpOperand *input : op.getDpsInputOperands())
1130  inputRankSum += op.getRank(input);
1131  for (OpOperand &output : op.getDpsInitsMutable())
1132  outputRankSum += op.getRank(&output);
1133  return {inputRankSum, inputRankSum + outputRankSum};
1134 }
1135 
1136 LogicalResult
1138  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1139  // An example that helps understand the logic below.
1140  // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
1141  // We want to express the shape of dim 0 of O in terms of shape of the inputs.
1142  // This is achieved as follows.
1143  // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
1144  // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
1145  // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
1146  // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
1147  // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
1148  AffineMap loopsToShapesMap = getLoopsToShapesMap();
1149 
1150  // Find the position in the above map that represents the shape of the
1151  // result:dim being inferred.
1152  auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
1153 
1154  /// From loopsToShapesMap extract the submap that represents the shape of the
1155  /// (resultIdx, dim) needed.
1156  AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
1157  resultShapesSubMapPos.first,
1158  resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1159  AffineMap resultShapesFromInputShapesMap =
1160  loopToResultsShapeMap.compose(getShapesToLoopsMap());
1161 
1162  // Check that the result dim map does not contain the positions corresponding
1163  // to the outputs.
1164  llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
1165  outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1166  HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1167  Location loc = getOperation()->getLoc();
1168  IRRewriter rewriter(b);
1169  SmallVector<OpFoldResult> allResultDimValues =
1171  rewriter, loc, resultShapesFromInputShapesMap,
1172  createFlatListOfOperandDims(b, loc));
1173  int64_t pos = 0;
1174  ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
1175  for (OpOperand &opOperand : getDpsInitsMutable()) {
1177  for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1178  auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1179  if (!shapedType.isDynamicDim(dim)) {
1180  // Static dim: Return IntegerAttr.
1181  shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
1182  } else {
1183  // Dynamic dim: Return Value.
1184  OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1185  ? createOrFoldDimOp(b, loc, opOperand.get(), dim)
1186  : allResultDimValues[pos];
1187  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
1188  }
1189  pos++;
1190  }
1191  reifiedReturnShapes.emplace_back(std::move(shapes));
1192  }
1193  return success();
1194 }
1195 
1196 /// Return the index in the indexingMaps vector that corresponds to this
1197 /// `opOperand`.
1198 int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
1199  auto operandNumber = opOperand->getOperandNumber();
1200  auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1201  if (!dpsIface.isDpsInput(opOperand))
1202  return operandNumber;
1203  unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1204  assert(!dpsIface.isDpsInit(opOperand));
1205  // Account for potential inputs that are not DPS and may not appear in
1206  // `indexingMaps`.
1207  return cast<DestinationStyleOpInterface>(*this->getOperation())
1208  .getNumDpsInputs() +
1209  operandNumber - start;
1210 }
1211 
1213  LinalgOp linalgOp = cast<LinalgOp>(op);
1214 
1215  // Mixed tensor/buffer operands are not allowed.
1216  if (!linalgOp.hasPureTensorSemantics() &&
1217  !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1218  return op->emitOpError("expected to have pure tensor or buffer semantics");
1219 
1220  // Before checking indexing maps, we need to make sure the attributes
1221  // referenced by it are valid.
1222  if (linalgOp.hasDynamicIndexingMaps())
1223  if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1224  return failure();
1225 
1226  // All input/output operands must be indexed.
1227  if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
1228  linalgOp->getNumOperands())
1229  return op->emitOpError("expected the number of indexing_map (")
1230  << linalgOp.getIndexingMapsArray().size()
1231  << ") to be equal to the number of input/output operands ("
1232  << linalgOp->getNumOperands() << ")";
1233 
1234  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1235  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1236 
1237  // Symbols disallowed.
1238  if (indexingMap.getNumSymbols() != 0)
1239  return op->emitOpError("unexpected symbols in indexing_map #")
1240  << opOperand.getOperandNumber();
1241 
1242  // Domain must be consistent.
1243  unsigned numLoops = linalgOp.getNumLoops();
1244  if (indexingMap.getNumDims() != numLoops)
1245  return op->emitOpError("expected indexing_map #")
1246  << opOperand.getOperandNumber() << " to have " << numLoops
1247  << " dim(s) to match the number of loops";
1248 
1249  int64_t rank = linalgOp.getRank(&opOperand);
1250  if (indexingMap.getNumResults() != rank)
1251  return op->emitOpError("expected operand rank (")
1252  << rank << ") to match the result rank of indexing_map #"
1253  << opOperand.getOperandNumber() << " ("
1254  << indexingMap.getNumResults() << ")";
1255  }
1256 
1257  SmallVector<unsigned> redDims;
1258  linalgOp.getReductionDims(redDims);
1259 
1260  if (!linalgOp.getShapesToLoopsMap())
1261  return op->emitOpError("expected the shape-to-loops map to be non-null");
1262 
1263  // Check if given shapes match to inferred shapes.
1264  SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
1265  SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
1266 
1267  // Verify only static cases since we can't get exact dimension sizes and loop
1268  // ranges for dynamic cases in this stage.
1269  if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
1270  for (int64_t &range : endLoopRangeValues)
1271  range -= 1;
1272  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1273  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1274  SmallVector<int64_t, 4> startIndices =
1275  indexingMap.compose(startLoopRangeValues);
1276  SmallVector<int64_t, 4> endIndices =
1277  indexingMap.compose(endLoopRangeValues);
1278  ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand);
1279  for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
1280  // Ignore dynamic dimension or the case that the dimension size is 0
1281  if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
1282  continue;
1283 
1284  // The first index or last index should be the maximum or the minimum in
1285  // the inferred index ranges since the range is increasing or
1286  // decreasing. The size of dimensions of input/output operands and the
1287  // maximum value + 1 in the inferred range should be the same. But, for
1288  // now we check if the inferred ranges are in boundary of input/output
1289  // operands' size or not in case that Affine Expressions are complicated
1290  // such as d0 * 3
1291  // + d1 since it is not easy to handle the issues.
1292  // Found the case that this solution can't check, for example, (d0, d1)
1293  // -> (d1 - d0)
1294  int64_t inferredDimSize =
1295  std::max(startIndices[dim], endIndices[dim]) + 1;
1296  if (std::min(startIndices[dim], endIndices[dim]) < 0) {
1297  std::string mapStr;
1298  {
1299  llvm::raw_string_ostream os(mapStr);
1300  os << indexingMap;
1301  }
1302  return op->emitOpError(
1303  "unexpected result less than 0 at expression #")
1304  << dim << " in " << mapStr;
1305  }
1306  if (dyn_cast<AffineDimExpr>(indexingMap.getResult(dim))) {
1307  if (inferredDimSize != shape[dim]) {
1308  return op->emitOpError("inferred input/output operand #")
1309  << opOperand.getOperandNumber() << " has shape's dimension #"
1310  << dim << " to be " << inferredDimSize << ", but found "
1311  << shape[dim];
1312  }
1313  } else {
1314  if (inferredDimSize > shape[dim]) {
1315  return op->emitOpError("inferred input/output operand #")
1316  << opOperand.getOperandNumber() << " has shape's dimension #"
1317  << dim << " to be greater than or equal to "
1318  << inferredDimSize << ", but found " << shape[dim];
1319  }
1320  }
1321  }
1322  }
1323  }
1324 
1325  // Check the region has exactly one block.
1326  if (linalgOp->getNumRegions() != 1 ||
1327  !llvm::hasSingleElement(linalgOp->getRegion(0)))
1328  return op->emitOpError("expects to have 1 region with 1 block");
1329 
1330  // Simplifying assumption: bbargs match 1-1 with shape operands elemental
1331  // types.
1332  // TODO: once ranked shape types are plugged in, we may want to drop the
1333  // corresponding bbargs, that can never be read from. This will be subject to
1334  // consistency discussions (i.e. what to do with output tensors whose bbarg is
1335  // not used).
1336  Block &block = linalgOp->getRegion(0).front();
1337 
1338  if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
1339  return op->emitOpError("expected as many non-induction variable region "
1340  "arguments as the number of input/output operands");
1341 
1342  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1343  Type elementType = opOperand->get().getType();
1344  if (isa<MemRefType, RankedTensorType>(elementType))
1345  elementType = getElementTypeOrSelf(opOperand->get().getType());
1346  Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
1347  if (elementType != argType)
1348  return op->emitOpError("expected type of bb argument #")
1349  << opOperand->getOperandNumber() << " (" << argType << ")"
1350  << " to match element or self type of the corresponding operand ("
1351  << elementType << ")";
1352  }
1353 
1354  return success();
1355 }
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:63
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static MatchFillResult isFillInterfaceImpl(Operation *op)
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity)
MatchFillResult
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Affine binary operation expression.
Definition: AffineExpr.h:227
AffineExpr getLHS() const
Definition: AffineExpr.cpp:340
AffineExpr getRHS() const
Definition: AffineExpr.cpp:343
An integer constant appearing in affine expression.
Definition: AffineExpr.h:252
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:236
unsigned getPosition() const
Definition: AffineExpr.cpp:348
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:35
MLIRContext * getContext() const
Definition: AffineExpr.cpp:33
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:662
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:244
Block represents an ordered list of Operations.
Definition: Block.h:31
bool empty()
Definition: Block.h:146
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation & back()
Definition: Block.h:150
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
OpListType & getOperations()
Definition: Block.h:135
Operation & front()
Definition: Block.h:151
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:215
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:764
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition: Operation.h:753
unsigned getNumOperands()
Definition: Operation.h:341
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1239
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:99
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.broadcast.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:90
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:836
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:641
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.