MLIR  17.0.0git
AffineOps.cpp
Go to the documentation of this file.
1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/IRMapping.h"
14 #include "mlir/IR/IntegerSet.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/IR/OpDefinition.h"
17 #include "mlir/IR/PatternMatch.h"
21 #include "llvm/ADT/ScopeExit.h"
22 #include "llvm/ADT/SmallBitVector.h"
23 #include "llvm/ADT/SmallVectorExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/Debug.h"
26 #include <numeric>
27 #include <optional>
28 
29 using namespace mlir;
30 using namespace mlir::affine;
31 
32 #define DEBUG_TYPE "affine-ops"
33 
34 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
35 
36 /// A utility function to check if a value is defined at the top level of
37 /// `region` or is an argument of `region`. A value of index type defined at the
38 /// top level of a `AffineScope` region is always a valid symbol for all
39 /// uses in that region.
41  if (auto arg = llvm::dyn_cast<BlockArgument>(value))
42  return arg.getParentRegion() == region;
43  return value.getDefiningOp()->getParentRegion() == region;
44 }
45 
46 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
47 /// region remains legal if the operation that uses it is inlined into `dest`
48 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
49 /// `isValidSymbol`, depending on the value being required to remain a valid
50 /// dimension or symbol.
51 static bool
53  const IRMapping &mapping,
54  function_ref<bool(Value, Region *)> legalityCheck) {
55  // If the value is a valid dimension for any other reason than being
56  // a top-level value, it will remain valid: constants get inlined
57  // with the function, transitive affine applies also get inlined and
58  // will be checked themselves, etc.
59  if (!isTopLevelValue(value, src))
60  return true;
61 
62  // If it's a top-level value because it's a block operand, i.e. a
63  // function argument, check whether the value replacing it after
64  // inlining is a valid dimension in the new region.
65  if (llvm::isa<BlockArgument>(value))
66  return legalityCheck(mapping.lookup(value), dest);
67 
68  // If it's a top-level value because it's defined in the region,
69  // it can only be inlined if the defining op is a constant or a
70  // `dim`, which can appear anywhere and be valid, since the defining
71  // op won't be top-level anymore after inlining.
72  Attribute operandCst;
73  bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp());
74  return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
75  isDimLikeOp;
76 }
77 
78 /// Checks if all values known to be legal affine dimensions or symbols in `src`
79 /// remain so if their respective users are inlined into `dest`.
80 static bool
82  const IRMapping &mapping,
83  function_ref<bool(Value, Region *)> legalityCheck) {
84  return llvm::all_of(values, [&](Value v) {
85  return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
86  });
87 }
88 
89 /// Checks if an affine read or write operation remains legal after inlining
90 /// from `src` to `dest`.
91 template <typename OpTy>
92 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
93  const IRMapping &mapping) {
94  static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
95  AffineWriteOpInterface>::value,
96  "only ops with affine read/write interface are supported");
97 
98  AffineMap map = op.getAffineMap();
99  ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
100  ValueRange symbolOperands =
101  op.getMapOperands().take_back(map.getNumSymbols());
103  dimOperands, src, dest, mapping,
104  static_cast<bool (*)(Value, Region *)>(isValidDim)))
105  return false;
107  symbolOperands, src, dest, mapping,
108  static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
109  return false;
110  return true;
111 }
112 
113 /// Checks if an affine apply operation remains legal after inlining from `src`
114 /// to `dest`.
115 // Use "unused attribute" marker to silence clang-tidy warning stemming from
116 // the inability to see through "llvm::TypeSwitch".
117 template <>
118 bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op,
119  Region *src, Region *dest,
120  const IRMapping &mapping) {
121  // If it's a valid dimension, we need to check that it remains so.
122  if (isValidDim(op.getResult(), src))
124  op.getMapOperands(), src, dest, mapping,
125  static_cast<bool (*)(Value, Region *)>(isValidDim));
126 
127  // Otherwise it must be a valid symbol, check that it remains so.
129  op.getMapOperands(), src, dest, mapping,
130  static_cast<bool (*)(Value, Region *)>(isValidSymbol));
131 }
132 
133 //===----------------------------------------------------------------------===//
134 // AffineDialect Interfaces
135 //===----------------------------------------------------------------------===//
136 
137 namespace {
138 /// This class defines the interface for handling inlining with affine
139 /// operations.
140 struct AffineInlinerInterface : public DialectInlinerInterface {
142 
143  //===--------------------------------------------------------------------===//
144  // Analysis Hooks
145  //===--------------------------------------------------------------------===//
146 
147  /// Returns true if the given region 'src' can be inlined into the region
148  /// 'dest' that is attached to an operation registered to the current dialect.
149  /// 'wouldBeCloned' is set if the region is cloned into its new location
150  /// rather than moved, indicating there may be other users.
151  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
152  IRMapping &valueMapping) const final {
153  // We can inline into affine loops and conditionals if this doesn't break
154  // affine value categorization rules.
155  Operation *destOp = dest->getParentOp();
156  if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
157  return false;
158 
159  // Multi-block regions cannot be inlined into affine constructs, all of
160  // which require single-block regions.
161  if (!llvm::hasSingleElement(*src))
162  return false;
163 
164  // Side-effecting operations that the affine dialect cannot understand
165  // should not be inlined.
166  Block &srcBlock = src->front();
167  for (Operation &op : srcBlock) {
168  // Ops with no side effects are fine,
169  if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
170  if (iface.hasNoEffect())
171  continue;
172  }
173 
174  // Assuming the inlined region is valid, we only need to check if the
175  // inlining would change it.
176  bool remainsValid =
178  .Case<AffineApplyOp, AffineReadOpInterface,
179  AffineWriteOpInterface>([&](auto op) {
180  return remainsLegalAfterInline(op, src, dest, valueMapping);
181  })
182  .Default([](Operation *) {
183  // Conservatively disallow inlining ops we cannot reason about.
184  return false;
185  });
186 
187  if (!remainsValid)
188  return false;
189  }
190 
191  return true;
192  }
193 
194  /// Returns true if the given operation 'op', that is registered to this
195  /// dialect, can be inlined into the given region, false otherwise.
196  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
197  IRMapping &valueMapping) const final {
198  // Always allow inlining affine operations into a region that is marked as
199  // affine scope, or into affine loops and conditionals. There are some edge
200  // cases when inlining *into* affine structures, but that is handled in the
201  // other 'isLegalToInline' hook above.
202  Operation *parentOp = region->getParentOp();
203  return parentOp->hasTrait<OpTrait::AffineScope>() ||
204  isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
205  }
206 
207  /// Affine regions should be analyzed recursively.
208  bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
209 };
210 } // namespace
211 
212 //===----------------------------------------------------------------------===//
213 // AffineDialect
214 //===----------------------------------------------------------------------===//
215 
216 void AffineDialect::initialize() {
217  addOperations<AffineDmaStartOp, AffineDmaWaitOp,
218 #define GET_OP_LIST
219 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
220  >();
221  addInterfaces<AffineInlinerInterface>();
222 }
223 
224 /// Materialize a single constant operation from a given attribute value with
225 /// the desired resultant type.
227  Attribute value, Type type,
228  Location loc) {
229  return arith::ConstantOp::materialize(builder, value, type, loc);
230 }
231 
232 /// A utility function to check if a value is defined at the top level of an
233 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
234 /// conservatively assume it is not top-level. A value of index type defined at
235 /// the top level is always a valid symbol.
237  if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
238  // The block owning the argument may be unlinked, e.g. when the surrounding
239  // region has not yet been attached to an Op, at which point the parent Op
240  // is null.
241  Operation *parentOp = arg.getOwner()->getParentOp();
242  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
243  }
244  // The defining Op may live in an unlinked block so its parent Op may be null.
245  Operation *parentOp = value.getDefiningOp()->getParentOp();
246  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
247 }
248 
249 /// Returns the closest region enclosing `op` that is held by an operation with
250 /// trait `AffineScope`; `nullptr` if there is no such region.
252  auto *curOp = op;
253  while (auto *parentOp = curOp->getParentOp()) {
254  if (parentOp->hasTrait<OpTrait::AffineScope>())
255  return curOp->getParentRegion();
256  curOp = parentOp;
257  }
258  return nullptr;
259 }
260 
261 // A Value can be used as a dimension id iff it meets one of the following
262 // conditions:
263 // *) It is valid as a symbol.
264 // *) It is an induction variable.
265 // *) It is the result of affine apply operation with dimension id arguments.
267  // The value must be an index type.
268  if (!value.getType().isIndex())
269  return false;
270 
271  if (auto *defOp = value.getDefiningOp())
272  return isValidDim(value, getAffineScope(defOp));
273 
274  // This value has to be a block argument for an op that has the
275  // `AffineScope` trait or for an affine.for or affine.parallel.
276  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
277  return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
278  isa<AffineForOp, AffineParallelOp>(parentOp));
279 }
280 
281 // Value can be used as a dimension id iff it meets one of the following
282 // conditions:
283 // *) It is valid as a symbol.
284 // *) It is an induction variable.
285 // *) It is the result of an affine apply operation with dimension id operands.
286 bool mlir::affine::isValidDim(Value value, Region *region) {
287  // The value must be an index type.
288  if (!value.getType().isIndex())
289  return false;
290 
291  // All valid symbols are okay.
292  if (isValidSymbol(value, region))
293  return true;
294 
295  auto *op = value.getDefiningOp();
296  if (!op) {
297  // This value has to be a block argument for an affine.for or an
298  // affine.parallel.
299  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
300  return isa<AffineForOp, AffineParallelOp>(parentOp);
301  }
302 
303  // Affine apply operation is ok if all of its operands are ok.
304  if (auto applyOp = dyn_cast<AffineApplyOp>(op))
305  return applyOp.isValidDim(region);
306  // The dim op is okay if its operand memref/tensor is defined at the top
307  // level.
308  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
309  return isTopLevelValue(dimOp.getShapedValue());
310  return false;
311 }
312 
313 /// Returns true if the 'index' dimension of the `memref` defined by
314 /// `memrefDefOp` is a statically shaped one or defined using a valid symbol
315 /// for `region`.
316 template <typename AnyMemRefDefOp>
317 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
318  Region *region) {
319  auto memRefType = memrefDefOp.getType();
320  // Statically shaped.
321  if (!memRefType.isDynamicDim(index))
322  return true;
323  // Get the position of the dimension among dynamic dimensions;
324  unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
325  return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
326  region);
327 }
328 
329 /// Returns true if the result of the dim op is a valid symbol for `region`.
330 static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
331  // The dim op is okay if its source is defined at the top level.
332  if (isTopLevelValue(dimOp.getShapedValue()))
333  return true;
334 
335  // Conservatively handle remaining BlockArguments as non-valid symbols.
336  // E.g. scf.for iterArgs.
337  if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
338  return false;
339 
340  // The dim op is also okay if its operand memref is a view/subview whose
341  // corresponding size is a valid symbol.
342  std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
343 
344  // Be conservative if we can't understand the dimension.
345  if (!index.has_value())
346  return false;
347 
348  int64_t i = index.value();
349  return TypeSwitch<Operation *, bool>(dimOp.getShapedValue().getDefiningOp())
350  .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
351  [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
352  .Default([](Operation *) { return false; });
353 }
354 
355 // A value can be used as a symbol (at all its use sites) iff it meets one of
356 // the following conditions:
357 // *) It is a constant.
358 // *) Its defining op or block arg appearance is immediately enclosed by an op
359 // with `AffineScope` trait.
360 // *) It is the result of an affine.apply operation with symbol operands.
361 // *) It is a result of the dim op on a memref whose corresponding size is a
362 // valid symbol.
364  if (!value)
365  return false;
366 
367  // The value must be an index type.
368  if (!value.getType().isIndex())
369  return false;
370 
371  // Check that the value is a top level value.
372  if (isTopLevelValue(value))
373  return true;
374 
375  if (auto *defOp = value.getDefiningOp())
376  return isValidSymbol(value, getAffineScope(defOp));
377 
378  return false;
379 }
380 
381 /// A value can be used as a symbol for `region` iff it meets one of the
382 /// following conditions:
383 /// *) It is a constant.
384 /// *) It is the result of an affine apply operation with symbol arguments.
385 /// *) It is a result of the dim op on a memref whose corresponding size is
386 /// a valid symbol.
387 /// *) It is defined at the top level of 'region' or is its argument.
388 /// *) It dominates `region`'s parent op.
389 /// If `region` is null, conservatively assume the symbol definition scope does
390 /// not exist and only accept the values that would be symbols regardless of
391 /// the surrounding region structure, i.e. the first three cases above.
393  // The value must be an index type.
394  if (!value.getType().isIndex())
395  return false;
396 
397  // A top-level value is a valid symbol.
398  if (region && ::isTopLevelValue(value, region))
399  return true;
400 
401  auto *defOp = value.getDefiningOp();
402  if (!defOp) {
403  // A block argument that is not a top-level value is a valid symbol if it
404  // dominates region's parent op.
405  Operation *regionOp = region ? region->getParentOp() : nullptr;
406  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
407  if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
408  return isValidSymbol(value, parentOpRegion);
409  return false;
410  }
411 
412  // Constant operation is ok.
413  Attribute operandCst;
414  if (matchPattern(defOp, m_Constant(&operandCst)))
415  return true;
416 
417  // Affine apply operation is ok if all of its operands are ok.
418  if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
419  return applyOp.isValidSymbol(region);
420 
421  // Dim op results could be valid symbols at any level.
422  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
423  return isDimOpValidSymbol(dimOp, region);
424 
425  // Check for values dominating `region`'s parent op.
426  Operation *regionOp = region ? region->getParentOp() : nullptr;
427  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
428  if (auto *parentRegion = region->getParentOp()->getParentRegion())
429  return isValidSymbol(value, parentRegion);
430 
431  return false;
432 }
433 
434 // Returns true if 'value' is a valid index to an affine operation (e.g.
435 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
436 // `region` provides the polyhedral symbol scope. Returns false otherwise.
437 static bool isValidAffineIndexOperand(Value value, Region *region) {
438  return isValidDim(value, region) || isValidSymbol(value, region);
439 }
440 
441 /// Prints dimension and symbol list.
444  unsigned numDims, OpAsmPrinter &printer) {
445  OperandRange operands(begin, end);
446  printer << '(' << operands.take_front(numDims) << ')';
447  if (operands.size() > numDims)
448  printer << '[' << operands.drop_front(numDims) << ']';
449 }
450 
451 /// Parses dimension and symbol list and returns true if parsing failed.
453  OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
455  if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
456  return failure();
457  // Store number of dimensions for validation by caller.
458  numDims = opInfos.size();
459 
460  // Parse the optional symbol operands.
461  auto indexTy = parser.getBuilder().getIndexType();
462  return failure(parser.parseOperandList(
464  parser.resolveOperands(opInfos, indexTy, operands));
465 }
466 
467 /// Utility function to verify that a set of operands are valid dimension and
468 /// symbol identifiers. The operands should be laid out such that the dimension
469 /// operands are before the symbol operands. This function returns failure if
470 /// there was an invalid operand. An operation is provided to emit any necessary
471 /// errors.
472 template <typename OpTy>
473 static LogicalResult
475  unsigned numDims) {
476  unsigned opIt = 0;
477  for (auto operand : operands) {
478  if (opIt++ < numDims) {
479  if (!isValidDim(operand, getAffineScope(op)))
480  return op.emitOpError("operand cannot be used as a dimension id");
481  } else if (!isValidSymbol(operand, getAffineScope(op))) {
482  return op.emitOpError("operand cannot be used as a symbol");
483  }
484  }
485  return success();
486 }
487 
488 //===----------------------------------------------------------------------===//
489 // AffineApplyOp
490 //===----------------------------------------------------------------------===//
491 
492 AffineValueMap AffineApplyOp::getAffineValueMap() {
493  return AffineValueMap(getAffineMap(), getOperands(), getResult());
494 }
495 
496 ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
497  auto &builder = parser.getBuilder();
498  auto indexTy = builder.getIndexType();
499 
500  AffineMapAttr mapAttr;
501  unsigned numDims;
502  if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
503  parseDimAndSymbolList(parser, result.operands, numDims) ||
504  parser.parseOptionalAttrDict(result.attributes))
505  return failure();
506  auto map = mapAttr.getValue();
507 
508  if (map.getNumDims() != numDims ||
509  numDims + map.getNumSymbols() != result.operands.size()) {
510  return parser.emitError(parser.getNameLoc(),
511  "dimension or symbol index mismatch");
512  }
513 
514  result.types.append(map.getNumResults(), indexTy);
515  return success();
516 }
517 
519  p << " " << getMapAttr();
520  printDimAndSymbolList(operand_begin(), operand_end(),
521  getAffineMap().getNumDims(), p);
522  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
523 }
524 
526  // Check input and output dimensions match.
527  AffineMap affineMap = getMap();
528 
529  // Verify that operand count matches affine map dimension and symbol count.
530  if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
531  return emitOpError(
532  "operand count and affine map dimension and symbol count must match");
533 
534  // Verify that the map only produces one result.
535  if (affineMap.getNumResults() != 1)
536  return emitOpError("mapping must produce one value");
537 
538  return success();
539 }
540 
541 // The result of the affine apply operation can be used as a dimension id if all
542 // its operands are valid dimension ids.
544  return llvm::all_of(getOperands(),
545  [](Value op) { return affine::isValidDim(op); });
546 }
547 
548 // The result of the affine apply operation can be used as a dimension id if all
549 // its operands are valid dimension ids with the parent operation of `region`
550 // defining the polyhedral scope for symbols.
551 bool AffineApplyOp::isValidDim(Region *region) {
552  return llvm::all_of(getOperands(),
553  [&](Value op) { return ::isValidDim(op, region); });
554 }
555 
556 // The result of the affine apply operation can be used as a symbol if all its
557 // operands are symbols.
559  return llvm::all_of(getOperands(),
560  [](Value op) { return affine::isValidSymbol(op); });
561 }
562 
563 // The result of the affine apply operation can be used as a symbol in `region`
564 // if all its operands are symbols in `region`.
565 bool AffineApplyOp::isValidSymbol(Region *region) {
566  return llvm::all_of(getOperands(), [&](Value operand) {
567  return affine::isValidSymbol(operand, region);
568  });
569 }
570 
571 OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
572  auto map = getAffineMap();
573 
574  // Fold dims and symbols to existing values.
575  auto expr = map.getResult(0);
576  if (auto dim = expr.dyn_cast<AffineDimExpr>())
577  return getOperand(dim.getPosition());
578  if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
579  return getOperand(map.getNumDims() + sym.getPosition());
580 
581  // Otherwise, default to folding the map.
583  if (failed(map.constantFold(adaptor.getMapOperands(), result)))
584  return {};
585  return result[0];
586 }
587 
588 /// Returns the largest known divisor of `e`. Exploits information from the
589 /// values in `operands`.
590 static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
591  // This method isn't aware of `operands`.
592  int64_t div = e.getLargestKnownDivisor();
593 
594  // We now make use of operands for the case `e` is a dim expression.
595  // TODO: More powerful simplification would have to modify
596  // getLargestKnownDivisor to take `operands` and exploit that information as
597  // well for dim/sym expressions, but in that case, getLargestKnownDivisor
598  // can't be part of the IR library but of the `Analysis` library. The IR
599  // library can only really depend on simple O(1) checks.
600  auto dimExpr = e.dyn_cast<AffineDimExpr>();
601  // If it's not a dim expr, `div` is the best we have.
602  if (!dimExpr)
603  return div;
604 
605  // We simply exploit information from loop IVs.
606  // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
607  // desired simplifications are expected to be part of other
608  // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
609  // LoopAnalysis library.
610  Value operand = operands[dimExpr.getPosition()];
611  int64_t operandDivisor = 1;
612  // TODO: With the right accessors, this can be extended to
613  // LoopLikeOpInterface.
614  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
615  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
616  operandDivisor = forOp.getStep();
617  } else {
618  uint64_t lbLargestKnownDivisor =
619  forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
620  operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStep());
621  }
622  }
623  return operandDivisor;
624 }
625 
626 /// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
627 /// being an affine dim expression or a constant.
629  int64_t k) {
630  if (auto constExpr = e.dyn_cast<AffineConstantExpr>()) {
631  int64_t constVal = constExpr.getValue();
632  return constVal >= 0 && constVal < k;
633  }
634  auto dimExpr = e.dyn_cast<AffineDimExpr>();
635  if (!dimExpr)
636  return false;
637  Value operand = operands[dimExpr.getPosition()];
638  // TODO: With the right accessors, this can be extended to
639  // LoopLikeOpInterface.
640  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
641  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
642  forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
643  return true;
644  }
645  }
646 
647  // We don't consider other cases like `operand` being defined by a constant or
648  // an affine.apply op since such cases will already be handled by other
649  // patterns and propagation of loop IVs or constant would happen.
650  return false;
651 }
652 
653 /// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
654 /// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
655 /// expression is in that form.
656 static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
657  AffineExpr &quotientTimesDiv, AffineExpr &rem) {
658  auto bin = e.dyn_cast<AffineBinaryOpExpr>();
659  if (!bin || bin.getKind() != AffineExprKind::Add)
660  return false;
661 
662  AffineExpr llhs = bin.getLHS();
663  AffineExpr rlhs = bin.getRHS();
664  div = getLargestKnownDivisor(llhs, operands);
665  if (isNonNegativeBoundedBy(rlhs, operands, div)) {
666  quotientTimesDiv = llhs;
667  rem = rlhs;
668  return true;
669  }
670  div = getLargestKnownDivisor(rlhs, operands);
671  if (isNonNegativeBoundedBy(llhs, operands, div)) {
672  quotientTimesDiv = rlhs;
673  rem = llhs;
674  return true;
675  }
676  return false;
677 }
678 
679 /// Gets the constant lower bound on an `iv`.
680 static std::optional<int64_t> getLowerBound(Value iv) {
681  AffineForOp forOp = getForInductionVarOwner(iv);
682  if (forOp && forOp.hasConstantLowerBound())
683  return forOp.getConstantLowerBound();
684  return std::nullopt;
685 }
686 
687 /// Gets the constant upper bound on an affine.for `iv`.
688 static std::optional<int64_t> getUpperBound(Value iv) {
689  AffineForOp forOp = getForInductionVarOwner(iv);
690  if (!forOp || !forOp.hasConstantUpperBound())
691  return std::nullopt;
692 
693  // If its lower bound is also known, we can get a more precise bound
694  // whenever the step is not one.
695  if (forOp.hasConstantLowerBound()) {
696  return forOp.getConstantUpperBound() - 1 -
697  (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
698  forOp.getStep();
699  }
700  return forOp.getConstantUpperBound() - 1;
701 }
702 
703 /// Get a lower or upper (depending on `isUpper`) bound for `expr` while using
704 /// the constant lower and upper bounds for its inputs provided in
705 /// `constLowerBounds` and `constUpperBounds`. Return std::nullopt if such a
706 /// bound can't be computed. This method only handles simple sum of product
707 /// expressions (w.r.t constant coefficients) so as to not depend on anything
708 /// heavyweight in `Analysis`. Expressions of the form: c0*d0 + c1*d1 + c2*s0 +
709 /// ... + c_n are handled. Expressions involving floordiv, ceildiv, mod or
710 /// semi-affine ones will lead std::nullopt being returned.
711 static std::optional<int64_t>
712 getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols,
713  ArrayRef<std::optional<int64_t>> constLowerBounds,
714  ArrayRef<std::optional<int64_t>> constUpperBounds,
715  bool isUpper) {
716  // Handle divs and mods.
717  if (auto binOpExpr = expr.dyn_cast<AffineBinaryOpExpr>()) {
718  // If the LHS of a floor or ceil is bounded and the RHS is a constant, we
719  // can compute an upper bound.
720  if (binOpExpr.getKind() == AffineExprKind::FloorDiv) {
721  auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
722  if (!rhsConst || rhsConst.getValue() < 1)
723  return std::nullopt;
724  auto bound = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
725  constLowerBounds, constUpperBounds, isUpper);
726  if (!bound)
727  return std::nullopt;
728  return mlir::floorDiv(*bound, rhsConst.getValue());
729  }
730  if (binOpExpr.getKind() == AffineExprKind::CeilDiv) {
731  auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
732  if (rhsConst && rhsConst.getValue() >= 1) {
733  auto bound =
734  getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
735  constLowerBounds, constUpperBounds, isUpper);
736  if (!bound)
737  return std::nullopt;
738  return mlir::ceilDiv(*bound, rhsConst.getValue());
739  }
740  return std::nullopt;
741  }
742  if (binOpExpr.getKind() == AffineExprKind::Mod) {
743  // lhs mod c is always <= c - 1 and non-negative. In addition, if `lhs` is
744  // bounded such that lb <= lhs <= ub and lb floordiv c == ub floordiv c
745  // (same "interval"), then lb mod c <= lhs mod c <= ub mod c.
746  auto rhsConst = binOpExpr.getRHS().dyn_cast<AffineConstantExpr>();
747  if (rhsConst && rhsConst.getValue() >= 1) {
748  int64_t rhsConstVal = rhsConst.getValue();
749  auto lb = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
750  constLowerBounds, constUpperBounds,
751  /*isUpper=*/false);
752  auto ub = getBoundForExpr(binOpExpr.getLHS(), numDims, numSymbols,
753  constLowerBounds, constUpperBounds, isUpper);
754  if (ub && lb &&
755  floorDiv(*lb, rhsConstVal) == floorDiv(*ub, rhsConstVal))
756  return isUpper ? mod(*ub, rhsConstVal) : mod(*lb, rhsConstVal);
757  return isUpper ? rhsConstVal - 1 : 0;
758  }
759  }
760  }
761  // Flatten the expression.
762  SimpleAffineExprFlattener flattener(numDims, numSymbols);
763  flattener.walkPostOrder(expr);
764  ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
765  // TODO: Handle local variables. We can get hold of flattener.localExprs and
766  // get bound on the local expr recursively.
767  if (flattener.numLocals > 0)
768  return std::nullopt;
769  int64_t bound = 0;
770  // Substitute the constant lower or upper bound for the dimensional or
771  // symbolic input depending on `isUpper` to determine the bound.
772  for (unsigned i = 0, e = numDims + numSymbols; i < e; ++i) {
773  if (flattenedExpr[i] > 0) {
774  auto &constBound = isUpper ? constUpperBounds[i] : constLowerBounds[i];
775  if (!constBound)
776  return std::nullopt;
777  bound += *constBound * flattenedExpr[i];
778  } else if (flattenedExpr[i] < 0) {
779  auto &constBound = isUpper ? constLowerBounds[i] : constUpperBounds[i];
780  if (!constBound)
781  return std::nullopt;
782  bound += *constBound * flattenedExpr[i];
783  }
784  }
785  // Constant term.
786  bound += flattenedExpr.back();
787  return bound;
788 }
789 
790 /// Determine a constant upper bound for `expr` if one exists while exploiting
791 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
792 /// is guaranteed to be less than or equal to it.
793 static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
794  unsigned numSymbols,
795  ArrayRef<Value> operands) {
796  // Get the constant lower or upper bounds on the operands.
797  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
798  constLowerBounds.reserve(operands.size());
799  constUpperBounds.reserve(operands.size());
800  for (Value operand : operands) {
801  constLowerBounds.push_back(getLowerBound(operand));
802  constUpperBounds.push_back(getUpperBound(operand));
803  }
804 
805  if (auto constExpr = expr.dyn_cast<AffineConstantExpr>())
806  return constExpr.getValue();
807 
808  return getBoundForExpr(expr, numDims, numSymbols, constLowerBounds,
809  constUpperBounds,
810  /*isUpper=*/true);
811 }
812 
813 /// Determine a constant lower bound for `expr` if one exists while exploiting
814 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
815 /// is guaranteed to be less than or equal to it.
816 static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
817  unsigned numSymbols,
818  ArrayRef<Value> operands) {
819  // Get the constant lower or upper bounds on the operands.
820  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
821  constLowerBounds.reserve(operands.size());
822  constUpperBounds.reserve(operands.size());
823  for (Value operand : operands) {
824  constLowerBounds.push_back(getLowerBound(operand));
825  constUpperBounds.push_back(getUpperBound(operand));
826  }
827 
828  std::optional<int64_t> lowerBound;
829  if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
830  lowerBound = constExpr.getValue();
831  } else {
832  lowerBound = getBoundForExpr(expr, numDims, numSymbols, constLowerBounds,
833  constUpperBounds,
834  /*isUpper=*/false);
835  }
836  return lowerBound;
837 }
838 
839 /// Simplify `expr` while exploiting information from the values in `operands`.
840 static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
841  unsigned numSymbols,
842  ArrayRef<Value> operands) {
843  // We do this only for certain floordiv/mod expressions.
844  auto binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
845  if (!binExpr)
846  return;
847 
848  // Simplify the child expressions first.
849  AffineExpr lhs = binExpr.getLHS();
850  AffineExpr rhs = binExpr.getRHS();
851  simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
852  simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
853  expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
854 
855  binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
856  if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
857  expr.getKind() != AffineExprKind::CeilDiv &&
858  expr.getKind() != AffineExprKind::Mod)) {
859  return;
860  }
861 
862  // The `lhs` and `rhs` may be different post construction of simplified expr.
863  lhs = binExpr.getLHS();
864  rhs = binExpr.getRHS();
865  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
866  if (!rhsConst)
867  return;
868 
869  int64_t rhsConstVal = rhsConst.getValue();
870  // Undefined exprsessions aren't touched; IR can still be valid with them.
871  if (rhsConstVal <= 0)
872  return;
873 
874  // Exploit constant lower/upper bounds to simplify a floordiv or mod.
875  MLIRContext *context = expr.getContext();
876  std::optional<int64_t> lhsLbConst =
877  getLowerBound(lhs, numDims, numSymbols, operands);
878  std::optional<int64_t> lhsUbConst =
879  getUpperBound(lhs, numDims, numSymbols, operands);
880  if (lhsLbConst && lhsUbConst) {
881  int64_t lhsLbConstVal = *lhsLbConst;
882  int64_t lhsUbConstVal = *lhsUbConst;
883  // lhs floordiv c is a single value lhs is bounded in a range `c` that has
884  // the same quotient.
885  if (binExpr.getKind() == AffineExprKind::FloorDiv &&
886  floorDiv(lhsLbConstVal, rhsConstVal) ==
887  floorDiv(lhsUbConstVal, rhsConstVal)) {
888  expr =
889  getAffineConstantExpr(floorDiv(lhsLbConstVal, rhsConstVal), context);
890  return;
891  }
892  // lhs ceildiv c is a single value if the entire range has the same ceil
893  // quotient.
894  if (binExpr.getKind() == AffineExprKind::CeilDiv &&
895  ceilDiv(lhsLbConstVal, rhsConstVal) ==
896  ceilDiv(lhsUbConstVal, rhsConstVal)) {
897  expr =
898  getAffineConstantExpr(ceilDiv(lhsLbConstVal, rhsConstVal), context);
899  return;
900  }
901  // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
902  if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
903  lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
904  expr = lhs;
905  return;
906  }
907  }
908 
909  // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
910  // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
911  // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
912  // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
913  AffineExpr quotientTimesDiv, rem;
914  int64_t divisor;
915  if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
916  if (rhsConstVal % divisor == 0 &&
917  binExpr.getKind() == AffineExprKind::FloorDiv) {
918  expr = quotientTimesDiv.floorDiv(rhsConst);
919  } else if (divisor % rhsConstVal == 0 &&
920  binExpr.getKind() == AffineExprKind::Mod) {
921  expr = rem % rhsConst;
922  }
923  return;
924  }
925 
926  // Handle the simple case when the LHS expression can be either upper
927  // bounded or is a known multiple of RHS constant.
928  // lhs floordiv c -> 0 if 0 <= lhs < c,
929  // lhs mod c -> 0 if lhs % c = 0.
930  if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) &&
931  binExpr.getKind() == AffineExprKind::FloorDiv) ||
932  (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 &&
933  binExpr.getKind() == AffineExprKind::Mod)) {
934  expr = getAffineConstantExpr(0, expr.getContext());
935  }
936 }
937 
938 /// Simplify the expressions in `map` while making use of lower or upper bounds
939 /// of its operands. If `isMax` is true, the map is to be treated as a max of
940 /// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
941 /// d1) can be simplified to (8) if the operands are respectively lower bounded
942 /// by 2 and 0 (the second expression can't be lower than 8).
944  ArrayRef<Value> operands,
945  bool isMax) {
946  // Can't simplify.
947  if (operands.empty())
948  return;
949 
950  // Get the upper or lower bound on an affine.for op IV using its range.
951  // Get the constant lower or upper bounds on the operands.
952  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
953  constLowerBounds.reserve(operands.size());
954  constUpperBounds.reserve(operands.size());
955  for (Value operand : operands) {
956  constLowerBounds.push_back(getLowerBound(operand));
957  constUpperBounds.push_back(getUpperBound(operand));
958  }
959 
960  // We will compute the lower and upper bounds on each of the expressions
961  // Then, we will check (depending on max or min) as to whether a specific
962  // bound is redundant by checking if its highest (in case of max) and its
963  // lowest (in the case of min) value is already lower than (or higher than)
964  // the lower bound (or upper bound in the case of min) of another bound.
965  SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
966  lowerBounds.reserve(map.getNumResults());
967  upperBounds.reserve(map.getNumResults());
968  for (AffineExpr e : map.getResults()) {
969  if (auto constExpr = e.dyn_cast<AffineConstantExpr>()) {
970  lowerBounds.push_back(constExpr.getValue());
971  upperBounds.push_back(constExpr.getValue());
972  } else {
973  lowerBounds.push_back(getBoundForExpr(e, map.getNumDims(),
974  map.getNumSymbols(),
975  constLowerBounds, constUpperBounds,
976  /*isUpper=*/false));
977  upperBounds.push_back(getBoundForExpr(e, map.getNumDims(),
978  map.getNumSymbols(),
979  constLowerBounds, constUpperBounds,
980  /*isUpper=*/true));
981  }
982  }
983 
984  // Collect expressions that are not redundant.
985  SmallVector<AffineExpr, 4> irredundantExprs;
986  for (auto exprEn : llvm::enumerate(map.getResults())) {
987  AffineExpr e = exprEn.value();
988  unsigned i = exprEn.index();
989  // Some expressions can be turned into constants.
990  if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
991  e = getAffineConstantExpr(*lowerBounds[i], e.getContext());
992 
993  // Check if the expression is redundant.
994  if (isMax) {
995  if (!upperBounds[i]) {
996  irredundantExprs.push_back(e);
997  continue;
998  }
999  // If there exists another expression such that its lower bound is greater
1000  // than this expression's upper bound, it's redundant.
1001  if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
1002  auto otherLowerBound = en.value();
1003  unsigned pos = en.index();
1004  if (pos == i || !otherLowerBound)
1005  return false;
1006  if (*otherLowerBound > *upperBounds[i])
1007  return true;
1008  if (*otherLowerBound < *upperBounds[i])
1009  return false;
1010  // Equality case. When both expressions are considered redundant, we
1011  // don't want to get both of them. We keep the one that appears
1012  // first.
1013  if (upperBounds[pos] && lowerBounds[i] &&
1014  lowerBounds[i] == upperBounds[i] &&
1015  otherLowerBound == *upperBounds[pos] && i < pos)
1016  return false;
1017  return true;
1018  }))
1019  irredundantExprs.push_back(e);
1020  } else {
1021  if (!lowerBounds[i]) {
1022  irredundantExprs.push_back(e);
1023  continue;
1024  }
1025  // Likewise for the `min` case. Use the complement of the condition above.
1026  if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
1027  auto otherUpperBound = en.value();
1028  unsigned pos = en.index();
1029  if (pos == i || !otherUpperBound)
1030  return false;
1031  if (*otherUpperBound < *lowerBounds[i])
1032  return true;
1033  if (*otherUpperBound > *lowerBounds[i])
1034  return false;
1035  if (lowerBounds[pos] && upperBounds[i] &&
1036  lowerBounds[i] == upperBounds[i] &&
1037  otherUpperBound == lowerBounds[pos] && i < pos)
1038  return false;
1039  return true;
1040  }))
1041  irredundantExprs.push_back(e);
1042  }
1043  }
1044 
1045  // Create the map without the redundant expressions.
1046  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs,
1047  map.getContext());
1048 }
1049 
1050 /// Simplify the map while exploiting information on the values in `operands`.
1051 // Use "unused attribute" marker to silence warning stemming from the inability
1052 // to see through the template expansion.
1053 static void LLVM_ATTRIBUTE_UNUSED
1055  assert(map.getNumInputs() == operands.size() && "invalid operands for map");
1056  SmallVector<AffineExpr> newResults;
1057  newResults.reserve(map.getNumResults());
1058  for (AffineExpr expr : map.getResults()) {
1060  operands);
1061  newResults.push_back(expr);
1062  }
1063  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
1064  map.getContext());
1065 }
1066 
1067 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
1068 /// defining AffineApplyOp expression and operands.
1069 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
1070 /// When `dimOrSymbolPosition >= dims.size()`,
1071 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
1072 /// Mutate `map`,`dims` and `syms` in place as follows:
1073 /// 1. `dims` and `syms` are only appended to.
1074 /// 2. `map` dim and symbols are gradually shifted to higher positions.
1075 /// 3. Old `dim` and `sym` entries are replaced by nullptr
1076 /// This avoids the need for any bookkeeping.
1078  unsigned dimOrSymbolPosition,
1079  SmallVectorImpl<Value> &dims,
1080  SmallVectorImpl<Value> &syms) {
1081  MLIRContext *ctx = map->getContext();
1082  bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1083  unsigned pos = isDimReplacement ? dimOrSymbolPosition
1084  : dimOrSymbolPosition - dims.size();
1085  Value &v = isDimReplacement ? dims[pos] : syms[pos];
1086  if (!v)
1087  return failure();
1088 
1089  auto affineApply = v.getDefiningOp<AffineApplyOp>();
1090  if (!affineApply)
1091  return failure();
1092 
1093  // At this point we will perform a replacement of `v`, set the entry in `dim`
1094  // or `sym` to nullptr immediately.
1095  v = nullptr;
1096 
1097  // Compute the map, dims and symbols coming from the AffineApplyOp.
1098  AffineMap composeMap = affineApply.getAffineMap();
1099  assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1100  SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1101  affineApply.getMapOperands().end());
1102  // Canonicalize the map to promote dims to symbols when possible. This is to
1103  // avoid generating invalid maps.
1104  canonicalizeMapAndOperands(&composeMap, &composeOperands);
1105  AffineExpr replacementExpr =
1106  composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
1107  ValueRange composeDims =
1108  ArrayRef<Value>(composeOperands).take_front(composeMap.getNumDims());
1109  ValueRange composeSyms =
1110  ArrayRef<Value>(composeOperands).take_back(composeMap.getNumSymbols());
1111  AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
1112  : getAffineSymbolExpr(pos, ctx);
1113 
1114  // Append the dims and symbols where relevant and perform the replacement.
1115  dims.append(composeDims.begin(), composeDims.end());
1116  syms.append(composeSyms.begin(), composeSyms.end());
1117  *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());
1118 
1119  return success();
1120 }
1121 
1122 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1123 /// iteratively. Perform canonicalization of map and operands as well as
1124 /// AffineMap simplification. `map` and `operands` are mutated in place.
1126  SmallVectorImpl<Value> *operands) {
1127  if (map->getNumResults() == 0) {
1128  canonicalizeMapAndOperands(map, operands);
1129  *map = simplifyAffineMap(*map);
1130  return;
1131  }
1132 
1133  MLIRContext *ctx = map->getContext();
1134  SmallVector<Value, 4> dims(operands->begin(),
1135  operands->begin() + map->getNumDims());
1136  SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1137  operands->end());
1138 
1139  // Iterate over dims and symbols coming from AffineApplyOp and replace until
1140  // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1141  // and `syms` can only increase by construction.
1142  // The implementation uses a `while` loop to support the case of symbols
1143  // that may be constructed from dims ;this may be overkill.
1144  while (true) {
1145  bool changed = false;
1146  for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1147  if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
1148  break;
1149  if (!changed)
1150  break;
1151  }
1152 
1153  // Clear operands so we can fill them anew.
1154  operands->clear();
1155 
1156  // At this point we may have introduced null operands, prune them out before
1157  // canonicalizing map and operands.
1158  unsigned nDims = 0, nSyms = 0;
1159  SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1160  dimReplacements.reserve(dims.size());
1161  symReplacements.reserve(syms.size());
1162  for (auto *container : {&dims, &syms}) {
1163  bool isDim = (container == &dims);
1164  auto &repls = isDim ? dimReplacements : symReplacements;
1165  for (const auto &en : llvm::enumerate(*container)) {
1166  Value v = en.value();
1167  if (!v) {
1168  assert(isDim ? !map->isFunctionOfDim(en.index())
1169  : !map->isFunctionOfSymbol(en.index()) &&
1170  "map is function of unexpected expr@pos");
1171  repls.push_back(getAffineConstantExpr(0, ctx));
1172  continue;
1173  }
1174  repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
1175  : getAffineSymbolExpr(nSyms++, ctx));
1176  operands->push_back(v);
1177  }
1178  }
1179  *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
1180  nSyms);
1181 
1182  // Canonicalize and simplify before returning.
1183  canonicalizeMapAndOperands(map, operands);
1184  *map = simplifyAffineMap(*map);
1185 }
1186 
1188  AffineMap *map, SmallVectorImpl<Value> *operands) {
1189  while (llvm::any_of(*operands, [](Value v) {
1190  return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
1191  })) {
1192  composeAffineMapAndOperands(map, operands);
1193  }
1194 }
1195 
1196 /// Given a list of `OpFoldResult`, build the necessary operations to populate
1197 /// `actualValues` with values produced by operations. In particular, for any
1198 /// attribute-typed element in `values`, call the constant materializer
1199 /// associated with the Affine dialect to produce an operation. Do NOT notify
1200 /// the builder listener about the constant ops being created as they are
1201 /// intended to be removed after being folded into affine constructs; this is
1202 /// not suitable for use beyond the Affine dialect.
1204  ArrayRef<OpFoldResult> values,
1205  SmallVectorImpl<Operation *> &constants,
1206  SmallVectorImpl<Value> &actualValues) {
1207  OpBuilder::Listener *listener = b.getListener();
1208  b.setListener(nullptr);
1209  auto listenerResetter =
1210  llvm::make_scope_exit([listener, &b] { b.setListener(listener); });
1211 
1212  actualValues.reserve(values.size());
1213  auto *dialect = b.getContext()->getLoadedDialect<AffineDialect>();
1214  for (OpFoldResult ofr : values) {
1215  if (auto value = llvm::dyn_cast_if_present<Value>(ofr)) {
1216  actualValues.push_back(value);
1217  continue;
1218  }
1219  // Since we are directly specifying `index` as the result type, we need to
1220  // ensure the provided attribute is also an index type. Otherwise, the
1221  // AffineDialect materializer will create invalid `arith.constant`
1222  // operations if the provided Attribute is any other kind of integer.
1223  constants.push_back(dialect->materializeConstant(
1224  b,
1225  b.getIndexAttr(llvm::cast<IntegerAttr>(ofr.get<Attribute>()).getInt()),
1226  b.getIndexType(), loc));
1227  actualValues.push_back(constants.back()->getResult(0));
1228  }
1229 }
1230 
1231 /// Create an operation of the type provided as template argument and attempt to
1232 /// fold it immediately. The operation is expected to have a builder taking
1233 /// arbitrary `leadingArguments`, followed by a list of Value-typed `operands`.
1234 /// The operation is also expected to always produce a single result. Return an
1235 /// `OpFoldResult` containing the Attribute representing the folded constant if
1236 /// complete folding was possible and a Value produced by the created operation
1237 /// otherwise.
1238 template <typename OpTy, typename... Args>
1239 static std::enable_if_t<OpTy::template hasTrait<OpTrait::OneResult>(),
1240  OpFoldResult>
1242  Args &&...leadingArguments) {
1243  // Identify the constant operands and extract their values as attributes.
1244  // Note that we cannot use the original values directly because the list of
1245  // operands may have changed due to canonicalization and composition.
1246  SmallVector<Attribute> constantOperands;
1247  constantOperands.reserve(operands.size());
1248  for (Value operand : operands) {
1249  IntegerAttr attr;
1250  if (matchPattern(operand, m_Constant(&attr)))
1251  constantOperands.push_back(attr);
1252  else
1253  constantOperands.push_back(nullptr);
1254  }
1255 
1256  // Create the operation and immediately attempt to fold it. On success,
1257  // delete the operation and prepare the (unmaterialized) value for being
1258  // returned. On failure, return the operation result value. Temporarily remove
1259  // the listener to avoid notifying it when the op is created as it may be
1260  // removed immediately and there is no way of notifying the caller about that
1261  // without resorting to RewriterBase.
1262  //
1263  // TODO: arguably, the main folder (createOrFold) API should support this use
1264  // case instead of indiscriminately materializing constants.
1265  OpBuilder::Listener *listener = b.getListener();
1266  b.setListener(nullptr);
1267  auto listenerResetter =
1268  llvm::make_scope_exit([listener, &b] { b.setListener(listener); });
1269  OpTy op =
1270  b.create<OpTy>(loc, std::forward<Args>(leadingArguments)..., operands);
1271  SmallVector<OpFoldResult, 1> foldResults;
1272  if (succeeded(op->fold(constantOperands, foldResults)) &&
1273  !foldResults.empty()) {
1274  op->erase();
1275  return foldResults.front();
1276  }
1277 
1278  // Notify the listener now that we definitely know that the operation will
1279  // persist. Use the original listener stored in the variable.
1280  if (listener)
1281  listener->notifyOperationInserted(op);
1282  return op->getResult(0);
1283 }
1284 
1286  AffineMap map,
1287  ValueRange operands) {
1288  AffineMap normalizedMap = map;
1289  SmallVector<Value, 8> normalizedOperands(operands.begin(), operands.end());
1290  composeAffineMapAndOperands(&normalizedMap, &normalizedOperands);
1291  assert(normalizedMap);
1292  return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands);
1293 }
1294 
1296  AffineExpr e,
1297  ValueRange values) {
1298  return makeComposedAffineApply(
1300  values);
1301 }
1302 
1303 /// Composes the given affine map with the given list of operands, pulling in
1304 /// the maps from any affine.apply operations that supply the operands.
1306  SmallVectorImpl<Value> &operands) {
1307  // Compose and canonicalize each expression in the map individually because
1308  // composition only applies to single-result maps, collecting potentially
1309  // duplicate operands in a single list with shifted dimensions and symbols.
1310  SmallVector<Value> dims, symbols;
1312  for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
1313  SmallVector<Value> submapOperands(operands.begin(), operands.end());
1314  AffineMap submap = map.getSubMap({i});
1315  fullyComposeAffineMapAndOperands(&submap, &submapOperands);
1316  canonicalizeMapAndOperands(&submap, &submapOperands);
1317  unsigned numNewDims = submap.getNumDims();
1318  submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
1319  llvm::append_range(dims,
1320  ArrayRef<Value>(submapOperands).take_front(numNewDims));
1321  llvm::append_range(symbols,
1322  ArrayRef<Value>(submapOperands).drop_front(numNewDims));
1323  exprs.push_back(submap.getResult(0));
1324  }
1325 
1326  // Canonicalize the map created from composed expressions to deduplicate the
1327  // dimension and symbol operands.
1328  operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1329  map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
1330  canonicalizeMapAndOperands(&map, &operands);
1331 }
1332 
1335  AffineMap map,
1336  ArrayRef<OpFoldResult> operands) {
1337  assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1338 
1339  SmallVector<Operation *> constants;
1340  SmallVector<Value> actualValues;
1341  materializeConstants(b, loc, operands, constants, actualValues);
1342  composeAffineMapAndOperands(&map, &actualValues);
1343  OpFoldResult result = createOrFold<AffineApplyOp>(b, loc, actualValues, map);
1344 
1345  // Constants are always folded into affine min/max because they can be
1346  // represented as constant expressions, so delete them.
1347  for (Operation *op : constants)
1348  op->erase();
1349  return result;
1350 }
1351 
1354  AffineExpr expr,
1355  ArrayRef<OpFoldResult> operands) {
1357  b, loc, AffineMap::inferFromExprList(ArrayRef<AffineExpr>{expr}).front(),
1358  operands);
1359 }
1360 
1363  OpBuilder &b, Location loc, AffineMap map,
1364  ArrayRef<OpFoldResult> operands) {
1365  return llvm::map_to_vector(llvm::seq<unsigned>(0, map.getNumResults()),
1366  [&](unsigned i) {
1367  return makeComposedFoldedAffineApply(
1368  b, loc, map.getSubMap({i}), operands);
1369  });
1370 }
1371 
1373  AffineMap map, ValueRange operands) {
1374  SmallVector<Value> allOperands = llvm::to_vector(operands);
1375  composeMultiResultAffineMap(map, allOperands);
1376  return b.createOrFold<AffineMinOp>(loc, b.getIndexType(), map, allOperands);
1377 }
1378 
1379 template <typename OpTy>
1381  AffineMap map,
1382  ArrayRef<OpFoldResult> operands) {
1383  SmallVector<Operation *> constants;
1384  SmallVector<Value> actualValues;
1385  materializeConstants(b, loc, operands, constants, actualValues);
1386  composeMultiResultAffineMap(map, actualValues);
1387  OpFoldResult result =
1388  createOrFold<OpTy>(b, loc, actualValues, b.getIndexType(), map);
1389 
1390  // Constants are always folded into affine min/max because they can be
1391  // represented as constant expressions, so delete them.
1392  for (Operation *op : constants)
1393  op->erase();
1394  return result;
1395 }
1396 
1399  AffineMap map,
1400  ArrayRef<OpFoldResult> operands) {
1401  return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1402 }
1403 
1406  AffineMap map,
1407  ArrayRef<OpFoldResult> operands) {
1408  return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1409 }
1410 
1411 /// Fully compose map with operands and canonicalize the result.
1412 /// Return the `createOrFold`'ed AffineApply op.
1414  AffineMap map,
1415  ValueRange operandsRef) {
1416  SmallVector<Value, 4> operands(operandsRef.begin(), operandsRef.end());
1417  fullyComposeAffineMapAndOperands(&map, &operands);
1418  canonicalizeMapAndOperands(&map, &operands);
1419  return b.createOrFold<AffineApplyOp>(loc, map, operands);
1420 }
1421 
1423  AffineMap map,
1424  ValueRange values) {
1426  res.reserve(map.getNumResults());
1427  unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols();
1428  // For each `expr` in `map`, applies the `expr` to the values extracted from
1429  // ranges. If the resulting application can be folded into a Value, the
1430  // folding occurs eagerly.
1431  for (auto expr : map.getResults()) {
1432  AffineMap map = AffineMap::get(numDims, numSym, expr);
1433  res.push_back(createFoldedComposedAffineApply(b, loc, map, values));
1434  }
1435  return res;
1436 }
1437 
1438 // A symbol may appear as a dim in affine.apply operations. This function
1439 // canonicalizes dims that are valid symbols into actual symbols.
1440 template <class MapOrSet>
1441 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1442  SmallVectorImpl<Value> *operands) {
1443  if (!mapOrSet || operands->empty())
1444  return;
1445 
1446  assert(mapOrSet->getNumInputs() == operands->size() &&
1447  "map/set inputs must match number of operands");
1448 
1449  auto *context = mapOrSet->getContext();
1450  SmallVector<Value, 8> resultOperands;
1451  resultOperands.reserve(operands->size());
1452  SmallVector<Value, 8> remappedSymbols;
1453  remappedSymbols.reserve(operands->size());
1454  unsigned nextDim = 0;
1455  unsigned nextSym = 0;
1456  unsigned oldNumSyms = mapOrSet->getNumSymbols();
1457  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1458  for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1459  if (i < mapOrSet->getNumDims()) {
1460  if (isValidSymbol((*operands)[i])) {
1461  // This is a valid symbol that appears as a dim, canonicalize it.
1462  dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1463  remappedSymbols.push_back((*operands)[i]);
1464  } else {
1465  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1466  resultOperands.push_back((*operands)[i]);
1467  }
1468  } else {
1469  resultOperands.push_back((*operands)[i]);
1470  }
1471  }
1472 
1473  resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1474  *operands = resultOperands;
1475  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1476  oldNumSyms + nextSym);
1477 
1478  assert(mapOrSet->getNumInputs() == operands->size() &&
1479  "map/set inputs must match number of operands");
1480 }
1481 
1482 // Works for either an affine map or an integer set.
1483 template <class MapOrSet>
1484 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1485  SmallVectorImpl<Value> *operands) {
1486  static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1487  "Argument must be either of AffineMap or IntegerSet type");
1488 
1489  if (!mapOrSet || operands->empty())
1490  return;
1491 
1492  assert(mapOrSet->getNumInputs() == operands->size() &&
1493  "map/set inputs must match number of operands");
1494 
1495  canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1496 
1497  // Check to see what dims are used.
1498  llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1499  llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1500  mapOrSet->walkExprs([&](AffineExpr expr) {
1501  if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
1502  usedDims[dimExpr.getPosition()] = true;
1503  else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
1504  usedSyms[symExpr.getPosition()] = true;
1505  });
1506 
1507  auto *context = mapOrSet->getContext();
1508 
1509  SmallVector<Value, 8> resultOperands;
1510  resultOperands.reserve(operands->size());
1511 
1512  llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1513  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1514  unsigned nextDim = 0;
1515  for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1516  if (usedDims[i]) {
1517  // Remap dim positions for duplicate operands.
1518  auto it = seenDims.find((*operands)[i]);
1519  if (it == seenDims.end()) {
1520  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1521  resultOperands.push_back((*operands)[i]);
1522  seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1523  } else {
1524  dimRemapping[i] = it->second;
1525  }
1526  }
1527  }
1528  llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1529  SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1530  unsigned nextSym = 0;
1531  for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1532  if (!usedSyms[i])
1533  continue;
1534  // Handle constant operands (only needed for symbolic operands since
1535  // constant operands in dimensional positions would have already been
1536  // promoted to symbolic positions above).
1537  IntegerAttr operandCst;
1538  if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1539  m_Constant(&operandCst))) {
1540  symRemapping[i] =
1541  getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1542  continue;
1543  }
1544  // Remap symbol positions for duplicate operands.
1545  auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1546  if (it == seenSymbols.end()) {
1547  symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1548  resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1549  seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1550  symRemapping[i]));
1551  } else {
1552  symRemapping[i] = it->second;
1553  }
1554  }
1555  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1556  nextDim, nextSym);
1557  *operands = resultOperands;
1558 }
1559 
1561  AffineMap *map, SmallVectorImpl<Value> *operands) {
1562  canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1563 }
1564 
1566  IntegerSet *set, SmallVectorImpl<Value> *operands) {
1567  canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1568 }
1569 
1570 namespace {
1571 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1572 /// maps that supply results into them.
1573 ///
1574 template <typename AffineOpTy>
1575 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1577 
1578  /// Replace the affine op with another instance of it with the supplied
1579  /// map and mapOperands.
1580  void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1581  AffineMap map, ArrayRef<Value> mapOperands) const;
1582 
1583  LogicalResult matchAndRewrite(AffineOpTy affineOp,
1584  PatternRewriter &rewriter) const override {
1585  static_assert(
1586  llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1587  AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1588  AffineVectorStoreOp, AffineVectorLoadOp>::value,
1589  "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1590  "expected");
1591  auto map = affineOp.getAffineMap();
1592  AffineMap oldMap = map;
1593  auto oldOperands = affineOp.getMapOperands();
1594  SmallVector<Value, 8> resultOperands(oldOperands);
1595  composeAffineMapAndOperands(&map, &resultOperands);
1596  canonicalizeMapAndOperands(&map, &resultOperands);
1597  simplifyMapWithOperands(map, resultOperands);
1598  if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1599  resultOperands.begin()))
1600  return failure();
1601 
1602  replaceAffineOp(rewriter, affineOp, map, resultOperands);
1603  return success();
1604  }
1605 };
1606 
1607 // Specialize the template to account for the different build signatures for
1608 // affine load, store, and apply ops.
1609 template <>
1610 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1611  PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1612  ArrayRef<Value> mapOperands) const {
1613  rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1614  mapOperands);
1615 }
1616 template <>
1617 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1618  PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1619  ArrayRef<Value> mapOperands) const {
1620  rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1621  prefetch, prefetch.getMemref(), map, mapOperands,
1622  prefetch.getLocalityHint(), prefetch.getIsWrite(),
1623  prefetch.getIsDataCache());
1624 }
1625 template <>
1626 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1627  PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1628  ArrayRef<Value> mapOperands) const {
1629  rewriter.replaceOpWithNewOp<AffineStoreOp>(
1630  store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1631 }
1632 template <>
1633 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1634  PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1635  ArrayRef<Value> mapOperands) const {
1636  rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1637  vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1638  mapOperands);
1639 }
1640 template <>
1641 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1642  PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1643  ArrayRef<Value> mapOperands) const {
1644  rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1645  vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1646  mapOperands);
1647 }
1648 
1649 // Generic version for ops that don't have extra operands.
1650 template <typename AffineOpTy>
1651 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1652  PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1653  ArrayRef<Value> mapOperands) const {
1654  rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1655 }
1656 } // namespace
1657 
1658 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1659  MLIRContext *context) {
1660  results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1661 }
1662 
1663 //===----------------------------------------------------------------------===//
1664 // AffineDmaStartOp
1665 //===----------------------------------------------------------------------===//
1666 
1667 // TODO: Check that map operands are loop IVs or symbols.
1668 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1669  Value srcMemRef, AffineMap srcMap,
1670  ValueRange srcIndices, Value destMemRef,
1671  AffineMap dstMap, ValueRange destIndices,
1672  Value tagMemRef, AffineMap tagMap,
1673  ValueRange tagIndices, Value numElements,
1674  Value stride, Value elementsPerStride) {
1675  result.addOperands(srcMemRef);
1676  result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1677  result.addOperands(srcIndices);
1678  result.addOperands(destMemRef);
1679  result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1680  result.addOperands(destIndices);
1681  result.addOperands(tagMemRef);
1682  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1683  result.addOperands(tagIndices);
1684  result.addOperands(numElements);
1685  if (stride) {
1686  result.addOperands({stride, elementsPerStride});
1687  }
1688 }
1689 
1691  p << " " << getSrcMemRef() << '[';
1692  p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1693  p << "], " << getDstMemRef() << '[';
1694  p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1695  p << "], " << getTagMemRef() << '[';
1696  p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1697  p << "], " << getNumElements();
1698  if (isStrided()) {
1699  p << ", " << getStride();
1700  p << ", " << getNumElementsPerStride();
1701  }
1702  p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1703  << getTagMemRefType();
1704 }
1705 
1706 // Parse AffineDmaStartOp.
1707 // Ex:
1708 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1709 // %stride, %num_elt_per_stride
1710 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1711 //
1712 ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
1713  OperationState &result) {
1714  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1715  AffineMapAttr srcMapAttr;
1717  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1718  AffineMapAttr dstMapAttr;
1720  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1721  AffineMapAttr tagMapAttr;
1723  OpAsmParser::UnresolvedOperand numElementsInfo;
1725 
1726  SmallVector<Type, 3> types;
1727  auto indexType = parser.getBuilder().getIndexType();
1728 
1729  // Parse and resolve the following list of operands:
1730  // *) dst memref followed by its affine maps operands (in square brackets).
1731  // *) src memref followed by its affine map operands (in square brackets).
1732  // *) tag memref followed by its affine map operands (in square brackets).
1733  // *) number of elements transferred by DMA operation.
1734  if (parser.parseOperand(srcMemRefInfo) ||
1735  parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1736  getSrcMapAttrStrName(),
1737  result.attributes) ||
1738  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1739  parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1740  getDstMapAttrStrName(),
1741  result.attributes) ||
1742  parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1743  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1744  getTagMapAttrStrName(),
1745  result.attributes) ||
1746  parser.parseComma() || parser.parseOperand(numElementsInfo))
1747  return failure();
1748 
1749  // Parse optional stride and elements per stride.
1750  if (parser.parseTrailingOperandList(strideInfo))
1751  return failure();
1752 
1753  if (!strideInfo.empty() && strideInfo.size() != 2) {
1754  return parser.emitError(parser.getNameLoc(),
1755  "expected two stride related operands");
1756  }
1757  bool isStrided = strideInfo.size() == 2;
1758 
1759  if (parser.parseColonTypeList(types))
1760  return failure();
1761 
1762  if (types.size() != 3)
1763  return parser.emitError(parser.getNameLoc(), "expected three types");
1764 
1765  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1766  parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1767  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1768  parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1769  parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1770  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1771  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1772  return failure();
1773 
1774  if (isStrided) {
1775  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1776  return failure();
1777  }
1778 
1779  // Check that src/dst/tag operand counts match their map.numInputs.
1780  if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1781  dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1782  tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1783  return parser.emitError(parser.getNameLoc(),
1784  "memref operand count not equal to map.numInputs");
1785  return success();
1786 }
1787 
1788 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1789  if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
1790  return emitOpError("expected DMA source to be of memref type");
1791  if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
1792  return emitOpError("expected DMA destination to be of memref type");
1793  if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
1794  return emitOpError("expected DMA tag to be of memref type");
1795 
1796  unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1797  getDstMap().getNumInputs() +
1798  getTagMap().getNumInputs();
1799  if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1800  getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1801  return emitOpError("incorrect number of operands");
1802  }
1803 
1804  Region *scope = getAffineScope(*this);
1805  for (auto idx : getSrcIndices()) {
1806  if (!idx.getType().isIndex())
1807  return emitOpError("src index to dma_start must have 'index' type");
1808  if (!isValidAffineIndexOperand(idx, scope))
1809  return emitOpError("src index must be a dimension or symbol identifier");
1810  }
1811  for (auto idx : getDstIndices()) {
1812  if (!idx.getType().isIndex())
1813  return emitOpError("dst index to dma_start must have 'index' type");
1814  if (!isValidAffineIndexOperand(idx, scope))
1815  return emitOpError("dst index must be a dimension or symbol identifier");
1816  }
1817  for (auto idx : getTagIndices()) {
1818  if (!idx.getType().isIndex())
1819  return emitOpError("tag index to dma_start must have 'index' type");
1820  if (!isValidAffineIndexOperand(idx, scope))
1821  return emitOpError("tag index must be a dimension or symbol identifier");
1822  }
1823  return success();
1824 }
1825 
1826 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1827  SmallVectorImpl<OpFoldResult> &results) {
1828  /// dma_start(memrefcast) -> dma_start
1829  return memref::foldMemRefCast(*this);
1830 }
1831 
1832 void AffineDmaStartOp::getEffects(
1834  &effects) {
1835  effects.emplace_back(MemoryEffects::Read::get(), getSrcMemRef(),
1837  effects.emplace_back(MemoryEffects::Write::get(), getDstMemRef(),
1839  effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
1841 }
1842 
1843 //===----------------------------------------------------------------------===//
1844 // AffineDmaWaitOp
1845 //===----------------------------------------------------------------------===//
1846 
1847 // TODO: Check that map operands are loop IVs or symbols.
1848 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1849  Value tagMemRef, AffineMap tagMap,
1850  ValueRange tagIndices, Value numElements) {
1851  result.addOperands(tagMemRef);
1852  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1853  result.addOperands(tagIndices);
1854  result.addOperands(numElements);
1855 }
1856 
1858  p << " " << getTagMemRef() << '[';
1859  SmallVector<Value, 2> operands(getTagIndices());
1860  p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1861  p << "], ";
1863  p << " : " << getTagMemRef().getType();
1864 }
1865 
1866 // Parse AffineDmaWaitOp.
1867 // Eg:
1868 // affine.dma_wait %tag[%index], %num_elements
1869 // : memref<1 x i32, (d0) -> (d0), 4>
1870 //
1871 ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
1872  OperationState &result) {
1873  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1874  AffineMapAttr tagMapAttr;
1876  Type type;
1877  auto indexType = parser.getBuilder().getIndexType();
1878  OpAsmParser::UnresolvedOperand numElementsInfo;
1879 
1880  // Parse tag memref, its map operands, and dma size.
1881  if (parser.parseOperand(tagMemRefInfo) ||
1882  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1883  getTagMapAttrStrName(),
1884  result.attributes) ||
1885  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1886  parser.parseColonType(type) ||
1887  parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1888  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1889  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1890  return failure();
1891 
1892  if (!llvm::isa<MemRefType>(type))
1893  return parser.emitError(parser.getNameLoc(),
1894  "expected tag to be of memref type");
1895 
1896  if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1897  return parser.emitError(parser.getNameLoc(),
1898  "tag memref operand count != to map.numInputs");
1899  return success();
1900 }
1901 
1902 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1903  if (!llvm::isa<MemRefType>(getOperand(0).getType()))
1904  return emitOpError("expected DMA tag to be of memref type");
1905  Region *scope = getAffineScope(*this);
1906  for (auto idx : getTagIndices()) {
1907  if (!idx.getType().isIndex())
1908  return emitOpError("index to dma_wait must have 'index' type");
1909  if (!isValidAffineIndexOperand(idx, scope))
1910  return emitOpError("index must be a dimension or symbol identifier");
1911  }
1912  return success();
1913 }
1914 
1915 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1916  SmallVectorImpl<OpFoldResult> &results) {
1917  /// dma_wait(memrefcast) -> dma_wait
1918  return memref::foldMemRefCast(*this);
1919 }
1920 
1921 void AffineDmaWaitOp::getEffects(
1923  &effects) {
1924  effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
1926 }
1927 
1928 //===----------------------------------------------------------------------===//
1929 // AffineForOp
1930 //===----------------------------------------------------------------------===//
1931 
1932 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1933 /// bodyBuilder are empty/null, we include default terminator op.
1934 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1935  ValueRange lbOperands, AffineMap lbMap,
1936  ValueRange ubOperands, AffineMap ubMap, int64_t step,
1937  ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1938  assert(((!lbMap && lbOperands.empty()) ||
1939  lbOperands.size() == lbMap.getNumInputs()) &&
1940  "lower bound operand count does not match the affine map");
1941  assert(((!ubMap && ubOperands.empty()) ||
1942  ubOperands.size() == ubMap.getNumInputs()) &&
1943  "upper bound operand count does not match the affine map");
1944  assert(step > 0 && "step has to be a positive integer constant");
1945 
1946  for (Value val : iterArgs)
1947  result.addTypes(val.getType());
1948 
1949  // Add an attribute for the step.
1950  result.addAttribute(getStepAttrStrName(),
1951  builder.getIntegerAttr(builder.getIndexType(), step));
1952 
1953  // Add the lower bound.
1954  result.addAttribute(getLowerBoundAttrStrName(), AffineMapAttr::get(lbMap));
1955  result.addOperands(lbOperands);
1956 
1957  // Add the upper bound.
1958  result.addAttribute(getUpperBoundAttrStrName(), AffineMapAttr::get(ubMap));
1959  result.addOperands(ubOperands);
1960 
1961  result.addOperands(iterArgs);
1962  // Create a region and a block for the body. The argument of the region is
1963  // the loop induction variable.
1964  Region *bodyRegion = result.addRegion();
1965  bodyRegion->push_back(new Block);
1966  Block &bodyBlock = bodyRegion->front();
1967  Value inductionVar =
1968  bodyBlock.addArgument(builder.getIndexType(), result.location);
1969  for (Value val : iterArgs)
1970  bodyBlock.addArgument(val.getType(), val.getLoc());
1971 
1972  // Create the default terminator if the builder is not provided and if the
1973  // iteration arguments are not provided. Otherwise, leave this to the caller
1974  // because we don't know which values to return from the loop.
1975  if (iterArgs.empty() && !bodyBuilder) {
1976  ensureTerminator(*bodyRegion, builder, result.location);
1977  } else if (bodyBuilder) {
1978  OpBuilder::InsertionGuard guard(builder);
1979  builder.setInsertionPointToStart(&bodyBlock);
1980  bodyBuilder(builder, result.location, inductionVar,
1981  bodyBlock.getArguments().drop_front());
1982  }
1983 }
1984 
1985 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1986  int64_t ub, int64_t step, ValueRange iterArgs,
1987  BodyBuilderFn bodyBuilder) {
1988  auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1989  auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1990  return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1991  bodyBuilder);
1992 }
1993 
1994 LogicalResult AffineForOp::verifyRegions() {
1995  // Check that the body defines as single block argument for the induction
1996  // variable.
1997  auto *body = getBody();
1998  if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1999  return emitOpError("expected body to have a single index argument for the "
2000  "induction variable");
2001 
2002  // Verify that the bound operands are valid dimension/symbols.
2003  /// Lower bound.
2004  if (getLowerBoundMap().getNumInputs() > 0)
2006  getLowerBoundMap().getNumDims())))
2007  return failure();
2008  /// Upper bound.
2009  if (getUpperBoundMap().getNumInputs() > 0)
2011  getUpperBoundMap().getNumDims())))
2012  return failure();
2013 
2014  unsigned opNumResults = getNumResults();
2015  if (opNumResults == 0)
2016  return success();
2017 
2018  // If ForOp defines values, check that the number and types of the defined
2019  // values match ForOp initial iter operands and backedge basic block
2020  // arguments.
2021  if (getNumIterOperands() != opNumResults)
2022  return emitOpError(
2023  "mismatch between the number of loop-carried values and results");
2024  if (getNumRegionIterArgs() != opNumResults)
2025  return emitOpError(
2026  "mismatch between the number of basic block args and results");
2027 
2028  return success();
2029 }
2030 
2031 /// Parse a for operation loop bounds.
2032 static ParseResult parseBound(bool isLower, OperationState &result,
2033  OpAsmParser &p) {
2034  // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
2035  // the map has multiple results.
2036  bool failedToParsedMinMax =
2037  failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
2038 
2039  auto &builder = p.getBuilder();
2040  auto boundAttrStrName = isLower ? AffineForOp::getLowerBoundAttrStrName()
2041  : AffineForOp::getUpperBoundAttrStrName();
2042 
2043  // Parse ssa-id as identity map.
2045  if (p.parseOperandList(boundOpInfos))
2046  return failure();
2047 
2048  if (!boundOpInfos.empty()) {
2049  // Check that only one operand was parsed.
2050  if (boundOpInfos.size() > 1)
2051  return p.emitError(p.getNameLoc(),
2052  "expected only one loop bound operand");
2053 
2054  // TODO: improve error message when SSA value is not of index type.
2055  // Currently it is 'use of value ... expects different type than prior uses'
2056  if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
2057  result.operands))
2058  return failure();
2059 
2060  // Create an identity map using symbol id. This representation is optimized
2061  // for storage. Analysis passes may expand it into a multi-dimensional map
2062  // if desired.
2063  AffineMap map = builder.getSymbolIdentityMap();
2064  result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
2065  return success();
2066  }
2067 
2068  // Get the attribute location.
2069  SMLoc attrLoc = p.getCurrentLocation();
2070 
2071  Attribute boundAttr;
2072  if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
2073  result.attributes))
2074  return failure();
2075 
2076  // Parse full form - affine map followed by dim and symbol list.
2077  if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
2078  unsigned currentNumOperands = result.operands.size();
2079  unsigned numDims;
2080  if (parseDimAndSymbolList(p, result.operands, numDims))
2081  return failure();
2082 
2083  auto map = affineMapAttr.getValue();
2084  if (map.getNumDims() != numDims)
2085  return p.emitError(
2086  p.getNameLoc(),
2087  "dim operand count and affine map dim count must match");
2088 
2089  unsigned numDimAndSymbolOperands =
2090  result.operands.size() - currentNumOperands;
2091  if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
2092  return p.emitError(
2093  p.getNameLoc(),
2094  "symbol operand count and affine map symbol count must match");
2095 
2096  // If the map has multiple results, make sure that we parsed the min/max
2097  // prefix.
2098  if (map.getNumResults() > 1 && failedToParsedMinMax) {
2099  if (isLower) {
2100  return p.emitError(attrLoc, "lower loop bound affine map with "
2101  "multiple results requires 'max' prefix");
2102  }
2103  return p.emitError(attrLoc, "upper loop bound affine map with multiple "
2104  "results requires 'min' prefix");
2105  }
2106  return success();
2107  }
2108 
2109  // Parse custom assembly form.
2110  if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
2111  result.attributes.pop_back();
2112  result.addAttribute(
2113  boundAttrStrName,
2114  AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
2115  return success();
2116  }
2117 
2118  return p.emitError(
2119  p.getNameLoc(),
2120  "expected valid affine map representation for loop bounds");
2121 }
2122 
2123 ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
2124  auto &builder = parser.getBuilder();
2125  OpAsmParser::Argument inductionVariable;
2126  inductionVariable.type = builder.getIndexType();
2127  // Parse the induction variable followed by '='.
2128  if (parser.parseArgument(inductionVariable) || parser.parseEqual())
2129  return failure();
2130 
2131  // Parse loop bounds.
2132  if (parseBound(/*isLower=*/true, result, parser) ||
2133  parser.parseKeyword("to", " between bounds") ||
2134  parseBound(/*isLower=*/false, result, parser))
2135  return failure();
2136 
2137  // Parse the optional loop step, we default to 1 if one is not present.
2138  if (parser.parseOptionalKeyword("step")) {
2139  result.addAttribute(
2140  AffineForOp::getStepAttrStrName(),
2141  builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
2142  } else {
2143  SMLoc stepLoc = parser.getCurrentLocation();
2144  IntegerAttr stepAttr;
2145  if (parser.parseAttribute(stepAttr, builder.getIndexType(),
2146  AffineForOp::getStepAttrStrName().data(),
2147  result.attributes))
2148  return failure();
2149 
2150  if (stepAttr.getValue().isNegative())
2151  return parser.emitError(
2152  stepLoc,
2153  "expected step to be representable as a positive signed integer");
2154  }
2155 
2156  // Parse the optional initial iteration arguments.
2159 
2160  // Induction variable.
2161  regionArgs.push_back(inductionVariable);
2162 
2163  if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
2164  // Parse assignment list and results type list.
2165  if (parser.parseAssignmentList(regionArgs, operands) ||
2166  parser.parseArrowTypeList(result.types))
2167  return failure();
2168  // Resolve input operands.
2169  for (auto argOperandType :
2170  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2171  Type type = std::get<2>(argOperandType);
2172  std::get<0>(argOperandType).type = type;
2173  if (parser.resolveOperand(std::get<1>(argOperandType), type,
2174  result.operands))
2175  return failure();
2176  }
2177  }
2178 
2179  // Parse the body region.
2180  Region *body = result.addRegion();
2181  if (regionArgs.size() != result.types.size() + 1)
2182  return parser.emitError(
2183  parser.getNameLoc(),
2184  "mismatch between the number of loop-carried values and results");
2185  if (parser.parseRegion(*body, regionArgs))
2186  return failure();
2187 
2188  AffineForOp::ensureTerminator(*body, builder, result.location);
2189 
2190  // Parse the optional attribute list.
2191  return parser.parseOptionalAttrDict(result.attributes);
2192 }
2193 
2194 static void printBound(AffineMapAttr boundMap,
2195  Operation::operand_range boundOperands,
2196  const char *prefix, OpAsmPrinter &p) {
2197  AffineMap map = boundMap.getValue();
2198 
2199  // Check if this bound should be printed using custom assembly form.
2200  // The decision to restrict printing custom assembly form to trivial cases
2201  // comes from the will to roundtrip MLIR binary -> text -> binary in a
2202  // lossless way.
2203  // Therefore, custom assembly form parsing and printing is only supported for
2204  // zero-operand constant maps and single symbol operand identity maps.
2205  if (map.getNumResults() == 1) {
2206  AffineExpr expr = map.getResult(0);
2207 
2208  // Print constant bound.
2209  if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2210  if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
2211  p << constExpr.getValue();
2212  return;
2213  }
2214  }
2215 
2216  // Print bound that consists of a single SSA symbol if the map is over a
2217  // single symbol.
2218  if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2219  if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
2220  p.printOperand(*boundOperands.begin());
2221  return;
2222  }
2223  }
2224  } else {
2225  // Map has multiple results. Print 'min' or 'max' prefix.
2226  p << prefix << ' ';
2227  }
2228 
2229  // Print the map and its operands.
2230  p << boundMap;
2231  printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
2232  map.getNumDims(), p);
2233 }
2234 
2235 unsigned AffineForOp::getNumIterOperands() {
2236  AffineMap lbMap = getLowerBoundMapAttr().getValue();
2237  AffineMap ubMap = getUpperBoundMapAttr().getValue();
2238 
2239  return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2240 }
2241 
2243  p << ' ';
2244  p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
2245  /*omitType=*/true);
2246  p << " = ";
2247  printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
2248  p << " to ";
2249  printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
2250 
2251  if (getStep() != 1)
2252  p << " step " << getStep();
2253 
2254  bool printBlockTerminators = false;
2255  if (getNumIterOperands() > 0) {
2256  p << " iter_args(";
2257  auto regionArgs = getRegionIterArgs();
2258  auto operands = getIterOperands();
2259 
2260  llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2261  p << std::get<0>(it) << " = " << std::get<1>(it);
2262  });
2263  p << ") -> (" << getResultTypes() << ")";
2264  printBlockTerminators = true;
2265  }
2266 
2267  p << ' ';
2268  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2269  printBlockTerminators);
2270  p.printOptionalAttrDict((*this)->getAttrs(),
2271  /*elidedAttrs=*/{getLowerBoundAttrStrName(),
2272  getUpperBoundAttrStrName(),
2273  getStepAttrStrName()});
2274 }
2275 
2276 /// Fold the constant bounds of a loop.
2277 static LogicalResult foldLoopBounds(AffineForOp forOp) {
2278  auto foldLowerOrUpperBound = [&forOp](bool lower) {
2279  // Check to see if each of the operands is the result of a constant. If
2280  // so, get the value. If not, ignore it.
2281  SmallVector<Attribute, 8> operandConstants;
2282  auto boundOperands =
2283  lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2284  for (auto operand : boundOperands) {
2285  Attribute operandCst;
2286  matchPattern(operand, m_Constant(&operandCst));
2287  operandConstants.push_back(operandCst);
2288  }
2289 
2290  AffineMap boundMap =
2291  lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2292  assert(boundMap.getNumResults() >= 1 &&
2293  "bound maps should have at least one result");
2294  SmallVector<Attribute, 4> foldedResults;
2295  if (failed(boundMap.constantFold(operandConstants, foldedResults)))
2296  return failure();
2297 
2298  // Compute the max or min as applicable over the results.
2299  assert(!foldedResults.empty() && "bounds should have at least one result");
2300  auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2301  for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2302  auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2303  maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2304  : llvm::APIntOps::smin(maxOrMin, foldedResult);
2305  }
2306  lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2307  : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2308  return success();
2309  };
2310 
2311  // Try to fold the lower bound.
2312  bool folded = false;
2313  if (!forOp.hasConstantLowerBound())
2314  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
2315 
2316  // Try to fold the upper bound.
2317  if (!forOp.hasConstantUpperBound())
2318  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
2319  return success(folded);
2320 }
2321 
2322 /// Canonicalize the bounds of the given loop.
2323 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2324  SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2325  SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2326 
2327  auto lbMap = forOp.getLowerBoundMap();
2328  auto ubMap = forOp.getUpperBoundMap();
2329  auto prevLbMap = lbMap;
2330  auto prevUbMap = ubMap;
2331 
2332  composeAffineMapAndOperands(&lbMap, &lbOperands);
2333  canonicalizeMapAndOperands(&lbMap, &lbOperands);
2334  simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
2335  simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
2336  lbMap = removeDuplicateExprs(lbMap);
2337 
2338  composeAffineMapAndOperands(&ubMap, &ubOperands);
2339  canonicalizeMapAndOperands(&ubMap, &ubOperands);
2340  ubMap = removeDuplicateExprs(ubMap);
2341 
2342  // Any canonicalization change always leads to updated map(s).
2343  if (lbMap == prevLbMap && ubMap == prevUbMap)
2344  return failure();
2345 
2346  if (lbMap != prevLbMap)
2347  forOp.setLowerBound(lbOperands, lbMap);
2348  if (ubMap != prevUbMap)
2349  forOp.setUpperBound(ubOperands, ubMap);
2350  return success();
2351 }
2352 
2353 namespace {
2354 /// Returns constant trip count in trivial cases.
2355 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2356  int64_t step = forOp.getStep();
2357  if (!forOp.hasConstantBounds() || step <= 0)
2358  return std::nullopt;
2359  int64_t lb = forOp.getConstantLowerBound();
2360  int64_t ub = forOp.getConstantUpperBound();
2361  return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2362 }
2363 
2364 /// This is a pattern to fold trivially empty loop bodies.
2365 /// TODO: This should be moved into the folding hook.
2366 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
2368 
2369  LogicalResult matchAndRewrite(AffineForOp forOp,
2370  PatternRewriter &rewriter) const override {
2371  // Check that the body only contains a yield.
2372  if (!llvm::hasSingleElement(*forOp.getBody()))
2373  return failure();
2374  if (forOp.getNumResults() == 0)
2375  return success();
2376  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2377  if (tripCount && *tripCount == 0) {
2378  // The initial values of the iteration arguments would be the op's
2379  // results.
2380  rewriter.replaceOp(forOp, forOp.getIterOperands());
2381  return success();
2382  }
2383  SmallVector<Value, 4> replacements;
2384  auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2385  auto iterArgs = forOp.getRegionIterArgs();
2386  bool hasValDefinedOutsideLoop = false;
2387  bool iterArgsNotInOrder = false;
2388  for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2389  Value val = yieldOp.getOperand(i);
2390  auto *iterArgIt = llvm::find(iterArgs, val);
2391  if (iterArgIt == iterArgs.end()) {
2392  // `val` is defined outside of the loop.
2393  assert(forOp.isDefinedOutsideOfLoop(val) &&
2394  "must be defined outside of the loop");
2395  hasValDefinedOutsideLoop = true;
2396  replacements.push_back(val);
2397  } else {
2398  unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2399  if (pos != i)
2400  iterArgsNotInOrder = true;
2401  replacements.push_back(forOp.getIterOperands()[pos]);
2402  }
2403  }
2404  // Bail out when the trip count is unknown and the loop returns any value
2405  // defined outside of the loop or any iterArg out of order.
2406  if (!tripCount.has_value() &&
2407  (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2408  return failure();
2409  // Bail out when the loop iterates more than once and it returns any iterArg
2410  // out of order.
2411  if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2412  return failure();
2413  rewriter.replaceOp(forOp, replacements);
2414  return success();
2415  }
2416 };
2417 } // namespace
2418 
2419 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
2420  MLIRContext *context) {
2421  results.add<AffineForEmptyLoopFolder>(context);
2422 }
2423 
2424 /// Return operands used when entering the region at 'index'. These operands
2425 /// correspond to the loop iterator operands, i.e., those excluding the
2426 /// induction variable. AffineForOp only has one region, so zero is the only
2427 /// valid value for `index`.
2429 AffineForOp::getSuccessorEntryOperands(std::optional<unsigned> index) {
2430  assert((!index || *index == 0) && "invalid region index");
2431 
2432  // The initial operands map to the loop arguments after the induction
2433  // variable or are forwarded to the results when the trip count is zero.
2434  return getIterOperands();
2435 }
2436 
2437 /// Given the region at `index`, or the parent operation if `index` is None,
2438 /// return the successor regions. These are the regions that may be selected
2439 /// during the flow of control. `operands` is a set of optional attributes that
2440 /// correspond to a constant value for each operand, or null if that operand is
2441 /// not a constant.
2442 void AffineForOp::getSuccessorRegions(
2443  std::optional<unsigned> index, ArrayRef<Attribute> operands,
2445  assert((!index.has_value() || index.value() == 0) && "expected loop region");
2446  // The loop may typically branch back to its body or to the parent operation.
2447  // If the predecessor is the parent op and the trip count is known to be at
2448  // least one, branch into the body using the iterator arguments. And in cases
2449  // we know the trip count is zero, it can only branch back to its parent.
2450  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2451  if (!index.has_value() && tripCount.has_value()) {
2452  if (tripCount.value() > 0) {
2453  regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
2454  return;
2455  }
2456  if (tripCount.value() == 0) {
2457  regions.push_back(RegionSuccessor(getResults()));
2458  return;
2459  }
2460  }
2461 
2462  // From the loop body, if the trip count is one, we can only branch back to
2463  // the parent.
2464  if (index && tripCount && *tripCount == 1) {
2465  regions.push_back(RegionSuccessor(getResults()));
2466  return;
2467  }
2468 
2469  // In all other cases, the loop may branch back to itself or the parent
2470  // operation.
2471  regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
2472  regions.push_back(RegionSuccessor(getResults()));
2473 }
2474 
2475 /// Returns true if the affine.for has zero iterations in trivial cases.
2476 static bool hasTrivialZeroTripCount(AffineForOp op) {
2477  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
2478  return tripCount && *tripCount == 0;
2479 }
2480 
2481 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2482  SmallVectorImpl<OpFoldResult> &results) {
2483  bool folded = succeeded(foldLoopBounds(*this));
2484  folded |= succeeded(canonicalizeLoopBounds(*this));
2485  if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
2486  // The initial values of the loop-carried variables (iter_args) are the
2487  // results of the op. But this must be avoided for an affine.for op that
2488  // does not return any results. Since ops that do not return results cannot
2489  // be folded away, we would enter an infinite loop of folds on the same
2490  // affine.for op.
2491  results.assign(getIterOperands().begin(), getIterOperands().end());
2492  folded = true;
2493  }
2494  return success(folded);
2495 }
2496 
2498  auto lbMap = getLowerBoundMap();
2499  return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap);
2500 }
2501 
2503  auto lbMap = getLowerBoundMap();
2504  auto ubMap = getUpperBoundMap();
2505  return AffineBound(AffineForOp(*this), lbMap.getNumInputs(),
2506  lbMap.getNumInputs() + ubMap.getNumInputs(), ubMap);
2507 }
2508 
2509 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2510  assert(lbOperands.size() == map.getNumInputs());
2511  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2512 
2513  SmallVector<Value, 4> newOperands(lbOperands.begin(), lbOperands.end());
2514 
2515  auto ubOperands = getUpperBoundOperands();
2516  newOperands.append(ubOperands.begin(), ubOperands.end());
2517  auto iterOperands = getIterOperands();
2518  newOperands.append(iterOperands.begin(), iterOperands.end());
2519  (*this)->setOperands(newOperands);
2520 
2521  (*this)->setAttr(getLowerBoundAttrStrName(), AffineMapAttr::get(map));
2522 }
2523 
2524 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2525  assert(ubOperands.size() == map.getNumInputs());
2526  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2527 
2529  newOperands.append(ubOperands.begin(), ubOperands.end());
2530  auto iterOperands = getIterOperands();
2531  newOperands.append(iterOperands.begin(), iterOperands.end());
2532  (*this)->setOperands(newOperands);
2533 
2534  (*this)->setAttr(getUpperBoundAttrStrName(), AffineMapAttr::get(map));
2535 }
2536 
2537 void AffineForOp::setLowerBoundMap(AffineMap map) {
2538  auto lbMap = getLowerBoundMap();
2539  assert(lbMap.getNumDims() == map.getNumDims() &&
2540  lbMap.getNumSymbols() == map.getNumSymbols());
2541  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2542  (void)lbMap;
2543  (*this)->setAttr(getLowerBoundAttrStrName(), AffineMapAttr::get(map));
2544 }
2545 
2546 void AffineForOp::setUpperBoundMap(AffineMap map) {
2547  auto ubMap = getUpperBoundMap();
2548  assert(ubMap.getNumDims() == map.getNumDims() &&
2549  ubMap.getNumSymbols() == map.getNumSymbols());
2550  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2551  (void)ubMap;
2552  (*this)->setAttr(getUpperBoundAttrStrName(), AffineMapAttr::get(map));
2553 }
2554 
2555 bool AffineForOp::hasConstantLowerBound() {
2556  return getLowerBoundMap().isSingleConstant();
2557 }
2558 
2559 bool AffineForOp::hasConstantUpperBound() {
2560  return getUpperBoundMap().isSingleConstant();
2561 }
2562 
2563 int64_t AffineForOp::getConstantLowerBound() {
2564  return getLowerBoundMap().getSingleConstantResult();
2565 }
2566 
2567 int64_t AffineForOp::getConstantUpperBound() {
2568  return getUpperBoundMap().getSingleConstantResult();
2569 }
2570 
2571 void AffineForOp::setConstantLowerBound(int64_t value) {
2572  setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2573 }
2574 
2575 void AffineForOp::setConstantUpperBound(int64_t value) {
2576  setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2577 }
2578 
2579 AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
2580  return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
2581 }
2582 
2583 AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
2584  return {operand_begin() + getLowerBoundMap().getNumInputs(),
2585  operand_begin() + getLowerBoundMap().getNumInputs() +
2586  getUpperBoundMap().getNumInputs()};
2587 }
2588 
2589 AffineForOp::operand_range AffineForOp::getControlOperands() {
2590  return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs() +
2591  getUpperBoundMap().getNumInputs()};
2592 }
2593 
2594 bool AffineForOp::matchingBoundOperandList() {
2595  auto lbMap = getLowerBoundMap();
2596  auto ubMap = getUpperBoundMap();
2597  if (lbMap.getNumDims() != ubMap.getNumDims() ||
2598  lbMap.getNumSymbols() != ubMap.getNumSymbols())
2599  return false;
2600 
2601  unsigned numOperands = lbMap.getNumInputs();
2602  for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2603  // Compare Value 's.
2604  if (getOperand(i) != getOperand(numOperands + i))
2605  return false;
2606  }
2607  return true;
2608 }
2609 
2610 Region &AffineForOp::getLoopBody() { return getRegion(); }
2611 
2612 std::optional<Value> AffineForOp::getSingleInductionVar() {
2613  return getInductionVar();
2614 }
2615 
2616 std::optional<OpFoldResult> AffineForOp::getSingleLowerBound() {
2617  if (!hasConstantLowerBound())
2618  return std::nullopt;
2619  OpBuilder b(getContext());
2620  return OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()));
2621 }
2622 
2623 std::optional<OpFoldResult> AffineForOp::getSingleStep() {
2624  OpBuilder b(getContext());
2625  return OpFoldResult(b.getI64IntegerAttr(getStep()));
2626 }
2627 
2628 std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
2629  if (!hasConstantUpperBound())
2630  return std::nullopt;
2631  OpBuilder b(getContext());
2632  return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
2633 }
2634 
2635 Speculation::Speculatability AffineForOp::getSpeculatability() {
2636  // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2637  // Start and End.
2638  //
2639  // For Step != 1, the loop may not terminate. We can add more smarts here if
2640  // needed.
2641  return getStep() == 1 ? Speculation::RecursivelySpeculatable
2643 }
2644 
2645 /// Returns true if the provided value is the induction variable of a
2646 /// AffineForOp.
2648  return getForInductionVarOwner(val) != AffineForOp();
2649 }
2650 
2652  return getAffineParallelInductionVarOwner(val) != nullptr;
2653 }
2654 
2657 }
2658 
2660  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2661  if (!ivArg || !ivArg.getOwner())
2662  return AffineForOp();
2663  auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
2664  if (auto forOp = dyn_cast<AffineForOp>(containingInst))
2665  // Check to make sure `val` is the induction variable, not an iter_arg.
2666  return forOp.getInductionVar() == val ? forOp : AffineForOp();
2667  return AffineForOp();
2668 }
2669 
2671  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2672  if (!ivArg || !ivArg.getOwner())
2673  return nullptr;
2674  Operation *containingOp = ivArg.getOwner()->getParentOp();
2675  auto parallelOp = dyn_cast<AffineParallelOp>(containingOp);
2676  if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2677  return parallelOp;
2678  return nullptr;
2679 }
2680 
2681 /// Extracts the induction variables from a list of AffineForOps and returns
2682 /// them.
2684  SmallVectorImpl<Value> *ivs) {
2685  ivs->reserve(forInsts.size());
2686  for (auto forInst : forInsts)
2687  ivs->push_back(forInst.getInductionVar());
2688 }
2689 
2692  ivs.reserve(affineOps.size());
2693  for (Operation *op : affineOps) {
2694  // Add constraints from forOp's bounds.
2695  if (auto forOp = dyn_cast<AffineForOp>(op))
2696  ivs.push_back(forOp.getInductionVar());
2697  else if (auto parallelOp = dyn_cast<AffineParallelOp>(op))
2698  for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2699  ivs.push_back(parallelOp.getBody()->getArgument(i));
2700  }
2701 }
2702 
2703 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2704 /// operations.
2705 template <typename BoundListTy, typename LoopCreatorTy>
2707  OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2708  ArrayRef<int64_t> steps,
2709  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2710  LoopCreatorTy &&loopCreatorFn) {
2711  assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2712  assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2713 
2714  // If there are no loops to be constructed, construct the body anyway.
2715  OpBuilder::InsertionGuard guard(builder);
2716  if (lbs.empty()) {
2717  if (bodyBuilderFn)
2718  bodyBuilderFn(builder, loc, ValueRange());
2719  return;
2720  }
2721 
2722  // Create the loops iteratively and store the induction variables.
2724  ivs.reserve(lbs.size());
2725  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2726  // Callback for creating the loop body, always creates the terminator.
2727  auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
2728  ValueRange iterArgs) {
2729  ivs.push_back(iv);
2730  // In the innermost loop, call the body builder.
2731  if (i == e - 1 && bodyBuilderFn) {
2732  OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
2733  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2734  }
2735  nestedBuilder.create<AffineYieldOp>(nestedLoc);
2736  };
2737 
2738  // Delegate actual loop creation to the callback in order to dispatch
2739  // between constant- and variable-bound loops.
2740  auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2741  builder.setInsertionPointToStart(loop.getBody());
2742  }
2743 }
2744 
2745 /// Creates an affine loop from the bounds known to be constants.
2746 static AffineForOp
2748  int64_t ub, int64_t step,
2749  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2750  return builder.create<AffineForOp>(loc, lb, ub, step,
2751  /*iterArgs=*/std::nullopt, bodyBuilderFn);
2752 }
2753 
2754 /// Creates an affine loop from the bounds that may or may not be constants.
2755 static AffineForOp
2757  int64_t step,
2758  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2759  auto lbConst = lb.getDefiningOp<arith::ConstantIndexOp>();
2760  auto ubConst = ub.getDefiningOp<arith::ConstantIndexOp>();
2761  if (lbConst && ubConst)
2762  return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
2763  ubConst.value(), step, bodyBuilderFn);
2764  return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
2765  builder.getDimIdentityMap(), step,
2766  /*iterArgs=*/std::nullopt, bodyBuilderFn);
2767 }
2768 
2770  OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
2772  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2773  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2775 }
2776 
2778  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
2779  ArrayRef<int64_t> steps,
2780  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2781  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2783 }
2784 
2786  AffineForOp loop,
2787  ValueRange newIterOperands,
2788  ValueRange newYieldedValues,
2789  ValueRange newIterArgs,
2790  bool replaceLoopResults) {
2791  assert(newIterOperands.size() == newYieldedValues.size() &&
2792  "newIterOperands must be of the same size as newYieldedValues");
2793  // Create a new loop before the existing one, with the extra operands.
2795  b.setInsertionPoint(loop);
2796  auto operands = llvm::to_vector<4>(loop.getIterOperands());
2797  operands.append(newIterOperands.begin(), newIterOperands.end());
2798  SmallVector<Value, 4> lbOperands(loop.getLowerBoundOperands());
2799  SmallVector<Value, 4> ubOperands(loop.getUpperBoundOperands());
2800  SmallVector<Value, 4> steps(loop.getStep());
2801  auto lbMap = loop.getLowerBoundMap();
2802  auto ubMap = loop.getUpperBoundMap();
2803  AffineForOp newLoop =
2804  b.create<AffineForOp>(loop.getLoc(), lbOperands, lbMap, ubOperands, ubMap,
2805  loop.getStep(), operands);
2806  // Take the body of the original parent loop.
2807  newLoop.getLoopBody().takeBody(loop.getLoopBody());
2808  for (Value val : newIterArgs)
2809  newLoop.getLoopBody().addArgument(val.getType(), val.getLoc());
2810 
2811  // Update yield operation with new values to be added.
2812  if (!newYieldedValues.empty()) {
2813  auto yield = cast<AffineYieldOp>(newLoop.getBody()->getTerminator());
2814  b.setInsertionPoint(yield);
2815  auto yieldOperands = llvm::to_vector<4>(yield.getOperands());
2816  yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end());
2817  b.create<AffineYieldOp>(yield.getLoc(), yieldOperands);
2818  yield.erase();
2819  }
2820  if (replaceLoopResults) {
2821  for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
2822  loop.getNumResults()))) {
2823  std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
2824  }
2825  }
2826  return newLoop;
2827 }
2828 
2829 //===----------------------------------------------------------------------===//
2830 // AffineIfOp
2831 //===----------------------------------------------------------------------===//
2832 
2833 namespace {
2834 /// Remove else blocks that have nothing other than a zero value yield.
2835 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
2837 
2838  LogicalResult matchAndRewrite(AffineIfOp ifOp,
2839  PatternRewriter &rewriter) const override {
2840  if (ifOp.getElseRegion().empty() ||
2841  !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2842  return failure();
2843 
2844  rewriter.startRootUpdate(ifOp);
2845  rewriter.eraseBlock(ifOp.getElseBlock());
2846  rewriter.finalizeRootUpdate(ifOp);
2847  return success();
2848  }
2849 };
2850 
2851 /// Removes affine.if cond if the condition is always true or false in certain
2852 /// trivial cases. Promotes the then/else block in the parent operation block.
2853 struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
2855 
2856  LogicalResult matchAndRewrite(AffineIfOp op,
2857  PatternRewriter &rewriter) const override {
2858 
2859  auto isTriviallyFalse = [](IntegerSet iSet) {
2860  return iSet.isEmptyIntegerSet();
2861  };
2862 
2863  auto isTriviallyTrue = [](IntegerSet iSet) {
2864  return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2865  iSet.getConstraint(0) == 0);
2866  };
2867 
2868  IntegerSet affineIfConditions = op.getIntegerSet();
2869  Block *blockToMove;
2870  if (isTriviallyFalse(affineIfConditions)) {
2871  // The absence, or equivalently, the emptiness of the else region need not
2872  // be checked when affine.if is returning results because if an affine.if
2873  // operation is returning results, it always has a non-empty else region.
2874  if (op.getNumResults() == 0 && !op.hasElse()) {
2875  // If the else region is absent, or equivalently, empty, remove the
2876  // affine.if operation (which is not returning any results).
2877  rewriter.eraseOp(op);
2878  return success();
2879  }
2880  blockToMove = op.getElseBlock();
2881  } else if (isTriviallyTrue(affineIfConditions)) {
2882  blockToMove = op.getThenBlock();
2883  } else {
2884  return failure();
2885  }
2886  Operation *blockToMoveTerminator = blockToMove->getTerminator();
2887  // Promote the "blockToMove" block to the parent operation block between the
2888  // prologue and epilogue of "op".
2889  rewriter.inlineBlockBefore(blockToMove, op);
2890  // Replace the "op" operation with the operands of the
2891  // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2892  // the affine.yield operation present in the "blockToMove" block. It has no
2893  // operands when affine.if is not returning results and therefore, in that
2894  // case, replaceOp just erases "op". When affine.if is not returning
2895  // results, the affine.yield operation can be omitted. It gets inserted
2896  // implicitly.
2897  rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
2898  // Erase the "blockToMoveTerminator" operation since it is now in the parent
2899  // operation block, which already has its own terminator.
2900  rewriter.eraseOp(blockToMoveTerminator);
2901  return success();
2902  }
2903 };
2904 } // namespace
2905 
2906 /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
2907 /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
2908 void AffineIfOp::getSuccessorRegions(
2909  std::optional<unsigned> index, ArrayRef<Attribute> operands,
2911  // If the predecessor is an AffineIfOp, then branching into both `then` and
2912  // `else` region is valid.
2913  if (!index.has_value()) {
2914  regions.reserve(2);
2915  regions.push_back(
2916  RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
2917  // Don't consider the else region if it is empty.
2918  if (!getElseRegion().empty())
2919  regions.push_back(
2920  RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
2921  return;
2922  }
2923 
2924  // If the predecessor is the `else`/`then` region, then branching into parent
2925  // op is valid.
2926  regions.push_back(RegionSuccessor(getResults()));
2927 }
2928 
2930  // Verify that we have a condition attribute.
2931  // FIXME: This should be specified in the arguments list in ODS.
2932  auto conditionAttr =
2933  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2934  if (!conditionAttr)
2935  return emitOpError("requires an integer set attribute named 'condition'");
2936 
2937  // Verify that there are enough operands for the condition.
2938  IntegerSet condition = conditionAttr.getValue();
2939  if (getNumOperands() != condition.getNumInputs())
2940  return emitOpError("operand count and condition integer set dimension and "
2941  "symbol count must match");
2942 
2943  // Verify that the operands are valid dimension/symbols.
2944  if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
2945  condition.getNumDims())))
2946  return failure();
2947 
2948  return success();
2949 }
2950 
2951 ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
2952  // Parse the condition attribute set.
2953  IntegerSetAttr conditionAttr;
2954  unsigned numDims;
2955  if (parser.parseAttribute(conditionAttr,
2956  AffineIfOp::getConditionAttrStrName(),
2957  result.attributes) ||
2958  parseDimAndSymbolList(parser, result.operands, numDims))
2959  return failure();
2960 
2961  // Verify the condition operands.
2962  auto set = conditionAttr.getValue();
2963  if (set.getNumDims() != numDims)
2964  return parser.emitError(
2965  parser.getNameLoc(),
2966  "dim operand count and integer set dim count must match");
2967  if (numDims + set.getNumSymbols() != result.operands.size())
2968  return parser.emitError(
2969  parser.getNameLoc(),
2970  "symbol operand count and integer set symbol count must match");
2971 
2972  if (parser.parseOptionalArrowTypeList(result.types))
2973  return failure();
2974 
2975  // Create the regions for 'then' and 'else'. The latter must be created even
2976  // if it remains empty for the validity of the operation.
2977  result.regions.reserve(2);
2978  Region *thenRegion = result.addRegion();
2979  Region *elseRegion = result.addRegion();
2980 
2981  // Parse the 'then' region.
2982  if (parser.parseRegion(*thenRegion, {}, {}))
2983  return failure();
2984  AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
2985  result.location);
2986 
2987  // If we find an 'else' keyword then parse the 'else' region.
2988  if (!parser.parseOptionalKeyword("else")) {
2989  if (parser.parseRegion(*elseRegion, {}, {}))
2990  return failure();
2991  AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
2992  result.location);
2993  }
2994 
2995  // Parse the optional attribute list.
2996  if (parser.parseOptionalAttrDict(result.attributes))
2997  return failure();
2998 
2999  return success();
3000 }
3001 
3003  auto conditionAttr =
3004  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3005  p << " " << conditionAttr;
3006  printDimAndSymbolList(operand_begin(), operand_end(),
3007  conditionAttr.getValue().getNumDims(), p);
3008  p.printOptionalArrowTypeList(getResultTypes());
3009  p << ' ';
3010  p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
3011  /*printBlockTerminators=*/getNumResults());
3012 
3013  // Print the 'else' regions if it has any blocks.
3014  auto &elseRegion = this->getElseRegion();
3015  if (!elseRegion.empty()) {
3016  p << " else ";
3017  p.printRegion(elseRegion,
3018  /*printEntryBlockArgs=*/false,
3019  /*printBlockTerminators=*/getNumResults());
3020  }
3021 
3022  // Print the attribute list.
3023  p.printOptionalAttrDict((*this)->getAttrs(),
3024  /*elidedAttrs=*/getConditionAttrStrName());
3025 }
3026 
3027 IntegerSet AffineIfOp::getIntegerSet() {
3028  return (*this)
3029  ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
3030  .getValue();
3031 }
3032 
3033 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
3034  (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
3035 }
3036 
3037 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
3038  setIntegerSet(set);
3039  (*this)->setOperands(operands);
3040 }
3041 
3042 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
3043  TypeRange resultTypes, IntegerSet set, ValueRange args,
3044  bool withElseRegion) {
3045  assert(resultTypes.empty() || withElseRegion);
3046  result.addTypes(resultTypes);
3047  result.addOperands(args);
3048  result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
3049 
3050  Region *thenRegion = result.addRegion();
3051  thenRegion->push_back(new Block());
3052  if (resultTypes.empty())
3053  AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
3054 
3055  Region *elseRegion = result.addRegion();
3056  if (withElseRegion) {
3057  elseRegion->push_back(new Block());
3058  if (resultTypes.empty())
3059  AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
3060  }
3061 }
3062 
3063 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
3064  IntegerSet set, ValueRange args, bool withElseRegion) {
3065  AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
3066  withElseRegion);
3067 }
3068 
3069 /// Compose any affine.apply ops feeding into `operands` of the integer set
3070 /// `set` by composing the maps of such affine.apply ops with the integer
3071 /// set constraints.
3073  SmallVectorImpl<Value> &operands) {
3074  // We will simply reuse the API of the map composition by viewing the LHSs of
3075  // the equalities and inequalities of `set` as the affine exprs of an affine
3076  // map. Convert to equivalent map, compose, and convert back to set.
3077  auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(),
3078  set.getConstraints(), set.getContext());
3079  // Check if any composition is possible.
3080  if (llvm::none_of(operands,
3081  [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
3082  return;
3083 
3084  composeAffineMapAndOperands(&map, &operands);
3085  set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
3086  set.getEqFlags());
3087 }
3088 
3089 /// Canonicalize an affine if op's conditional (integer set + operands).
3090 LogicalResult AffineIfOp::fold(FoldAdaptor,
3092  auto set = getIntegerSet();
3093  SmallVector<Value, 4> operands(getOperands());
3094  composeSetAndOperands(set, operands);
3095  canonicalizeSetAndOperands(&set, &operands);
3096 
3097  // Check if the canonicalization or composition led to any change.
3098  if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
3099  return failure();
3100 
3101  setConditional(set, operands);
3102  return success();
3103 }
3104 
3105 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
3106  MLIRContext *context) {
3107  results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
3108 }
3109 
3110 //===----------------------------------------------------------------------===//
3111 // AffineLoadOp
3112 //===----------------------------------------------------------------------===//
3113 
3114 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3115  AffineMap map, ValueRange operands) {
3116  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
3117  result.addOperands(operands);
3118  if (map)
3119  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3120  auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
3121  result.types.push_back(memrefType.getElementType());
3122 }
3123 
3124 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3125  Value memref, AffineMap map, ValueRange mapOperands) {
3126  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3127  result.addOperands(memref);
3128  result.addOperands(mapOperands);
3129  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3130  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3131  result.types.push_back(memrefType.getElementType());
3132 }
3133 
3134 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3135  Value memref, ValueRange indices) {
3136  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3137  int64_t rank = memrefType.getRank();
3138  // Create identity map for memrefs with at least one dimension or () -> ()
3139  // for zero-dimensional memrefs.
3140  auto map =
3141  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3142  build(builder, result, memref, map, indices);
3143 }
3144 
3145 ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
3146  auto &builder = parser.getBuilder();
3147  auto indexTy = builder.getIndexType();
3148 
3149  MemRefType type;
3150  OpAsmParser::UnresolvedOperand memrefInfo;
3151  AffineMapAttr mapAttr;
3153  return failure(
3154  parser.parseOperand(memrefInfo) ||
3155  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3156  AffineLoadOp::getMapAttrStrName(),
3157  result.attributes) ||
3158  parser.parseOptionalAttrDict(result.attributes) ||
3159  parser.parseColonType(type) ||
3160  parser.resolveOperand(memrefInfo, type, result.operands) ||
3161  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3162  parser.addTypeToList(type.getElementType(), result.types));
3163 }
3164 
3166  p << " " << getMemRef() << '[';
3167  if (AffineMapAttr mapAttr =
3168  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3169  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3170  p << ']';
3171  p.printOptionalAttrDict((*this)->getAttrs(),
3172  /*elidedAttrs=*/{getMapAttrStrName()});
3173  p << " : " << getMemRefType();
3174 }
3175 
3176 /// Verify common indexing invariants of affine.load, affine.store,
3177 /// affine.vector_load and affine.vector_store.
3178 static LogicalResult
3179 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
3180  Operation::operand_range mapOperands,
3181  MemRefType memrefType, unsigned numIndexOperands) {
3182  if (mapAttr) {
3183  AffineMap map = mapAttr.getValue();
3184  if (map.getNumResults() != memrefType.getRank())
3185  return op->emitOpError("affine map num results must equal memref rank");
3186  if (map.getNumInputs() != numIndexOperands)
3187  return op->emitOpError("expects as many subscripts as affine map inputs");
3188  } else {
3189  if (memrefType.getRank() != numIndexOperands)
3190  return op->emitOpError(
3191  "expects the number of subscripts to be equal to memref rank");
3192  }
3193 
3194  Region *scope = getAffineScope(op);
3195  for (auto idx : mapOperands) {
3196  if (!idx.getType().isIndex())
3197  return op->emitOpError("index to load must have 'index' type");
3198  if (!isValidAffineIndexOperand(idx, scope))
3199  return op->emitOpError("index must be a dimension or symbol identifier");
3200  }
3201 
3202  return success();
3203 }
3204 
3206  auto memrefType = getMemRefType();
3207  if (getType() != memrefType.getElementType())
3208  return emitOpError("result type must match element type of memref");
3209 
3211  getOperation(),
3212  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3213  getMapOperands(), memrefType,
3214  /*numIndexOperands=*/getNumOperands() - 1)))
3215  return failure();
3216 
3217  return success();
3218 }
3219 
3220 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3221  MLIRContext *context) {
3222  results.add<SimplifyAffineOp<AffineLoadOp>>(context);
3223 }
3224 
3225 OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3226  /// load(memrefcast) -> load
3227  if (succeeded(memref::foldMemRefCast(*this)))
3228  return getResult();
3229 
3230  // Fold load from a global constant memref.
3231  auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3232  if (!getGlobalOp)
3233  return {};
3234  // Get to the memref.global defining the symbol.
3235  auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3236  if (!symbolTableOp)
3237  return {};
3238  auto global = dyn_cast_or_null<memref::GlobalOp>(
3239  SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
3240  if (!global)
3241  return {};
3242 
3243  // Check if the global memref is a constant.
3244  auto cstAttr =
3245  llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3246  if (!cstAttr)
3247  return {};
3248  // If it's a splat constant, we can fold irrespective of indices.
3249  if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3250  return splatAttr.getSplatValue<Attribute>();
3251  // Otherwise, we can fold only if we know the indices.
3252  if (!getAffineMap().isConstant())
3253  return {};
3254  auto indices = llvm::to_vector<4>(
3255  llvm::map_range(getAffineMap().getConstantResults(),
3256  [](int64_t v) -> uint64_t { return v; }));
3257  return cstAttr.getValues<Attribute>()[indices];
3258 }
3259 
3260 //===----------------------------------------------------------------------===//
3261 // AffineStoreOp
3262 //===----------------------------------------------------------------------===//
3263 
3264 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3265  Value valueToStore, Value memref, AffineMap map,
3266  ValueRange mapOperands) {
3267  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3268  result.addOperands(valueToStore);
3269  result.addOperands(memref);
3270  result.addOperands(mapOperands);
3271  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3272 }
3273 
3274 // Use identity map.
3275 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3276  Value valueToStore, Value memref,
3277  ValueRange indices) {
3278  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3279  int64_t rank = memrefType.getRank();
3280  // Create identity map for memrefs with at least one dimension or () -> ()
3281  // for zero-dimensional memrefs.
3282  auto map =
3283  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3284  build(builder, result, valueToStore, memref, map, indices);
3285 }
3286 
3287 ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
3288  auto indexTy = parser.getBuilder().getIndexType();
3289 
3290  MemRefType type;
3291  OpAsmParser::UnresolvedOperand storeValueInfo;
3292  OpAsmParser::UnresolvedOperand memrefInfo;
3293  AffineMapAttr mapAttr;
3295  return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3296  parser.parseOperand(memrefInfo) ||
3297  parser.parseAffineMapOfSSAIds(
3298  mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3299  result.attributes) ||
3300  parser.parseOptionalAttrDict(result.attributes) ||
3301  parser.parseColonType(type) ||
3302  parser.resolveOperand(storeValueInfo, type.getElementType(),
3303  result.operands) ||
3304  parser.resolveOperand(memrefInfo, type, result.operands) ||
3305  parser.resolveOperands(mapOperands, indexTy, result.operands));
3306 }
3307 
3309  p << " " << getValueToStore();
3310  p << ", " << getMemRef() << '[';
3311  if (AffineMapAttr mapAttr =
3312  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3313  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3314  p << ']';
3315  p.printOptionalAttrDict((*this)->getAttrs(),
3316  /*elidedAttrs=*/{getMapAttrStrName()});
3317  p << " : " << getMemRefType();
3318 }
3319 
3321  // The value to store must have the same type as memref element type.
3322  auto memrefType = getMemRefType();
3323  if (getValueToStore().getType() != memrefType.getElementType())
3324  return emitOpError(
3325  "value to store must have the same type as memref element type");
3326 
3328  getOperation(),
3329  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3330  getMapOperands(), memrefType,
3331  /*numIndexOperands=*/getNumOperands() - 2)))
3332  return failure();
3333 
3334  return success();
3335 }
3336 
3337 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3338  MLIRContext *context) {
3339  results.add<SimplifyAffineOp<AffineStoreOp>>(context);
3340 }
3341 
3342 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3343  SmallVectorImpl<OpFoldResult> &results) {
3344  /// store(memrefcast) -> store
3345  return memref::foldMemRefCast(*this, getValueToStore());
3346 }
3347 
3348 //===----------------------------------------------------------------------===//
3349 // AffineMinMaxOpBase
3350 //===----------------------------------------------------------------------===//
3351 
3352 template <typename T>
3354  // Verify that operand count matches affine map dimension and symbol count.
3355  if (op.getNumOperands() !=
3356  op.getMap().getNumDims() + op.getMap().getNumSymbols())
3357  return op.emitOpError(
3358  "operand count and affine map dimension and symbol count must match");
3359  return success();
3360 }
3361 
3362 template <typename T>
3363 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3364  p << ' ' << op->getAttr(T::getMapAttrStrName());
3365  auto operands = op.getOperands();
3366  unsigned numDims = op.getMap().getNumDims();
3367  p << '(' << operands.take_front(numDims) << ')';
3368 
3369  if (operands.size() != numDims)
3370  p << '[' << operands.drop_front(numDims) << ']';
3372  /*elidedAttrs=*/{T::getMapAttrStrName()});
3373 }
3374 
3375 template <typename T>
3377  OperationState &result) {
3378  auto &builder = parser.getBuilder();
3379  auto indexType = builder.getIndexType();
3382  AffineMapAttr mapAttr;
3383  return failure(
3384  parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3385  result.attributes) ||
3386  parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) ||
3387  parser.parseOperandList(symInfos,
3389  parser.parseOptionalAttrDict(result.attributes) ||
3390  parser.resolveOperands(dimInfos, indexType, result.operands) ||
3391  parser.resolveOperands(symInfos, indexType, result.operands) ||
3392  parser.addTypeToList(indexType, result.types));
3393 }
3394 
3395 /// Fold an affine min or max operation with the given operands. The operand
3396 /// list may contain nulls, which are interpreted as the operand not being a
3397 /// constant.
3398 template <typename T>
3400  static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3401  "expected affine min or max op");
3402 
3403  // Fold the affine map.
3404  // TODO: Fold more cases:
3405  // min(some_affine, some_affine + constant, ...), etc.
3406  SmallVector<int64_t, 2> results;
3407  auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3408 
3409  if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3410  return op.getOperand(0);
3411 
3412  // If some of the map results are not constant, try changing the map in-place.
3413  if (results.empty()) {
3414  // If the map is the same, report that folding did not happen.
3415  if (foldedMap == op.getMap())
3416  return {};
3417  op->setAttr("map", AffineMapAttr::get(foldedMap));
3418  return op.getResult();
3419  }
3420 
3421  // Otherwise, completely fold the op into a constant.
3422  auto resultIt = std::is_same<T, AffineMinOp>::value
3423  ? std::min_element(results.begin(), results.end())
3424  : std::max_element(results.begin(), results.end());
3425  if (resultIt == results.end())
3426  return {};
3427  return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3428 }
3429 
3430 /// Remove duplicated expressions in affine min/max ops.
3431 template <typename T>
3434 
3436  PatternRewriter &rewriter) const override {
3437  AffineMap oldMap = affineOp.getAffineMap();
3438 
3439  SmallVector<AffineExpr, 4> newExprs;
3440  for (AffineExpr expr : oldMap.getResults()) {
3441  // This is a linear scan over newExprs, but it should be fine given that
3442  // we typically just have a few expressions per op.
3443  if (!llvm::is_contained(newExprs, expr))
3444  newExprs.push_back(expr);
3445  }
3446 
3447  if (newExprs.size() == oldMap.getNumResults())
3448  return failure();
3449 
3450  auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
3451  newExprs, rewriter.getContext());
3452  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3453 
3454  return success();
3455  }
3456 };
3457 
3458 /// Merge an affine min/max op to its consumers if its consumer is also an
3459 /// affine min/max op.
3460 ///
3461 /// This pattern requires the producer affine min/max op is bound to a
3462 /// dimension/symbol that is used as a standalone expression in the consumer
3463 /// affine op's map.
3464 ///
3465 /// For example, a pattern like the following:
3466 ///
3467 /// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3468 /// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3469 ///
3470 /// Can be turned into:
3471 ///
3472 /// %1 = affine.min affine_map<
3473 /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3474 template <typename T>
3477 
3479  PatternRewriter &rewriter) const override {
3480  AffineMap oldMap = affineOp.getAffineMap();
3481  ValueRange dimOperands =
3482  affineOp.getMapOperands().take_front(oldMap.getNumDims());
3483  ValueRange symOperands =
3484  affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3485 
3486  auto newDimOperands = llvm::to_vector<8>(dimOperands);
3487  auto newSymOperands = llvm::to_vector<8>(symOperands);
3488  SmallVector<AffineExpr, 4> newExprs;
3489  SmallVector<T, 4> producerOps;
3490 
3491  // Go over each expression to see whether it's a single dimension/symbol
3492  // with the corresponding operand which is the result of another affine
3493  // min/max op. If So it can be merged into this affine op.
3494  for (AffineExpr expr : oldMap.getResults()) {
3495  if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
3496  Value symValue = symOperands[symExpr.getPosition()];
3497  if (auto producerOp = symValue.getDefiningOp<T>()) {
3498  producerOps.push_back(producerOp);
3499  continue;
3500  }
3501  } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
3502  Value dimValue = dimOperands[dimExpr.getPosition()];
3503  if (auto producerOp = dimValue.getDefiningOp<T>()) {
3504  producerOps.push_back(producerOp);
3505  continue;
3506  }
3507  }
3508  // For the above cases we will remove the expression by merging the
3509  // producer affine min/max's affine expressions. Otherwise we need to
3510  // keep the existing expression.
3511  newExprs.push_back(expr);
3512  }
3513 
3514  if (producerOps.empty())
3515  return failure();
3516 
3517  unsigned numUsedDims = oldMap.getNumDims();
3518  unsigned numUsedSyms = oldMap.getNumSymbols();
3519 
3520  // Now go over all producer affine ops and merge their expressions.
3521  for (T producerOp : producerOps) {
3522  AffineMap producerMap = producerOp.getAffineMap();
3523  unsigned numProducerDims = producerMap.getNumDims();
3524  unsigned numProducerSyms = producerMap.getNumSymbols();
3525 
3526  // Collect all dimension/symbol values.
3527  ValueRange dimValues =
3528  producerOp.getMapOperands().take_front(numProducerDims);
3529  ValueRange symValues =
3530  producerOp.getMapOperands().take_back(numProducerSyms);
3531  newDimOperands.append(dimValues.begin(), dimValues.end());
3532  newSymOperands.append(symValues.begin(), symValues.end());
3533 
3534  // For expressions we need to shift to avoid overlap.
3535  for (AffineExpr expr : producerMap.getResults()) {
3536  newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3537  .shiftSymbols(numProducerSyms, numUsedSyms));
3538  }
3539 
3540  numUsedDims += numProducerDims;
3541  numUsedSyms += numProducerSyms;
3542  }
3543 
3544  auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
3545  rewriter.getContext());
3546  auto newOperands =
3547  llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3548  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3549 
3550  return success();
3551  }
3552 };
3553 
3554 /// Canonicalize the result expression order of an affine map and return success
3555 /// if the order changed.
3556 ///
3557 /// The function flattens the map's affine expressions to coefficient arrays and
3558 /// sorts them in lexicographic order. A coefficient array contains a multiplier
3559 /// for every dimension/symbol and a constant term. The canonicalization fails
3560 /// if a result expression is not pure or if the flattening requires local
3561 /// variables that, unlike dimensions and symbols, have no global order.
3563  SmallVector<SmallVector<int64_t>> flattenedExprs;
3564  for (const AffineExpr &resultExpr : map.getResults()) {
3565  // Fail if the expression is not pure.
3566  if (!resultExpr.isPureAffine())
3567  return failure();
3568 
3569  SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3570  flattener.walkPostOrder(resultExpr);
3571 
3572  // Fail if the flattened expression has local variables.
3573  if (flattener.operandExprStack.back().size() !=
3574  map.getNumDims() + map.getNumSymbols() + 1)
3575  return failure();
3576 
3577  flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3578  flattener.operandExprStack.back().end());
3579  }
3580 
3581  // Fail if sorting is not necessary.
3582  if (llvm::is_sorted(flattenedExprs))
3583  return failure();
3584 
3585  // Reorder the result expressions according to their flattened form.
3586  SmallVector<unsigned> resultPermutation =
3587  llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults()));
3588  llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3589  return flattenedExprs[lhs] < flattenedExprs[rhs];
3590  });
3591  SmallVector<AffineExpr> newExprs;
3592  for (unsigned idx : resultPermutation)
3593  newExprs.push_back(map.getResult(idx));
3594 
3595  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs,
3596  map.getContext());
3597  return success();
3598 }
3599 
3600 /// Canonicalize the affine map result expression order of an affine min/max
3601 /// operation.
3602 ///
3603 /// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3604 /// expressions and replaces the operation if the order changed.
3605 ///
3606 /// For example, the following operation:
3607 ///
3608 /// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3609 ///
3610 /// Turns into:
3611 ///
3612 /// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3613 template <typename T>
3616 
3618  PatternRewriter &rewriter) const override {
3619  AffineMap map = affineOp.getAffineMap();
3621  return failure();
3622  rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3623  return success();
3624  }
3625 };
3626 
3627 template <typename T>
3630 
3632  PatternRewriter &rewriter) const override {
3633  if (affineOp.getMap().getNumResults() != 1)
3634  return failure();
3635  rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3636  affineOp.getOperands());
3637  return success();
3638  }
3639 };
3640 
3641 //===----------------------------------------------------------------------===//
3642 // AffineMinOp
3643 //===----------------------------------------------------------------------===//
3644 //
3645 // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3646 //
3647 
3648 OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3649  return foldMinMaxOp(*this, adaptor.getOperands());
3650 }
3651 
3652 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3653  MLIRContext *context) {
3656  MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3658  context);
3659 }
3660 
3662 
3663 ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3664  return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3665 }
3666 
3667 void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3668 
3669 //===----------------------------------------------------------------------===//
3670 // AffineMaxOp
3671 //===----------------------------------------------------------------------===//
3672 //
3673 // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3674 //
3675 
3676 OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3677  return foldMinMaxOp(*this, adaptor.getOperands());
3678 }
3679 
3680 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3681  MLIRContext *context) {
3684  MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3686  context);
3687 }
3688 
3690 
3691 ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3692  return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3693 }
3694 
3695 void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3696 
3697 //===----------------------------------------------------------------------===//
3698 // AffinePrefetchOp
3699 //===----------------------------------------------------------------------===//
3700 
3701 //
3702 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3703 //
3704 ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3705  OperationState &result) {
3706  auto &builder = parser.getBuilder();
3707  auto indexTy = builder.getIndexType();
3708 
3709  MemRefType type;
3710  OpAsmParser::UnresolvedOperand memrefInfo;
3711  IntegerAttr hintInfo;
3712  auto i32Type = parser.getBuilder().getIntegerType(32);
3713  StringRef readOrWrite, cacheType;
3714 
3715  AffineMapAttr mapAttr;
3717  if (parser.parseOperand(memrefInfo) ||
3718  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3719  AffinePrefetchOp::getMapAttrStrName(),
3720  result.attributes) ||
3721  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3722  parser.parseComma() || parser.parseKeyword("locality") ||
3723  parser.parseLess() ||
3724  parser.parseAttribute(hintInfo, i32Type,
3725  AffinePrefetchOp::getLocalityHintAttrStrName(),
3726  result.attributes) ||
3727  parser.parseGreater() || parser.parseComma() ||
3728  parser.parseKeyword(&cacheType) ||
3729  parser.parseOptionalAttrDict(result.attributes) ||
3730  parser.parseColonType(type) ||
3731  parser.resolveOperand(memrefInfo, type, result.operands) ||
3732  parser.resolveOperands(mapOperands, indexTy, result.operands))
3733  return failure();
3734 
3735  if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
3736  return parser.emitError(parser.getNameLoc(),
3737  "rw specifier has to be 'read' or 'write'");
3738  result.addAttribute(
3739  AffinePrefetchOp::getIsWriteAttrStrName(),
3740  parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
3741 
3742  if (!cacheType.equals("data") && !cacheType.equals("instr"))
3743  return parser.emitError(parser.getNameLoc(),
3744  "cache type has to be 'data' or 'instr'");
3745 
3746  result.addAttribute(
3747  AffinePrefetchOp::getIsDataCacheAttrStrName(),
3748  parser.getBuilder().getBoolAttr(cacheType.equals("data")));
3749 
3750  return success();
3751 }
3752 
3754  p << " " << getMemref() << '[';
3755  AffineMapAttr mapAttr =
3756  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3757  if (mapAttr)
3758  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3759  p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3760  << "locality<" << getLocalityHint() << ">, "
3761  << (getIsDataCache() ? "data" : "instr");
3763  (*this)->getAttrs(),
3764  /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3765  getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3766  p << " : " << getMemRefType();
3767 }
3768 
3770  auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3771  if (mapAttr) {
3772  AffineMap map = mapAttr.getValue();
3773  if (map.getNumResults() != getMemRefType().getRank())
3774  return emitOpError("affine.prefetch affine map num results must equal"
3775  " memref rank");
3776  if (map.getNumInputs() + 1 != getNumOperands())
3777  return emitOpError("too few operands");
3778  } else {
3779  if (getNumOperands() != 1)
3780  return emitOpError("too few operands");
3781  }
3782 
3783  Region *scope = getAffineScope(*this);
3784  for (auto idx : getMapOperands()) {
3785  if (!isValidAffineIndexOperand(idx, scope))
3786  return emitOpError("index must be a dimension or symbol identifier");
3787  }
3788  return success();
3789 }
3790 
3791 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3792  MLIRContext *context) {
3793  // prefetch(memrefcast) -> prefetch
3794  results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3795 }
3796 
3797 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3798  SmallVectorImpl<OpFoldResult> &results) {
3799  /// prefetch(memrefcast) -> prefetch
3800  return memref::foldMemRefCast(*this);
3801 }
3802 
3803 //===----------------------------------------------------------------------===//
3804 // AffineParallelOp
3805 //===----------------------------------------------------------------------===//
3806 
3807 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3808  TypeRange resultTypes,
3809  ArrayRef<arith::AtomicRMWKind> reductions,
3810  ArrayRef<int64_t> ranges) {
3811  SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
3812  auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3813  return builder.getConstantAffineMap(value);
3814  }));
3815  SmallVector<int64_t> steps(ranges.size(), 1);
3816  build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
3817  /*ubArgs=*/{}, steps);
3818 }
3819 
3820 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3821  TypeRange resultTypes,
3822  ArrayRef<arith::AtomicRMWKind> reductions,
3823  ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
3824  ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
3825  ArrayRef<int64_t> steps) {
3826  assert(llvm::all_of(lbMaps,
3827  [lbMaps](AffineMap m) {
3828  return m.getNumDims() == lbMaps[0].getNumDims() &&
3829  m.getNumSymbols() == lbMaps[0].getNumSymbols();
3830  }) &&
3831  "expected all lower bounds maps to have the same number of dimensions "
3832  "and symbols");
3833  assert(llvm::all_of(ubMaps,
3834  [ubMaps](AffineMap m) {
3835  return m.getNumDims() == ubMaps[0].getNumDims() &&
3836  m.getNumSymbols() == ubMaps[0].getNumSymbols();
3837  }) &&
3838  "expected all upper bounds maps to have the same number of dimensions "
3839  "and symbols");
3840  assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3841  "expected lower bound maps to have as many inputs as lower bound "
3842  "operands");
3843  assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3844  "expected upper bound maps to have as many inputs as upper bound "
3845  "operands");
3846 
3847  result.addTypes(resultTypes);
3848 
3849  // Convert the reductions to integer attributes.
3850  SmallVector<Attribute, 4> reductionAttrs;
3851  for (arith::AtomicRMWKind reduction : reductions)
3852  reductionAttrs.push_back(
3853  builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
3854  result.addAttribute(getReductionsAttrStrName(),
3855  builder.getArrayAttr(reductionAttrs));
3856 
3857  // Concatenates maps defined in the same input space (same dimensions and
3858  // symbols), assumes there is at least one map.
3859  auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
3860  SmallVectorImpl<int32_t> &groups) {
3861  if (maps.empty())
3862  return AffineMap::get(builder.getContext());
3864  groups.reserve(groups.size() + maps.size());
3865  exprs.reserve(maps.size());
3866  for (AffineMap m : maps) {
3867  llvm::append_range(exprs, m.getResults());
3868  groups.push_back(m.getNumResults());
3869  }
3870  return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3871  maps[0].getContext());
3872  };
3873 
3874  // Set up the bounds.
3875  SmallVector<int32_t> lbGroups, ubGroups;
3876  AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3877  AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3878  result.addAttribute(getLowerBoundsMapAttrStrName(),
3879  AffineMapAttr::get(lbMap));
3880  result.addAttribute(getLowerBoundsGroupsAttrStrName(),
3881  builder.getI32TensorAttr(lbGroups));
3882  result.addAttribute(getUpperBoundsMapAttrStrName(),
3883  AffineMapAttr::get(ubMap));
3884  result.addAttribute(getUpperBoundsGroupsAttrStrName(),
3885  builder.getI32TensorAttr(ubGroups));
3886  result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
3887  result.addOperands(lbArgs);
3888  result.addOperands(ubArgs);
3889 
3890  // Create a region and a block for the body.
3891  auto *bodyRegion = result.addRegion();
3892  auto *body = new Block();
3893  // Add all the block arguments.
3894  for (unsigned i = 0, e = steps.size(); i < e; ++i)
3895  body->addArgument(IndexType::get(builder.getContext()), result.location);
3896  bodyRegion->push_back(body);
3897  if (resultTypes.empty())
3898  ensureTerminator(*bodyRegion, builder, result.location);
3899 }
3900 
3901 Region &AffineParallelOp::getLoopBody() { return getRegion(); }
3902 
3903 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
3904 
3905 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3906  return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3907 }
3908 
3909 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3910  return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3911 }
3912 
3913 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
3914  auto values = getLowerBoundsGroups().getValues<int32_t>();
3915  unsigned start = 0;
3916  for (unsigned i = 0; i < pos; ++i)
3917  start += values[i];
3918  return getLowerBoundsMap().getSliceMap(start, values[pos]);
3919 }
3920 
3921 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
3922  auto values = getUpperBoundsGroups().getValues<int32_t>();
3923  unsigned start = 0;
3924  for (unsigned i = 0; i < pos; ++i)
3925  start += values[i];
3926  return getUpperBoundsMap().getSliceMap(start, values[pos]);
3927 }
3928 
3929 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
3930  return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3931 }
3932 
3933 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
3934  return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3935 }
3936 
3937 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3938  if (hasMinMaxBounds())
3939  return std::nullopt;
3940 
3941  // Try to convert all the ranges to constant expressions.
3943  AffineValueMap rangesValueMap;
3944  AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3945  &rangesValueMap);
3946  out.reserve(rangesValueMap.getNumResults());
3947  for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
3948  auto expr = rangesValueMap.getResult(i);
3949  auto cst = expr.dyn_cast<AffineConstantExpr>();
3950  if (!cst)
3951  return std::nullopt;
3952  out.push_back(cst.getValue());
3953  }
3954  return out;
3955 }
3956 
3957 Block *AffineParallelOp::getBody() { return &getRegion().front(); }
3958 
3959 OpBuilder AffineParallelOp::getBodyBuilder() {
3960  return OpBuilder(getBody(), std::prev(getBody()->end()));
3961 }
3962 
3963 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
3964  assert(lbOperands.size() == map.getNumInputs() &&
3965  "operands to map must match number of inputs");
3966 
3967  auto ubOperands = getUpperBoundsOperands();
3968 
3969  SmallVector<Value, 4> newOperands(lbOperands);
3970  newOperands.append(ubOperands.begin(), ubOperands.end());
3971  (*this)->setOperands(newOperands);
3972 
3973  setLowerBoundsMapAttr(AffineMapAttr::get(map));
3974 }
3975 
3976 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
3977  assert(ubOperands.size() == map.getNumInputs() &&
3978  "operands to map must match number of inputs");
3979 
3980  SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
3981  newOperands.append(ubOperands.begin(), ubOperands.end());
3982  (*this)->setOperands(newOperands);
3983 
3984  setUpperBoundsMapAttr(AffineMapAttr::get(map));
3985 }
3986 
3987 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
3988  setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3989 }
3990 
3992  auto numDims = getNumDims();
3993  if (getLowerBoundsGroups().getNumElements() != numDims ||
3994  getUpperBoundsGroups().getNumElements() != numDims ||
3995  getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3996  return emitOpError() << "the number of region arguments ("
3997  << getBody()->getNumArguments()
3998  << ") and the number of map groups for lower ("
3999  << getLowerBoundsGroups().getNumElements()
4000  << ") and upper bound ("
4001  << getUpperBoundsGroups().getNumElements()
4002  << "), and the number of steps (" << getSteps().size()
4003  << ") must all match";
4004  }
4005 
4006  unsigned expectedNumLBResults = 0;
4007  for (APInt v : getLowerBoundsGroups())
4008  expectedNumLBResults += v.getZExtValue();
4009  if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
4010  return emitOpError() << "expected lower bounds map to have "
4011  << expectedNumLBResults << " results";
4012  unsigned expectedNumUBResults = 0;
4013  for (APInt v : getUpperBoundsGroups())
4014  expectedNumUBResults += v.getZExtValue();
4015  if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
4016  return emitOpError() << "expected upper bounds map to have "
4017  << expectedNumUBResults << " results";
4018 
4019  if (getReductions().size() != getNumResults())
4020  return emitOpError("a reduction must be specified for each output");
4021 
4022  // Verify reduction ops are all valid
4023  for (Attribute attr : getReductions()) {
4024  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
4025  if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
4026  return emitOpError("invalid reduction attribute");
4027  }
4028 
4029  // Verify that the bound operands are valid dimension/symbols.
4030  /// Lower bounds.
4031  if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
4032  getLowerBoundsMap().getNumDims())))
4033  return failure();
4034  /// Upper bounds.
4035  if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
4036  getUpperBoundsMap().getNumDims())))
4037  return failure();
4038  return success();
4039 }
4040 
4041 LogicalResult AffineValueMap::canonicalize() {
4042  SmallVector<Value, 4> newOperands{operands};
4043  auto newMap = getAffineMap();
4044  composeAffineMapAndOperands(&newMap, &newOperands);
4045  if (newMap == getAffineMap() && newOperands == operands)
4046  return failure();
4047  reset(newMap, newOperands);
4048  return success();
4049 }
4050 
4051 /// Canonicalize the bounds of the given loop.
4052 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
4053  AffineValueMap lb = op.getLowerBoundsValueMap();
4054  bool lbCanonicalized = succeeded(lb.canonicalize());
4055 
4056  AffineValueMap ub = op.getUpperBoundsValueMap();
4057  bool ubCanonicalized = succeeded(ub.canonicalize());
4058 
4059  // Any canonicalization change always leads to updated map(s).
4060  if (!lbCanonicalized && !ubCanonicalized)
4061  return failure();
4062 
4063  if (lbCanonicalized)
4064  op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
4065  if (ubCanonicalized)
4066  op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
4067 
4068  return success();
4069 }
4070 
4071 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4072  SmallVectorImpl<OpFoldResult> &results) {
4073  return canonicalizeLoopBounds(*this);
4074 }
4075 
4076 /// Prints a lower(upper) bound of an affine parallel loop with max(min)
4077 /// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
4078 /// identifies which of the those expressions form max/min groups. `operands`
4079 /// are the SSA values of dimensions and symbols and `keyword` is either "min"
4080 /// or "max".
4081 static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
4082  DenseIntElementsAttr group, ValueRange operands,
4083  StringRef keyword) {
4084  AffineMap map = mapAttr.getValue();
4085  unsigned numDims = map.getNumDims();
4086  ValueRange dimOperands = operands.take_front(numDims);
4087  ValueRange symOperands = operands.drop_front(numDims);
4088  unsigned start = 0;
4089  for (llvm::APInt groupSize : group) {
4090  if (start != 0)
4091  p << ", ";
4092 
4093  unsigned size = groupSize.getZExtValue();
4094  if (size == 1) {
4095  p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
4096  ++start;
4097  } else {
4098  p << keyword << '(';
4099  AffineMap submap = map.getSliceMap(start, size);
4100  p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
4101  p << ')';
4102  start += size;
4103  }
4104  }
4105 }
4106 
4108  p << " (" << getBody()->getArguments() << ") = (";
4109  printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
4110  getLowerBoundsOperands(), "max");
4111  p << ") to (";
4112  printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
4113  getUpperBoundsOperands(), "min");
4114  p << ')';
4115  SmallVector<int64_t, 8> steps = getSteps();
4116  bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4117  if (!elideSteps) {
4118  p << " step (";
4119  llvm::interleaveComma(steps, p);
4120  p << ')';
4121  }
4122  if (getNumResults()) {
4123  p << " reduce (";
4124  llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4125  arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4126  llvm::cast<IntegerAttr>(attr).getInt());
4127  p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4128  });
4129  p << ") -> (" << getResultTypes() << ")";
4130  }
4131 
4132  p << ' ';
4133  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4134  /*printBlockTerminators=*/getNumResults());
4136  (*this)->getAttrs(),
4137  /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4138  AffineParallelOp::getLowerBoundsMapAttrStrName(),
4139  AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4140  AffineParallelOp::getUpperBoundsMapAttrStrName(),
4141  AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4142  AffineParallelOp::getStepsAttrStrName()});
4143 }
4144 
4145 /// Given a list of lists of parsed operands, populates `uniqueOperands` with
4146 /// unique operands. Also populates `replacements with affine expressions of
4147 /// `kind` that can be used to update affine maps previously accepting a
4148 /// `operands` to accept `uniqueOperands` instead.
4150  OpAsmParser &parser,
4152  SmallVectorImpl<Value> &uniqueOperands,
4153  SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
4154  assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
4155  "expected operands to be dim or symbol expression");
4156 
4157  Type indexType = parser.getBuilder().getIndexType();
4158  for (const auto &list : operands) {
4159  SmallVector<Value> valueOperands;
4160  if (parser.resolveOperands(list, indexType, valueOperands))
4161  return failure();
4162  for (Value operand : valueOperands) {
4163  unsigned pos = std::distance(uniqueOperands.begin(),
4164  llvm::find(uniqueOperands, operand));
4165  if (pos == uniqueOperands.size())
4166  uniqueOperands.push_back(operand);
4167  replacements.push_back(
4168  kind == AffineExprKind::DimId
4169  ? getAffineDimExpr(pos, parser.getContext())
4170  : getAffineSymbolExpr(pos, parser.getContext()));
4171  }
4172  }
4173  return success();
4174 }
4175 
4176 namespace {
4177 enum class MinMaxKind { Min, Max };
4178 } // namespace
4179 
4180 /// Parses an affine map that can contain a min/max for groups of its results,
4181 /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4182 /// `result` attributes with the map (flat list of expressions) and the grouping
4183 /// (list of integers that specify how many expressions to put into each
4184 /// min/max) attributes. Deduplicates repeated operands.
4185 ///
4186 /// parallel-bound ::= `(` parallel-group-list `)`
4187 /// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
4188 /// parallel-group ::= simple-group | min-max-group
4189 /// simple-group ::= expr-of-ssa-ids
4190 /// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4191 /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4192 ///
4193 /// Examples:
4194 /// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4195 /// (%0, max(%1 - 2 * %2))
4197  OperationState &result,
4198  MinMaxKind kind) {
4199  // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4200  // see: https://reviews.llvm.org/D134227#3821753
4201  const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4202 
4203  StringRef mapName = kind == MinMaxKind::Min
4204  ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4205  : AffineParallelOp::getLowerBoundsMapAttrStrName();
4206  StringRef groupsName =
4207  kind == MinMaxKind::Min
4208  ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4209  : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4210 
4211  if (failed(parser.parseLParen()))
4212  return failure();
4213 
4214  if (succeeded(parser.parseOptionalRParen())) {
4215  result.addAttribute(
4216  mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4217  result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
4218  return success();
4219  }
4220 
4221  SmallVector<AffineExpr> flatExprs;
4224  SmallVector<int32_t> numMapsPerGroup;
4226  auto parseOperands = [&]() {
4227  if (succeeded(parser.parseOptionalKeyword(
4228  kind == MinMaxKind::Min ? "min" : "max"))) {
4229  mapOperands.clear();
4230  AffineMapAttr map;
4231  if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName,
4232  result.attributes,
4234  return failure();
4235  result.attributes.erase(tmpAttrStrName);
4236  llvm::append_range(flatExprs, map.getValue().getResults());
4237  auto operandsRef = llvm::ArrayRef(mapOperands);
4238  auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4239  SmallVector<OpAsmParser::UnresolvedOperand> dims(dimsRef.begin(),
4240  dimsRef.end());
4241  auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4242  SmallVector<OpAsmParser::UnresolvedOperand> syms(symsRef.begin(),
4243  symsRef.end());
4244  flatDimOperands.append(map.getValue().getNumResults(), dims);
4245  flatSymOperands.append(map.getValue().getNumResults(), syms);
4246  numMapsPerGroup.push_back(map.getValue().getNumResults());
4247  } else {
4248  if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
4249  flatSymOperands.emplace_back(),
4250  flatExprs.emplace_back())))
4251  return failure();
4252  numMapsPerGroup.push_back(1);
4253  }
4254  return success();
4255  };
4256  if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
4257  return failure();
4258 
4259  unsigned totalNumDims = 0;
4260  unsigned totalNumSyms = 0;
4261  for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4262  unsigned numDims = flatDimOperands[i].size();
4263  unsigned numSyms = flatSymOperands[i].size();
4264  flatExprs[i] = flatExprs[i]
4265  .shiftDims(numDims, totalNumDims)
4266  .shiftSymbols(numSyms, totalNumSyms);
4267  totalNumDims += numDims;
4268  totalNumSyms += numSyms;
4269  }
4270 
4271  // Deduplicate map operands.
4272  SmallVector<Value> dimOperands, symOperands;
4273  SmallVector<AffineExpr> dimRplacements, symRepacements;
4274  if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
4275  dimRplacements, AffineExprKind::DimId) ||
4276  deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
4277  symRepacements, AffineExprKind::SymbolId))
4278  return failure();
4279 
4280  result.operands.append(dimOperands.begin(), dimOperands.end());
4281  result.operands.append(symOperands.begin(), symOperands.end());
4282 
4283  Builder &builder = parser.getBuilder();
4284  auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4285  parser.getContext());
4286  flatMap = flatMap.replaceDimsAndSymbols(
4287  dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4288 
4289  result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4290  result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
4291  return success();
4292 }
4293 
4294 //
4295 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4296 // `to` parallel-bound steps? region attr-dict?
4297 // steps ::= `steps` `(` integer-literals `)`
4298 //
4299 ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4300  OperationState &result) {
4301  auto &builder = parser.getBuilder();
4302  auto indexType = builder.getIndexType();
4305  parser.parseEqual() ||
4306  parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
4307  parser.parseKeyword("to") ||
4308  parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
4309  return failure();
4310 
4311  AffineMapAttr stepsMapAttr;
4312  NamedAttrList stepsAttrs;
4314  if (failed(parser.parseOptionalKeyword("step"))) {
4315  SmallVector<int64_t, 4> steps(ivs.size(), 1);
4316  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4317  builder.getI64ArrayAttr(steps));
4318  } else {
4319  if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
4320  AffineParallelOp::getStepsAttrStrName(),
4321  stepsAttrs,
4323  return failure();
4324 
4325  // Convert steps from an AffineMap into an I64ArrayAttr.
4327  auto stepsMap = stepsMapAttr.getValue();
4328  for (const auto &result : stepsMap.getResults()) {
4329  auto constExpr = result.dyn_cast<AffineConstantExpr>();
4330  if (!constExpr)
4331  return parser.emitError(parser.getNameLoc(),
4332  "steps must be constant integers");
4333  steps.push_back(constExpr.getValue());
4334  }
4335  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4336  builder.getI64ArrayAttr(steps));
4337  }
4338 
4339  // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4340  // quoted strings are a member of the enum AtomicRMWKind.
4341  SmallVector<Attribute, 4> reductions;
4342  if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4343  if (parser.parseLParen())
4344  return failure();
4345  auto parseAttributes = [&]() -> ParseResult {
4346  // Parse a single quoted string via the attribute parsing, and then
4347  // verify it is a member of the enum and convert to it's integer
4348  // representation.
4349  StringAttr attrVal;
4350  NamedAttrList attrStorage;
4351  auto loc = parser.getCurrentLocation();
4352  if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
4353  attrStorage))
4354  return failure();
4355  std::optional<arith::AtomicRMWKind> reduction =
4356  arith::symbolizeAtomicRMWKind(attrVal.getValue());
4357  if (!reduction)
4358  return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4359  reductions.push_back(
4360  builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4361  // While we keep getting commas, keep parsing.
4362  return success();
4363  };
4364  if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
4365  return failure();
4366  }
4367  result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4368  builder.getArrayAttr(reductions));
4369 
4370  // Parse return types of reductions (if any)
4371  if (parser.parseOptionalArrowTypeList(result.types))
4372  return failure();
4373 
4374  // Now parse the body.
4375  Region *body = result.addRegion();
4376  for (auto &iv : ivs)
4377  iv.type = indexType;
4378  if (parser.parseRegion(*body, ivs) ||
4379  parser.parseOptionalAttrDict(result.attributes))
4380  return failure();
4381 
4382  // Add a terminator if none was parsed.
4383  AffineParallelOp::ensureTerminator(*body, builder, result.location);
4384  return success();
4385 }
4386 
4387 //===----------------------------------------------------------------------===//
4388 // AffineYieldOp
4389 //===----------------------------------------------------------------------===//
4390 
4392  auto *parentOp = (*this)->getParentOp();
4393  auto results = parentOp->getResults();
4394  auto operands = getOperands();
4395 
4396  if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4397  return emitOpError() << "only terminates affine.if/for/parallel regions";
4398  if (parentOp->getNumResults() != getNumOperands())
4399  return emitOpError() << "parent of yield must have same number of "
4400  "results as the yield operands";
4401  for (auto it : llvm::zip(results, operands)) {
4402  if (std::get<0>(it).getType() != std::get<1>(it).getType())
4403  return emitOpError() << "types mismatch between yield op and its parent";
4404  }
4405 
4406  return success();
4407 }
4408 
4409 //===----------------------------------------------------------------------===//
4410 // AffineVectorLoadOp
4411 //===----------------------------------------------------------------------===//
4412 
4413 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4414  VectorType resultType, AffineMap map,
4415  ValueRange operands) {
4416  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4417  result.addOperands(operands);
4418  if (map)
4419  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4420  result.types.push_back(resultType);
4421 }
4422 
4423 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4424  VectorType resultType, Value memref,
4425  AffineMap map, ValueRange mapOperands) {
4426  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4427  result.addOperands(memref);
4428  result.addOperands(mapOperands);
4429  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4430  result.types.push_back(resultType);
4431 }
4432 
4433 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4434  VectorType resultType, Value memref,
4435  ValueRange indices) {
4436  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4437  int64_t rank = memrefType.getRank();
4438  // Create identity map for memrefs with at least one dimension or () -> ()
4439  // for zero-dimensional memrefs.
4440  auto map =
4441  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4442  build(builder, result, resultType, memref, map, indices);
4443 }
4444 
4445 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4446  MLIRContext *context) {
4447  results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4448 }
4449 
4450 ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4451  OperationState &result) {
4452  auto &builder = parser.getBuilder();
4453  auto indexTy = builder.getIndexType();
4454 
4455  MemRefType memrefType;
4456  VectorType resultType;
4457  OpAsmParser::UnresolvedOperand memrefInfo;
4458  AffineMapAttr mapAttr;
4460  return failure(
4461  parser.parseOperand(memrefInfo) ||
4462  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4463  AffineVectorLoadOp::getMapAttrStrName(),
4464  result.attributes) ||
4465  parser.parseOptionalAttrDict(result.attributes) ||
4466  parser.parseColonType(memrefType) || parser.parseComma() ||
4467  parser.parseType(resultType) ||
4468  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4469  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
4470  parser.addTypeToList(resultType, result.types));
4471 }
4472 
4474  p << " " << getMemRef() << '[';
4475  if (AffineMapAttr mapAttr =
4476  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4477  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4478  p << ']';
4479  p.printOptionalAttrDict((*this)->getAttrs(),
4480  /*elidedAttrs=*/{getMapAttrStrName()});
4481  p << " : " << getMemRefType() << ", " << getType();
4482 }
4483 
4484 /// Verify common invariants of affine.vector_load and affine.vector_store.
4485 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4486  VectorType vectorType) {
4487  // Check that memref and vector element types match.
4488  if (memrefType.getElementType() != vectorType.getElementType())
4489  return op->emitOpError(
4490  "requires memref and vector types of the same elemental type");
4491  return success();
4492 }
4493 
4495  MemRefType memrefType = getMemRefType();
4497  getOperation(),
4498  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4499  getMapOperands(), memrefType,
4500  /*numIndexOperands=*/getNumOperands() - 1)))
4501  return failure();
4502 
4503  if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
4504  return failure();
4505 
4506  return success();
4507 }
4508 
4509 //===----------------------------------------------------------------------===//
4510 // AffineVectorStoreOp
4511 //===----------------------------------------------------------------------===//
4512 
4513 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4514  Value valueToStore, Value memref, AffineMap map,
4515  ValueRange mapOperands) {
4516  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4517  result.addOperands(valueToStore);
4518  result.addOperands(memref);
4519  result.addOperands(mapOperands);
4520  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4521 }
4522 
4523 // Use identity map.
4524 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4525  Value valueToStore, Value memref,
4526  ValueRange indices) {
4527  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4528  int64_t rank = memrefType.getRank();
4529  // Create identity map for memrefs with at least one dimension or () -> ()
4530  // for zero-dimensional memrefs.
4531  auto map =
4532  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4533  build(builder, result, valueToStore, memref, map, indices);
4534 }
4535 void AffineVectorStoreOp::getCanonicalizationPatterns(
4536  RewritePatternSet &results, MLIRContext *context) {
4537  results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4538 }
4539 
4540 ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4541  OperationState &result) {
4542  auto indexTy = parser.getBuilder().getIndexType();
4543 
4544  MemRefType memrefType;
4545  VectorType resultType;
4546  OpAsmParser::UnresolvedOperand storeValueInfo;
4547  OpAsmParser::UnresolvedOperand memrefInfo;
4548  AffineMapAttr mapAttr;
4550  return failure(
4551  parser.parseOperand(storeValueInfo) || parser.parseComma() ||
4552  parser.parseOperand(memrefInfo) ||
4553  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4554  AffineVectorStoreOp::getMapAttrStrName(),
4555  result.attributes) ||
4556  parser.parseOptionalAttrDict(result.attributes) ||
4557  parser.parseColonType(memrefType) || parser.parseComma() ||
4558  parser.parseType(resultType) ||
4559  parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
4560  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4561  parser.resolveOperands(mapOperands, indexTy, result.operands));
4562 }
4563 
4565  p << " " << getValueToStore();
4566  p << ", " << getMemRef() << '[';
4567  if (AffineMapAttr mapAttr =
4568  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4569  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4570  p << ']';
4571  p.printOptionalAttrDict((*this)->getAttrs(),
4572  /*elidedAttrs=*/{getMapAttrStrName()});
4573  p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4574 }
4575 
4577  MemRefType memrefType = getMemRefType();
4579  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4580  getMapOperands(), memrefType,
4581  /*numIndexOperands=*/getNumOperands() - 2)))
4582  return failure();
4583 
4584  if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4585  return failure();
4586 
4587  return success();
4588 }
4589 
4590 //===----------------------------------------------------------------------===//
4591 // DelinearizeIndexOp
4592 //===----------------------------------------------------------------------===//
4593 
4594 void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result,
4595  Value linearIndex,
4596  ArrayRef<OpFoldResult> basis) {
4597  result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType()));
4598  result.addOperands(linearIndex);
4599  SmallVector<Value> basisValues =
4600  llvm::map_to_vector(basis, [&](OpFoldResult ofr) -> Value {
4601  std::optional<int64_t> staticDim = getConstantIntValue(ofr);
4602  if (staticDim.has_value())
4603  return builder.create<arith::ConstantIndexOp>(result.location,
4604  *staticDim);
4605  return llvm::dyn_cast_if_present<Value>(ofr);
4606  });
4607  result.addOperands(basisValues);
4608 }
4609 
4611  if (getBasis().empty())
4612  return emitOpError("basis should not be empty");
4613  if (getNumResults() != getBasis().size())
4614  return emitOpError("should return an index for each basis element");
4615  return success();
4616 }
4617 
4618 //===----------------------------------------------------------------------===//
4619 // TableGen'd op method definitions
4620 //===----------------------------------------------------------------------===//
4621 
4622 #define GET_OP_CLASSES
4623 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
Definition: AffineOps.cpp:2747
static void materializeConstants(OpBuilder &b, Location loc, ArrayRef< OpFoldResult > values, SmallVectorImpl< Operation * > &constants, SmallVectorImpl< Value > &actualValues)
Given a list of OpFoldResult, build the necessary operations to populate actualValues with values pro...
Definition: AffineOps.cpp:1203
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
Definition: AffineOps.cpp:2476
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
Definition: AffineOps.cpp:1305
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
Definition: AffineOps.cpp:3363
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
Definition: AffineOps.cpp:52
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
Definition: AffineOps.cpp:4081
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
Definition: AffineOps.cpp:1054
static std::optional< int64_t > getBoundForExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Definition: AffineOps.cpp:712
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
Definition: AffineOps.cpp:3399
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
Definition: AffineOps.cpp:2323
static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ValueRange operandsRef)
Fully compose map with operands and canonicalize the result.
Definition: AffineOps.cpp:1413
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
Definition: AffineOps.cpp:840
static bool isValidAffineIndexOperand(Value value, Region *region)
Definition: AffineOps.cpp:437
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1484
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
Definition: AffineOps.cpp:1125
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:688
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
Definition: AffineOps.cpp:2032
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:680
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
Definition: AffineOps.cpp:3072
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
Definition: AffineOps.cpp:1077
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1441
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
Definition: AffineOps.cpp:4485
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
Definition: AffineOps.cpp:943
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
Definition: AffineOps.cpp:3376
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
Definition: AffineOps.cpp:317
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
Definition: AffineOps.cpp:628
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
Definition: AffineOps.cpp:4196
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
Definition: AffineOps.cpp:2756
static std::enable_if_t< OpTy::template hasTrait< OpTrait::OneResult >), OpFoldResult > createOrFold(OpBuilder &b, Location loc, ValueRange operands, Args &&...leadingArguments)
Create an operation of the type provided as template argument and attempt to fold it immediately.
Definition: AffineOps.cpp:1241
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
Definition: AffineOps.cpp:442
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
Definition: AffineOps.cpp:590
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
Definition: AffineOps.cpp:2706
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
Definition: AffineOps.cpp:2277
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
Definition: AffineOps.cpp:474
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1380
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
Definition: AffineOps.cpp:330
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr &quotientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
Definition: AffineOps.cpp:656
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
Definition: AffineOps.cpp:4149
static LogicalResult verifyAffineMinMaxOp(T op)
Definition: AffineOps.cpp:3353
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
Definition: AffineOps.cpp:2194
static LogicalResult verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
Definition: AffineOps.cpp:3179
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
Definition: AffineOps.cpp:3562
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:76
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:81
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1296
Affine binary operation expression.
Definition: AffineExpr.h:207
An integer constant appearing in affine expression.
Definition: AffineExpr.h:232
int64_t getValue() const
Definition: AffineExpr.cpp:519
A dimensional identifier appearing in an affine expression.
Definition: AffineExpr.h:216
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:778
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:27
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
Definition: AffineExpr.cpp:220
MLIRContext * getContext() const
Definition: AffineExpr.cpp:25
U dyn_cast() const
Definition: AffineExpr.h:281
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:44
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:565
MLIRContext * getContext() const
Definition: AffineMap.cpp:271
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
Definition: AffineMap.h:181
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:227
unsigned getNumSymbols() const
Definition: AffineMap.cpp:328
unsigned getNumDims() const
Definition: AffineMap.cpp:324
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:337
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
Definition: AffineMap.h:188
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
Definition: AffineMap.cpp:365
unsigned getNumResults() const
Definition: AffineMap.cpp:332
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineMap.cpp:426
unsigned getNumInputs() const
Definition: AffineMap.cpp:333
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
Definition: AffineMap.h:240
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:242
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:341
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:441
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
Definition: AffineMap.cpp:102
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:557
A symbolic identifier appearing in an affine expression.
Definition: AffineExpr.h:224
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:67
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:232
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:141
BlockArgListType getArguments()
Definition: Block.h:76
Operation & front()
Definition: Block.h:142
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:122
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:225
AffineMap getDimIdentityMap()
Definition: Builders.cpp:359
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:363
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
Definition: Builders.cpp:182
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:126
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:85
NoneType getNoneType()
Definition: Builders.cpp:102
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:114
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
Definition: Builders.cpp:352
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:354
MLIRContext * getContext() const
Definition: Builders.h:55
AffineMap getSymbolIdentityMap()
Definition: Builders.cpp:372
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:260
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:275
IndexType getIndexType()
Definition: Builders.cpp:69
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:43
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:45
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition: IntegerSet.h:44
unsigned getNumDims() const
Definition: IntegerSet.cpp:15
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
Definition: IntegerSet.cpp:57
unsigned getNumInputs() const
Definition: IntegerSet.cpp:17
ArrayRef< AffineExpr > getConstraints() const
Definition: IntegerSet.cpp:41
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
Definition: IntegerSet.cpp:51
unsigned getNumSymbols() const
Definition: IntegerSet.cpp:16
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:329
This class helps build Operations.
Definition: Builders.h:202
void setListener(Listener *newListener)
Sets the listener of this builder to the one provided.
Definition: Builders.h:297
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:412
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:379
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:301
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:501
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:433
This class represents a single result from folding an operation.
Definition: OpDefinition.h:265
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:400
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
LogicalResult fold(ArrayRef< Attribute > operands, SmallVectorImpl< OpFoldResult > &results)
Attempt to fold this operation with the specified constant operand values.
Definition: Operation.cpp:611
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:690
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:495
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:469
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:543
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
operand_range::iterator operand_iterator
Definition: Operation.h:367
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:632
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:539
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:700
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:566
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:514
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:55
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
AffineBound represents a lower or upper bound in the for operation.
Definition: AffineOps.h:514
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
Definition: AffineOps.h:92
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
Definition: AffineOps.h:289
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
Definition: AffineOps.cpp:4041
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
unsigned getNumResults() const
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
Definition: AffineOps.cpp:2769
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
Definition: AffineOps.cpp:1187
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
Definition: AffineOps.cpp:2683
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
Definition: AffineOps.cpp:266
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1362
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
Definition: AffineOps.cpp:2655
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ValueRange operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1285
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: AffineOps.cpp:2659
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1560
Value makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ValueRange operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
Definition: AffineOps.cpp:1372
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1405
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
Definition: AffineOps.cpp:2647
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1398
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Definition: AffineOps.cpp:236
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
Definition: AffineOps.cpp:1565
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
Definition: AffineOps.cpp:2690
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
Definition: AffineOps.cpp:363
AffineForOp replaceForOpWithNewYields(OpBuilder &b, AffineForOp loop, ValueRange newIterOperands, ValueRange newYieldedValues, ValueRange newIterArgs, bool replaceLoopResults=true)
Replace loop with a new loop where newIterOperands are appended with new initialization values and ne...
Definition: AffineOps.cpp:2785
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1334
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Definition: AffineOps.cpp:2670
SmallVector< Value, 4 > applyMapToValues(OpBuilder &b, Location loc, AffineMap map, ValueRange values)
Returns the values obtained by applying map to the list of values.
Definition: AffineOps.cpp:1422
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
Definition: AffineOps.cpp:251
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
Definition: AffineOps.cpp:452
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
Definition: AffineOps.cpp:2651
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:262
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:88
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt gcd(const MPInt &a, const MPInt &b)
Definition: MPInt.h:399
This header declares functions that assit transformations in the MemRef dialect.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:648
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:378
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
Definition: AffineMap.cpp:658
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's floordiv operation on constants.
Definition: MathExtras.h:33
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
Definition: MathExtras.h:23
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
AffineExprKind
Definition: AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:47
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:527
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isStrided(MemRefType t)
Return true if the layout for t is compatible with strided semantics.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:287
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:502
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:374
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:512
int64_t mod(int64_t lhs, int64_t rhs)
Returns MLIR's mod operation on constants.
Definition: MathExtras.h:45
Canonicalize the affine map result expression order of an affine min/max operation.
Definition: AffineOps.cpp:3614
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3617
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3631
Remove duplicated expressions in affine min/max ops.
Definition: AffineOps.cpp:3432
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3435
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
Definition: AffineOps.cpp:3475
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3478
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:279
virtual void notifyOperationInserted(Operation *op)
Notification handler for when an operation is inserted into the builder.
Definition: Builders.h:286
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.