MLIR  22.0.0git
LinalgInterfaces.cpp
Go to the documentation of this file.
1 //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SetOperations.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <optional>
29 
30 using namespace mlir;
31 using namespace mlir::linalg;
32 
33 /// Include the definitions of the copy operation interface.
34 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
35 
36 //===----------------------------------------------------------------------===//
37 // Interface utility functions
38 //===----------------------------------------------------------------------===//
39 
41  linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
42  SmallVector<AffineMap> indexingMaps;
43  for (auto &opOperand : linalgOp->getOpOperands()) {
44  if (llvm::is_contained(droppedOperands, &opOperand))
45  continue;
46  indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
47  }
48  if (indexingMaps.empty()) {
49  // If there are no indexing maps, the operand can only be dropped
50  // if the op has no loops.
51  return linalgOp.getNumLoops() == 0;
52  }
54  indexingMaps, linalgOp.getContext())) != AffineMap();
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // CopyOpInterface implementation
59 //===----------------------------------------------------------------------===//
60 
61 bool linalg::isaCopyOpInterface(LinalgOp op) {
62  // Check all loops are parallel and linalgOp is single input and output.
63  if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
64  return false;
65 
66  auto mapRange = op.getIndexingMapsArray();
67  if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68  !mapRange.back().isIdentity()) {
69  return false;
70  }
71  // Region.
72  return llvm::hasSingleElement(op.getBlock()->getOperations());
73 }
74 
75 //===----------------------------------------------------------------------===//
76 // FillOpInterface implementation
77 //===----------------------------------------------------------------------===//
78 /// Detects if a linalg.generic operation represents a fill with an inlined
79 /// constant. If so, returns the constant value. Otherwise, returns
80 /// std::nullopt.
81 static std::optional<Value> isaInlinedFillOp(GenericOp op) {
82  if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
83  op.getNumDpsInputs() != 0)
84  return std::nullopt;
85 
86  // Init should not be referenced.
87  if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
88  return std::nullopt;
89 
90  Block *body = op.getBody();
91  if (body->getOperations().size() != 1)
92  return std::nullopt;
93 
94  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
95  if (!yieldOp || yieldOp.getNumOperands() != 1)
96  return std::nullopt;
97 
98  Value yieldOperand = yieldOp->getOperand(0);
99  if (!yieldOperand.getDefiningOp<arith::ConstantOp>() &&
100  !yieldOperand.getDefiningOp<complex::ConstantOp>())
101  return std::nullopt;
102 
103  return yieldOperand;
104 }
105 
106 /// Detects if a linalg.generic operation represents an external scalar input.
107 /// If so, returns the constant value. Otherwise, returns std::nullopt.
108 static std::optional<Value> isaExternalFillOp(GenericOp op) {
109  // Structural.
110  if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
111  !op.isSingleYieldOp())
112  return std::nullopt;
113 
114  // Input should be referenced and init should not.
115  if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
116  op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
117  return std::nullopt;
118 
119  OpOperand *value = op.getDpsInputOperand(0);
120  if (!op.isScalar(value))
121  return std::nullopt;
122  return value->get();
123 }
124 
125 std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
126  if (auto fillVal = isaInlinedFillOp(op))
127  return fillVal;
128  return isaExternalFillOp(op);
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // BroadcastOpInterface implementation
133 //===----------------------------------------------------------------------===//
134 std::optional<SmallVector<int64_t>>
136  // Structural.
137  if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
138  !op.isSingleYieldOp())
139  return std::nullopt;
140 
141  auto srcTy = op.getDpsInputOperand(0)->get().getType();
142  auto dstTy = op.getDpsInitOperand(0)->get().getType();
143  if (!isa<MemRefType, RankedTensorType>(srcTy) ||
144  !isa<MemRefType, RankedTensorType>(dstTy))
145  return std::nullopt;
146 
147  // Check output is identity map. Broadcast could additionally be
148  // employing permutation of indices and that would be expressible
149  // in linalg.generic but is not expressible for named broadcast op.
150  auto dstMap = op.getIndexingMapsArray()[1];
151  if (!dstMap.isIdentity())
152  return std::nullopt;
153 
154  SmallVector<int64_t> position;
155  auto srcMap = op.getIndexingMapsArray()[0];
156 
157  if (srcMap.getResults().size() >= dstMap.getResults().size())
158  return std::nullopt;
159 
160  // Check input map is monotonically increasing DimIds.
161  for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
162  auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
163  if (!expr)
164  return std::nullopt;
165  int64_t pos = expr.getPosition();
166  if (i > 0 && pos <= position[i - 1])
167  return std::nullopt;
168  position.push_back(expr.getPosition());
169  }
170 
171  SmallVector<int64_t> broadcastedDims;
172  auto numDims = srcMap.getNumDims();
173  // This is quadratic but number of items is generally small.
174  for (auto dim : llvm::seq<int64_t>(0, numDims)) {
175  if (!llvm::is_contained(position, dim))
176  broadcastedDims.push_back(dim);
177  }
178  return broadcastedDims;
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // TransposeOpInterface implementation
183 //===----------------------------------------------------------------------===//
184 std::optional<SmallVector<int64_t>>
186  // To specialize as a transpose op, the genericOp must be
187  // all parallel loops, single input, single output, and its body
188  // should be just a yield op, yielding input as output as is (no compute).
189  if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
190  !op.isSingleYieldOp())
191  return std::nullopt;
192 
193  auto mapRange = op.getIndexingMapsArray();
194  if (mapRange.size() != 2)
195  return std::nullopt;
196 
197  auto mapOfInput = mapRange.front();
198  auto mapOfResult = mapRange.back();
199 
200  // linalg.transpose permutes the dimensions of input using this
201  // rule: dim(result, i) = dim(input, permutation[i])
202  if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
203  return std::nullopt;
204 
205  SmallVector<int64_t> permutation(mapOfInput.getNumDims());
206  for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
207  auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
208  permutation[expr.getPosition()] = i;
209  }
210  return permutation;
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // Elementwise Single Unary/Binary-OpInterface implementation
215 //===----------------------------------------------------------------------===//
216 static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
217  unsigned arity) {
218  // Check all loops are parallel.
219  if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
220  return false;
221 
222  // Check there are arity-inputs, 1-output and all are identity-maps.
223  if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
224  !llvm::all_of(op.getIndexingMapsArray(),
225  [](AffineMap map) { return map.isIdentity(); }))
226  return false;
227 
228  // Init should not be referenced for elementwise operations.
229  if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
230  return false;
231 
232  // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
233  // as resulting from producer-consumer fusion. Here, we restrict to two ops in
234  // the body, where the first is the elementwise single op and the second a
235  // yield.
236  Block *body = op.getBody();
237  if (body->getOperations().size() != 2)
238  return false;
239 
240  Operation *oper = &body->front();
241  if (oper->getNumOperands() != arity || oper->getNumResults() != 1)
242  return false;
243 
244  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
245  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
246  yieldOp->getOperand(0).getDefiningOp() != oper)
247  return false;
248  return true;
249 }
250 
251 bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) {
252  // All basic elemwise checks.
254  return false;
255 
256  // Check input is actully used.
257  if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
258  return false;
259  return true;
260 }
261 
262 bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
264  return false;
265 
266  // Check both inputs are used (elementwise).
267  OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
268  OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
269  if (!op.payloadUsesValueFromOperand(inputOpOperand0) ||
270  !op.payloadUsesValueFromOperand(inputOpOperand1))
271  return false;
272  return true;
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // ContractionOpInterface implementation
277 //===----------------------------------------------------------------------===//
278 
279 /// If the value is defined by a chain of unary side effect-free, go up the
280 /// use-def chain until the first value that isn't defined by such an op.
281 // TODO: relax to multi-operands with constants, which are technically unary ops
282 // as needed (e.g. add5).
284  Operation *op = value.getDefiningOp();
285  while (op && op->getNumOperands() == 1) {
286  auto iface = dyn_cast<MemoryEffectOpInterface>(op);
287  if (!iface || !iface.hasNoEffect())
288  break;
289  value = op->getOperand(0);
290  op = value.getDefiningOp();
291  }
292  return value;
293 }
294 
296  Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
297  llvm::raw_ostream &errs) {
298  if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
299  errs << "no terminator in the block";
300  return false;
301  }
302 
303  if (block.getNumArguments() != 3) {
304  errs << "expected block with 3 arguments";
305  return false;
306  }
307 
308  Operation *terminator = block.getTerminator();
309  if (terminator->getNumOperands() != 1) {
310  errs << "expected terminator with 1 operand";
311  return false;
312  }
313 
314  Value yielded = getSourceSkipUnary(terminator->getOperand(0));
315  Operation *reductionOp = yielded.getDefiningOp();
316  if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
317  errs << "expected reduction op to be binary";
318  return false;
319  }
320 
321  Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
322  Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
323 
324  if (reductionLHS != block.getArgument(2) &&
325  reductionRHS != block.getArgument(2)) {
326  errs << "expected reduction to take block argument #2 as one of the "
327  "operands (modulo unary casts)";
328  return false;
329  }
330 
331  Value contributed = getSourceSkipUnary(
332  isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
333  Operation *elementwiseOp = contributed.getDefiningOp();
334  if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||
335  elementwiseOp->getNumOperands() != 2) {
336  errs << "expected elementwise op to be binary";
337  return false;
338  }
339 
340  if (!isaPair(elementwiseOp, reductionOp)) {
341  errs << "expected reduction/elementwise op kind not satisfied";
342  return false;
343  }
344 
345  Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
346  Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
347  if ((elementwiseLHS == block.getArgument(0) &&
348  elementwiseRHS == block.getArgument(1)) ||
349  (elementwiseLHS == block.getArgument(1) &&
350  elementwiseRHS == block.getArgument(0))) {
351  return true;
352  }
353 
354  errs << "expected elementwise op to apply to block arguments (modulo unary "
355  "casts)";
356  return false;
357 }
358 
359 /// Returns true if the two operations are of the kinds specified by a pair of
360 /// consecutive template arguments.
361 template <typename AddOpTy, typename MulOpTy, typename... Args>
362 static bool isPairTemplateImpl(Operation *add, Operation *mul) {
363  static_assert(sizeof...(Args) % 2 == 0,
364  "expected an even number of template arguments");
365  if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
366  return true;
367 
368  if constexpr (sizeof...(Args) > 0)
369  return isPairTemplateImpl<Args...>(add, mul);
370  else
371  return false;
372 }
373 
374 /// Returns true if the block is a body of a contraction with the kinds of
375 /// operations given pairwise by template arguments.
376 template <typename... Args>
377 static bool isContractionBody(Block &block) {
378  return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
379 }
380 
381 /// Given an `indexingMap` and its corresponding `iterators`, returns
382 /// the positions of the iterators of type `iter` that are indexed by
383 /// the `indexingMap` as a permutation. This is useful to infer various
384 /// subcomputations on a `LinalgOp`. This is performed by looking up
385 /// each result in the `indexingMap` and determining whether:
386 /// - It is a single AffineDimExpr.
387 /// - It is the only result involving this AffineDimExpr.
388 static llvm::SmallDenseSet<int64_t>
391  utils::IteratorType iter) {
392  assert(iterators.size() == indexingMap.getNumDims());
393  llvm::SmallDenseSet<int64_t> res;
394  for (AffineExpr e : indexingMap.getResults()) {
395  if (auto d = dyn_cast<AffineDimExpr>(e)) {
396  if (iterators[d.getPosition()] == iter &&
397  llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
398  return e.isFunctionOfDim(d.getPosition());
399  }) == 1)
400  res.insert(d.getPosition());
401  }
402  }
403  return res;
404 }
405 
406 namespace {
407 auto par = utils::IteratorType::parallel;
408 auto red = utils::IteratorType::reduction;
409 } // namespace
410 
411 /// Infer the iterator types from the init affine map. This looks at which dims
412 /// are present in the map results, and returns an iterator types array with
413 /// parallel types for dims that are present, and reduction types for dims that
414 /// are not present.
415 static FailureOr<SmallVector<utils::IteratorType>>
417  if (!map.isProjectedPermutation())
418  return failure();
419  SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
420  for (auto expr : map.getResults())
421  if (auto dim = dyn_cast<AffineDimExpr>(expr))
422  iterators[dim.getPosition()] = par;
423  return iterators;
424 }
425 
426 /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
427 /// a matmul subcomputation within `linalgOp`. These dimensions are such that:
428 /// 1. The m dimension is involved in an outer-product along LHS
429 /// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
430 /// 2. The n dimension is involved in an outer-product along RHS
431 /// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
432 /// 3. The k dimension appears as a permutation on LHS and RHS.
433 /// 4. m, n and k appear only once in any given indexing.
434 /// 5. Optional batch dimensions that appear in all operands are captured.
435 /// This allows e.g. detecting that some contraction is embedded within
436 /// `linalgOp` with some orthogonal heuristic.
437 static FailureOr<ContractionDimensions>
439  ArrayRef<utils::IteratorType> iterators) {
440  llvm::SmallDenseSet<int64_t> a =
441  findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
442  llvm::SmallDenseSet<int64_t> b =
443  findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
444  llvm::SmallDenseSet<int64_t> c =
445  findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
446 
447  // A & C - B are the iterators involved in an outer-product along A (the LHS).
448  llvm::SmallDenseSet<int64_t> ac = a;
449  llvm::set_intersect(ac, c);
450  llvm::set_subtract(ac, b);
451  // B & C - A are the iterators involved in an outer-product along B (the RHS).
452  llvm::SmallDenseSet<int64_t> bc = b;
453  llvm::set_intersect(bc, c);
454  llvm::set_subtract(bc, a);
455  // A & B & C are the "batch" dimensions.
456  llvm::SmallDenseSet<int64_t> batches = a;
457  llvm::set_intersect(batches, b);
458  llvm::set_intersect(batches, c);
459 
460  // A & B red are the reduction dimensions.
461  llvm::SmallDenseSet<int64_t> ra =
462  findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
463  llvm::SmallDenseSet<int64_t> rb =
464  findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
465  llvm::set_intersect(ra, rb);
466 
467  // Return each set in sorted order.
468  ContractionDimensions dimensions{
469  SmallVector<unsigned, 2>(batches.begin(), batches.end()),
470  SmallVector<unsigned, 2>(ac.begin(), ac.end()),
471  SmallVector<unsigned, 2>(bc.begin(), bc.end()),
472  SmallVector<unsigned, 2>(ra.begin(), ra.end())};
473  llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
474  llvm::sort(dimensions.m.begin(), dimensions.m.end());
475  llvm::sort(dimensions.n.begin(), dimensions.n.end());
476  llvm::sort(dimensions.k.begin(), dimensions.k.end());
477  return dimensions;
478 }
479 
480 FailureOr<ContractionDimensions>
482  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
483  return failure();
484  return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
485  linalgOp.getIteratorTypesArray());
486 }
487 
488 FailureOr<ContractionDimensions>
490  if (indexingMaps.size() != 3)
491  return failure();
492  auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
493  if (failed(iterators))
494  return failure();
495  return inferContractionDimsImpl(indexingMaps, iterators.value());
496 }
497 
498 namespace mlir::linalg::detail {
500  Success = 0,
501  NotLinalgOp,
503  NoReduction,
505  NotAddMul
506 };
507 } // namespace mlir::linalg::detail
508 
512  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
513  if (!linalgOp)
514  return MatchContractionResult::NotLinalgOp;
515  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
516  return MatchContractionResult::WrongNumOperands;
517  auto mapRange = linalgOp.getIndexingMapsArray();
518  if (linalgOp.getNumReductionLoops() == 0)
519  return MatchContractionResult::NoReduction;
520  if (llvm::any_of(mapRange,
521  [](AffineMap m) { return !m.isProjectedPermutation(); }))
522  return MatchContractionResult::NotProjectedPermutations;
523  // TODO: more fields than add/mul.
524  // clang-format off
525  if (!::isContractionBody<
526  arith::MulFOp, arith::AddFOp,
527  arith::MulIOp, arith::AddIOp,
528  complex::MulOp, complex::AddOp,
529  arith::AndIOp, arith::OrIOp>(
530  *linalgOp.getBlock())) {
531  return MatchContractionResult::NotAddMul;
532  }
533  // clang-format on
534 
535  if (dimensions) {
536  FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);
537  assert(succeeded(res) && "unexpected failure to infer contraction dims");
538  *dimensions = *res;
539  }
540  return MatchContractionResult::Success;
541 }
542 
543 StringRef
545  switch (res) {
546  case MatchContractionResult::NotLinalgOp:
547  return "expected a LinalgOp";
548  case MatchContractionResult::WrongNumOperands:
549  return "expected op with 2 inputs and 1 output";
550  case MatchContractionResult::NoReduction:
551  return "expected at least 1 reduction";
552  case MatchContractionResult::NotProjectedPermutations:
553  return "expected indexing maps to be projected permutations";
554  case MatchContractionResult::NotAddMul:
555  return "expected add/mul op in the body";
556  case MatchContractionResult::Success:
557  return "";
558  }
559  llvm_unreachable("unhandled MatchContractionResult case");
560 }
561 
563  if (!linalgOp)
564  return false;
565  Operation *op = linalgOp.getOperation();
566  return isa<ContractionOpInterface>(op) ||
569 }
570 
571 /// Verify that a LinalgOp `op` is a contraction.
572 /// A Linalg contraction is defined in general terms:
573 /// 1. Has 2 input and 1 output shapes.
574 /// 2. Has at least one reduction dimension.
575 /// 3. Has only projected permutation indexing maps.
576 /// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
577 /// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
578 /// operations that may change the type (e.g. for mixed-precision).
579 /// As a consequence, when vectorization of such an op occurs, the only special
580 /// behavior is that the (unique) MulOpType is vectorized into a
581 /// `vector.contract`. All other ops are handled in a generic fashion.
582 /// In the future, we may wish to allow more input arguments and elementwise and
583 /// constant operations that do not involve the reduction dimension(s).
585  auto res = isContractionInterfaceImpl(op);
586  if (res != MatchContractionResult::Success)
587  return op->emitError(getMatchContractionMessage(res));
588  return success();
589 }
590 
591 //===----------------------------------------------------------------------===//
592 // ConvolutionOpInterface implementation
593 //===----------------------------------------------------------------------===//
594 
595 /// Of the given two expressions returns one that is of type T (`lhs` gets
596 /// preference over `rhs`)
597 template <typename T>
599  return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
600 }
601 
602 namespace {
603 /// Walk the indexing expressions for input of a convolution operation to verify
604 /// its of the right form, either
605 /// - AffineDimExpr
606 /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
607 /// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
608 ///
609 /// classifies the AffineDimExpr as convolved dimensions or unconvolved
610 /// dimensions and verifies each dimension occurs only once.
611 struct ConvAccessExprWalker
612  : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
613  // Stores dimensions used in expressions of the above form.
614  llvm::SmallDenseSet<int64_t> convolvedDims;
615  // Stores the dual mapping between LHS and RHS of convolution exprs.
616  llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
617  // Stores single use dimensions used by an AffineDimExpr.
618  llvm::SmallDenseSet<int64_t> unConvolvedDims;
619  // Stores a mapping from convolved dims to their coefficient.
620  llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
621 
622  // Removes dims with multiple uses in the source input map from dimension
623  // sets tracked by this walker.
624  void clearMultiUseDims(AffineMap map) {
625  for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
626  if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) {
627  return e.isFunctionOfDim(dimPos);
628  }) > 1) {
629  convolvedDims.erase(dimPos);
630  unConvolvedDims.erase(dimPos);
631  // If a duplicate dim is marked as convolved, the pair of the duplicate
632  // dim must be removed from the map as well.
633  auto it = convolvedDimMapping.find(dimPos);
634  if (it != convolvedDimMapping.end()) {
635  int64_t pairedDim = it->second;
636  convolvedDims.erase(pairedDim);
637  unConvolvedDims.erase(pairedDim);
638  strideAndDilationMapping.erase(pairedDim);
639  convolvedDimMapping.erase(dimPos);
640  convolvedDimMapping.erase(pairedDim);
641  }
642  }
643  }
644  }
645 
646  LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
647  unsigned position = dimExpr.getPosition();
648  if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
649  return failure();
650  }
651  unConvolvedDims.insert(position);
652  return success();
653  }
654 
655  LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
656 
657  LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
658 
659  LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
660  // In pre-order visit, top level op has to be an add op.
661  if (binaryExpr.getKind() != AffineExprKind::Add)
662  return failure();
663  auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());
664  auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());
665  if (failed(lhsDimPos) || failed(rhsDimPos))
666  return failure();
667  convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
668  convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
669  return success();
670  }
671 
672  FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
673  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
674  int64_t dim = dimExpr.getPosition();
675  if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
676  return failure();
677  // Stride/dilation for this dim is implicitly 1.
678  strideAndDilationMapping[dim] =
680  convolvedDims.insert(dim);
681  return dim;
682  }
683  if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
684  if (symbolMulExpr.getKind() != AffineExprKind::Mul)
685  return failure();
686  auto lhsExpr = symbolMulExpr.getLHS();
687  auto rhsExpr = symbolMulExpr.getRHS();
688  // Check for symbol expression.
689  AffineExpr mulExpr =
690  getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
691  // If there was no symbol expr, check for constant expression.
692  if (!mulExpr) {
693  mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
694  }
695  auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
696  if (!mulExpr || !dimExpr)
697  return failure();
698  int64_t dim = dimExpr.getPosition();
699  if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
700  return failure();
701  strideAndDilationMapping[dim] = mulExpr;
702  convolvedDims.insert(dim);
703  return dim;
704  }
705  return failure();
706  }
707 };
708 } // namespace
709 
710 static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
711  assert(map.isProjectedPermutation() &&
712  "expected map to have projected permutations");
713  llvm::SmallDenseSet<int64_t> preservedDims;
714  for (auto expr : map.getResults())
715  preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
716  return preservedDims;
717 }
718 
722  for (auto e : exprs) {
723  auto constantExpr = dyn_cast<AffineConstantExpr>(e);
724  assert(constantExpr && "Found non-constant stride/dilation");
725  vals.push_back(constantExpr.getValue());
726  }
727  return vals;
728 }
729 
730 /// Classifies dimensions in the `linalgOp` used by a convolution
731 /// subcomputation, as captured by `inputExprWalker`. If
732 /// `allowEmptyConvolvedDims` is not set this this will fail if there is not
733 /// at least convolved dimension pair (output image + filter loop). Convolution
734 /// dimensions are specified in sorted order, and strides match the order of
735 /// the filter loop dimensions, while the dilations match the order of the
736 /// output image dimensions.
737 static FailureOr<ConvolutionDimensions>
738 inferConvolutionDimsImpl(LinalgOp linalgOp,
739  ConvAccessExprWalker &inputExprWalker,
740  bool allowEmptyConvolvedDims) {
741  auto filterMap =
742  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
743  auto outputMap =
744  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
745  llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
746  filterMap, linalgOp.getIteratorTypesArray(), par);
747  llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
748  outputMap, linalgOp.getIteratorTypesArray(), par);
749 
750  // unConvolvedDims & outputDims - filterDims are the batch iterators.
751  llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
752  llvm::set_intersect(batch, outputDims);
753  llvm::set_subtract(batch, filterDims);
754 
755  // convolvedDims & outputDims are the output image iterators.
756  llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
757  llvm::set_intersect(oi, outputDims);
758 
759  // filterDims & outputDims - unConvolvedDims are the output channel iterators.
760  llvm::SmallDenseSet<int64_t> oc = filterDims;
761  llvm::set_intersect(oc, outputDims);
762  llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
763 
764  // filterDims & outputDims & unConvolvedDims are the depth iterators.
765  llvm::SmallDenseSet<int64_t> depth = filterDims;
766  llvm::set_intersect(depth, outputDims);
767  llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
768 
769  llvm::SmallDenseSet<int64_t> filterReducedDims =
771  linalgOp.getIteratorTypesArray(), red);
772 
773  // convolvedDims & filterReducedDims are the filter loop iterators.
774  llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
775  llvm::set_intersect(fl, filterReducedDims);
776 
777  // unConvolvedDims & filterReducedDims are the input channel iterators.
778  llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
779  llvm::set_intersect(ic, filterReducedDims);
780 
781  if (oi.empty() && !allowEmptyConvolvedDims)
782  return failure();
783 
784  // Return each set in sorted order.
785  ConvolutionDimensions dimensions{
786  SmallVector<unsigned, 2>(batch.begin(), batch.end()),
787  SmallVector<unsigned, 2>(oi.begin(), oi.end()),
788  SmallVector<unsigned, 2>(oc.begin(), oc.end()),
789  SmallVector<unsigned, 2>(fl.begin(), fl.end()),
790  SmallVector<unsigned, 2>(ic.begin(), ic.end()),
791  SmallVector<unsigned, 2>(depth.begin(), depth.end()),
792  /*strides=*/SmallVector<int64_t, 2>{},
793  /*dilations=*/SmallVector<int64_t, 2>{}};
794  llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
795  llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
796  llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
797  llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
798  llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
799  llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
800 
801  // Use the op carried strides/dilations attribute if present.
802  auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
803  if (!nativeStrides) {
804  SmallVector<AffineExpr, 2> strideExprs;
805  for (unsigned oiDim : dimensions.outputImage)
806  strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
807  dimensions.strides = getConstantsFromExprList(strideExprs);
808  } else {
809  dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
810  }
811  auto nativeDilations =
812  linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
813  if (!nativeDilations) {
814  SmallVector<AffineExpr, 2> dilationExprs;
815  for (unsigned flDim : dimensions.filterLoop)
816  dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
817  dimensions.dilations = getConstantsFromExprList(dilationExprs);
818  } else {
819  dimensions.dilations =
820  llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
821  }
822  return dimensions;
823 }
824 
825 /// Find at least 1 parallel (output_image) and reduction (filter_loop)
826 /// dimension candidates that form a convolution subcomputation within
827 /// `linalgOp`. The LHS is assumed to be the convolution input while the
828 /// RHS is assumed as the filter.
829 /// These dimensions are such that:
830 /// 1. Optional batch dimensions that appear in the input and filter.
831 /// 2. The output_image dimension is involved in a cross-correlation along LHS
832 /// (i.e. it is a permutation on RES and LHS and has an associated
833 /// filter_loop in RHS).
834 /// 3. Optional output_channel dimension is involved in an outer-product along
835 /// RHS (i.e. it is a permutation on RES and RHS and does not appear in
836 /// LHS).
837 /// 4. Optional input_channel dimension appears as a permutation on LHS and
838 /// RHS.
839 /// 5. The filter_loop dimension appears as a permutation on the RHS and
840 /// represents the shape of the kernel cross-correlated along a
841 /// corresponding output_image dim.
842 /// 6. The input_channel dimension appears as a permutation on LHS and RHS.
843 /// 7. All dimensions appear only once in any given indexing map.
844 /// This allows e.g. detecting that some convolution is embedded within
845 /// `linalgOp` with some orthogonal heuristic.
846 /// When multiple dimension occurrences exist that match any classification
847 /// indices are returned in sorted order.
848 /// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
849 FailureOr<ConvolutionDimensions>
851  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
852  return failure();
853 
854  auto indexingMaps = linalgOp.getIndexingMapsArray();
855 
856  // Check the input indexing map has the right form.
857  ConvAccessExprWalker inputExprWalker;
858  for (AffineExpr expr : indexingMaps[0].getResults())
859  (void)inputExprWalker.visit(expr);
860  inputExprWalker.clearMultiUseDims(indexingMaps[0]);
861 
862  return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
863  /*allowEmptyConvolvedDims=*/false);
864 }
865 
866 namespace mlir::linalg::detail {
868  Success = 0,
869  NotLinalgOp,
877 };
878 } // namespace mlir::linalg::detail
879 
882  Operation *op, ConvolutionDimensions *dimensions,
883  bool allowEmptyConvolvedDims) {
884  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
885  if (!linalgOp)
886  return MatchConvolutionResult::NotLinalgOp;
887  if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
888  return MatchConvolutionResult::WrongNumOperands;
889 
890  auto indexingMaps = linalgOp.getIndexingMapsArray();
891 
892  // Check the input indexing map has the right form.
893  ConvAccessExprWalker inputExprWalker;
894  if (llvm::any_of(indexingMaps[0].getResults(),
895  [&inputExprWalker](AffineExpr expr) {
896  return failed(inputExprWalker.visit(expr));
897  })) {
898  return MatchConvolutionResult::WrongInputIndexingMap;
899  }
900 
901  // Filter and output maps must be projected permutation.
902  if (!indexingMaps[1].isProjectedPermutation() ||
903  !indexingMaps.back().isProjectedPermutation())
904  return MatchConvolutionResult::NotProjectedPermutations;
905 
906  auto iteratorTypes = linalgOp.getIteratorTypesArray();
907 
908  llvm::SmallDenseSet<int64_t> outputDims =
909  getPreservedDims(indexingMaps.back());
910  llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
911  // Make sure all loops are characterized as one of:
912  // - Batch loop : present in output, as non-convolved in input, not present in
913  // filter.
914  // - Output image dimension : present in output, convolved dims in input, not
915  // present in filter.
916  // - Output channel dimension : present in output, not present in input,
917  // present in filter.
918  // - Filter loop dimension : present in filter, convolved in input, not
919  // present in output.
920  // - Input channel dimension : unconvolved in input, not present in output,
921  // present in filter.
922  // - Depth multiplier : unconvolved in input, present in output, present in
923  // filter.
924  llvm::SmallDenseSet<int64_t> allLoopDims;
925  for (auto outputExpr : indexingMaps.back().getResults()) {
926  int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
927  if (inputExprWalker.unConvolvedDims.count(outputDim) &&
928  !filterDims.count(outputDim)) {
929  // Batch dimension.
930  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
931  return MatchConvolutionResult::OutputDimsNotParallel;
932  allLoopDims.insert(outputDim);
933  continue;
934  }
935  if (inputExprWalker.convolvedDims.count(outputDim) &&
936  !filterDims.count(outputDim)) {
937  // Output image Loop dimension.
938  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
939  return MatchConvolutionResult::OutputDimsNotParallel;
940  allLoopDims.insert(outputDim);
941  continue;
942  }
943  if (!inputExprWalker.convolvedDims.count(outputDim) &&
944  !inputExprWalker.unConvolvedDims.count(outputDim) &&
945  filterDims.count(outputDim)) {
946  // Output channel dimension.
947  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
948  return MatchConvolutionResult::OutputDimsNotParallel;
949  allLoopDims.insert(outputDim);
950  continue;
951  }
952  if (inputExprWalker.unConvolvedDims.count(outputDim) &&
953  filterDims.count(outputDim)) {
954  // Depth multiplier.
955  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
956  return MatchConvolutionResult::OutputDimsNotParallel;
957  allLoopDims.insert(outputDim);
958  continue;
959  }
960  return MatchConvolutionResult::NonConvolutionLoop;
961  }
962  for (auto filterExpr : indexingMaps[1].getResults()) {
963  int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
964  if (outputDims.count(filterDim) &&
965  !inputExprWalker.unConvolvedDims.count(filterDim) &&
966  !inputExprWalker.convolvedDims.count(filterDim)) {
967  // Output channel dimension. This is already seen, continue;
968  continue;
969  }
970  if (inputExprWalker.convolvedDims.count(filterDim) &&
971  !outputDims.count(filterDim)) {
972  // Filter loop dimension.
973  if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
974  return MatchConvolutionResult::NonOutputDimNotReduction;
975  if (allLoopDims.count(filterDim))
976  return MatchConvolutionResult::NonConvolutionLoop;
977  allLoopDims.insert(filterDim);
978  continue;
979  }
980  if (inputExprWalker.unConvolvedDims.count(filterDim) &&
981  !outputDims.count(filterDim)) {
982  // Input channel dimension.
983  if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
984  return MatchConvolutionResult::NonOutputDimNotReduction;
985  if (allLoopDims.count(filterDim))
986  return MatchConvolutionResult::NonConvolutionLoop;
987  allLoopDims.insert(filterDim);
988  continue;
989  }
990  if (inputExprWalker.unConvolvedDims.count(filterDim) &&
991  outputDims.count(filterDim)) {
992  // Depthwise loop. Already seen.
993  continue;
994  }
995  return MatchConvolutionResult::NonConvolutionLoop;
996  }
997  // All loops must be covered now.
998  if (allLoopDims.size() != linalgOp.getNumLoops())
999  return MatchConvolutionResult::NonConvolutionLoop;
1000 
1001  if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1002  return MatchConvolutionResult::EmptyConvolvedDims;
1003 
1004  if (dimensions) {
1005  FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl(
1006  linalgOp, inputExprWalker, allowEmptyConvolvedDims);
1007  assert(succeeded(res) && "unexpected failure to infer convolution dims");
1008  *dimensions = *res;
1009  }
1010 
1011  return MatchConvolutionResult::Success;
1012 }
1013 
1014 StringRef
1016  switch (res) {
1017  case MatchConvolutionResult::NotLinalgOp:
1018  return "expected a LinalgOp";
1019  case MatchConvolutionResult::WrongNumOperands:
1020  return "expected op with 2 inputs and 1 output";
1021  case MatchConvolutionResult::WrongInputIndexingMap:
1022  return "unexpected input index map for convolutions";
1023  case MatchConvolutionResult::NotProjectedPermutations:
1024  return "expected output/filter indexing maps to be projected permutations";
1025  case MatchConvolutionResult::NonConvolutionLoop:
1026  return "unexpected loop dimension for convolution op";
1027  case MatchConvolutionResult::OutputDimsNotParallel:
1028  return "expected all iterators used to access outputs to be parallel";
1029  case MatchConvolutionResult::NonOutputDimNotReduction:
1030  return "expected all iterators not used to access outputs to be reduction";
1031  case MatchConvolutionResult::EmptyConvolvedDims:
1032  return "expected convolved dim to be non-empty";
1033  case MatchConvolutionResult::Success:
1034  return "";
1035  }
1036  llvm_unreachable("unhandled MatchConvolutionResult case");
1037 }
1038 
1040  bool allowEmptyConvolvedDims) {
1042  linalgOp.getOperation(), nullptr, allowEmptyConvolvedDims) ==
1044 }
1045 
1048  if (res != MatchConvolutionResult::Success)
1049  return op->emitError(getMatchConvolutionMessage(res));
1050  return success();
1051 }
1052 
1053 //===----------------------------------------------------------------------===//
1054 // FillOpInterface implementation
1055 //===----------------------------------------------------------------------===//
1056 
1057 enum class MatchFillResult {
1058  Success = 0,
1059  NotLinalgOp,
1060  WrongNumOperands,
1062 };
1063 
1065  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1066  if (!linalgOp)
1068  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1070 
1071  OpOperand *value = linalgOp.getDpsInputOperand(0);
1072  if (!linalgOp.isScalar(value))
1074 
1075  return MatchFillResult::Success;
1076 }
1077 
1079  auto res = isFillInterfaceImpl(op);
1080  if (res == MatchFillResult::NotLinalgOp)
1081  return op->emitError("expected a LinalgOp");
1083  return op->emitError("expected op with 1 input and 1 output");
1085  return op->emitError("expected op with scalar input");
1086 
1087  return success();
1088 }
1089 
1090 //===----------------------------------------------------------------------===//
1091 // StructuredOpInterface implementation
1092 //===----------------------------------------------------------------------===//
1093 
1094 SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
1095  Location loc) {
1097  for (OpOperand &opOperand : getOperation()->getOpOperands()) {
1098  for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1099  res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
1100  }
1101  return res;
1102 }
1103 
1104 SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
1106  assert(!hasDynamicShape() && "expected operands to have static shapes");
1107  for (OpOperand &opOperand : getOperation()->getOpOperands())
1108  llvm::append_range(res, getShape(&opOperand));
1109  return res;
1110 }
1111 
1112 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
1113  AffineMap map = getLoopsToShapesMap();
1114  unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
1115  auto viewSizes = createFlatListOfOperandDims(b, loc);
1116  SmallVector<Range, 4> res(numDims);
1117  for (unsigned idx = 0; idx < numRes; ++idx) {
1118  auto result = map.getResult(idx);
1119  if (auto d = dyn_cast<AffineDimExpr>(result)) {
1120  if (res[d.getPosition()].offset)
1121  continue;
1122  res[d.getPosition()] =
1123  Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
1124  }
1125  }
1126  return res;
1127 }
1128 
1129 /// Visitor to check if any of the given set of positions from AffineDimExprs
1130 /// are used within an AffineExpr.
1132  : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
1133  HasAffineDimExprVisitor(llvm::SmallBitVector positions)
1134  : positions(std::move(positions)) {}
1135 
1137  return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
1138  }
1139 
1140  bool visitDimExpr(AffineDimExpr dimExpr) {
1141  return positions.test(dimExpr.getPosition());
1142  }
1143 
1144  bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
1145 
1146  bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
1147 
1148 private:
1149  llvm::SmallBitVector positions;
1150 };
1151 
1152 static std::pair<int64_t, int64_t>
1154  int64_t inputRankSum = 0;
1155  int64_t outputRankSum = 0;
1156  for (OpOperand *input : op.getDpsInputOperands())
1157  inputRankSum += op.getRank(input);
1158  for (OpOperand &output : op.getDpsInitsMutable())
1159  outputRankSum += op.getRank(&output);
1160  return {inputRankSum, inputRankSum + outputRankSum};
1161 }
1162 
1163 LogicalResult
1165  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1166  // An example that helps understand the logic below.
1167  // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
1168  // We want to express the shape of dim 0 of O in terms of shape of the inputs.
1169  // This is achieved as follows.
1170  // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
1171  // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
1172  // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
1173  // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
1174  // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
1175  AffineMap loopsToShapesMap = getLoopsToShapesMap();
1176 
1177  // Find the position in the above map that represents the shape of the
1178  // result:dim being inferred.
1179  auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
1180 
1181  /// From loopsToShapesMap extract the submap that represents the shape of the
1182  /// (resultIdx, dim) needed.
1183  AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
1184  resultShapesSubMapPos.first,
1185  resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1186  AffineMap resultShapesFromInputShapesMap =
1187  loopToResultsShapeMap.compose(getShapesToLoopsMap());
1188 
1189  // Check that the result dim map does not contain the positions corresponding
1190  // to the outputs.
1191  llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
1192  outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1193  HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1194  Location loc = getOperation()->getLoc();
1195  IRRewriter rewriter(b);
1196  SmallVector<OpFoldResult> allResultDimValues =
1198  rewriter, loc, resultShapesFromInputShapesMap,
1199  createFlatListOfOperandDims(b, loc));
1200  int64_t pos = 0;
1201  ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
1202  for (OpOperand &opOperand : getDpsInitsMutable()) {
1204  for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1205  auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1206  if (!shapedType.isDynamicDim(dim)) {
1207  // Static dim: Return IntegerAttr.
1208  shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
1209  } else {
1210  // Dynamic dim: Return Value.
1211  OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1212  ? createOrFoldDimOp(b, loc, opOperand.get(), dim)
1213  : allResultDimValues[pos];
1214  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
1215  }
1216  pos++;
1217  }
1218  reifiedReturnShapes.emplace_back(std::move(shapes));
1219  }
1220  return success();
1221 }
1222 
1223 /// Return the index in the indexingMaps vector that corresponds to this
1224 /// `opOperand`.
1225 int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
1226  auto operandNumber = opOperand->getOperandNumber();
1227  auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1228  if (!dpsIface.isDpsInput(opOperand))
1229  return operandNumber;
1230  unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1231  assert(!dpsIface.isDpsInit(opOperand));
1232  // Account for potential inputs that are not DPS and may not appear in
1233  // `indexingMaps`.
1234  return cast<DestinationStyleOpInterface>(*this->getOperation())
1235  .getNumDpsInputs() +
1236  operandNumber - start;
1237 }
1238 
1240  LinalgOp linalgOp = cast<LinalgOp>(op);
1241  // Mixed tensor/buffer operands are not allowed.
1242  if (!linalgOp.hasPureTensorSemantics() &&
1243  !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1244  return op->emitOpError("expected to have pure tensor or buffer semantics");
1245 
1246  // Before checking indexing maps, we need to make sure the attributes
1247  // referenced by it are valid.
1248  if (linalgOp.hasDynamicIndexingMaps())
1249  if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1250  return failure();
1251 
1252  // Delayed calling of IndexingMapOpInterface::verifyImpl.
1253  if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1254  return failure();
1255 
1256  // Set this flag if this op has user defined maps. This is required to guard
1257  // the below error condition which assume default indexing maps.
1258  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1259  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1260  // Domain must be consistent.
1261  unsigned numLoops = linalgOp.getNumLoops();
1262  if (indexingMap.getNumDims() != numLoops)
1263  return op->emitOpError("expected indexing_map #")
1264  << opOperand.getOperandNumber() << " to have " << numLoops
1265  << " dim(s) to match the number of loops";
1266  }
1267  SmallVector<unsigned> redDims;
1268  linalgOp.getReductionDims(redDims);
1269 
1270  if (!linalgOp.getShapesToLoopsMap())
1271  return op->emitOpError("expected the shape-to-loops map to be non-null");
1272 
1273  // Check the region has exactly one block.
1274  if (linalgOp->getNumRegions() != 1 ||
1275  !llvm::hasSingleElement(linalgOp->getRegion(0)))
1276  return op->emitOpError("expects to have 1 region with 1 block");
1277 
1278  // Simplifying assumption: bbargs match 1-1 with shape operands elemental
1279  // types.
1280  // TODO: once ranked shape types are plugged in, we may want to drop the
1281  // corresponding bbargs, that can never be read from. This will be subject to
1282  // consistency discussions (i.e. what to do with output tensors whose bbarg is
1283  // not used).
1284  Block &block = linalgOp->getRegion(0).front();
1285 
1286  if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
1287  return op->emitOpError("expected as many non-induction variable region "
1288  "arguments as the number of input/output operands");
1289 
1290  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1291  Type elementType = opOperand->get().getType();
1292  if (isa<MemRefType, RankedTensorType>(elementType))
1293  elementType = getElementTypeOrSelf(opOperand->get().getType());
1294  Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
1295  if (elementType != argType)
1296  return op->emitOpError("expected type of bb argument #")
1297  << opOperand->getOperandNumber() << " (" << argType << ")"
1298  << " to match element or self type of the corresponding operand ("
1299  << elementType << ")";
1300  }
1301 
1302  return success();
1303 }
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:62
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static std::optional< Value > isaExternalFillOp(GenericOp op)
Detects if a linalg.generic operation represents an external scalar input.
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static MatchFillResult isFillInterfaceImpl(Operation *op)
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static std::optional< Value > isaInlinedFillOp(GenericOp op)
Detects if a linalg.generic operation represents a fill with an inlined constant.
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity)
MatchFillResult
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Affine binary operation expression.
Definition: AffineExpr.h:214
AffineExpr getLHS() const
Definition: AffineExpr.cpp:340
AffineExpr getRHS() const
Definition: AffineExpr.cpp:343
An integer constant appearing in affine expression.
Definition: AffineExpr.h:239
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
unsigned getPosition() const
Definition: AffineExpr.cpp:348
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:35
MLIRContext * getContext() const
Definition: AffineExpr.cpp:33
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:655
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:231
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:748
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:228
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:773
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition: Operation.h:757
unsigned getNumOperands()
Definition: Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1376
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:103
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.broadcast.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:94
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:829
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.