MLIR  16.0.0git
LinalgInterfaces.cpp
Go to the documentation of this file.
1 //===- LinalgInterfaces.cpp - Linalg interfaces implementation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 
22 using namespace mlir;
23 using namespace mlir::linalg;
24 
25 /// Include the definitions of the copy operation interface.
26 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc"
27 
28 //===----------------------------------------------------------------------===//
29 // Interface utility functions
30 //===----------------------------------------------------------------------===//
32  linalg::LinalgOp linalgOp, ArrayRef<OpOperand *> droppedOperands) {
33  SmallVector<AffineMap> indexingMaps;
34  for (auto &opOperand : linalgOp->getOpOperands()) {
35  if (llvm::is_contained(droppedOperands, &opOperand))
36  continue;
37  indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
38  }
39  if (indexingMaps.empty()) {
40  // If there are no indexing maps, the operand can only be dropped
41  // if the op has no loops.
42  return linalgOp.getNumLoops() == 0;
43  }
44  return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // ContractionOpInterface implementation
49 //===----------------------------------------------------------------------===//
50 
51 /// Return true if the use-def chain from `v` to `from` consists of 0 or more
52 /// unary single-operand operations.
53 // TODO: relax to multi-operands with constants, which are technically unary ops
54 // as needed (e.g. add5).
55 static bool isChainOfUnaryOpsFrom(Value v, Value from) {
56  while (true) {
57  if (v == from)
58  return true;
59  Operation *op = v.getDefiningOp();
60  if (!op || op->getNumOperands() != 1)
61  return false;
62  v = op->getOperand(0);
63  };
64 }
65 
66 /// Return the unique instance of OpType in `block` if it is indeed unique.
67 /// Return null if none or more than 1 instances exist.
68 template <typename OpType>
69 static OpType getSingleOpOfType(Block &block) {
70  OpType res = nullptr;
71  block.walk([&](OpType op) {
72  if (res) {
73  res = nullptr;
74  return WalkResult::interrupt();
75  }
76  res = op;
77  return WalkResult::advance();
78  });
79  return res;
80 }
81 
82 /// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))`
83 /// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent
84 /// unary operations that may change the type.
85 template <typename AddOpType, typename MulOpType>
86 static bool isAddMul(Block &block) {
87  if (block.getNumArguments() != 3)
88  return false;
89  Operation *yieldOp = block.getTerminator();
90  if (yieldOp->getNumOperands() != 1)
91  return false;
92 
93  AddOpType addOp = getSingleOpOfType<AddOpType>(block);
94  MulOpType mulOp = getSingleOpOfType<MulOpType>(block);
95  if (!addOp || !mulOp)
96  return false;
97 
98  Value argA = block.getArgument(0), argB = block.getArgument(1);
99  Value a = mulOp->getOperand(0), b = mulOp->getOperand(1);
100  Value mul = mulOp->getResult(0);
101  Value argC = block.getArgument(2);
102  Value c1 = addOp->getOperand(0), c2 = addOp->getOperand(1);
103  Value add = addOp->getResult(0);
104  Value res = yieldOp->getOperand(0);
105  // Result traces back to add.
106  auto un = isChainOfUnaryOpsFrom;
107  bool success = un(res, add);
108  // One of the operands of add traces back to argC, the other to the mul.
109  success |= (un(c1, argC) && un(c2, mul)) || ((un(c1, mul)) && un(c2, argC));
110  // One of the operands of mul traces back to argA, the other to argB.
111  success |= (un(a, argA) && un(b, argB)) || ((un(a, argB)) && un(b, argA));
112  return success;
113 }
114 
116  Success = 0,
117  NotLinalgOp,
119  NoReduction,
121  NotAddMul
122 };
124  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
125  if (!linalgOp)
127  if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
129  auto mapRange = linalgOp.getIndexingMapsArray();
130  if (linalgOp.getNumReductionLoops() == 0)
132  if (llvm::any_of(mapRange,
133  [](AffineMap m) { return !m.isProjectedPermutation(); }))
135  // TODO: more fields than add/mul.
136  if (!isAddMul<arith::AddFOp, arith::MulFOp>(linalgOp->getRegion(0).front()) &&
137  !isAddMul<arith::AddIOp, arith::MulIOp>(linalgOp->getRegion(0).front()) &&
138  !isAddMul<complex::AddOp, complex::MulOp>(
139  linalgOp->getRegion(0).front()) &&
140  !isAddMul<arith::OrIOp, arith::AndIOp>(linalgOp->getRegion(0).front()))
143 }
144 
146  if (!linalgOp)
147  return false;
148  Operation *op = linalgOp.getOperation();
149  return isa<ContractionOpInterface>(op) ||
151 }
152 
153 /// Verify that a LinalgOp `op` is a contraction.
154 /// A Linalg contraction is defined in general terms:
155 /// 1. Has 2 input and 1 output shapes.
156 /// 2. Has at least one reduction dimension.
157 /// 3. Has only projected permutation indexing maps.
158 /// 4. its body computes `u5(u1(c) + u2(u3(a) * u4(b)))` on some field
159 /// (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent scalar unary
160 /// operations that may change the type (e.g. for mixed-precision).
161 /// As a consequence, when vectorization of such an op occurs, the only special
162 /// behavior is that the (unique) MulOpType is vectorized into a
163 /// `vector.contract`. All other ops are handled in a generic fashion.
164 /// In the future, we may wish to allow more input arguments and elementwise and
165 /// constant operations that do not involve the reduction dimension(s).
167  auto res = isContractionInterfaceImpl(op);
169  return op->emitError("expected a LinalgOp");
171  return op->emitError("expected op with 2 inputs and 1 outputs");
173  return op->emitError("expected at least a reduction loop");
175  return op->emitError("expected all indexings to be projected permutations");
177  return op->emitError("(add, mul) operations not found");
178  return success();
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // ConvolutionOpInterface implementation
183 //===----------------------------------------------------------------------===//
184 
185 /// Of the given two expressions returns one that is of type T (`lhs` gets
186 /// preference over `rhs`)
187 template <typename T>
189  return lhs.isa<T>() ? lhs.cast<T>()
190  : (rhs.isa<T>() ? rhs.cast<T>() : nullptr);
191 }
192 
193 namespace {
194 /// Walk the indexing expressions for input of a convolution operation to verify
195 /// its of the right form, either
196 /// - AffineDimExpr
197 /// - AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?
198 /// (`+` AffineDimExpr (`*` (AffineSymbolExpr | AffineConstantExpr))?)*
199 ///
200 /// classifies the AffineDimExpr as convolved dimensions or unconvolved
201 /// dimensions and verifies each dimension occurs only once.
202 struct ConvAccessExprWalker
203  : public AffineExprVisitor<ConvAccessExprWalker, LogicalResult> {
204  llvm::SmallDenseSet<unsigned> convolvedDims;
205  llvm::SmallDenseSet<unsigned> unConvolvedDims;
206 
207  LogicalResult visitDimExpr(AffineDimExpr dimExpr) {
208  unsigned position = dimExpr.getPosition();
209  if (unConvolvedDims.count(position) || convolvedDims.count(position)) {
210  return failure();
211  }
212  unConvolvedDims.insert(position);
213  return success();
214  }
215 
216  LogicalResult visitSymbolExpr(AffineSymbolExpr expr) { return failure(); }
217 
218  LogicalResult visitConstantExpr(AffineConstantExpr expr) { return failure(); }
219 
220  LogicalResult visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryExpr) {
221  // In pre-order visit, top level op has to be an add op.
222  if (binaryExpr.getKind() != AffineExprKind::Add)
223  return failure();
224  return success(succeeded(isDimExprOrMulExpr(binaryExpr.getLHS())) &&
225  succeeded(isDimExprOrMulExpr(binaryExpr.getRHS())));
226  }
227 
228  LogicalResult isDimExprOrMulExpr(AffineExpr expr) {
229  if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
230  unsigned dim = dimExpr.getPosition();
231  if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
232  return failure();
233  convolvedDims.insert(dim);
234  return success();
235  }
236  if (auto symbolMulExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
237  if (symbolMulExpr.getKind() != AffineExprKind::Mul)
238  return failure();
239  auto lhsExpr = symbolMulExpr.getLHS();
240  auto rhsExpr = symbolMulExpr.getRHS();
241  // Check for symbol expression.
242  AffineExpr mulExpr =
243  getAffineExprOfType<AffineSymbolExpr>(lhsExpr, rhsExpr);
244  // If there was no symbol expr, check for constant expression.
245  if (!mulExpr) {
246  mulExpr = getAffineExprOfType<AffineConstantExpr>(lhsExpr, rhsExpr);
247  }
248  auto dimExpr = getAffineExprOfType<AffineDimExpr>(lhsExpr, rhsExpr);
249  if (!mulExpr || !dimExpr)
250  return failure();
251  unsigned dim = dimExpr.getPosition();
252  if (convolvedDims.count(dim) || unConvolvedDims.count(dim))
253  return failure();
254  convolvedDims.insert(dim);
255  return success();
256  }
257  return failure();
258  }
259 };
260 } // namespace
261 
262 static llvm::SmallDenseSet<unsigned> getPreservedDims(AffineMap map) {
263  assert(map.isProjectedPermutation() &&
264  "expected map to have projected permutations");
265  llvm::SmallDenseSet<unsigned> preservedDims;
266  for (auto expr : map.getResults())
267  preservedDims.insert(expr.cast<AffineDimExpr>().getPosition());
268  return preservedDims;
269 }
270 
272  Success = 0,
273  NotLinalgOp,
280 };
281 
283  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
284  if (!linalgOp)
286  if (linalgOp.getNumDpsInputs() < 2 || linalgOp.getNumDpsInits() != 1)
288 
289  auto indexingMaps = linalgOp.getIndexingMapsArray();
290 
291  // Check the input indexing map has the right form.
292  ConvAccessExprWalker inputExprWalker;
293  if (llvm::any_of(indexingMaps[0].getResults(),
294  [&inputExprWalker](AffineExpr expr) {
295  return failed(inputExprWalker.visit(expr));
296  })) {
298  }
299 
300  // Filter and output maps must be projected permutation.
301  if (!indexingMaps[1].isProjectedPermutation() ||
302  !indexingMaps.back().isProjectedPermutation())
304 
305  auto iteratorTypes = linalgOp.getIteratorTypesArray();
306 
307  llvm::SmallDenseSet<unsigned> outputDims =
308  getPreservedDims(indexingMaps.back());
309  llvm::SmallDenseSet<unsigned> filterDims = getPreservedDims(indexingMaps[1]);
310  // Make sure all loops are charecterized as one of:
311  // - Batch loop : present in output, as non-convolved in input, not present in
312  // filter.
313  // - Output image dimension : present in output, convolved dims in input, not
314  // present in filter.
315  // - Output channel dimension : present in output, not present in input,
316  // present in filter.
317  // - Filter loop dimension : present in filter, convolved in input, not
318  // present in output.
319  // - Input channel dimension : unconvolved in input, not present in output,
320  // present in filter.
321  // - Depth multiplier : unconvolved in input, present in output, present in
322  // filter.
323  llvm::SmallDenseSet<unsigned> allLoopDims;
324  for (auto outputExpr : indexingMaps.back().getResults()) {
325  unsigned outputDim = outputExpr.cast<AffineDimExpr>().getPosition();
326  if (inputExprWalker.unConvolvedDims.count(outputDim) &&
327  !filterDims.count(outputDim)) {
328  // Batch dimension.
329  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
331  allLoopDims.insert(outputDim);
332  continue;
333  }
334  if (inputExprWalker.convolvedDims.count(outputDim) &&
335  !filterDims.count(outputDim)) {
336  // Output image Loop dimension.
337  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
339  allLoopDims.insert(outputDim);
340  continue;
341  }
342  if (!inputExprWalker.convolvedDims.count(outputDim) &&
343  !inputExprWalker.unConvolvedDims.count(outputDim) &&
344  filterDims.count(outputDim)) {
345  // Output channel dimension.
346  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
348  allLoopDims.insert(outputDim);
349  continue;
350  }
351  if (inputExprWalker.unConvolvedDims.count(outputDim) &&
352  filterDims.count(outputDim)) {
353  // Depth multiplier.
354  if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
356  allLoopDims.insert(outputDim);
357  continue;
358  }
360  }
361  for (auto filterExpr : indexingMaps[1].getResults()) {
362  unsigned filterDim = filterExpr.cast<AffineDimExpr>().getPosition();
363  if (outputDims.count(filterDim) &&
364  !inputExprWalker.unConvolvedDims.count(filterDim) &&
365  !inputExprWalker.convolvedDims.count(filterDim)) {
366  // Output channel dimension. THis is already seen, continue;
367  continue;
368  }
369  if (inputExprWalker.convolvedDims.count(filterDim) &&
370  !outputDims.count(filterDim)) {
371  // Filter loop dimension.
372  if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
374  if (allLoopDims.count(filterDim))
376  allLoopDims.insert(filterDim);
377  continue;
378  }
379  if (inputExprWalker.unConvolvedDims.count(filterDim) &&
380  !outputDims.count(filterDim)) {
381  // Input channel dimension.
382  if (iteratorTypes[filterDim] != utils::IteratorType::reduction)
384  if (allLoopDims.count(filterDim))
386  allLoopDims.insert(filterDim);
387  continue;
388  }
389  if (inputExprWalker.unConvolvedDims.count(filterDim) &&
390  outputDims.count(filterDim)) {
391  // Depthwise loop. Already seen.
392  continue;
393  }
395  }
396  // All loops must be covered now.
397  if (allLoopDims.size() != linalgOp.getNumLoops())
399 
401 }
402 
404  auto res = isConvolutionInterfaceImpl(op);
406  return op->emitError("expected a LinalgOp");
408  return op->emitError("expected op with 2 inputs and 1 output");
410  return op->emitError("unexpected input index map for convolutions");
412  return op->emitError(
413  "expected output/filter indexing maps to be projected permutations");
414  }
416  return op->emitError("unexpected loop dimension for convolution op");
417  }
419  return op->emitError(
420  "expected all iterators used to access outputs to be parallel");
421  }
423  return op->emitError(
424  "expected all iterators not used to access outputs to be reduction");
425  }
426  return success();
427 }
428 
429 //===----------------------------------------------------------------------===//
430 // FillOpInterface implementation
431 //===----------------------------------------------------------------------===//
432 
433 enum class MatchFillResult {
434  Success = 0,
435  NotLinalgOp,
438 };
439 
441  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
442  if (!linalgOp)
444  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
446 
447  OpOperand *value = linalgOp.getDpsInputOperand(0);
448  if (!linalgOp.isScalar(value))
450 
452 }
453 
455  auto res = isFillInterfaceImpl(op);
456  if (res == MatchFillResult::NotLinalgOp)
457  return op->emitError("expected a LinalgOp");
459  return op->emitError("expected op with 1 input and 1 output");
461  return op->emitError("expected op with scalar input");
462 
463  return success();
464 }
465 
466 //===----------------------------------------------------------------------===//
467 // StructuredOpInterface implementation
468 //===----------------------------------------------------------------------===//
469 
470 /// Helper function that creates a memref::DimOp or tensor::DimOp depending on
471 /// the type of `source`.
473  int64_t dim) {
474  if (source.getType().isa<UnrankedMemRefType, MemRefType>())
475  return b.createOrFold<memref::DimOp>(loc, source, dim);
476  if (source.getType().isa<UnrankedTensorType, RankedTensorType>())
477  return b.createOrFold<tensor::DimOp>(loc, source, dim);
478  llvm_unreachable("Expected MemRefType or TensorType");
479 }
481  int64_t dim) {
482  auto shapedType = source.getType().cast<ShapedType>();
483  if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
484  return createOrFoldDimOp(b, loc, source, dim);
485  return b.getIndexAttr(shapedType.getDimSize(dim));
486 }
487 
488 SmallVector<OpFoldResult> LinalgOp::createFlatListOfOperandDims(OpBuilder &b,
489  Location loc) {
491  for (OpOperand &opOperand : getOperation()->getOpOperands()) {
492  for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i)
493  res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i));
494  }
495  return res;
496 }
497 
498 SmallVector<int64_t, 4> LinalgOp::createFlatListOfOperandStaticDims() {
500  assert(!hasDynamicShape() && "expected operands to have static shapes");
501  for (OpOperand &opOperand : getOperation()->getOpOperands())
502  llvm::append_range(res, getShape(&opOperand));
503  return res;
504 }
505 
506 SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) {
507  AffineMap map = getLoopsToShapesMap();
508  unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
509  auto viewSizes = createFlatListOfOperandDims(b, loc);
510  SmallVector<Range, 4> res(numDims);
511  for (unsigned idx = 0; idx < numRes; ++idx) {
512  auto result = map.getResult(idx);
513  if (auto d = result.dyn_cast<AffineDimExpr>()) {
514  if (res[d.getPosition()].offset)
515  continue;
516  res[d.getPosition()] =
517  Range{b.getIndexAttr(0), viewSizes[idx], b.getIndexAttr(1)};
518  }
519  }
520  return res;
521 }
522 
523 SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
524  AffineMap map = getLoopsToShapesMap();
525  unsigned numDims = map.getNumDims(), numRes = map.getNumResults();
526  SmallVector<int64_t, 4> allShapeSizes = createFlatListOfOperandStaticDims();
527  SmallVector<int64_t, 4> res(numDims, 0);
528  for (unsigned idx = 0; idx < numRes; ++idx) {
529  auto result = map.getResult(idx);
530  if (auto d = result.dyn_cast<AffineDimExpr>())
531  res[d.getPosition()] = allShapeSizes[idx];
532  }
533  return res;
534 }
535 
536 /// Visitor to check if any of the given set of positions from AffineDimExprs
537 /// are used within an AffineExpr.
539  : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
540  HasAffineDimExprVisitor(llvm::SmallBitVector positions)
541  : positions(std::move(positions)) {}
542 
544  return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
545  }
546 
547  bool visitDimExpr(AffineDimExpr dimExpr) {
548  return positions.test(dimExpr.getPosition());
549  }
550 
551  bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
552 
553  bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
554 
555 private:
556  llvm::SmallBitVector positions;
557 };
558 
559 static std::pair<int64_t, int64_t>
561  int64_t inputRankSum = 0;
562  int64_t outputRankSum = 0;
563  for (OpOperand *input : op.getDpsInputOperands())
564  inputRankSum += op.getRank(input);
565  for (OpOperand *output : op.getDpsInitOperands())
566  outputRankSum += op.getRank(output);
567  return {inputRankSum, inputRankSum + outputRankSum};
568 }
569 
571 LinalgOp::reifyResultShapes(OpBuilder &b,
572  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
573  // An example that helps understand the logic below.
574  // Consider the following expression O(i+j, j) += A(i,k) * B(k, j)
575  // We want to express the shape of dim 0 of O in terms of shape of the inputs.
576  // This is achieved as follows.
577  // loopsToShapesMap = (d0, d1, d2) -> (d0, d2, d2, d1, d0 + d1, d1)
578  // subMapOfResultShapes = (d0, d1, d2) -> (d0 + d1, d1)
579  // shapesToLoopsMap = (d0, d2, d2, d3, d4, d5) -> (d0, d3, d2)
580  // resultShapesFromInputShapes = subMapOfResultDim.compose(shapesToLoopMap)
581  // = (d0, d1, d2, d3, d4, d5) -> (d0 + d1, d1)
582  AffineMap loopsToShapesMap = getLoopsToShapesMap();
583 
584  // Find the position in the above map that represents the shape of the
585  // result:dim being inferred.
586  auto resultShapesSubMapPos = getResultsPositionInLoopsToShapeMap(*this);
587 
588  /// From loopsToShapesMap extract the submap that represents the shape of the
589  /// (resultIdx, dim) needed.
590  AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
591  resultShapesSubMapPos.first,
592  resultShapesSubMapPos.second - resultShapesSubMapPos.first);
593  AffineMap resultShapesFromInputShapesMap =
594  loopToResultsShapeMap.compose(getShapesToLoopsMap());
595 
596  // Check that the result dim map does not contain the positions corresponding
597  // to the outputs.
598  llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
599  outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
600  HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
601  Location loc = getOperation()->getLoc();
602  IRRewriter rewriter(b);
603  SmallVector<OpFoldResult> allResultDimValues =
605  rewriter, loc, resultShapesFromInputShapesMap,
606  createFlatListOfOperandDims(b, loc));
607  int64_t pos = 0;
608  ArrayRef<AffineExpr> shapeExprs = resultShapesFromInputShapesMap.getResults();
609  for (OpOperand *opOperand : getDpsInitOperands()) {
610  SmallVector<Value> shapes;
611  for (int64_t dim : llvm::seq<int64_t>(0, getRank(opOperand))) {
612  if (checkDimExpr.visit(shapeExprs[pos]))
613  shapes.push_back(createOrFoldDimOp(b, loc, opOperand->get(), dim));
614  else
615  shapes.push_back(
616  getValueOrCreateConstantIndexOp(b, loc, allResultDimValues[pos]));
617  pos++;
618  }
619  reifiedReturnShapes.emplace_back(std::move(shapes));
620  }
621  return success();
622 }
623 
625  LinalgOp linalgOp = cast<LinalgOp>(op);
626 
627  // Before checking indexing maps, we need to make sure the attributes
628  // referenced by it are valid.
629  if (linalgOp.hasDynamicIndexingMaps())
630  if (failed(linalgOp.verifyIndexingMapRequiredAttributes()))
631  return failure();
632 
633  // All input/output operands must be indexed.
634  if (static_cast<int64_t>(linalgOp.getIndexingMapsArray().size()) !=
635  linalgOp->getNumOperands())
636  return op->emitOpError("expected the number of indexing_map (")
637  << linalgOp.getIndexingMapsArray().size()
638  << ") to be equal to the number of input/output operands ("
639  << linalgOp->getNumOperands() << ")";
640 
641  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
642  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
643 
644  // Symbols disallowed.
645  if (indexingMap.getNumSymbols() != 0)
646  return op->emitOpError("unexpected symbols in indexing_map #")
647  << opOperand.getOperandNumber();
648 
649  // Domain must be consistent.
650  unsigned numLoops = linalgOp.getNumLoops();
651  if (indexingMap.getNumDims() != numLoops)
652  return op->emitOpError("expected indexing_map #")
653  << opOperand.getOperandNumber() << " to have " << numLoops
654  << " dim(s) to match the number of loops";
655 
656  int64_t rank = linalgOp.getRank(&opOperand);
657  if (indexingMap.getNumResults() != rank)
658  return op->emitOpError("expected operand rank (")
659  << rank << ") to match the result rank of indexing_map #"
660  << opOperand.getOperandNumber() << " ("
661  << indexingMap.getNumResults() << ")";
662  }
663 
664  SmallVector<unsigned> redDims;
665  linalgOp.getReductionDims(redDims);
666 
667  if (!linalgOp.getShapesToLoopsMap())
668  return op->emitOpError("expected the shape-to-loops map to be non-null");
669 
670  // Check if given shapes match to inferred shapes.
671  SmallVector<int64_t, 4> endLoopRangeValues = linalgOp.getStaticLoopRanges();
672  SmallVector<int64_t, 4> startLoopRangeValues(endLoopRangeValues.size(), 0);
673 
674  // Verify only static cases since we can't get exact dimension sizes and loop
675  // ranges for dynamic cases in this stage.
676  if (llvm::none_of(endLoopRangeValues, ShapedType::isDynamic)) {
677  for (int64_t &range : endLoopRangeValues)
678  range -= 1;
679  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
680  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
681  SmallVector<int64_t, 4> startIndices =
682  indexingMap.compose(startLoopRangeValues);
683  SmallVector<int64_t, 4> endIndices =
684  indexingMap.compose(endLoopRangeValues);
685  ArrayRef<int64_t> shape = linalgOp.getShape(&opOperand);
686  for (auto dim : llvm::seq<int64_t>(0, shape.size())) {
687  // Ignore dynamic dimension or the case that the dimension size is 0
688  if (ShapedType::isDynamic(shape[dim]) || shape[dim] == 0)
689  continue;
690 
691  // The first index or last index should be the maximum or the minimum in
692  // the inferred index ranges since the range is increasing or
693  // decreasing. The size of dimensions of input/output operands and the
694  // maximum value + 1 in the inferred range should be the same. But, for
695  // now we check if the inferred ranges are in boundary of input/output
696  // operands' size or not in case that Affine Expressions are complicated
697  // such as d0 * 3
698  // + d1 since it is not easy to handle the issues.
699  // Found the case that this solution can't check, for example, (d0, d1)
700  // -> (d1 - d0)
701  int64_t inferredDimSize =
702  std::max(startIndices[dim], endIndices[dim]) + 1;
703  if (std::min(startIndices[dim], endIndices[dim]) < 0) {
704  std::string mapStr;
705  {
706  llvm::raw_string_ostream os(mapStr);
707  os << indexingMap;
708  }
709  return op->emitOpError(
710  "unexpected result less than 0 at expression #")
711  << dim << " in " << mapStr;
712  }
713  if (indexingMap.getResult(dim).dyn_cast<AffineDimExpr>()) {
714  if (inferredDimSize != shape[dim]) {
715  return op->emitOpError("inferred input/output operand #")
716  << opOperand.getOperandNumber() << " has shape's dimension #"
717  << dim << " to be " << inferredDimSize << ", but found "
718  << shape[dim];
719  }
720  } else {
721  if (inferredDimSize > shape[dim]) {
722  return op->emitOpError("inferred input/output operand #")
723  << opOperand.getOperandNumber() << " has shape's dimension #"
724  << dim << " to be greater than or equal to "
725  << inferredDimSize << ", but found " << shape[dim];
726  }
727  }
728  }
729  }
730  }
731 
732  // Check the region has exactly one block.
733  if (linalgOp->getNumRegions() != 1 ||
734  !llvm::hasSingleElement(linalgOp->getRegion(0)))
735  return op->emitOpError("expects to have 1 region with 1 block");
736 
737  // Simplifying assumption: bbargs match 1-1 with shape operands elemental
738  // types.
739  // TODO: once ranked shape types are plugged in, we may want to drop the
740  // corresponding bbargs, that can never be read from. This will be subject to
741  // consistency discussions (i.e. what to do with output tensors whose bbarg is
742  // not used).
743  Block &block = linalgOp->getRegion(0).front();
744 
745  if (linalgOp.getOpOperandsMatchingBBargs().size() != block.getNumArguments())
746  return op->emitOpError("expected as many non-induction variable region "
747  "arguments as the number of input/output operands");
748 
749  for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) {
750  Type elementType = getElementTypeOrSelf(opOperand->get());
751  Type argType = block.getArgument(opOperand->getOperandNumber()).getType();
752  if (elementType != argType)
753  return op->emitOpError("expected type of bb argument #")
754  << opOperand->getOperandNumber() << " (" << argType << ")"
755  << " to match element or self type of the corresponding operand ("
756  << elementType << ")";
757  }
758 
759  return success();
760 }
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:62
static constexpr const bool value
static llvm::SmallDenseSet< unsigned > getPreservedDims(AffineMap map)
static T getAffineExprOfType(AffineExpr lhs, AffineExpr rhs)
Of the given two expressions returns one that is of type T (lhs gets preference over rhs)
MatchContractionResult
static MatchFillResult isFillInterfaceImpl(Operation *op)
static bool isAddMul(Block &block)
Detect whether res is any permutation of u5(u1(c) + u2(u3(a) * u4(b))) on the field (AddOpType,...
static bool isChainOfUnaryOpsFrom(Value v, Value from)
Return true if the use-def chain from v to from consists of 0 or more unary single-operand operations...
MatchConvolutionResult
static MatchContractionResult isContractionInterfaceImpl(Operation *op)
static OpType getSingleOpOfType(Block &block)
Return the unique instance of OpType in block if it is indeed unique.
static MatchConvolutionResult isConvolutionInterfaceImpl(Operation *op)
MatchFillResult
static std::pair< int64_t, int64_t > getResultsPositionInLoopsToShapeMap(LinalgOp &op)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Affine binary operation expression.
Definition: AffineExpr.h:207
AffineExpr getLHS() const
Definition: AffineExpr.cpp:303
AffineExpr getRHS() const
Definition: AffineExpr.cpp:306
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
unsigned getPosition() const
Definition: AffineExpr.cpp:311
Base class for AffineExpr visitors/walkers.
Base type for affine expression.
Definition: AffineExpr.h:68
U cast() const
Definition: AffineExpr.h:291
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:26
U dyn_cast() const
Definition: AffineExpr.h:281
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:42
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:538
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:494
unsigned getNumSymbols() const
Definition: AffineMap.cpp:310
unsigned getNumDims() const
Definition: AffineMap.cpp:306
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:319
unsigned getNumResults() const
Definition: AffineMap.cpp:314
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:323
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:455
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:224
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:118
unsigned getNumArguments()
Definition: Block.h:117
RetT walk(FnT &&callback)
Walk the operations in this block.
Definition: Block.h:271
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
Operation & front()
Definition: Block.h:142
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:109
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:589
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
This class helps build Operations.
Definition: Builders.h:198
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:472
This class represents a single result from folding an operation.
Definition: OpDefinition.h:233
This class represents an operand of an operation.
Definition: Value.h:247
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Value getOperand(unsigned idx)
Definition: Operation.h:267
unsigned getNumOperands()
Definition: Operation.h:263
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:225
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:512
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:280
bool isa() const
Definition: Types.h:260
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
bool isa() const
Definition: Value.h:90
Type getType() const
Return the type of this value.
Definition: Value.h:114
U cast() const
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static WalkResult advance()
Definition: Visitors.h:51
static WalkResult interrupt()
Definition: Visitors.h:50
bool canOpOperandsBeDroppedImpl(linalg::LinalgOp linalgOp, ArrayRef< OpOperand * > droppedOperands)
Implementation of the method that that check if given operands can be dropped, i.e.
LogicalResult verifyContractionInterface(Operation *op)
Verify that op conforms to ContractionOpInterface.
LogicalResult verifyFillInterface(Operation *op)
Verify that op conforms to the FillOpInterface.
LogicalResult verifyStructuredOpInterface(Operation *op)
Verify that op conforms to the invariants of StructuredOpInterface.
LogicalResult verifyConvolutionInterface(Operation *op)
Verify that op conforms to the ConvolutionOpInterface.
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Helper function that creates a memref::DimOp or tensor::DimOp depending on the type of source.
Definition: Utils.cpp:200
bool isaContractionOpInterface(LinalgOp linalgOp)
Checks whether linalgOp conforms to ContractionOpInterface.
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source, int64_t dim)
Definition: Utils.cpp:208
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:669
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
Definition: AffineMap.cpp:714
@ Mul
RHS of mul is always a constant or a symbolic expression.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1037
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:53
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
Visitor to check if any of the given set of positions from AffineDimExprs are used within an AffineEx...
HasAffineDimExprVisitor(llvm::SmallBitVector positions)
bool visitDimExpr(AffineDimExpr dimExpr)
bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr)
bool visitSymbolExpr(AffineSymbolExpr symbolExpr)
bool visitConstantExpr(AffineConstantExpr constExpr)
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...