MLIR  22.0.0git
LinalgInterfaces.cpp
Go to the documentation of this file.
1 //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/AffineExpr.h"
18 #include "mlir/IR/AffineMap.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/IR/TypeUtilities.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SetOperations.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include <optional>
29 
30 using namespace mlir;
31 using namespace mlir::linalg;
32 
33 /// Include the definitions of the copy operation interface.
34 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
35 
36 //===----------------------------------------------------------------------===//
37 // Interface utility functions
38 //===----------------------------------------------------------------------===//
39 
41  linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
42  SmallVector<AffineMap> indexingMaps;
43  for (auto &opOperand : linalgOp->getOpOperands()) {
44  if (llvm::is_contained(droppedOperands, &opOperand))
45  continue;
46  indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
47  }
48  if (indexingMaps.empty()) {
49  // If there are no indexing maps, the operand can only be dropped
50  // if the op has no loops.
51  return linalgOp.getNumLoops() == 0;
52  }
54  indexingMaps, linalgOp.getContext())) != AffineMap();
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // CopyOpInterface implementation
59 //===----------------------------------------------------------------------===//
60 
61 bool linalg::isaCopyOpInterface(LinalgOp op) {
62  // Check all loops are parallel and linalgOp is single input and output.
63  if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
64  return false;
65 
66  auto mapRange = op.getIndexingMapsArray();
67  if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
68  !mapRange.back().isIdentity()) {
69  return false;
70  }
71  // Check yield first block argument.
72  Block *body = op.getBlock();
73  if (body->getOperations().size() != 1)
74  return false;
75  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
76  if (!yieldOp || yieldOp.getNumOperands() != 1)
77  return false;
78  return yieldOp->getOperand(0) == body->getArgument(0);
79 }
80 
81 //===----------------------------------------------------------------------===//
82 // FillOpInterface implementation
83 //===----------------------------------------------------------------------===//
84 /// Detects if a linalg.generic operation represents a fill with an inlined
85 /// constant. If so, returns the constant value. Otherwise, returns
86 /// std::nullopt.
87 static std::optional<Value> isaInlinedFillOp(GenericOp op) {
88  if (!op.isAllParallelLoops() || op.getNumDpsInits() != 1 ||
89  op.getNumDpsInputs() != 0)
90  return std::nullopt;
91 
92  // Init should not be referenced.
93  if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
94  return std::nullopt;
95 
96  Block *body = op.getBody();
97  if (body->getOperations().size() != 1)
98  return std::nullopt;
99 
100  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
101  if (!yieldOp || yieldOp.getNumOperands() != 1)
102  return std::nullopt;
103 
104  Value yieldOperand = yieldOp->getOperand(0);
105  if (!yieldOperand.getDefiningOp<arith::ConstantOp>() &&
106  !yieldOperand.getDefiningOp<complex::ConstantOp>())
107  return std::nullopt;
108 
109  return yieldOperand;
110 }
111 
112 /// Detects if a linalg.generic operation represents an external scalar input.
113 /// If so, returns the constant value. Otherwise, returns std::nullopt.
114 static std::optional<Value> isaExternalFillOp(GenericOp op) {
115  // Structural.
116  if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
117  !op.isSingleYieldOp())
118  return std::nullopt;
119 
120  // Input should be referenced and init should not.
121  if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
122  op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
123  return std::nullopt;
124 
125  OpOperand *value = op.getDpsInputOperand(0);
126  if (!op.isScalar(value))
127  return std::nullopt;
128  return value->get();
129 }
130 
131 std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
132  if (auto fillVal = isaInlinedFillOp(op))
133  return fillVal;
134  return isaExternalFillOp(op);
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // BroadcastOpInterface implementation
139 //===----------------------------------------------------------------------===//
140 std::optional<SmallVector<int64_t>>
142  // Structural.
143  if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
144  !op.isSingleYieldOp())
145  return std::nullopt;
146 
147  auto srcTy = op.getDpsInputOperand(0)->get().getType();
148  auto dstTy = op.getDpsInitOperand(0)->get().getType();
149  if (!isa<MemRefType, RankedTensorType>(srcTy) ||
150  !isa<MemRefType, RankedTensorType>(dstTy))
151  return std::nullopt;
152 
153  // Check output is identity map. Broadcast could additionally be
154  // employing permutation of indices and that would be expressible
155  // in linalg.generic but is not expressible for named broadcast op.
156  auto dstMap = op.getIndexingMapsArray()[1];
157  if (!dstMap.isIdentity())
158  return std::nullopt;
159 
160  SmallVector<int64_t> position;
161  auto srcMap = op.getIndexingMapsArray()[0];
162 
163  if (srcMap.getResults().size() >= dstMap.getResults().size())
164  return std::nullopt;
165 
166  // Check input map is monotonically increasing DimIds.
167  for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
168  auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
169  if (!expr)
170  return std::nullopt;
171  int64_t pos = expr.getPosition();
172  if (i > 0 && pos <= position[i - 1])
173  return std::nullopt;
174  position.push_back(expr.getPosition());
175  }
176 
177  SmallVector<int64_t> broadcastedDims;
178  auto numDims = srcMap.getNumDims();
179  // This is quadratic but number of items is generally small.
180  for (auto dim : llvm::seq<int64_t>(0, numDims)) {
181  if (!llvm::is_contained(position, dim))
182  broadcastedDims.push_back(dim);
183  }
184  return broadcastedDims;
185 }
186 
187 //===----------------------------------------------------------------------===//
188 // TransposeOpInterface implementation
189 //===----------------------------------------------------------------------===//
190 std::optional<SmallVector<int64_t>>
192  // To specialize as a transpose op, the genericOp must be
193  // all parallel loops, single input, single output, and its body
194  // should be just a yield op, yielding input as output as is (no compute).
195  if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
196  !op.isSingleYieldOp())
197  return std::nullopt;
198 
199  auto mapRange = op.getIndexingMapsArray();
200  if (mapRange.size() != 2)
201  return std::nullopt;
202 
203  auto mapOfInput = mapRange.front();
204  auto mapOfResult = mapRange.back();
205 
206  // linalg.transpose permutes the dimensions of input using this
207  // rule: dim(result, i) = dim(input, permutation[i])
208  if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
209  return std::nullopt;
210 
211  SmallVector<int64_t> permutation(mapOfInput.getNumDims());
212  for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
213  auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
214  permutation[expr.getPosition()] = i;
215  }
216  return permutation;
217 }
218 
219 //===----------------------------------------------------------------------===//
220 // Elementwise Single Unary/Binary-OpInterface implementation
221 //===----------------------------------------------------------------------===//
222 static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
223  unsigned arity) {
224  // Check all loops are parallel.
225  if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
226  return false;
227 
228  // Check there are arity-inputs, 1-output and all are identity-maps.
229  if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
230  !llvm::all_of(op.getIndexingMapsArray(),
231  [](AffineMap map) { return map.isIdentity(); }))
232  return false;
233 
234  // Init should not be referenced for elementwise operations.
235  if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
236  return false;
237 
238  // A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
239  // as resulting from producer-consumer fusion. Here, we restrict to two ops in
240  // the body, where the first is the elementwise single op and the second a
241  // yield.
242  Block *body = op.getBody();
243  if (body->getOperations().size() != 2)
244  return false;
245 
246  Operation *oper = &body->front();
247  if (oper->getNumOperands() != arity || oper->getNumResults() != 1)
248  return false;
249 
250  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
251  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
252  yieldOp->getOperand(0).getDefiningOp() != oper)
253  return false;
254  return true;
255 }
256 
257 bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) {
258  // All basic elemwise checks.
260  return false;
261 
262  // Check input is actully used.
263  if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
264  return false;
265  return true;
266 }
267 
268 bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
270  return false;
271 
272  // Check both inputs are used (elementwise).
273  OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
274  OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
275  if (!op.payloadUsesValueFromOperand(inputOpOperand0) ||
276  !op.payloadUsesValueFromOperand(inputOpOperand1))
277  return false;
278  return true;
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // ContractionOpInterface implementation
283 //===----------------------------------------------------------------------===//
284 
285 /// If the value is defined by a chain of unary side effect-free, go up the
286 /// use-def chain until the first value that isn't defined by such an op.
287 // TODO: relax to multi-operands with constants, which are technically unary ops
288 // as needed (e.g. add5).
290  Operation *op = value.getDefiningOp();
291  while (op && op->getNumOperands() == 1) {
292  auto iface = dyn_cast<MemoryEffectOpInterface>(op);
293  if (!iface || !iface.hasNoEffect())
294  break;
295  value = op->getOperand(0);
296  op = value.getDefiningOp();
297  }
298  return value;
299 }
300 
302  Block &block, function_ref<bool(Operation *, Operation *)> isaPair,
303  llvm::raw_ostream &errs) {
304  if (block.empty() || !block.back().mightHaveTrait<OpTrait::IsTerminator>()) {
305  errs << "no terminator in the block";
306  return false;
307  }
308 
309  if (block.getNumArguments() != 3) {
310  errs << "expected block with 3 arguments";
311  return false;
312  }
313 
314  Operation *terminator = block.getTerminator();
315  if (terminator->getNumOperands() != 1) {
316  errs << "expected terminator with 1 operand";
317  return false;
318  }
319 
320  Value yielded = getSourceSkipUnary(terminator->getOperand(0));
321  Operation *reductionOp = yielded.getDefiningOp();
322  if (reductionOp->getNumResults() != 1 || reductionOp->getNumOperands() != 2) {
323  errs << "expected reduction op to be binary";
324  return false;
325  }
326 
327  Value reductionLHS = getSourceSkipUnary(reductionOp->getOperand(0));
328  Value reductionRHS = getSourceSkipUnary(reductionOp->getOperand(1));
329 
330  if (reductionLHS != block.getArgument(2) &&
331  reductionRHS != block.getArgument(2)) {
332  errs << "expected reduction to take block argument #2 as one of the "
333  "operands (modulo unary casts)";
334  return false;
335  }
336 
337  Value contributed = getSourceSkipUnary(
338  isa<BlockArgument>(reductionLHS) ? reductionRHS : reductionLHS);
339  Operation *elementwiseOp = contributed.getDefiningOp();
340  if (!elementwiseOp || elementwiseOp->getNumResults() != 1 ||
341  elementwiseOp->getNumOperands() != 2) {
342  errs << "expected elementwise op to be binary";
343  return false;
344  }
345 
346  if (!isaPair(elementwiseOp, reductionOp)) {
347  errs << "expected reduction/elementwise op kind not satisfied";
348  return false;
349  }
350 
351  Value elementwiseLHS = getSourceSkipUnary(elementwiseOp->getOperand(0));
352  Value elementwiseRHS = getSourceSkipUnary(elementwiseOp->getOperand(1));
353  if ((elementwiseLHS == block.getArgument(0) &&
354  elementwiseRHS == block.getArgument(1)) ||
355  (elementwiseLHS == block.getArgument(1) &&
356  elementwiseRHS == block.getArgument(0))) {
357  return true;
358  }
359 
360  errs << "expected elementwise op to apply to block arguments (modulo unary "
361  "casts)";
362  return false;
363 }
364 
365 /// Returns true if the two operations are of the kinds specified by a pair of
366 /// consecutive template arguments.
367 template <typename AddOpTy, typename MulOpTy, typename... Args>
368 static bool isPairTemplateImpl(Operation *add, Operation *mul) {
369  static_assert(sizeof...(Args) % 2 == 0,
370  "expected an even number of template arguments");
371  if (isa<AddOpTy>(add) && isa<MulOpTy>(mul))
372  return true;
373 
374  if constexpr (sizeof...(Args) > 0)
375  return isPairTemplateImpl<Args...>(add, mul);
376  else
377  return false;
378 }
379 
380 /// Returns true if the block is a body of a contraction with the kinds of
381 /// operations given pairwise by template arguments.
382 template <typename... Args>
383 static bool isContractionBody(Block &block) {
384  return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
385 }
386 
387 /// Given an `indexingMap` and its corresponding `iterators`, returns
388 /// the positions of the iterators of type `iter` that are indexed by
389 /// the `indexingMap` as a permutation. This is useful to infer various
390 /// subcomputations on a `LinalgOp`. This is performed by looking up
391 /// each result in the `indexingMap` and determining whether:
392 /// - It is a single AffineDimExpr.
393 /// - It is the only result involving this AffineDimExpr.
394 static llvm::SmallDenseSet<int64_t>
397  utils::IteratorType iter) {
398  assert(iterators.size() == indexingMap.getNumDims());
399  llvm::SmallDenseSet<int64_t> res;
400  for (AffineExpr e : indexingMap.getResults()) {
401  if (auto d = dyn_cast<AffineDimExpr>(e)) {
402  if (iterators[d.getPosition()] == iter &&
403  llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
404  return e.isFunctionOfDim(d.getPosition());
405  }) == 1)
406  res.insert(d.getPosition());
407  }
408  }
409  return res;
410 }
411 
412 namespace {
413 auto par = utils::IteratorType::parallel;
414 auto red = utils::IteratorType::reduction;
415 } // namespace
416 
417 /// Infer the iterator types from the init affine map. This looks at which dims
418 /// are present in the map results, and returns an iterator types array with
419 /// parallel types for dims that are present, and reduction types for dims that
420 /// are not present.
421 static FailureOr<SmallVector<utils::IteratorType>>
423  if (!map.isProjectedPermutation())
424  return failure();
425  SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
426  for (auto expr : map.getResults())
427  if (auto dim = dyn_cast<AffineDimExpr>(expr))
428  iterators[dim.getPosition()] = par;
429  return iterators;
430 }
431 
432 /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
433 /// a matmul subcomputation within `linalgOp`. These dimensions are such that:
434 /// 1. The m dimension is involved in an outer-product along LHS
435 /// (i.e. it is a permutation on RES and LHS and does not appear in RHS).
436 /// 2. The n dimension is involved in an outer-product along RHS
437 /// (i.e. it is a permutation on RES and RHS and does not appear in LHS).
438 /// 3. The k dimension appears as a permutation on LHS and RHS.
439 /// 4. m, n and k appear only once in any given indexing.
440 /// 5. Optional batch dimensions that appear in all operands are captured.
441 /// This allows e.g. detecting that some contraction is embedded within
442 /// `linalgOp` with some orthogonal heuristic.
443 static FailureOr<ContractionDimensions>
445  ArrayRef<utils::IteratorType> iterators) {
446  llvm::SmallDenseSet<int64_t> a =
447  findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
448  llvm::SmallDenseSet<int64_t> b =
449  findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
450  llvm::SmallDenseSet<int64_t> c =
451  findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
452 
453  // A & C - B are the iterators involved in an outer-product along A (the LHS).
454  llvm::SmallDenseSet<int64_t> ac = a;
455  llvm::set_intersect(ac, c);
456  llvm::set_subtract(ac, b);
457  // B & C - A are the iterators involved in an outer-product along B (the RHS).
458  llvm::SmallDenseSet<int64_t> bc = b;
459  llvm::set_intersect(bc, c);
460  llvm::set_subtract(bc, a);
461  // A & B & C are the "batch" dimensions.
462  llvm::SmallDenseSet<int64_t> batches = a;
463  llvm::set_intersect(batches, b);
464  llvm::set_intersect(batches, c);
465 
466  // A & B red are the reduction dimensions.
467  llvm::SmallDenseSet<int64_t> ra =
468  findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
469  llvm::SmallDenseSet<int64_t> rb =
470  findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
471  llvm::set_intersect(ra, rb);
472 
473  // Return each set in sorted order.
474  ContractionDimensions dimensions{
475  SmallVector<unsigned, 2>(batches.begin(), batches.end()),
476  SmallVector<unsigned, 2>(ac.begin(), ac.end()),
477  SmallVector<unsigned, 2>(bc.begin(), bc.end()),
478  SmallVector<unsigned, 2>(ra.begin(), ra.end())};
479  llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
480  llvm::sort(dimensions.m.begin(), dimensions.m.end());
481  llvm::sort(dimensions.n.begin(), dimensions.n.end());
482  llvm::sort(dimensions.k.begin(), dimensions.k.end());
483  return dimensions;
484 }
485 
486 FailureOr<ContractionDimensions>
488  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
489  return failure();
490  return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
491  linalgOp.getIteratorTypesArray());
492 }
493 
494 FailureOr<ContractionDimensions>
496  if (indexingMaps.size() != 3)
497  return failure();
498  auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
499  if (failed(iterators))
500  return failure();
501  return inferContractionDimsImpl(indexingMaps, iterators.value());
502 }
503 
504 namespace mlir::linalg::detail {
506  Success = 0,
507  NotLinalgOp,
509  NoReduction,
511  NotAddMul
512 };
513 } // namespace mlir::linalg::detail
514 
518  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
519  if (!linalgOp)
520  return MatchContractionResult::NotLinalgOp;
521  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
522  return MatchContractionResult::WrongNumOperands;
523  auto mapRange = linalgOp.getIndexingMapsArray();
524  if (linalgOp.getNumReductionLoops() == 0)
525  return MatchContractionResult::NoReduction;
526  if (llvm::any_of(mapRange,
527  [](AffineMap m) { return !m.isProjectedPermutation(); }))
528  return MatchContractionResult::NotProjectedPermutations;
529  // TODO: more fields than add/mul.
530  // clang-format off
531  if (!::isContractionBody<
532  arith::MulFOp, arith::AddFOp,
533  arith::MulIOp, arith::AddIOp,
534  complex::MulOp, complex::AddOp,
535  arith::AndIOp, arith::OrIOp>(
536  *linalgOp.getBlock())) {
537  return MatchContractionResult::NotAddMul;
538  }
539  // clang-format on
540 
541  if (dimensions) {
542  FailureOr<ContractionDimensions> res = inferContractionDims(linalgOp);
543  assert(succeeded(res) && "unexpected failure to infer contraction dims");
544  *dimensions = *res;
545  }
546  return MatchContractionResult::Success;
547 }
548 
549 StringRef
551  switch (res) {
552  case MatchContractionResult::NotLinalgOp:
553  return "expected a LinalgOp";
554  case MatchContractionResult::WrongNumOperands:
555  return "expected op with 2 inputs and 1 output";
556  case MatchContractionResult::NoReduction:
557  return "expected at least 1 reduction";
558  case MatchContractionResult::NotProjectedPermutations:
559  return "expected indexing maps to be projected permutations";
560  case MatchContractionResult::NotAddMul:
561  return "expected add/mul op in the body";
562  case MatchContractionResult::Success:
563  return "";
564  }
565  llvm_unreachable("unhandled MatchContractionResult case");
566 }
567 
569  if (!linalgOp)
570  return false;
571  Operation *op = linalgOp.getOperation();
572  return isa<ContractionOpInterface>(op) ||
575 }
576 
577 /// Verify that a LinalgOp `op` is a contraction.
578 /// A Linalg contraction is defined in general terms:
579 /// 1. Has 2 input and 1 output shapes.
580 /// 2. Has at least one reduction dimension.
581 /// 3. Has only projected permutation indexing maps.
582 /// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
583 /// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
584 /// operations that may change the type (e.g. for mixed-precision).
585 /// As a consequence, when vectorization of such an op occurs, the only special
586 /// behavior is that the (unique) MulOpType is vectorized into a
587 /// `vector.contract`. All other ops are handled in a generic fashion.
588 /// In the future, we may wish to allow more input arguments and elementwise and
589 /// constant operations that do not involve the reduction dimension(s).
591  auto res = isContractionInterfaceImpl(op);
592  if (res != MatchContractionResult::Success)
593  return op->emitError(getMatchContractionMessage(res));
594  return success();
595 }
596 
597 //===----------------------------------------------------------------------===//
598 // ConvolutionOpInterface implementation
599 //===----------------------------------------------------------------------===//
600 
601 /// Of the given two expressions returns one that is of type T (`lhs` gets
602 /// preference over `rhs`)
603 template <typename T>
605  return isa<T>(lhs) ? cast<T>(lhs) : (isa<T>(rhs) ? cast<T>(rhs) : nullptr);
606 }
607 
608 namespace {
609 /// Walk the indexing expressions for input of a convolution operation to verify
610 /// its of the right form, either
611 /// - AffineDimExpr
612 /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
613 /// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
614 ///
615 /// classifies the AffineDimExpr as convolved dimensions or unconvolved
616 /// dimensions and verifies each dimension occurs only once.
617 struct ConvAccessExprWalker
618  : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
619  // Stores dimensions used in expressions of the above form.
620  llvm::SmallDenseSet<int64_t> convolvedDims;
621  // Stores the dual mapping between LHS and RHS of convolution exprs.
622  llvm::SmallDenseMap<int64_t, int64_t> convolvedDimMapping;
623  // Stores single use dimensions used by an AffineDimExpr.
624  llvm::SmallDenseSet<int64_t> unConvolvedDims;
625  // Stores a mapping from convolved dims to their coefficient.
626  llvm::SmallDenseMap<int64_t, AffineExpr> strideAndDilationMapping;
627 
628  // Removes dims with multiple uses in the source input map from dimension
629  // sets tracked by this walker.
630  void clearMultiUseDims(AffineMap map) {
631  for (int dimPos = 0, e = map.getNumDims(); dimPos < e; ++dimPos) {
632  if (llvm::count_if(map.getResults(), [dimPos](AffineExpr e) {
633  return e.isFunctionOfDim(dimPos);
634  }) > 1) {
635  convolvedDims.erase(dimPos);
636  unConvolvedDims.erase(dimPos);
637  // If a duplicate dim is marked as convolved, the pair of the duplicate
638  // dim must be removed from the map as well.
639  auto it = convolvedDimMapping.find(dimPos);
640  if (it != convolvedDimMapping.end()) {
641  int64_t pairedDim = it->second;
642  convolvedDims.erase(pairedDim);
643  unConvolvedDims.erase(pairedDim);
644  strideAndDilationMapping.erase(pairedDim);
645  convolvedDimMapping.erase(dimPos);
646  convolvedDimMapping.erase(pairedDim);
647  }
648  }
649  }
650  }
651 
652  LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
653  unsigned position = dimExpr.getPosition();
654  if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
655  return failure();
656  }
657  unConvolvedDims.insert(position);
658  return success();
659  }
660 
661  LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
662 
663  LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
664 
665  LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
666  // In pre-order visit, top level op has to be an add op.
667  if (binaryExpr.getKind() != AffineExprKind::Add)
668  return failure();
669  auto lhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getLHS());
670  auto rhsDimPos = getDimExprOrMulExprDimPos(binaryExpr.getRHS());
671  if (failed(lhsDimPos) || failed(rhsDimPos))
672  return failure();
673  convolvedDimMapping[*lhsDimPos] = *rhsDimPos;
674  convolvedDimMapping[*rhsDimPos] = *lhsDimPos;
675  return success();
676  }
677 
678  FailureOr<int64_t> getDimExprOrMulExprDimPos(AffineExpr expr) {
679  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
680  int64_t dim = dimExpr.getPosition();
681  if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
682  return failure();
683  // Stride/dilation for this dim is implicitly 1.
684  strideAndDilationMapping[dim] =
686  convolvedDims.insert(dim);
687  return dim;
688  }
689  if (auto symbolMulExpr = dyn_cast<AffineBinaryOpExpr>(expr)) {
690  if (symbolMulExpr.getKind() != AffineExprKind::Mul)
691  return failure();
692  auto lhsExpr = symbolMulExpr.getLHS();
693  auto rhsExpr = symbolMulExpr.getRHS();
694  // Check for symbol expression.
695  AffineExpr mulExpr =
696  getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
697  // If there was no symbol expr, check for constant expression.
698  if (!mulExpr) {
699  mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
700  }
701  auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
702  if (!mulExpr || !dimExpr)
703  return failure();
704  int64_t dim = dimExpr.getPosition();
705  if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
706  return failure();
707  strideAndDilationMapping[dim] = mulExpr;
708  convolvedDims.insert(dim);
709  return dim;
710  }
711  return failure();
712  }
713 };
714 } // namespace
715 
716 static llvm::SmallDenseSet<int64_t> getPreservedDims(AffineMap map) {
717  assert(map.isProjectedPermutation() &&
718  "expected map to have projected permutations");
719  llvm::SmallDenseSet<int64_t> preservedDims;
720  for (auto expr : map.getResults())
721  preservedDims.insert(cast<AffineDimExpr>(expr).getPosition());
722  return preservedDims;
723 }
724 
728  for (auto e : exprs) {
729  auto constantExpr = dyn_cast<AffineConstantExpr>(e);
730  assert(constantExpr && "Found non-constant stride/dilation");
731  vals.push_back(constantExpr.getValue());
732  }
733  return vals;
734 }
735 
736 /// Classifies dimensions in the `linalgOp` used by a convolution
737 /// subcomputation, as captured by `inputExprWalker`. If
738 /// `allowEmptyConvolvedDims` is not set this this will fail if there is not
739 /// at least convolved dimension pair (output image + filter loop). Convolution
740 /// dimensions are specified in sorted order, and strides match the order of
741 /// the filter loop dimensions, while the dilations match the order of the
742 /// output image dimensions.
743 static FailureOr<ConvolutionDimensions>
744 inferConvolutionDimsImpl(LinalgOp linalgOp,
745  ConvAccessExprWalker &inputExprWalker,
746  bool allowEmptyConvolvedDims) {
747  auto filterMap =
748  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
749  auto outputMap =
750  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
751  llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
752  filterMap, linalgOp.getIteratorTypesArray(), par);
753  llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
754  outputMap, linalgOp.getIteratorTypesArray(), par);
755 
756  // unConvolvedDims & outputDims - filterDims are the batch iterators.
757  llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
758  llvm::set_intersect(batch, outputDims);
759  llvm::set_subtract(batch, filterDims);
760 
761  // convolvedDims & outputDims are the output image iterators.
762  llvm::SmallDenseSet<int64_t> oi = inputExprWalker.convolvedDims;
763  llvm::set_intersect(oi, outputDims);
764 
765  // filterDims & outputDims - unConvolvedDims are the output channel iterators.
766  llvm::SmallDenseSet<int64_t> oc = filterDims;
767  llvm::set_intersect(oc, outputDims);
768  llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
769 
770  // filterDims & outputDims & unConvolvedDims are the depth iterators.
771  llvm::SmallDenseSet<int64_t> depth = filterDims;
772  llvm::set_intersect(depth, outputDims);
773  llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
774 
775  llvm::SmallDenseSet<int64_t> filterReducedDims =
777  linalgOp.getIteratorTypesArray(), red);
778 
779  // convolvedDims & filterReducedDims are the filter loop iterators.
780  llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
781  llvm::set_intersect(fl, filterReducedDims);
782 
783  // unConvolvedDims & filterReducedDims are the input channel iterators.
784  llvm::SmallDenseSet<int64_t> ic = inputExprWalker.unConvolvedDims;
785  llvm::set_intersect(ic, filterReducedDims);
786 
787  if (oi.empty() && !allowEmptyConvolvedDims)
788  return failure();
789 
790  // Return each set in sorted order.
791  ConvolutionDimensions dimensions{
792  SmallVector<unsigned, 2>(batch.begin(), batch.end()),
793  SmallVector<unsigned, 2>(oi.begin(), oi.end()),
794  SmallVector<unsigned, 2>(oc.begin(), oc.end()),
795  SmallVector<unsigned, 2>(fl.begin(), fl.end()),
796  SmallVector<unsigned, 2>(ic.begin(), ic.end()),
797  SmallVector<unsigned, 2>(depth.begin(), depth.end()),
798  /*strides=*/SmallVector<int64_t, 2>{},
799  /*dilations=*/SmallVector<int64_t, 2>{}};
800  llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
801  llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
802  llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
803  llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
804  llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
805  llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
806 
807  // Use the op carried strides/dilations attribute if present.
808  auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
809  if (!nativeStrides) {
810  SmallVector<AffineExpr, 2> strideExprs;
811  for (unsigned oiDim : dimensions.outputImage)
812  strideExprs.push_back(inputExprWalker.strideAndDilationMapping[oiDim]);
813  dimensions.strides = getConstantsFromExprList(strideExprs);
814  } else {
815  dimensions.strides = llvm::to_vector<2>(nativeStrides.getValues<int64_t>());
816  }
817  auto nativeDilations =
818  linalgOp->getAttrOfType<DenseIntElementsAttr>("dilations");
819  if (!nativeDilations) {
820  SmallVector<AffineExpr, 2> dilationExprs;
821  for (unsigned flDim : dimensions.filterLoop)
822  dilationExprs.push_back(inputExprWalker.strideAndDilationMapping[flDim]);
823  dimensions.dilations = getConstantsFromExprList(dilationExprs);
824  } else {
825  dimensions.dilations =
826  llvm::to_vector<2>(nativeDilations.getValues<int64_t>());
827  }
828  return dimensions;
829 }
830 
831 /// Find at least 1 parallel (output_image) and reduction (filter_loop)
832 /// dimension candidates that form a convolution subcomputation within
833 /// `linalgOp`. The LHS is assumed to be the convolution input while the
834 /// RHS is assumed as the filter.
835 /// These dimensions are such that:
836 /// 1. Optional batch dimensions that appear in the input and filter.
837 /// 2. The output_image dimension is involved in a cross-correlation along LHS
838 /// (i.e. it is a permutation on RES and LHS and has an associated
839 /// filter_loop in RHS).
840 /// 3. Optional output_channel dimension is involved in an outer-product along
841 /// RHS (i.e. it is a permutation on RES and RHS and does not appear in
842 /// LHS).
843 /// 4. Optional input_channel dimension appears as a permutation on LHS and
844 /// RHS.
845 /// 5. The filter_loop dimension appears as a permutation on the RHS and
846 /// represents the shape of the kernel cross-correlated along a
847 /// corresponding output_image dim.
848 /// 6. The input_channel dimension appears as a permutation on LHS and RHS.
849 /// 7. All dimensions appear only once in any given indexing map.
850 /// This allows e.g. detecting that some convolution is embedded within
851 /// `linalgOp` with some orthogonal heuristic.
852 /// When multiple dimension occurrences exist that match any classification
853 /// indices are returned in sorted order.
854 /// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.
855 FailureOr<ConvolutionDimensions>
857  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
858  return failure();
859 
860  auto indexingMaps = linalgOp.getIndexingMapsArray();
861 
862  // Check the input indexing map has the right form.
863  ConvAccessExprWalker inputExprWalker;
864  for (AffineExpr expr : indexingMaps[0].getResults())
865  (void)inputExprWalker.visit(expr);
866  inputExprWalker.clearMultiUseDims(indexingMaps[0]);
867 
868  return inferConvolutionDimsImpl(linalgOp, inputExprWalker,
869  /*allowEmptyConvolvedDims=*/false);
870 }
871 
872 namespace mlir::linalg::detail {
874  Success = 0,
875  NotLinalgOp,
883 };
884 } // namespace mlir::linalg::detail
885 
888  Operation *op, ConvolutionDimensions *dimensions,
889  bool allowEmptyConvolvedDims) {
890  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
891  if (!linalgOp)
892  return MatchConvolutionResult::NotLinalgOp;
893  if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
894  return MatchConvolutionResult::WrongNumOperands;
895 
896  auto indexingMaps = linalgOp.getIndexingMapsArray();
897 
898  // Check the input indexing map has the right form.
899  ConvAccessExprWalker inputExprWalker;
900  if (llvm::any_of(indexingMaps[0].getResults(),
901  [&inputExprWalker](AffineExpr expr) {
902  return failed(inputExprWalker.visit(expr));
903  })) {
904  return MatchConvolutionResult::WrongInputIndexingMap;
905  }
906 
907  // Filter and output maps must be projected permutation.
908  if (!indexingMaps[1].isProjectedPermutation() ||
909  !indexingMaps.back().isProjectedPermutation())
910  return MatchConvolutionResult::NotProjectedPermutations;
911 
912  auto iteratorTypes = linalgOp.getIteratorTypesArray();
913 
914  llvm::SmallDenseSet<int64_t> outputDims =
915  getPreservedDims(indexingMaps.back());
916  llvm::SmallDenseSet<int64_t> filterDims = getPreservedDims(indexingMaps[1]);
917  // Make sure all loops are characterized as one of:
918  // - Batch loop : present in output, as non-convolved in input, not present in
919  // filter.
920  // - Output image dimension : present in output, convolved dims in input, not
921  // present in filter.
922  // - Output channel dimension : present in output, not present in input,
923  // present in filter.
924  // - Filter loop dimension : present in filter, convolved in input, not
925  // present in output.
926  // - Input channel dimension : unconvolved in input, not present in output,
927  // present in filter.
928  // - Depth multiplier : unconvolved in input, present in output, present in
929  // filter.
930  llvm::SmallDenseSet<int64_t> allLoopDims;
931  for (auto outputExpr : indexingMaps.back().getResults()) {
932  int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
933  if (inputExprWalker.unConvolvedDims.count(outputDim) &&
934  !filterDims.count(outputDim)) {
935  // Batch dimension.
936  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
937  return MatchConvolutionResult::OutputDimsNotParallel;
938  allLoopDims.insert(outputDim);
939  continue;
940  }
941  if (inputExprWalker.convolvedDims.count(outputDim) &&
942  !filterDims.count(outputDim)) {
943  // Output image Loop dimension.
944  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
945  return MatchConvolutionResult::OutputDimsNotParallel;
946  allLoopDims.insert(outputDim);
947  continue;
948  }
949  if (!inputExprWalker.convolvedDims.count(outputDim) &&
950  !inputExprWalker.unConvolvedDims.count(outputDim) &&
951  filterDims.count(outputDim)) {
952  // Output channel dimension.
953  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
954  return MatchConvolutionResult::OutputDimsNotParallel;
955  allLoopDims.insert(outputDim);
956  continue;
957  }
958  if (inputExprWalker.unConvolvedDims.count(outputDim) &&
959  filterDims.count(outputDim)) {
960  // Depth multiplier.
961  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
962  return MatchConvolutionResult::OutputDimsNotParallel;
963  allLoopDims.insert(outputDim);
964  continue;
965  }
966  return MatchConvolutionResult::NonConvolutionLoop;
967  }
968  for (auto filterExpr : indexingMaps[1].getResults()) {
969  int64_t filterDim = cast<AffineDimExpr>(filterExpr).getPosition();
970  if (outputDims.count(filterDim) &&
971  !inputExprWalker.unConvolvedDims.count(filterDim) &&
972  !inputExprWalker.convolvedDims.count(filterDim)) {
973  // Output channel dimension. This is already seen, continue;
974  continue;
975  }
976  if (inputExprWalker.convolvedDims.count(filterDim) &&
977  !outputDims.count(filterDim)) {
978  // Filter loop dimension.
979  if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
980  return MatchConvolutionResult::NonOutputDimNotReduction;
981  if (allLoopDims.count(filterDim))
982  return MatchConvolutionResult::NonConvolutionLoop;
983  allLoopDims.insert(filterDim);
984  continue;
985  }
986  if (inputExprWalker.unConvolvedDims.count(filterDim) &&
987  !outputDims.count(filterDim)) {
988  // Input channel dimension.
989  if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
990  return MatchConvolutionResult::NonOutputDimNotReduction;
991  if (allLoopDims.count(filterDim))
992  return MatchConvolutionResult::NonConvolutionLoop;
993  allLoopDims.insert(filterDim);
994  continue;
995  }
996  if (inputExprWalker.unConvolvedDims.count(filterDim) &&
997  outputDims.count(filterDim)) {
998  // Depthwise loop. Already seen.
999  continue;
1000  }
1001  return MatchConvolutionResult::NonConvolutionLoop;
1002  }
1003  // All loops must be covered now.
1004  if (allLoopDims.size() != linalgOp.getNumLoops())
1005  return MatchConvolutionResult::NonConvolutionLoop;
1006 
1007  if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
1008  return MatchConvolutionResult::EmptyConvolvedDims;
1009 
1010  if (dimensions) {
1011  FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl(
1012  linalgOp, inputExprWalker, allowEmptyConvolvedDims);
1013  assert(succeeded(res) && "unexpected failure to infer convolution dims");
1014  *dimensions = *res;
1015  }
1016 
1017  return MatchConvolutionResult::Success;
1018 }
1019 
1020 StringRef
1022  switch (res) {
1023  case MatchConvolutionResult::NotLinalgOp:
1024  return "expected a LinalgOp";
1025  case MatchConvolutionResult::WrongNumOperands:
1026  return "expected op with 2 inputs and 1 output";
1027  case MatchConvolutionResult::WrongInputIndexingMap:
1028  return "unexpected input index map for convolutions";
1029  case MatchConvolutionResult::NotProjectedPermutations:
1030  return "expected output/filter indexing maps to be projected permutations";
1031  case MatchConvolutionResult::NonConvolutionLoop:
1032  return "unexpected loop dimension for convolution op";
1033  case MatchConvolutionResult::OutputDimsNotParallel:
1034  return "expected all iterators used to access outputs to be parallel";
1035  case MatchConvolutionResult::NonOutputDimNotReduction:
1036  return "expected all iterators not used to access outputs to be reduction";
1037  case MatchConvolutionResult::EmptyConvolvedDims:
1038  return "expected convolved dim to be non-empty";
1039  case MatchConvolutionResult::Success:
1040  return "";
1041  }
1042  llvm_unreachable("unhandled MatchConvolutionResult case");
1043 }
1044 
1046  bool allowEmptyConvolvedDims) {
1048  linalgOp.getOperation(), nullptr, allowEmptyConvolvedDims) ==
1050 }
1051 
1054  if (res != MatchConvolutionResult::Success)
1055  return op->emitError(getMatchConvolutionMessage(res));
1056  return success();
1057 }
1058 
1059 //===----------------------------------------------------------------------===//
1060 // FillOpInterface implementation
1061 //===----------------------------------------------------------------------===//
1062 
1063 enum class MatchFillResult {
1064  Success = 0,
1065  NotLinalgOp,
1066  WrongNumOperands,
1068 };
1069 
1071  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
1072  if (!linalgOp)
1074  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
1076 
1077  OpOperand *value = linalgOp.getDpsInputOperand(0);
1078  if (!linalgOp.isScalar(value))
1080 
1081  return MatchFillResult::Success;
1082 }
1083 
1085  auto res = isFillInterfaceImpl(op);
1086  if (res == MatchFillResult::NotLinalgOp)
1087  return op->emitError("expected a LinalgOp");
1089  return op->emitError("expected op with 1 input and 1 output");
1091  return op->emitError("expected op with scalar input");
1092 
1093  return success();
1094 }
1095 
1096 //===----------------------------------------------------------------------===//
1097 // StructuredOpInterface implementation
1098 //===----------------------------------------------------------------------===//
1099 
1100 SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
1101  Location loc) {
1103  for (OpOperand &opOperand : getOperation()->getOpOperands()) {
1104  for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
1105  res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
1106  }
1107  return res;
1108 }
1109 
1110 SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
1112  assert(!hasDynamicShape() && "expected operands to have static shapes");
1113  for (OpOperand &opOperand : getOperation()->getOpOperands())
1114  llvm::append_range(res, getShape(&opOperand));
1115  return res;
1116 }
1117 
1118 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
1119  AffineMap map = getLoopsToShapesMap();
1120  unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
1121  auto viewSizes = createFlatListOfOperandDims(b, loc);
1122  SmallVector<Range, 4> res(numDims);
1123  for (unsigned idx = 0; idx < numRes; ++idx) {
1124  auto result = map.getResult(idx);
1125  if (auto d = dyn_cast<AffineDimExpr>(result)) {
1126  if (res[d.getPosition()].offset)
1127  continue;
1128  res[d.getPosition()] =
1129  Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
1130  }
1131  }
1132  return res;
1133 }
1134 
1135 /// Visitor to check if any of the given set of positions from AffineDimExprs
1136 /// are used within an AffineExpr.
1138  : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
1139  HasAffineDimExprVisitor(llvm::SmallBitVector positions)
1140  : positions(std::move(positions)) {}
1141 
1143  return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
1144  }
1145 
1146  bool visitDimExpr(AffineDimExpr dimExpr) {
1147  return positions.test(dimExpr.getPosition());
1148  }
1149 
1150  bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
1151 
1152  bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
1153 
1154 private:
1155  llvm::SmallBitVector positions;
1156 };
1157 
1158 static std::pair<int64_t, int64_t>
1160  int64_t inputRankSum = 0;
1161  int64_t outputRankSum = 0;
1162  for (OpOperand *input : op.getDpsInputOperands())
1163  inputRankSum += op.getRank(input);
1164  for (OpOperand &output : op.getDpsInitsMutable())
1165  outputRankSum += op.getRank(&output);
1166  return {inputRankSum, inputRankSum + outputRankSum};
1167 }
1168 
1169 LogicalResult
1171  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
1172  // An example that helps understand the logic below.
1173  // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
1174  // We want to express the shape of dim 0 of O in terms of shape of the inputs.
1175  // This is achieved as follows.
1176  // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
1177  // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
1178  // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
1179  // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
1180  // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
1181  AffineMap loopsToShapesMap = getLoopsToShapesMap();
1182 
1183  // Find the position in the above map that represents the shape of the
1184  // result:dim being inferred.
1185  auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
1186 
1187  /// From loopsToShapesMap extract the submap that represents the shape of the
1188  /// (resultIdx, dim) needed.
1189  AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
1190  resultShapesSubMapPos.first,
1191  resultShapesSubMapPos.second - resultShapesSubMapPos.first);
1192  AffineMap resultShapesFromInputShapesMap =
1193  loopToResultsShapeMap.compose(getShapesToLoopsMap());
1194 
1195  // Check that the result dim map does not contain the positions corresponding
1196  // to the outputs.
1197  llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
1198  outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
1199  HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
1200  Location loc = getOperation()->getLoc();
1201  IRRewriter rewriter(b);
1202  SmallVector<OpFoldResult> allResultDimValues =
1204  rewriter, loc, resultShapesFromInputShapesMap,
1205  createFlatListOfOperandDims(b, loc));
1206  int64_t pos = 0;
1207  ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
1208  for (OpOperand &opOperand : getDpsInitsMutable()) {
1210  for (int64_t dim : llvm::seq<int64_t>(0, getRank(&opOperand))) {
1211  auto shapedType = llvm::cast<ShapedType>(opOperand.get().getType());
1212  if (!shapedType.isDynamicDim(dim)) {
1213  // Static dim: Return IntegerAttr.
1214  shapes.push_back(b.getIndexAttr(shapedType.getDimSize(dim)));
1215  } else {
1216  // Dynamic dim: Return Value.
1217  OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos])
1218  ? createOrFoldDimOp(b, loc, opOperand.get(), dim)
1219  : allResultDimValues[pos];
1220  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
1221  }
1222  pos++;
1223  }
1224  reifiedReturnShapes.emplace_back(std::move(shapes));
1225  }
1226  return success();
1227 }
1228 
1229 /// Return the index in the indexingMaps vector that corresponds to this
1230 /// `opOperand`.
1231 int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
1232  auto operandNumber = opOperand->getOperandNumber();
1233  auto dpsIface = cast<DestinationStyleOpInterface>(*this->getOperation());
1234  if (!dpsIface.isDpsInput(opOperand))
1235  return operandNumber;
1236  unsigned start = dpsIface.getDpsInits().getBeginOperandIndex();
1237  assert(!dpsIface.isDpsInit(opOperand));
1238  // Account for potential inputs that are not DPS and may not appear in
1239  // `indexingMaps`.
1240  return cast<DestinationStyleOpInterface>(*this->getOperation())
1241  .getNumDpsInputs() +
1242  operandNumber - start;
1243 }
1244 
1246  LinalgOp linalgOp = cast<LinalgOp>(op);
1247  // Mixed tensor/buffer operands are not allowed.
1248  if (!linalgOp.hasPureTensorSemantics() &&
1249  !linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1250  return op->emitOpError("expected to have pure tensor or buffer semantics");
1251 
1252  // Before checking indexing maps, we need to make sure the attributes
1253  // referenced by it are valid.
1254  if (linalgOp.hasDynamicIndexingMaps())
1255  if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
1256  return failure();
1257 
1258  // Delayed calling of IndexingMapOpInterface::verifyImpl.
1259  if (failed(cast<IndexingMapOpInterface>(op).verifyImpl()))
1260  return failure();
1261 
1262  // Set this flag if this op has user defined maps. This is required to guard
1263  // the below error condition which assume default indexing maps.
1264  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
1265  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
1266  // Domain must be consistent.
1267  unsigned numLoops = linalgOp.getNumLoops();
1268  if (indexingMap.getNumDims() != numLoops)
1269  return op->emitOpError("expected indexing_map #")
1270  << opOperand.getOperandNumber() << " to have " << numLoops
1271  << " dim(s) to match the number of loops";
1272  }
1273  SmallVector<unsigned> redDims;
1274  linalgOp.getReductionDims(redDims);
1275 
1276  if (!linalgOp.getShapesToLoopsMap())
1277  return op->emitOpError("expected the shape-to-loops map to be non-null");
1278 
1279  // Check the region has exactly one block.
1280  if (linalgOp->getNumRegions() != 1 || !linalgOp->getRegion(0).hasOneBlock())
1281  return op->emitOpError("expects to have 1 region with 1 block");
1282 
1283  // Simplifying assumption: bbargs match 1-1 with shape operands elemental
1284  // types.
1285  // TODO: once ranked shape types are plugged in, we may want to drop the
1286  // corresponding bbargs, that can never be read from. This will be subject to
1287  // consistency discussions (i.e. what to do with output tensors whose bbarg is
1288  // not used).
1289  Block &block = linalgOp->getRegion(0).front();
1290 
1291  if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
1292  return op->emitOpError("expected as many non-induction variable region "
1293  "arguments as the number of input/output operands");
1294 
1295  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
1296  Type elementType = opOperand->get().getType();
1297  if (isa<MemRefType, RankedTensorType>(elementType))
1298  elementType = getElementTypeOrSelf(opOperand->get().getType());
1299  Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
1300  if (elementType != argType)
1301  return op->emitOpError("expected type of bb argument #")
1302  << opOperand->getOperandNumber() << " (" << argType << ")"
1303  << " to match element or self type of the corresponding operand ("
1304  << elementType << ")";
1305  }
1306 
1307  return success();
1308 }
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:62
static FailureOr< ConvolutionDimensions > inferConvolutionDimsImpl(LinalgOp linalgOp, ConvAccessExprWalker &inputExprWalker, bool allowEmptyConvolvedDims)
Classifies dimensions in the linalgOp used by a convolution subcomputation, as captured by inputExprW...
static std::optional< Value > isaExternalFillOp(GenericOp op)
Detects if a linalg.generic operation represents an external scalar input.
static Value getSourceSkipUnary(Value value)
If the value is defined by a chain of unary side effect-free, go up the use-def chain until the first...
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
static bool isPairTemplateImpl(Operation *add, Operation *mul)
Returns true if the two operations are of the kinds specified by a pair of consecutive template argum...
static SmallVector< int64_t, 2 > getConstantsFromExprList(const SmallVector< AffineExpr, 2 > &exprs)
static MatchFillResult isFillInterfaceImpl(Operation *op)
static FailureOr< ContractionDimensions > inferContractionDimsImpl(ArrayRef< AffineMap > indexingMaps, ArrayRef< utils::IteratorType > iterators)
Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcomputation ...
static bool isContractionBody(Block &block)
Returns true if the block is a body of a contraction with the kinds of operations given pairwise by t...
static std::optional< Value > isaInlinedFillOp(GenericOp op)
Detects if a linalg.generic operation represents a fill with an inlined constant.
static llvm::SmallDenseSet< int64_t > getPreservedDims(AffineMap map)
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op, unsigned arity)
MatchFillResult
static llvm::SmallDenseSet< int64_t > findPermutationsIndexingOperand(AffineMap indexingMap, ArrayRef< utils::IteratorType > iterators, utils::IteratorType iter)
Given an indexingMap and its corresponding iterators, returns the positions of the iterators of type ...
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Affine binary operation expression.
Definition: AffineExpr.h:214
AffineExpr getLHS() const
Definition: AffineExpr.cpp:338
AffineExpr getRHS() const
Definition: AffineExpr.cpp:341
An integer constant appearing in affine expression.
Definition: AffineExpr.h:239
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:223
unsigned getPosition() const
Definition: AffineExpr.cpp:346
See documentation for AffineExprVisitorBase.
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:33
MLIRContext * getContext() const
Definition: AffineExpr.cpp:31
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:655
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:231
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
An attribute that represents a reference to a dense integer vector or tensor object.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:750
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:205
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:773
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool mightHaveTrait()
Returns true if the operation might have the provided trait.
Definition: Operation.h:757
unsigned getNumOperands()
Definition: Operation.h:346
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1376
MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op, ConvolutionDimensions *dimensions=nullptr, bool allowEmptyConvolvedDims=false)
Checks whether op conforms to ConvolutionOpInterface and populates dimensions with indexes of the dif...
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
StringRef getMatchConvolutionMessage(MatchConvolutionResult res)
Returns the error message corresponding to the convolution checking return code.
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that check if given operands can be dropped, i.e.
MatchContractionResult isContractionInterfaceImpl(Operation *op, ContractionDimensions *dimensions=nullptr)
Checks whether op conforms to ContractionOpInterface and populates dimensions with indexes of the dif...
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
StringRef getMatchContractionMessage(MatchContractionResult res)
Returns the error message corresponding to the contraction checking return code.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaTransposeOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.transpose.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp)
Checks whether a given genericOp is semantically equivalent to a single linalgelementwise unary op.
bool isaCopyOpInterface(LinalgOp linalgOp)
Checks whether linalgOp is semantically equivalent to a linalg.copyOp.
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:103
bool isaConvolutionOpInterface(LinalgOp linalgOp, bool allowEmptyConvolvedDims=false)
Checks whether linalgOp conforms to ConvolutionOpInterface.
std::optional< SmallVector< int64_t > > isaBroadcastOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.broadcast.
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:94
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
std::optional< Value > isaFillOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a linalg.fill.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp)
Checks whether genericOp is semantically equivalent to a single linalg elementwise binary op e....
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:829
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Positions of a Linalg op loops that correspond to different kinds of a contraction dimension.
Positions of a Linalg op loops that correspond to different kinds of a convolution dimension.