MLIR  20.0.0git
AffineOps.cpp
Go to the documentation of this file.
1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/ScopeExit.h"
23 #include "llvm/ADT/SmallBitVector.h"
24 #include "llvm/ADT/SmallVectorExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/MathExtras.h"
28 #include <numeric>
29 #include <optional>
30 
31 using namespace mlir;
32 using namespace mlir::affine;
33 
34 using llvm::divideCeilSigned;
35 using llvm::divideFloorSigned;
36 using llvm::mod;
37 
38 #define DEBUG_TYPE "affine-ops"
39 
40 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
41 
42 /// A utility function to check if a value is defined at the top level of
43 /// `region` or is an argument of `region`. A value of index type defined at the
44 /// top level of a `AffineScope` region is always a valid symbol for all
45 /// uses in that region.
47  if (auto arg = llvm::dyn_cast<BlockArgument>(value))
48  return arg.getParentRegion() == region;
49  return value.getDefiningOp()->getParentRegion() == region;
50 }
51 
52 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
53 /// region remains legal if the operation that uses it is inlined into `dest`
54 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
55 /// `isValidSymbol`, depending on the value being required to remain a valid
56 /// dimension or symbol.
57 static bool
59  const IRMapping &mapping,
60  function_ref<bool(Value, Region *)> legalityCheck) {
61  // If the value is a valid dimension for any other reason than being
62  // a top-level value, it will remain valid: constants get inlined
63  // with the function, transitive affine applies also get inlined and
64  // will be checked themselves, etc.
65  if (!isTopLevelValue(value, src))
66  return true;
67 
68  // If it's a top-level value because it's a block operand, i.e. a
69  // function argument, check whether the value replacing it after
70  // inlining is a valid dimension in the new region.
71  if (llvm::isa<BlockArgument>(value))
72  return legalityCheck(mapping.lookup(value), dest);
73 
74  // If it's a top-level value because it's defined in the region,
75  // it can only be inlined if the defining op is a constant or a
76  // `dim`, which can appear anywhere and be valid, since the defining
77  // op won't be top-level anymore after inlining.
78  Attribute operandCst;
79  bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp());
80  return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
81  isDimLikeOp;
82 }
83 
84 /// Checks if all values known to be legal affine dimensions or symbols in `src`
85 /// remain so if their respective users are inlined into `dest`.
86 static bool
88  const IRMapping &mapping,
89  function_ref<bool(Value, Region *)> legalityCheck) {
90  return llvm::all_of(values, [&](Value v) {
91  return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
92  });
93 }
94 
95 /// Checks if an affine read or write operation remains legal after inlining
96 /// from `src` to `dest`.
97 template <typename OpTy>
98 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
99  const IRMapping &mapping) {
100  static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
101  AffineWriteOpInterface>::value,
102  "only ops with affine read/write interface are supported");
103 
104  AffineMap map = op.getAffineMap();
105  ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
106  ValueRange symbolOperands =
107  op.getMapOperands().take_back(map.getNumSymbols());
109  dimOperands, src, dest, mapping,
110  static_cast<bool (*)(Value, Region *)>(isValidDim)))
111  return false;
113  symbolOperands, src, dest, mapping,
114  static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
115  return false;
116  return true;
117 }
118 
119 /// Checks if an affine apply operation remains legal after inlining from `src`
120 /// to `dest`.
121 // Use "unused attribute" marker to silence clang-tidy warning stemming from
122 // the inability to see through "llvm::TypeSwitch".
123 template <>
124 bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op,
125  Region *src, Region *dest,
126  const IRMapping &mapping) {
127  // If it's a valid dimension, we need to check that it remains so.
128  if (isValidDim(op.getResult(), src))
130  op.getMapOperands(), src, dest, mapping,
131  static_cast<bool (*)(Value, Region *)>(isValidDim));
132 
133  // Otherwise it must be a valid symbol, check that it remains so.
135  op.getMapOperands(), src, dest, mapping,
136  static_cast<bool (*)(Value, Region *)>(isValidSymbol));
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // AffineDialect Interfaces
141 //===----------------------------------------------------------------------===//
142 
143 namespace {
144 /// This class defines the interface for handling inlining with affine
145 /// operations.
146 struct AffineInlinerInterface : public DialectInlinerInterface {
148 
149  //===--------------------------------------------------------------------===//
150  // Analysis Hooks
151  //===--------------------------------------------------------------------===//
152 
153  /// Returns true if the given region 'src' can be inlined into the region
154  /// 'dest' that is attached to an operation registered to the current dialect.
155  /// 'wouldBeCloned' is set if the region is cloned into its new location
156  /// rather than moved, indicating there may be other users.
157  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
158  IRMapping &valueMapping) const final {
159  // We can inline into affine loops and conditionals if this doesn't break
160  // affine value categorization rules.
161  Operation *destOp = dest->getParentOp();
162  if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
163  return false;
164 
165  // Multi-block regions cannot be inlined into affine constructs, all of
166  // which require single-block regions.
167  if (!llvm::hasSingleElement(*src))
168  return false;
169 
170  // Side-effecting operations that the affine dialect cannot understand
171  // should not be inlined.
172  Block &srcBlock = src->front();
173  for (Operation &op : srcBlock) {
174  // Ops with no side effects are fine,
175  if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
176  if (iface.hasNoEffect())
177  continue;
178  }
179 
180  // Assuming the inlined region is valid, we only need to check if the
181  // inlining would change it.
182  bool remainsValid =
184  .Case<AffineApplyOp, AffineReadOpInterface,
185  AffineWriteOpInterface>([&](auto op) {
186  return remainsLegalAfterInline(op, src, dest, valueMapping);
187  })
188  .Default([](Operation *) {
189  // Conservatively disallow inlining ops we cannot reason about.
190  return false;
191  });
192 
193  if (!remainsValid)
194  return false;
195  }
196 
197  return true;
198  }
199 
200  /// Returns true if the given operation 'op', that is registered to this
201  /// dialect, can be inlined into the given region, false otherwise.
202  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
203  IRMapping &valueMapping) const final {
204  // Always allow inlining affine operations into a region that is marked as
205  // affine scope, or into affine loops and conditionals. There are some edge
206  // cases when inlining *into* affine structures, but that is handled in the
207  // other 'isLegalToInline' hook above.
208  Operation *parentOp = region->getParentOp();
209  return parentOp->hasTrait<OpTrait::AffineScope>() ||
210  isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
211  }
212 
213  /// Affine regions should be analyzed recursively.
214  bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
215 };
216 } // namespace
217 
218 //===----------------------------------------------------------------------===//
219 // AffineDialect
220 //===----------------------------------------------------------------------===//
221 
222 void AffineDialect::initialize() {
223  addOperations<AffineDmaStartOp, AffineDmaWaitOp,
224 #define GET_OP_LIST
225 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
226  >();
227  addInterfaces<AffineInlinerInterface>();
228  declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
229  AffineMinOp>();
230 }
231 
232 /// Materialize a single constant operation from a given attribute value with
233 /// the desired resultant type.
235  Attribute value, Type type,
236  Location loc) {
237  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
238  return builder.create<ub::PoisonOp>(loc, type, poison);
239  return arith::ConstantOp::materialize(builder, value, type, loc);
240 }
241 
242 /// A utility function to check if a value is defined at the top level of an
243 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
244 /// conservatively assume it is not top-level. A value of index type defined at
245 /// the top level is always a valid symbol.
247  if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
248  // The block owning the argument may be unlinked, e.g. when the surrounding
249  // region has not yet been attached to an Op, at which point the parent Op
250  // is null.
251  Operation *parentOp = arg.getOwner()->getParentOp();
252  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
253  }
254  // The defining Op may live in an unlinked block so its parent Op may be null.
255  Operation *parentOp = value.getDefiningOp()->getParentOp();
256  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
257 }
258 
259 /// Returns the closest region enclosing `op` that is held by an operation with
260 /// trait `AffineScope`; `nullptr` if there is no such region.
262  auto *curOp = op;
263  while (auto *parentOp = curOp->getParentOp()) {
264  if (parentOp->hasTrait<OpTrait::AffineScope>())
265  return curOp->getParentRegion();
266  curOp = parentOp;
267  }
268  return nullptr;
269 }
270 
271 // A Value can be used as a dimension id iff it meets one of the following
272 // conditions:
273 // *) It is valid as a symbol.
274 // *) It is an induction variable.
275 // *) It is the result of affine apply operation with dimension id arguments.
277  // The value must be an index type.
278  if (!value.getType().isIndex())
279  return false;
280 
281  if (auto *defOp = value.getDefiningOp())
282  return isValidDim(value, getAffineScope(defOp));
283 
284  // This value has to be a block argument for an op that has the
285  // `AffineScope` trait or for an affine.for or affine.parallel.
286  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
287  return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
288  isa<AffineForOp, AffineParallelOp>(parentOp));
289 }
290 
291 // Value can be used as a dimension id iff it meets one of the following
292 // conditions:
293 // *) It is valid as a symbol.
294 // *) It is an induction variable.
295 // *) It is the result of an affine apply operation with dimension id operands.
296 bool mlir::affine::isValidDim(Value value, Region *region) {
297  // The value must be an index type.
298  if (!value.getType().isIndex())
299  return false;
300 
301  // All valid symbols are okay.
302  if (isValidSymbol(value, region))
303  return true;
304 
305  auto *op = value.getDefiningOp();
306  if (!op) {
307  // This value has to be a block argument for an affine.for or an
308  // affine.parallel.
309  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
310  return isa<AffineForOp, AffineParallelOp>(parentOp);
311  }
312 
313  // Affine apply operation is ok if all of its operands are ok.
314  if (auto applyOp = dyn_cast<AffineApplyOp>(op))
315  return applyOp.isValidDim(region);
316  // The dim op is okay if its operand memref/tensor is defined at the top
317  // level.
318  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
319  return isTopLevelValue(dimOp.getShapedValue());
320  return false;
321 }
322 
323 /// Returns true if the 'index' dimension of the `memref` defined by
324 /// `memrefDefOp` is a statically shaped one or defined using a valid symbol
325 /// for `region`.
326 template <typename AnyMemRefDefOp>
327 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
328  Region *region) {
329  MemRefType memRefType = memrefDefOp.getType();
330 
331  // Dimension index is out of bounds.
332  if (index >= memRefType.getRank()) {
333  return false;
334  }
335 
336  // Statically shaped.
337  if (!memRefType.isDynamicDim(index))
338  return true;
339  // Get the position of the dimension among dynamic dimensions;
340  unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
341  return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
342  region);
343 }
344 
345 /// Returns true if the result of the dim op is a valid symbol for `region`.
346 static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
347  // The dim op is okay if its source is defined at the top level.
348  if (isTopLevelValue(dimOp.getShapedValue()))
349  return true;
350 
351  // Conservatively handle remaining BlockArguments as non-valid symbols.
352  // E.g. scf.for iterArgs.
353  if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
354  return false;
355 
356  // The dim op is also okay if its operand memref is a view/subview whose
357  // corresponding size is a valid symbol.
358  std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
359 
360  // Be conservative if we can't understand the dimension.
361  if (!index.has_value())
362  return false;
363 
364  // Skip over all memref.cast ops (if any).
365  Operation *op = dimOp.getShapedValue().getDefiningOp();
366  while (auto castOp = dyn_cast<memref::CastOp>(op)) {
367  // Bail on unranked memrefs.
368  if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
369  return false;
370  op = castOp.getSource().getDefiningOp();
371  if (!op)
372  return false;
373  }
374 
375  int64_t i = index.value();
377  .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
378  [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
379  .Default([](Operation *) { return false; });
380 }
381 
382 // A value can be used as a symbol (at all its use sites) iff it meets one of
383 // the following conditions:
384 // *) It is a constant.
385 // *) Its defining op or block arg appearance is immediately enclosed by an op
386 // with `AffineScope` trait.
387 // *) It is the result of an affine.apply operation with symbol operands.
388 // *) It is a result of the dim op on a memref whose corresponding size is a
389 // valid symbol.
391  if (!value)
392  return false;
393 
394  // The value must be an index type.
395  if (!value.getType().isIndex())
396  return false;
397 
398  // Check that the value is a top level value.
399  if (isTopLevelValue(value))
400  return true;
401 
402  if (auto *defOp = value.getDefiningOp())
403  return isValidSymbol(value, getAffineScope(defOp));
404 
405  return false;
406 }
407 
408 /// A value can be used as a symbol for `region` iff it meets one of the
409 /// following conditions:
410 /// *) It is a constant.
411 /// *) It is the result of an affine apply operation with symbol arguments.
412 /// *) It is a result of the dim op on a memref whose corresponding size is
413 /// a valid symbol.
414 /// *) It is defined at the top level of 'region' or is its argument.
415 /// *) It dominates `region`'s parent op.
416 /// If `region` is null, conservatively assume the symbol definition scope does
417 /// not exist and only accept the values that would be symbols regardless of
418 /// the surrounding region structure, i.e. the first three cases above.
420  // The value must be an index type.
421  if (!value.getType().isIndex())
422  return false;
423 
424  // A top-level value is a valid symbol.
425  if (region && ::isTopLevelValue(value, region))
426  return true;
427 
428  auto *defOp = value.getDefiningOp();
429  if (!defOp) {
430  // A block argument that is not a top-level value is a valid symbol if it
431  // dominates region's parent op.
432  Operation *regionOp = region ? region->getParentOp() : nullptr;
433  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
434  if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
435  return isValidSymbol(value, parentOpRegion);
436  return false;
437  }
438 
439  // Constant operation is ok.
440  Attribute operandCst;
441  if (matchPattern(defOp, m_Constant(&operandCst)))
442  return true;
443 
444  // Affine apply operation is ok if all of its operands are ok.
445  if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
446  return applyOp.isValidSymbol(region);
447 
448  // Dim op results could be valid symbols at any level.
449  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
450  return isDimOpValidSymbol(dimOp, region);
451 
452  // Check for values dominating `region`'s parent op.
453  Operation *regionOp = region ? region->getParentOp() : nullptr;
454  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
455  if (auto *parentRegion = region->getParentOp()->getParentRegion())
456  return isValidSymbol(value, parentRegion);
457 
458  return false;
459 }
460 
461 // Returns true if 'value' is a valid index to an affine operation (e.g.
462 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
463 // `region` provides the polyhedral symbol scope. Returns false otherwise.
464 static bool isValidAffineIndexOperand(Value value, Region *region) {
465  return isValidDim(value, region) || isValidSymbol(value, region);
466 }
467 
468 /// Prints dimension and symbol list.
471  unsigned numDims, OpAsmPrinter &printer) {
472  OperandRange operands(begin, end);
473  printer << '(' << operands.take_front(numDims) << ')';
474  if (operands.size() > numDims)
475  printer << '[' << operands.drop_front(numDims) << ']';
476 }
477 
478 /// Parses dimension and symbol list and returns true if parsing failed.
480  OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
482  if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
483  return failure();
484  // Store number of dimensions for validation by caller.
485  numDims = opInfos.size();
486 
487  // Parse the optional symbol operands.
488  auto indexTy = parser.getBuilder().getIndexType();
489  return failure(parser.parseOperandList(
491  parser.resolveOperands(opInfos, indexTy, operands));
492 }
493 
494 /// Utility function to verify that a set of operands are valid dimension and
495 /// symbol identifiers. The operands should be laid out such that the dimension
496 /// operands are before the symbol operands. This function returns failure if
497 /// there was an invalid operand. An operation is provided to emit any necessary
498 /// errors.
499 template <typename OpTy>
500 static LogicalResult
502  unsigned numDims) {
503  unsigned opIt = 0;
504  for (auto operand : operands) {
505  if (opIt++ < numDims) {
506  if (!isValidDim(operand, getAffineScope(op)))
507  return op.emitOpError("operand cannot be used as a dimension id");
508  } else if (!isValidSymbol(operand, getAffineScope(op))) {
509  return op.emitOpError("operand cannot be used as a symbol");
510  }
511  }
512  return success();
513 }
514 
515 //===----------------------------------------------------------------------===//
516 // AffineApplyOp
517 //===----------------------------------------------------------------------===//
518 
519 AffineValueMap AffineApplyOp::getAffineValueMap() {
520  return AffineValueMap(getAffineMap(), getOperands(), getResult());
521 }
522 
523 ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
524  auto &builder = parser.getBuilder();
525  auto indexTy = builder.getIndexType();
526 
527  AffineMapAttr mapAttr;
528  unsigned numDims;
529  if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
530  parseDimAndSymbolList(parser, result.operands, numDims) ||
531  parser.parseOptionalAttrDict(result.attributes))
532  return failure();
533  auto map = mapAttr.getValue();
534 
535  if (map.getNumDims() != numDims ||
536  numDims + map.getNumSymbols() != result.operands.size()) {
537  return parser.emitError(parser.getNameLoc(),
538  "dimension or symbol index mismatch");
539  }
540 
541  result.types.append(map.getNumResults(), indexTy);
542  return success();
543 }
544 
546  p << " " << getMapAttr();
547  printDimAndSymbolList(operand_begin(), operand_end(),
548  getAffineMap().getNumDims(), p);
549  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
550 }
551 
552 LogicalResult AffineApplyOp::verify() {
553  // Check input and output dimensions match.
554  AffineMap affineMap = getMap();
555 
556  // Verify that operand count matches affine map dimension and symbol count.
557  if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
558  return emitOpError(
559  "operand count and affine map dimension and symbol count must match");
560 
561  // Verify that the map only produces one result.
562  if (affineMap.getNumResults() != 1)
563  return emitOpError("mapping must produce one value");
564 
565  return success();
566 }
567 
568 // The result of the affine apply operation can be used as a dimension id if all
569 // its operands are valid dimension ids.
571  return llvm::all_of(getOperands(),
572  [](Value op) { return affine::isValidDim(op); });
573 }
574 
575 // The result of the affine apply operation can be used as a dimension id if all
576 // its operands are valid dimension ids with the parent operation of `region`
577 // defining the polyhedral scope for symbols.
578 bool AffineApplyOp::isValidDim(Region *region) {
579  return llvm::all_of(getOperands(),
580  [&](Value op) { return ::isValidDim(op, region); });
581 }
582 
583 // The result of the affine apply operation can be used as a symbol if all its
584 // operands are symbols.
586  return llvm::all_of(getOperands(),
587  [](Value op) { return affine::isValidSymbol(op); });
588 }
589 
590 // The result of the affine apply operation can be used as a symbol in `region`
591 // if all its operands are symbols in `region`.
592 bool AffineApplyOp::isValidSymbol(Region *region) {
593  return llvm::all_of(getOperands(), [&](Value operand) {
594  return affine::isValidSymbol(operand, region);
595  });
596 }
597 
598 OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
599  auto map = getAffineMap();
600 
601  // Fold dims and symbols to existing values.
602  auto expr = map.getResult(0);
603  if (auto dim = dyn_cast<AffineDimExpr>(expr))
604  return getOperand(dim.getPosition());
605  if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
606  return getOperand(map.getNumDims() + sym.getPosition());
607 
608  // Otherwise, default to folding the map.
610  bool hasPoison = false;
611  auto foldResult =
612  map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
613  if (hasPoison)
615  if (failed(foldResult))
616  return {};
617  return result[0];
618 }
619 
620 /// Returns the largest known divisor of `e`. Exploits information from the
621 /// values in `operands`.
622 static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
623  // This method isn't aware of `operands`.
624  int64_t div = e.getLargestKnownDivisor();
625 
626  // We now make use of operands for the case `e` is a dim expression.
627  // TODO: More powerful simplification would have to modify
628  // getLargestKnownDivisor to take `operands` and exploit that information as
629  // well for dim/sym expressions, but in that case, getLargestKnownDivisor
630  // can't be part of the IR library but of the `Analysis` library. The IR
631  // library can only really depend on simple O(1) checks.
632  auto dimExpr = dyn_cast<AffineDimExpr>(e);
633  // If it's not a dim expr, `div` is the best we have.
634  if (!dimExpr)
635  return div;
636 
637  // We simply exploit information from loop IVs.
638  // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
639  // desired simplifications are expected to be part of other
640  // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
641  // LoopAnalysis library.
642  Value operand = operands[dimExpr.getPosition()];
643  int64_t operandDivisor = 1;
644  // TODO: With the right accessors, this can be extended to
645  // LoopLikeOpInterface.
646  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
647  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
648  operandDivisor = forOp.getStepAsInt();
649  } else {
650  uint64_t lbLargestKnownDivisor =
651  forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
652  operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
653  }
654  }
655  return operandDivisor;
656 }
657 
658 /// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
659 /// being an affine dim expression or a constant.
661  int64_t k) {
662  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
663  int64_t constVal = constExpr.getValue();
664  return constVal >= 0 && constVal < k;
665  }
666  auto dimExpr = dyn_cast<AffineDimExpr>(e);
667  if (!dimExpr)
668  return false;
669  Value operand = operands[dimExpr.getPosition()];
670  // TODO: With the right accessors, this can be extended to
671  // LoopLikeOpInterface.
672  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
673  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
674  forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
675  return true;
676  }
677  }
678 
679  // We don't consider other cases like `operand` being defined by a constant or
680  // an affine.apply op since such cases will already be handled by other
681  // patterns and propagation of loop IVs or constant would happen.
682  return false;
683 }
684 
685 /// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
686 /// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
687 /// expression is in that form.
688 static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
689  AffineExpr &quotientTimesDiv, AffineExpr &rem) {
690  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
691  if (!bin || bin.getKind() != AffineExprKind::Add)
692  return false;
693 
694  AffineExpr llhs = bin.getLHS();
695  AffineExpr rlhs = bin.getRHS();
696  div = getLargestKnownDivisor(llhs, operands);
697  if (isNonNegativeBoundedBy(rlhs, operands, div)) {
698  quotientTimesDiv = llhs;
699  rem = rlhs;
700  return true;
701  }
702  div = getLargestKnownDivisor(rlhs, operands);
703  if (isNonNegativeBoundedBy(llhs, operands, div)) {
704  quotientTimesDiv = rlhs;
705  rem = llhs;
706  return true;
707  }
708  return false;
709 }
710 
711 /// Gets the constant lower bound on an `iv`.
712 static std::optional<int64_t> getLowerBound(Value iv) {
713  AffineForOp forOp = getForInductionVarOwner(iv);
714  if (forOp && forOp.hasConstantLowerBound())
715  return forOp.getConstantLowerBound();
716  return std::nullopt;
717 }
718 
719 /// Gets the constant upper bound on an affine.for `iv`.
720 static std::optional<int64_t> getUpperBound(Value iv) {
721  AffineForOp forOp = getForInductionVarOwner(iv);
722  if (!forOp || !forOp.hasConstantUpperBound())
723  return std::nullopt;
724 
725  // If its lower bound is also known, we can get a more precise bound
726  // whenever the step is not one.
727  if (forOp.hasConstantLowerBound()) {
728  return forOp.getConstantUpperBound() - 1 -
729  (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
730  forOp.getStepAsInt();
731  }
732  return forOp.getConstantUpperBound() - 1;
733 }
734 
735 /// Determine a constant upper bound for `expr` if one exists while exploiting
736 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
737 /// is guaranteed to be less than or equal to it.
738 static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
739  unsigned numSymbols,
740  ArrayRef<Value> operands) {
741  // Get the constant lower or upper bounds on the operands.
742  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
743  constLowerBounds.reserve(operands.size());
744  constUpperBounds.reserve(operands.size());
745  for (Value operand : operands) {
746  constLowerBounds.push_back(getLowerBound(operand));
747  constUpperBounds.push_back(getUpperBound(operand));
748  }
749 
750  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
751  return constExpr.getValue();
752 
753  return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
754  constUpperBounds,
755  /*isUpper=*/true);
756 }
757 
758 /// Determine a constant lower bound for `expr` if one exists while exploiting
759 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
760 /// is guaranteed to be less than or equal to it.
761 static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
762  unsigned numSymbols,
763  ArrayRef<Value> operands) {
764  // Get the constant lower or upper bounds on the operands.
765  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
766  constLowerBounds.reserve(operands.size());
767  constUpperBounds.reserve(operands.size());
768  for (Value operand : operands) {
769  constLowerBounds.push_back(getLowerBound(operand));
770  constUpperBounds.push_back(getUpperBound(operand));
771  }
772 
773  std::optional<int64_t> lowerBound;
774  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
775  lowerBound = constExpr.getValue();
776  } else {
777  lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
778  constLowerBounds, constUpperBounds,
779  /*isUpper=*/false);
780  }
781  return lowerBound;
782 }
783 
784 /// Simplify `expr` while exploiting information from the values in `operands`.
785 static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
786  unsigned numSymbols,
787  ArrayRef<Value> operands) {
788  // We do this only for certain floordiv/mod expressions.
789  auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
790  if (!binExpr)
791  return;
792 
793  // Simplify the child expressions first.
794  AffineExpr lhs = binExpr.getLHS();
795  AffineExpr rhs = binExpr.getRHS();
796  simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
797  simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
798  expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
799 
800  binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
801  if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
802  expr.getKind() != AffineExprKind::CeilDiv &&
803  expr.getKind() != AffineExprKind::Mod)) {
804  return;
805  }
806 
807  // The `lhs` and `rhs` may be different post construction of simplified expr.
808  lhs = binExpr.getLHS();
809  rhs = binExpr.getRHS();
810  auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
811  if (!rhsConst)
812  return;
813 
814  int64_t rhsConstVal = rhsConst.getValue();
815  // Undefined exprsessions aren't touched; IR can still be valid with them.
816  if (rhsConstVal <= 0)
817  return;
818 
819  // Exploit constant lower/upper bounds to simplify a floordiv or mod.
820  MLIRContext *context = expr.getContext();
821  std::optional<int64_t> lhsLbConst =
822  getLowerBound(lhs, numDims, numSymbols, operands);
823  std::optional<int64_t> lhsUbConst =
824  getUpperBound(lhs, numDims, numSymbols, operands);
825  if (lhsLbConst && lhsUbConst) {
826  int64_t lhsLbConstVal = *lhsLbConst;
827  int64_t lhsUbConstVal = *lhsUbConst;
828  // lhs floordiv c is a single value lhs is bounded in a range `c` that has
829  // the same quotient.
830  if (binExpr.getKind() == AffineExprKind::FloorDiv &&
831  divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
832  divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
833  expr = getAffineConstantExpr(
834  divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
835  return;
836  }
837  // lhs ceildiv c is a single value if the entire range has the same ceil
838  // quotient.
839  if (binExpr.getKind() == AffineExprKind::CeilDiv &&
840  divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
841  divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
842  expr = getAffineConstantExpr(divideCeilSigned(lhsLbConstVal, rhsConstVal),
843  context);
844  return;
845  }
846  // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
847  if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
848  lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
849  expr = lhs;
850  return;
851  }
852  }
853 
854  // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
855  // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
856  // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
857  // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
858  AffineExpr quotientTimesDiv, rem;
859  int64_t divisor;
860  if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
861  if (rhsConstVal % divisor == 0 &&
862  binExpr.getKind() == AffineExprKind::FloorDiv) {
863  expr = quotientTimesDiv.floorDiv(rhsConst);
864  } else if (divisor % rhsConstVal == 0 &&
865  binExpr.getKind() == AffineExprKind::Mod) {
866  expr = rem % rhsConst;
867  }
868  return;
869  }
870 
871  // Handle the simple case when the LHS expression can be either upper
872  // bounded or is a known multiple of RHS constant.
873  // lhs floordiv c -> 0 if 0 <= lhs < c,
874  // lhs mod c -> 0 if lhs % c = 0.
875  if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) &&
876  binExpr.getKind() == AffineExprKind::FloorDiv) ||
877  (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 &&
878  binExpr.getKind() == AffineExprKind::Mod)) {
879  expr = getAffineConstantExpr(0, expr.getContext());
880  }
881 }
882 
883 /// Simplify the expressions in `map` while making use of lower or upper bounds
884 /// of its operands. If `isMax` is true, the map is to be treated as a max of
885 /// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
886 /// d1) can be simplified to (8) if the operands are respectively lower bounded
887 /// by 2 and 0 (the second expression can't be lower than 8).
889  ArrayRef<Value> operands,
890  bool isMax) {
891  // Can't simplify.
892  if (operands.empty())
893  return;
894 
895  // Get the upper or lower bound on an affine.for op IV using its range.
896  // Get the constant lower or upper bounds on the operands.
897  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
898  constLowerBounds.reserve(operands.size());
899  constUpperBounds.reserve(operands.size());
900  for (Value operand : operands) {
901  constLowerBounds.push_back(getLowerBound(operand));
902  constUpperBounds.push_back(getUpperBound(operand));
903  }
904 
905  // We will compute the lower and upper bounds on each of the expressions
906  // Then, we will check (depending on max or min) as to whether a specific
907  // bound is redundant by checking if its highest (in case of max) and its
908  // lowest (in the case of min) value is already lower than (or higher than)
909  // the lower bound (or upper bound in the case of min) of another bound.
910  SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
911  lowerBounds.reserve(map.getNumResults());
912  upperBounds.reserve(map.getNumResults());
913  for (AffineExpr e : map.getResults()) {
914  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
915  lowerBounds.push_back(constExpr.getValue());
916  upperBounds.push_back(constExpr.getValue());
917  } else {
918  lowerBounds.push_back(
920  constLowerBounds, constUpperBounds,
921  /*isUpper=*/false));
922  upperBounds.push_back(
924  constLowerBounds, constUpperBounds,
925  /*isUpper=*/true));
926  }
927  }
928 
929  // Collect expressions that are not redundant.
930  SmallVector<AffineExpr, 4> irredundantExprs;
931  for (auto exprEn : llvm::enumerate(map.getResults())) {
932  AffineExpr e = exprEn.value();
933  unsigned i = exprEn.index();
934  // Some expressions can be turned into constants.
935  if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
936  e = getAffineConstantExpr(*lowerBounds[i], e.getContext());
937 
938  // Check if the expression is redundant.
939  if (isMax) {
940  if (!upperBounds[i]) {
941  irredundantExprs.push_back(e);
942  continue;
943  }
944  // If there exists another expression such that its lower bound is greater
945  // than this expression's upper bound, it's redundant.
946  if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
947  auto otherLowerBound = en.value();
948  unsigned pos = en.index();
949  if (pos == i || !otherLowerBound)
950  return false;
951  if (*otherLowerBound > *upperBounds[i])
952  return true;
953  if (*otherLowerBound < *upperBounds[i])
954  return false;
955  // Equality case. When both expressions are considered redundant, we
956  // don't want to get both of them. We keep the one that appears
957  // first.
958  if (upperBounds[pos] && lowerBounds[i] &&
959  lowerBounds[i] == upperBounds[i] &&
960  otherLowerBound == *upperBounds[pos] && i < pos)
961  return false;
962  return true;
963  }))
964  irredundantExprs.push_back(e);
965  } else {
966  if (!lowerBounds[i]) {
967  irredundantExprs.push_back(e);
968  continue;
969  }
970  // Likewise for the `min` case. Use the complement of the condition above.
971  if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
972  auto otherUpperBound = en.value();
973  unsigned pos = en.index();
974  if (pos == i || !otherUpperBound)
975  return false;
976  if (*otherUpperBound < *lowerBounds[i])
977  return true;
978  if (*otherUpperBound > *lowerBounds[i])
979  return false;
980  if (lowerBounds[pos] && upperBounds[i] &&
981  lowerBounds[i] == upperBounds[i] &&
982  otherUpperBound == lowerBounds[pos] && i < pos)
983  return false;
984  return true;
985  }))
986  irredundantExprs.push_back(e);
987  }
988  }
989 
990  // Create the map without the redundant expressions.
991  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs,
992  map.getContext());
993 }
994 
995 /// Simplify the map while exploiting information on the values in `operands`.
996 // Use "unused attribute" marker to silence warning stemming from the inability
997 // to see through the template expansion.
998 static void LLVM_ATTRIBUTE_UNUSED
1000  assert(map.getNumInputs() == operands.size() && "invalid operands for map");
1001  SmallVector<AffineExpr> newResults;
1002  newResults.reserve(map.getNumResults());
1003  for (AffineExpr expr : map.getResults()) {
1005  operands);
1006  newResults.push_back(expr);
1007  }
1008  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
1009  map.getContext());
1010 }
1011 
1012 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
1013 /// defining AffineApplyOp expression and operands.
1014 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
1015 /// When `dimOrSymbolPosition >= dims.size()`,
1016 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
1017 /// Mutate `map`,`dims` and `syms` in place as follows:
1018 /// 1. `dims` and `syms` are only appended to.
1019 /// 2. `map` dim and symbols are gradually shifted to higher positions.
1020 /// 3. Old `dim` and `sym` entries are replaced by nullptr
1021 /// This avoids the need for any bookkeeping.
1022 static LogicalResult replaceDimOrSym(AffineMap *map,
1023  unsigned dimOrSymbolPosition,
1024  SmallVectorImpl<Value> &dims,
1025  SmallVectorImpl<Value> &syms) {
1026  MLIRContext *ctx = map->getContext();
1027  bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1028  unsigned pos = isDimReplacement ? dimOrSymbolPosition
1029  : dimOrSymbolPosition - dims.size();
1030  Value &v = isDimReplacement ? dims[pos] : syms[pos];
1031  if (!v)
1032  return failure();
1033 
1034  auto affineApply = v.getDefiningOp<AffineApplyOp>();
1035  if (!affineApply)
1036  return failure();
1037 
1038  // At this point we will perform a replacement of `v`, set the entry in `dim`
1039  // or `sym` to nullptr immediately.
1040  v = nullptr;
1041 
1042  // Compute the map, dims and symbols coming from the AffineApplyOp.
1043  AffineMap composeMap = affineApply.getAffineMap();
1044  assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1045  SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1046  affineApply.getMapOperands().end());
1047  // Canonicalize the map to promote dims to symbols when possible. This is to
1048  // avoid generating invalid maps.
1049  canonicalizeMapAndOperands(&composeMap, &composeOperands);
1050  AffineExpr replacementExpr =
1051  composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
1052  ValueRange composeDims =
1053  ArrayRef<Value>(composeOperands).take_front(composeMap.getNumDims());
1054  ValueRange composeSyms =
1055  ArrayRef<Value>(composeOperands).take_back(composeMap.getNumSymbols());
1056  AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
1057  : getAffineSymbolExpr(pos, ctx);
1058 
1059  // Append the dims and symbols where relevant and perform the replacement.
1060  dims.append(composeDims.begin(), composeDims.end());
1061  syms.append(composeSyms.begin(), composeSyms.end());
1062  *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());
1063 
1064  return success();
1065 }
1066 
1067 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1068 /// iteratively. Perform canonicalization of map and operands as well as
1069 /// AffineMap simplification. `map` and `operands` are mutated in place.
1071  SmallVectorImpl<Value> *operands) {
1072  if (map->getNumResults() == 0) {
1073  canonicalizeMapAndOperands(map, operands);
1074  *map = simplifyAffineMap(*map);
1075  return;
1076  }
1077 
1078  MLIRContext *ctx = map->getContext();
1079  SmallVector<Value, 4> dims(operands->begin(),
1080  operands->begin() + map->getNumDims());
1081  SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1082  operands->end());
1083 
1084  // Iterate over dims and symbols coming from AffineApplyOp and replace until
1085  // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1086  // and `syms` can only increase by construction.
1087  // The implementation uses a `while` loop to support the case of symbols
1088  // that may be constructed from dims ;this may be overkill.
1089  while (true) {
1090  bool changed = false;
1091  for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1092  if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
1093  break;
1094  if (!changed)
1095  break;
1096  }
1097 
1098  // Clear operands so we can fill them anew.
1099  operands->clear();
1100 
1101  // At this point we may have introduced null operands, prune them out before
1102  // canonicalizing map and operands.
1103  unsigned nDims = 0, nSyms = 0;
1104  SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1105  dimReplacements.reserve(dims.size());
1106  symReplacements.reserve(syms.size());
1107  for (auto *container : {&dims, &syms}) {
1108  bool isDim = (container == &dims);
1109  auto &repls = isDim ? dimReplacements : symReplacements;
1110  for (const auto &en : llvm::enumerate(*container)) {
1111  Value v = en.value();
1112  if (!v) {
1113  assert(isDim ? !map->isFunctionOfDim(en.index())
1114  : !map->isFunctionOfSymbol(en.index()) &&
1115  "map is function of unexpected expr@pos");
1116  repls.push_back(getAffineConstantExpr(0, ctx));
1117  continue;
1118  }
1119  repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
1120  : getAffineSymbolExpr(nSyms++, ctx));
1121  operands->push_back(v);
1122  }
1123  }
1124  *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
1125  nSyms);
1126 
1127  // Canonicalize and simplify before returning.
1128  canonicalizeMapAndOperands(map, operands);
1129  *map = simplifyAffineMap(*map);
1130 }
1131 
1133  AffineMap *map, SmallVectorImpl<Value> *operands) {
1134  while (llvm::any_of(*operands, [](Value v) {
1135  return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
1136  })) {
1137  composeAffineMapAndOperands(map, operands);
1138  }
1139 }
1140 
1141 AffineApplyOp
1143  ArrayRef<OpFoldResult> operands) {
1144  SmallVector<Value> valueOperands;
1145  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1146  composeAffineMapAndOperands(&map, &valueOperands);
1147  assert(map);
1148  return b.create<AffineApplyOp>(loc, map, valueOperands);
1149 }
1150 
1151 AffineApplyOp
1153  ArrayRef<OpFoldResult> operands) {
1154  return makeComposedAffineApply(
1155  b, loc,
1157  .front(),
1158  operands);
1159 }
1160 
1161 /// Composes the given affine map with the given list of operands, pulling in
1162 /// the maps from any affine.apply operations that supply the operands.
1164  SmallVectorImpl<Value> &operands) {
1165  // Compose and canonicalize each expression in the map individually because
1166  // composition only applies to single-result maps, collecting potentially
1167  // duplicate operands in a single list with shifted dimensions and symbols.
1168  SmallVector<Value> dims, symbols;
1170  for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
1171  SmallVector<Value> submapOperands(operands.begin(), operands.end());
1172  AffineMap submap = map.getSubMap({i});
1173  fullyComposeAffineMapAndOperands(&submap, &submapOperands);
1174  canonicalizeMapAndOperands(&submap, &submapOperands);
1175  unsigned numNewDims = submap.getNumDims();
1176  submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
1177  llvm::append_range(dims,
1178  ArrayRef<Value>(submapOperands).take_front(numNewDims));
1179  llvm::append_range(symbols,
1180  ArrayRef<Value>(submapOperands).drop_front(numNewDims));
1181  exprs.push_back(submap.getResult(0));
1182  }
1183 
1184  // Canonicalize the map created from composed expressions to deduplicate the
1185  // dimension and symbol operands.
1186  operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1187  map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
1188  canonicalizeMapAndOperands(&map, &operands);
1189 }
1190 
1193  AffineMap map,
1194  ArrayRef<OpFoldResult> operands) {
1195  assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1196 
1197  // Create new builder without a listener, so that no notification is
1198  // triggered if the op is folded.
1199  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1200  // workaround is no longer needed.
1201  OpBuilder newBuilder(b.getContext());
1203 
1204  // Create op.
1205  AffineApplyOp applyOp =
1206  makeComposedAffineApply(newBuilder, loc, map, operands);
1207 
1208  // Get constant operands.
1209  SmallVector<Attribute> constOperands(applyOp->getNumOperands());
1210  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1211  matchPattern(applyOp->getOperand(i), m_Constant(&constOperands[i]));
1212 
1213  // Try to fold the operation.
1214  SmallVector<OpFoldResult> foldResults;
1215  if (failed(applyOp->fold(constOperands, foldResults)) ||
1216  foldResults.empty()) {
1217  if (OpBuilder::Listener *listener = b.getListener())
1218  listener->notifyOperationInserted(applyOp, /*previous=*/{});
1219  return applyOp.getResult();
1220  }
1221 
1222  applyOp->erase();
1223  assert(foldResults.size() == 1 && "expected 1 folded result");
1224  return foldResults.front();
1225 }
1226 
1229  AffineExpr expr,
1230  ArrayRef<OpFoldResult> operands) {
1232  b, loc,
1234  .front(),
1235  operands);
1236 }
1237 
1240  OpBuilder &b, Location loc, AffineMap map,
1241  ArrayRef<OpFoldResult> operands) {
1242  return llvm::map_to_vector(llvm::seq<unsigned>(0, map.getNumResults()),
1243  [&](unsigned i) {
1244  return makeComposedFoldedAffineApply(
1245  b, loc, map.getSubMap({i}), operands);
1246  });
1247 }
1248 
1249 template <typename OpTy>
1251  ArrayRef<OpFoldResult> operands) {
1252  SmallVector<Value> valueOperands;
1253  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1254  composeMultiResultAffineMap(map, valueOperands);
1255  return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands);
1256 }
1257 
1258 AffineMinOp
1260  ArrayRef<OpFoldResult> operands) {
1261  return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1262 }
1263 
1264 template <typename OpTy>
1266  AffineMap map,
1267  ArrayRef<OpFoldResult> operands) {
1268  // Create new builder without a listener, so that no notification is
1269  // triggered if the op is folded.
1270  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1271  // workaround is no longer needed.
1272  OpBuilder newBuilder(b.getContext());
1274 
1275  // Create op.
1276  auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1277 
1278  // Get constant operands.
1279  SmallVector<Attribute> constOperands(minMaxOp->getNumOperands());
1280  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1281  matchPattern(minMaxOp->getOperand(i), m_Constant(&constOperands[i]));
1282 
1283  // Try to fold the operation.
1284  SmallVector<OpFoldResult> foldResults;
1285  if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1286  foldResults.empty()) {
1287  if (OpBuilder::Listener *listener = b.getListener())
1288  listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
1289  return minMaxOp.getResult();
1290  }
1291 
1292  minMaxOp->erase();
1293  assert(foldResults.size() == 1 && "expected 1 folded result");
1294  return foldResults.front();
1295 }
1296 
1299  AffineMap map,
1300  ArrayRef<OpFoldResult> operands) {
1301  return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1302 }
1303 
1306  AffineMap map,
1307  ArrayRef<OpFoldResult> operands) {
1308  return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1309 }
1310 
1311 // A symbol may appear as a dim in affine.apply operations. This function
1312 // canonicalizes dims that are valid symbols into actual symbols.
1313 template <class MapOrSet>
1314 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1315  SmallVectorImpl<Value> *operands) {
1316  if (!mapOrSet || operands->empty())
1317  return;
1318 
1319  assert(mapOrSet->getNumInputs() == operands->size() &&
1320  "map/set inputs must match number of operands");
1321 
1322  auto *context = mapOrSet->getContext();
1323  SmallVector<Value, 8> resultOperands;
1324  resultOperands.reserve(operands->size());
1325  SmallVector<Value, 8> remappedSymbols;
1326  remappedSymbols.reserve(operands->size());
1327  unsigned nextDim = 0;
1328  unsigned nextSym = 0;
1329  unsigned oldNumSyms = mapOrSet->getNumSymbols();
1330  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1331  for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1332  if (i < mapOrSet->getNumDims()) {
1333  if (isValidSymbol((*operands)[i])) {
1334  // This is a valid symbol that appears as a dim, canonicalize it.
1335  dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1336  remappedSymbols.push_back((*operands)[i]);
1337  } else {
1338  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1339  resultOperands.push_back((*operands)[i]);
1340  }
1341  } else {
1342  resultOperands.push_back((*operands)[i]);
1343  }
1344  }
1345 
1346  resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1347  *operands = resultOperands;
1348  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1349  oldNumSyms + nextSym);
1350 
1351  assert(mapOrSet->getNumInputs() == operands->size() &&
1352  "map/set inputs must match number of operands");
1353 }
1354 
1355 // Works for either an affine map or an integer set.
1356 template <class MapOrSet>
1357 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1358  SmallVectorImpl<Value> *operands) {
1359  static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1360  "Argument must be either of AffineMap or IntegerSet type");
1361 
1362  if (!mapOrSet || operands->empty())
1363  return;
1364 
1365  assert(mapOrSet->getNumInputs() == operands->size() &&
1366  "map/set inputs must match number of operands");
1367 
1368  canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1369 
1370  // Check to see what dims are used.
1371  llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1372  llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1373  mapOrSet->walkExprs([&](AffineExpr expr) {
1374  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1375  usedDims[dimExpr.getPosition()] = true;
1376  else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1377  usedSyms[symExpr.getPosition()] = true;
1378  });
1379 
1380  auto *context = mapOrSet->getContext();
1381 
1382  SmallVector<Value, 8> resultOperands;
1383  resultOperands.reserve(operands->size());
1384 
1385  llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1386  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1387  unsigned nextDim = 0;
1388  for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1389  if (usedDims[i]) {
1390  // Remap dim positions for duplicate operands.
1391  auto it = seenDims.find((*operands)[i]);
1392  if (it == seenDims.end()) {
1393  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1394  resultOperands.push_back((*operands)[i]);
1395  seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1396  } else {
1397  dimRemapping[i] = it->second;
1398  }
1399  }
1400  }
1401  llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1402  SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1403  unsigned nextSym = 0;
1404  for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1405  if (!usedSyms[i])
1406  continue;
1407  // Handle constant operands (only needed for symbolic operands since
1408  // constant operands in dimensional positions would have already been
1409  // promoted to symbolic positions above).
1410  IntegerAttr operandCst;
1411  if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1412  m_Constant(&operandCst))) {
1413  symRemapping[i] =
1414  getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1415  continue;
1416  }
1417  // Remap symbol positions for duplicate operands.
1418  auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1419  if (it == seenSymbols.end()) {
1420  symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1421  resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1422  seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1423  symRemapping[i]));
1424  } else {
1425  symRemapping[i] = it->second;
1426  }
1427  }
1428  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1429  nextDim, nextSym);
1430  *operands = resultOperands;
1431 }
1432 
1434  AffineMap *map, SmallVectorImpl<Value> *operands) {
1435  canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1436 }
1437 
1439  IntegerSet *set, SmallVectorImpl<Value> *operands) {
1440  canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1441 }
1442 
1443 namespace {
1444 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1445 /// maps that supply results into them.
1446 ///
1447 template <typename AffineOpTy>
1448 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1450 
1451  /// Replace the affine op with another instance of it with the supplied
1452  /// map and mapOperands.
1453  void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1454  AffineMap map, ArrayRef<Value> mapOperands) const;
1455 
1456  LogicalResult matchAndRewrite(AffineOpTy affineOp,
1457  PatternRewriter &rewriter) const override {
1458  static_assert(
1459  llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1460  AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1461  AffineVectorStoreOp, AffineVectorLoadOp>::value,
1462  "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1463  "expected");
1464  auto map = affineOp.getAffineMap();
1465  AffineMap oldMap = map;
1466  auto oldOperands = affineOp.getMapOperands();
1467  SmallVector<Value, 8> resultOperands(oldOperands);
1468  composeAffineMapAndOperands(&map, &resultOperands);
1469  canonicalizeMapAndOperands(&map, &resultOperands);
1470  simplifyMapWithOperands(map, resultOperands);
1471  if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1472  resultOperands.begin()))
1473  return failure();
1474 
1475  replaceAffineOp(rewriter, affineOp, map, resultOperands);
1476  return success();
1477  }
1478 };
1479 
1480 // Specialize the template to account for the different build signatures for
1481 // affine load, store, and apply ops.
1482 template <>
1483 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1484  PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1485  ArrayRef<Value> mapOperands) const {
1486  rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1487  mapOperands);
1488 }
1489 template <>
1490 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1491  PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1492  ArrayRef<Value> mapOperands) const {
1493  rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1494  prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1495  prefetch.getLocalityHint(), prefetch.getIsDataCache());
1496 }
1497 template <>
1498 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1499  PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1500  ArrayRef<Value> mapOperands) const {
1501  rewriter.replaceOpWithNewOp<AffineStoreOp>(
1502  store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1503 }
1504 template <>
1505 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1506  PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1507  ArrayRef<Value> mapOperands) const {
1508  rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1509  vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1510  mapOperands);
1511 }
1512 template <>
1513 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1514  PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1515  ArrayRef<Value> mapOperands) const {
1516  rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1517  vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1518  mapOperands);
1519 }
1520 
1521 // Generic version for ops that don't have extra operands.
1522 template <typename AffineOpTy>
1523 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1524  PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1525  ArrayRef<Value> mapOperands) const {
1526  rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1527 }
1528 } // namespace
1529 
1530 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1531  MLIRContext *context) {
1532  results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1533 }
1534 
1535 //===----------------------------------------------------------------------===//
1536 // AffineDmaStartOp
1537 //===----------------------------------------------------------------------===//
1538 
1539 // TODO: Check that map operands are loop IVs or symbols.
1540 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1541  Value srcMemRef, AffineMap srcMap,
1542  ValueRange srcIndices, Value destMemRef,
1543  AffineMap dstMap, ValueRange destIndices,
1544  Value tagMemRef, AffineMap tagMap,
1545  ValueRange tagIndices, Value numElements,
1546  Value stride, Value elementsPerStride) {
1547  result.addOperands(srcMemRef);
1548  result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1549  result.addOperands(srcIndices);
1550  result.addOperands(destMemRef);
1551  result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1552  result.addOperands(destIndices);
1553  result.addOperands(tagMemRef);
1554  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1555  result.addOperands(tagIndices);
1556  result.addOperands(numElements);
1557  if (stride) {
1558  result.addOperands({stride, elementsPerStride});
1559  }
1560 }
1561 
1563  p << " " << getSrcMemRef() << '[';
1564  p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1565  p << "], " << getDstMemRef() << '[';
1566  p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1567  p << "], " << getTagMemRef() << '[';
1568  p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1569  p << "], " << getNumElements();
1570  if (isStrided()) {
1571  p << ", " << getStride();
1572  p << ", " << getNumElementsPerStride();
1573  }
1574  p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1575  << getTagMemRefType();
1576 }
1577 
1578 // Parse AffineDmaStartOp.
1579 // Ex:
1580 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1581 // %stride, %num_elt_per_stride
1582 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1583 //
1585  OperationState &result) {
1586  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1587  AffineMapAttr srcMapAttr;
1589  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1590  AffineMapAttr dstMapAttr;
1592  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1593  AffineMapAttr tagMapAttr;
1595  OpAsmParser::UnresolvedOperand numElementsInfo;
1597 
1598  SmallVector<Type, 3> types;
1599  auto indexType = parser.getBuilder().getIndexType();
1600 
1601  // Parse and resolve the following list of operands:
1602  // *) dst memref followed by its affine maps operands (in square brackets).
1603  // *) src memref followed by its affine map operands (in square brackets).
1604  // *) tag memref followed by its affine map operands (in square brackets).
1605  // *) number of elements transferred by DMA operation.
1606  if (parser.parseOperand(srcMemRefInfo) ||
1607  parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1608  getSrcMapAttrStrName(),
1609  result.attributes) ||
1610  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1611  parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1612  getDstMapAttrStrName(),
1613  result.attributes) ||
1614  parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1615  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1616  getTagMapAttrStrName(),
1617  result.attributes) ||
1618  parser.parseComma() || parser.parseOperand(numElementsInfo))
1619  return failure();
1620 
1621  // Parse optional stride and elements per stride.
1622  if (parser.parseTrailingOperandList(strideInfo))
1623  return failure();
1624 
1625  if (!strideInfo.empty() && strideInfo.size() != 2) {
1626  return parser.emitError(parser.getNameLoc(),
1627  "expected two stride related operands");
1628  }
1629  bool isStrided = strideInfo.size() == 2;
1630 
1631  if (parser.parseColonTypeList(types))
1632  return failure();
1633 
1634  if (types.size() != 3)
1635  return parser.emitError(parser.getNameLoc(), "expected three types");
1636 
1637  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1638  parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1639  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1640  parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1641  parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1642  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1643  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1644  return failure();
1645 
1646  if (isStrided) {
1647  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1648  return failure();
1649  }
1650 
1651  // Check that src/dst/tag operand counts match their map.numInputs.
1652  if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1653  dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1654  tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1655  return parser.emitError(parser.getNameLoc(),
1656  "memref operand count not equal to map.numInputs");
1657  return success();
1658 }
1659 
1660 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1661  if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
1662  return emitOpError("expected DMA source to be of memref type");
1663  if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
1664  return emitOpError("expected DMA destination to be of memref type");
1665  if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
1666  return emitOpError("expected DMA tag to be of memref type");
1667 
1668  unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1669  getDstMap().getNumInputs() +
1670  getTagMap().getNumInputs();
1671  if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1672  getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1673  return emitOpError("incorrect number of operands");
1674  }
1675 
1676  Region *scope = getAffineScope(*this);
1677  for (auto idx : getSrcIndices()) {
1678  if (!idx.getType().isIndex())
1679  return emitOpError("src index to dma_start must have 'index' type");
1680  if (!isValidAffineIndexOperand(idx, scope))
1681  return emitOpError(
1682  "src index must be a valid dimension or symbol identifier");
1683  }
1684  for (auto idx : getDstIndices()) {
1685  if (!idx.getType().isIndex())
1686  return emitOpError("dst index to dma_start must have 'index' type");
1687  if (!isValidAffineIndexOperand(idx, scope))
1688  return emitOpError(
1689  "dst index must be a valid dimension or symbol identifier");
1690  }
1691  for (auto idx : getTagIndices()) {
1692  if (!idx.getType().isIndex())
1693  return emitOpError("tag index to dma_start must have 'index' type");
1694  if (!isValidAffineIndexOperand(idx, scope))
1695  return emitOpError(
1696  "tag index must be a valid dimension or symbol identifier");
1697  }
1698  return success();
1699 }
1700 
1701 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1702  SmallVectorImpl<OpFoldResult> &results) {
1703  /// dma_start(memrefcast) -> dma_start
1704  return memref::foldMemRefCast(*this);
1705 }
1706 
1707 void AffineDmaStartOp::getEffects(
1709  &effects) {
1710  effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
1712  effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
1714  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
1716 }
1717 
1718 //===----------------------------------------------------------------------===//
1719 // AffineDmaWaitOp
1720 //===----------------------------------------------------------------------===//
1721 
1722 // TODO: Check that map operands are loop IVs or symbols.
1723 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1724  Value tagMemRef, AffineMap tagMap,
1725  ValueRange tagIndices, Value numElements) {
1726  result.addOperands(tagMemRef);
1727  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1728  result.addOperands(tagIndices);
1729  result.addOperands(numElements);
1730 }
1731 
1733  p << " " << getTagMemRef() << '[';
1734  SmallVector<Value, 2> operands(getTagIndices());
1735  p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1736  p << "], ";
1738  p << " : " << getTagMemRef().getType();
1739 }
1740 
1741 // Parse AffineDmaWaitOp.
1742 // Eg:
1743 // affine.dma_wait %tag[%index], %num_elements
1744 // : memref<1 x i32, (d0) -> (d0), 4>
1745 //
1747  OperationState &result) {
1748  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1749  AffineMapAttr tagMapAttr;
1751  Type type;
1752  auto indexType = parser.getBuilder().getIndexType();
1753  OpAsmParser::UnresolvedOperand numElementsInfo;
1754 
1755  // Parse tag memref, its map operands, and dma size.
1756  if (parser.parseOperand(tagMemRefInfo) ||
1757  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1758  getTagMapAttrStrName(),
1759  result.attributes) ||
1760  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1761  parser.parseColonType(type) ||
1762  parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1763  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1764  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1765  return failure();
1766 
1767  if (!llvm::isa<MemRefType>(type))
1768  return parser.emitError(parser.getNameLoc(),
1769  "expected tag to be of memref type");
1770 
1771  if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1772  return parser.emitError(parser.getNameLoc(),
1773  "tag memref operand count != to map.numInputs");
1774  return success();
1775 }
1776 
1777 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1778  if (!llvm::isa<MemRefType>(getOperand(0).getType()))
1779  return emitOpError("expected DMA tag to be of memref type");
1780  Region *scope = getAffineScope(*this);
1781  for (auto idx : getTagIndices()) {
1782  if (!idx.getType().isIndex())
1783  return emitOpError("index to dma_wait must have 'index' type");
1784  if (!isValidAffineIndexOperand(idx, scope))
1785  return emitOpError(
1786  "index must be a valid dimension or symbol identifier");
1787  }
1788  return success();
1789 }
1790 
1791 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1792  SmallVectorImpl<OpFoldResult> &results) {
1793  /// dma_wait(memrefcast) -> dma_wait
1794  return memref::foldMemRefCast(*this);
1795 }
1796 
1797 void AffineDmaWaitOp::getEffects(
1799  &effects) {
1800  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
1802 }
1803 
1804 //===----------------------------------------------------------------------===//
1805 // AffineForOp
1806 //===----------------------------------------------------------------------===//
1807 
1808 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1809 /// bodyBuilder are empty/null, we include default terminator op.
1810 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1811  ValueRange lbOperands, AffineMap lbMap,
1812  ValueRange ubOperands, AffineMap ubMap, int64_t step,
1813  ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1814  assert(((!lbMap && lbOperands.empty()) ||
1815  lbOperands.size() == lbMap.getNumInputs()) &&
1816  "lower bound operand count does not match the affine map");
1817  assert(((!ubMap && ubOperands.empty()) ||
1818  ubOperands.size() == ubMap.getNumInputs()) &&
1819  "upper bound operand count does not match the affine map");
1820  assert(step > 0 && "step has to be a positive integer constant");
1821 
1822  OpBuilder::InsertionGuard guard(builder);
1823 
1824  // Set variadic segment sizes.
1825  result.addAttribute(
1826  getOperandSegmentSizeAttr(),
1827  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()),
1828  static_cast<int32_t>(ubOperands.size()),
1829  static_cast<int32_t>(iterArgs.size())}));
1830 
1831  for (Value val : iterArgs)
1832  result.addTypes(val.getType());
1833 
1834  // Add an attribute for the step.
1835  result.addAttribute(getStepAttrName(result.name),
1836  builder.getIntegerAttr(builder.getIndexType(), step));
1837 
1838  // Add the lower bound.
1839  result.addAttribute(getLowerBoundMapAttrName(result.name),
1840  AffineMapAttr::get(lbMap));
1841  result.addOperands(lbOperands);
1842 
1843  // Add the upper bound.
1844  result.addAttribute(getUpperBoundMapAttrName(result.name),
1845  AffineMapAttr::get(ubMap));
1846  result.addOperands(ubOperands);
1847 
1848  result.addOperands(iterArgs);
1849  // Create a region and a block for the body. The argument of the region is
1850  // the loop induction variable.
1851  Region *bodyRegion = result.addRegion();
1852  Block *bodyBlock = builder.createBlock(bodyRegion);
1853  Value inductionVar =
1854  bodyBlock->addArgument(builder.getIndexType(), result.location);
1855  for (Value val : iterArgs)
1856  bodyBlock->addArgument(val.getType(), val.getLoc());
1857 
1858  // Create the default terminator if the builder is not provided and if the
1859  // iteration arguments are not provided. Otherwise, leave this to the caller
1860  // because we don't know which values to return from the loop.
1861  if (iterArgs.empty() && !bodyBuilder) {
1862  ensureTerminator(*bodyRegion, builder, result.location);
1863  } else if (bodyBuilder) {
1864  OpBuilder::InsertionGuard guard(builder);
1865  builder.setInsertionPointToStart(bodyBlock);
1866  bodyBuilder(builder, result.location, inductionVar,
1867  bodyBlock->getArguments().drop_front());
1868  }
1869 }
1870 
1871 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1872  int64_t ub, int64_t step, ValueRange iterArgs,
1873  BodyBuilderFn bodyBuilder) {
1874  auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1875  auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1876  return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1877  bodyBuilder);
1878 }
1879 
1880 LogicalResult AffineForOp::verifyRegions() {
1881  // Check that the body defines as single block argument for the induction
1882  // variable.
1883  auto *body = getBody();
1884  if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1885  return emitOpError("expected body to have a single index argument for the "
1886  "induction variable");
1887 
1888  // Verify that the bound operands are valid dimension/symbols.
1889  /// Lower bound.
1890  if (getLowerBoundMap().getNumInputs() > 0)
1892  getLowerBoundMap().getNumDims())))
1893  return failure();
1894  /// Upper bound.
1895  if (getUpperBoundMap().getNumInputs() > 0)
1897  getUpperBoundMap().getNumDims())))
1898  return failure();
1899 
1900  unsigned opNumResults = getNumResults();
1901  if (opNumResults == 0)
1902  return success();
1903 
1904  // If ForOp defines values, check that the number and types of the defined
1905  // values match ForOp initial iter operands and backedge basic block
1906  // arguments.
1907  if (getNumIterOperands() != opNumResults)
1908  return emitOpError(
1909  "mismatch between the number of loop-carried values and results");
1910  if (getNumRegionIterArgs() != opNumResults)
1911  return emitOpError(
1912  "mismatch between the number of basic block args and results");
1913 
1914  return success();
1915 }
1916 
1917 /// Parse a for operation loop bounds.
1918 static ParseResult parseBound(bool isLower, OperationState &result,
1919  OpAsmParser &p) {
1920  // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1921  // the map has multiple results.
1922  bool failedToParsedMinMax =
1923  failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1924 
1925  auto &builder = p.getBuilder();
1926  auto boundAttrStrName =
1927  isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)
1928  : AffineForOp::getUpperBoundMapAttrName(result.name);
1929 
1930  // Parse ssa-id as identity map.
1932  if (p.parseOperandList(boundOpInfos))
1933  return failure();
1934 
1935  if (!boundOpInfos.empty()) {
1936  // Check that only one operand was parsed.
1937  if (boundOpInfos.size() > 1)
1938  return p.emitError(p.getNameLoc(),
1939  "expected only one loop bound operand");
1940 
1941  // TODO: improve error message when SSA value is not of index type.
1942  // Currently it is 'use of value ... expects different type than prior uses'
1943  if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1944  result.operands))
1945  return failure();
1946 
1947  // Create an identity map using symbol id. This representation is optimized
1948  // for storage. Analysis passes may expand it into a multi-dimensional map
1949  // if desired.
1950  AffineMap map = builder.getSymbolIdentityMap();
1951  result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
1952  return success();
1953  }
1954 
1955  // Get the attribute location.
1956  SMLoc attrLoc = p.getCurrentLocation();
1957 
1958  Attribute boundAttr;
1959  if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
1960  result.attributes))
1961  return failure();
1962 
1963  // Parse full form - affine map followed by dim and symbol list.
1964  if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
1965  unsigned currentNumOperands = result.operands.size();
1966  unsigned numDims;
1967  if (parseDimAndSymbolList(p, result.operands, numDims))
1968  return failure();
1969 
1970  auto map = affineMapAttr.getValue();
1971  if (map.getNumDims() != numDims)
1972  return p.emitError(
1973  p.getNameLoc(),
1974  "dim operand count and affine map dim count must match");
1975 
1976  unsigned numDimAndSymbolOperands =
1977  result.operands.size() - currentNumOperands;
1978  if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1979  return p.emitError(
1980  p.getNameLoc(),
1981  "symbol operand count and affine map symbol count must match");
1982 
1983  // If the map has multiple results, make sure that we parsed the min/max
1984  // prefix.
1985  if (map.getNumResults() > 1 && failedToParsedMinMax) {
1986  if (isLower) {
1987  return p.emitError(attrLoc, "lower loop bound affine map with "
1988  "multiple results requires 'max' prefix");
1989  }
1990  return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1991  "results requires 'min' prefix");
1992  }
1993  return success();
1994  }
1995 
1996  // Parse custom assembly form.
1997  if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
1998  result.attributes.pop_back();
1999  result.addAttribute(
2000  boundAttrStrName,
2001  AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
2002  return success();
2003  }
2004 
2005  return p.emitError(
2006  p.getNameLoc(),
2007  "expected valid affine map representation for loop bounds");
2008 }
2009 
2010 ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
2011  auto &builder = parser.getBuilder();
2012  OpAsmParser::Argument inductionVariable;
2013  inductionVariable.type = builder.getIndexType();
2014  // Parse the induction variable followed by '='.
2015  if (parser.parseArgument(inductionVariable) || parser.parseEqual())
2016  return failure();
2017 
2018  // Parse loop bounds.
2019  int64_t numOperands = result.operands.size();
2020  if (parseBound(/*isLower=*/true, result, parser))
2021  return failure();
2022  int64_t numLbOperands = result.operands.size() - numOperands;
2023  if (parser.parseKeyword("to", " between bounds"))
2024  return failure();
2025  numOperands = result.operands.size();
2026  if (parseBound(/*isLower=*/false, result, parser))
2027  return failure();
2028  int64_t numUbOperands = result.operands.size() - numOperands;
2029 
2030  // Parse the optional loop step, we default to 1 if one is not present.
2031  if (parser.parseOptionalKeyword("step")) {
2032  result.addAttribute(
2033  getStepAttrName(result.name),
2034  builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
2035  } else {
2036  SMLoc stepLoc = parser.getCurrentLocation();
2037  IntegerAttr stepAttr;
2038  if (parser.parseAttribute(stepAttr, builder.getIndexType(),
2039  getStepAttrName(result.name).data(),
2040  result.attributes))
2041  return failure();
2042 
2043  if (stepAttr.getValue().isNegative())
2044  return parser.emitError(
2045  stepLoc,
2046  "expected step to be representable as a positive signed integer");
2047  }
2048 
2049  // Parse the optional initial iteration arguments.
2052 
2053  // Induction variable.
2054  regionArgs.push_back(inductionVariable);
2055 
2056  if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
2057  // Parse assignment list and results type list.
2058  if (parser.parseAssignmentList(regionArgs, operands) ||
2059  parser.parseArrowTypeList(result.types))
2060  return failure();
2061  // Resolve input operands.
2062  for (auto argOperandType :
2063  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2064  Type type = std::get<2>(argOperandType);
2065  std::get<0>(argOperandType).type = type;
2066  if (parser.resolveOperand(std::get<1>(argOperandType), type,
2067  result.operands))
2068  return failure();
2069  }
2070  }
2071 
2072  result.addAttribute(
2073  getOperandSegmentSizeAttr(),
2074  builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands),
2075  static_cast<int32_t>(numUbOperands),
2076  static_cast<int32_t>(operands.size())}));
2077 
2078  // Parse the body region.
2079  Region *body = result.addRegion();
2080  if (regionArgs.size() != result.types.size() + 1)
2081  return parser.emitError(
2082  parser.getNameLoc(),
2083  "mismatch between the number of loop-carried values and results");
2084  if (parser.parseRegion(*body, regionArgs))
2085  return failure();
2086 
2087  AffineForOp::ensureTerminator(*body, builder, result.location);
2088 
2089  // Parse the optional attribute list.
2090  return parser.parseOptionalAttrDict(result.attributes);
2091 }
2092 
2093 static void printBound(AffineMapAttr boundMap,
2094  Operation::operand_range boundOperands,
2095  const char *prefix, OpAsmPrinter &p) {
2096  AffineMap map = boundMap.getValue();
2097 
2098  // Check if this bound should be printed using custom assembly form.
2099  // The decision to restrict printing custom assembly form to trivial cases
2100  // comes from the will to roundtrip MLIR binary -> text -> binary in a
2101  // lossless way.
2102  // Therefore, custom assembly form parsing and printing is only supported for
2103  // zero-operand constant maps and single symbol operand identity maps.
2104  if (map.getNumResults() == 1) {
2105  AffineExpr expr = map.getResult(0);
2106 
2107  // Print constant bound.
2108  if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2109  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2110  p << constExpr.getValue();
2111  return;
2112  }
2113  }
2114 
2115  // Print bound that consists of a single SSA symbol if the map is over a
2116  // single symbol.
2117  if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2118  if (dyn_cast<AffineSymbolExpr>(expr)) {
2119  p.printOperand(*boundOperands.begin());
2120  return;
2121  }
2122  }
2123  } else {
2124  // Map has multiple results. Print 'min' or 'max' prefix.
2125  p << prefix << ' ';
2126  }
2127 
2128  // Print the map and its operands.
2129  p << boundMap;
2130  printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
2131  map.getNumDims(), p);
2132 }
2133 
2134 unsigned AffineForOp::getNumIterOperands() {
2135  AffineMap lbMap = getLowerBoundMapAttr().getValue();
2136  AffineMap ubMap = getUpperBoundMapAttr().getValue();
2137 
2138  return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2139 }
2140 
2141 std::optional<MutableArrayRef<OpOperand>>
2142 AffineForOp::getYieldedValuesMutable() {
2143  return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2144 }
2145 
2147  p << ' ';
2148  p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
2149  /*omitType=*/true);
2150  p << " = ";
2151  printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
2152  p << " to ";
2153  printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
2154 
2155  if (getStepAsInt() != 1)
2156  p << " step " << getStepAsInt();
2157 
2158  bool printBlockTerminators = false;
2159  if (getNumIterOperands() > 0) {
2160  p << " iter_args(";
2161  auto regionArgs = getRegionIterArgs();
2162  auto operands = getInits();
2163 
2164  llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2165  p << std::get<0>(it) << " = " << std::get<1>(it);
2166  });
2167  p << ") -> (" << getResultTypes() << ")";
2168  printBlockTerminators = true;
2169  }
2170 
2171  p << ' ';
2172  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2173  printBlockTerminators);
2175  (*this)->getAttrs(),
2176  /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()),
2177  getUpperBoundMapAttrName(getOperation()->getName()),
2178  getStepAttrName(getOperation()->getName()),
2179  getOperandSegmentSizeAttr()});
2180 }
2181 
2182 /// Fold the constant bounds of a loop.
2183 static LogicalResult foldLoopBounds(AffineForOp forOp) {
2184  auto foldLowerOrUpperBound = [&forOp](bool lower) {
2185  // Check to see if each of the operands is the result of a constant. If
2186  // so, get the value. If not, ignore it.
2187  SmallVector<Attribute, 8> operandConstants;
2188  auto boundOperands =
2189  lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2190  for (auto operand : boundOperands) {
2191  Attribute operandCst;
2192  matchPattern(operand, m_Constant(&operandCst));
2193  operandConstants.push_back(operandCst);
2194  }
2195 
2196  AffineMap boundMap =
2197  lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2198  assert(boundMap.getNumResults() >= 1 &&
2199  "bound maps should have at least one result");
2200  SmallVector<Attribute, 4> foldedResults;
2201  if (failed(boundMap.constantFold(operandConstants, foldedResults)))
2202  return failure();
2203 
2204  // Compute the max or min as applicable over the results.
2205  assert(!foldedResults.empty() && "bounds should have at least one result");
2206  auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2207  for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2208  auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2209  maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2210  : llvm::APIntOps::smin(maxOrMin, foldedResult);
2211  }
2212  lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2213  : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2214  return success();
2215  };
2216 
2217  // Try to fold the lower bound.
2218  bool folded = false;
2219  if (!forOp.hasConstantLowerBound())
2220  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
2221 
2222  // Try to fold the upper bound.
2223  if (!forOp.hasConstantUpperBound())
2224  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
2225  return success(folded);
2226 }
2227 
2228 /// Canonicalize the bounds of the given loop.
2229 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2230  SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2231  SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2232 
2233  auto lbMap = forOp.getLowerBoundMap();
2234  auto ubMap = forOp.getUpperBoundMap();
2235  auto prevLbMap = lbMap;
2236  auto prevUbMap = ubMap;
2237 
2238  composeAffineMapAndOperands(&lbMap, &lbOperands);
2239  canonicalizeMapAndOperands(&lbMap, &lbOperands);
2240  simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
2241  simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
2242  lbMap = removeDuplicateExprs(lbMap);
2243 
2244  composeAffineMapAndOperands(&ubMap, &ubOperands);
2245  canonicalizeMapAndOperands(&ubMap, &ubOperands);
2246  ubMap = removeDuplicateExprs(ubMap);
2247 
2248  // Any canonicalization change always leads to updated map(s).
2249  if (lbMap == prevLbMap && ubMap == prevUbMap)
2250  return failure();
2251 
2252  if (lbMap != prevLbMap)
2253  forOp.setLowerBound(lbOperands, lbMap);
2254  if (ubMap != prevUbMap)
2255  forOp.setUpperBound(ubOperands, ubMap);
2256  return success();
2257 }
2258 
2259 namespace {
2260 /// Returns constant trip count in trivial cases.
2261 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2262  int64_t step = forOp.getStepAsInt();
2263  if (!forOp.hasConstantBounds() || step <= 0)
2264  return std::nullopt;
2265  int64_t lb = forOp.getConstantLowerBound();
2266  int64_t ub = forOp.getConstantUpperBound();
2267  return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2268 }
2269 
2270 /// This is a pattern to fold trivially empty loop bodies.
2271 /// TODO: This should be moved into the folding hook.
2272 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
2274 
2275  LogicalResult matchAndRewrite(AffineForOp forOp,
2276  PatternRewriter &rewriter) const override {
2277  // Check that the body only contains a yield.
2278  if (!llvm::hasSingleElement(*forOp.getBody()))
2279  return failure();
2280  if (forOp.getNumResults() == 0)
2281  return success();
2282  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2283  if (tripCount && *tripCount == 0) {
2284  // The initial values of the iteration arguments would be the op's
2285  // results.
2286  rewriter.replaceOp(forOp, forOp.getInits());
2287  return success();
2288  }
2289  SmallVector<Value, 4> replacements;
2290  auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2291  auto iterArgs = forOp.getRegionIterArgs();
2292  bool hasValDefinedOutsideLoop = false;
2293  bool iterArgsNotInOrder = false;
2294  for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2295  Value val = yieldOp.getOperand(i);
2296  auto *iterArgIt = llvm::find(iterArgs, val);
2297  if (iterArgIt == iterArgs.end()) {
2298  // `val` is defined outside of the loop.
2299  assert(forOp.isDefinedOutsideOfLoop(val) &&
2300  "must be defined outside of the loop");
2301  hasValDefinedOutsideLoop = true;
2302  replacements.push_back(val);
2303  } else {
2304  unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2305  if (pos != i)
2306  iterArgsNotInOrder = true;
2307  replacements.push_back(forOp.getInits()[pos]);
2308  }
2309  }
2310  // Bail out when the trip count is unknown and the loop returns any value
2311  // defined outside of the loop or any iterArg out of order.
2312  if (!tripCount.has_value() &&
2313  (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2314  return failure();
2315  // Bail out when the loop iterates more than once and it returns any iterArg
2316  // out of order.
2317  if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2318  return failure();
2319  rewriter.replaceOp(forOp, replacements);
2320  return success();
2321  }
2322 };
2323 } // namespace
2324 
2325 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
2326  MLIRContext *context) {
2327  results.add<AffineForEmptyLoopFolder>(context);
2328 }
2329 
2330 OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2331  assert((point.isParent() || point == getRegion()) && "invalid region point");
2332 
2333  // The initial operands map to the loop arguments after the induction
2334  // variable or are forwarded to the results when the trip count is zero.
2335  return getInits();
2336 }
2337 
2338 void AffineForOp::getSuccessorRegions(
2340  assert((point.isParent() || point == getRegion()) && "expected loop region");
2341  // The loop may typically branch back to its body or to the parent operation.
2342  // If the predecessor is the parent op and the trip count is known to be at
2343  // least one, branch into the body using the iterator arguments. And in cases
2344  // we know the trip count is zero, it can only branch back to its parent.
2345  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2346  if (point.isParent() && tripCount.has_value()) {
2347  if (tripCount.value() > 0) {
2348  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2349  return;
2350  }
2351  if (tripCount.value() == 0) {
2352  regions.push_back(RegionSuccessor(getResults()));
2353  return;
2354  }
2355  }
2356 
2357  // From the loop body, if the trip count is one, we can only branch back to
2358  // the parent.
2359  if (!point.isParent() && tripCount && *tripCount == 1) {
2360  regions.push_back(RegionSuccessor(getResults()));
2361  return;
2362  }
2363 
2364  // In all other cases, the loop may branch back to itself or the parent
2365  // operation.
2366  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2367  regions.push_back(RegionSuccessor(getResults()));
2368 }
2369 
2370 /// Returns true if the affine.for has zero iterations in trivial cases.
2371 static bool hasTrivialZeroTripCount(AffineForOp op) {
2372  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
2373  return tripCount && *tripCount == 0;
2374 }
2375 
2376 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2377  SmallVectorImpl<OpFoldResult> &results) {
2378  bool folded = succeeded(foldLoopBounds(*this));
2379  folded |= succeeded(canonicalizeLoopBounds(*this));
2380  if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
2381  // The initial values of the loop-carried variables (iter_args) are the
2382  // results of the op. But this must be avoided for an affine.for op that
2383  // does not return any results. Since ops that do not return results cannot
2384  // be folded away, we would enter an infinite loop of folds on the same
2385  // affine.for op.
2386  results.assign(getInits().begin(), getInits().end());
2387  folded = true;
2388  }
2389  return success(folded);
2390 }
2391 
2393  return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
2394 }
2395 
2397  return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap());
2398 }
2399 
2400 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2401  assert(lbOperands.size() == map.getNumInputs());
2402  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2403  getLowerBoundOperandsMutable().assign(lbOperands);
2404  setLowerBoundMap(map);
2405 }
2406 
2407 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2408  assert(ubOperands.size() == map.getNumInputs());
2409  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2410  getUpperBoundOperandsMutable().assign(ubOperands);
2411  setUpperBoundMap(map);
2412 }
2413 
2414 bool AffineForOp::hasConstantLowerBound() {
2415  return getLowerBoundMap().isSingleConstant();
2416 }
2417 
2418 bool AffineForOp::hasConstantUpperBound() {
2419  return getUpperBoundMap().isSingleConstant();
2420 }
2421 
2422 int64_t AffineForOp::getConstantLowerBound() {
2423  return getLowerBoundMap().getSingleConstantResult();
2424 }
2425 
2426 int64_t AffineForOp::getConstantUpperBound() {
2427  return getUpperBoundMap().getSingleConstantResult();
2428 }
2429 
2430 void AffineForOp::setConstantLowerBound(int64_t value) {
2431  setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2432 }
2433 
2434 void AffineForOp::setConstantUpperBound(int64_t value) {
2435  setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2436 }
2437 
2438 AffineForOp::operand_range AffineForOp::getControlOperands() {
2439  return {operand_begin(), operand_begin() + getLowerBoundOperands().size() +
2440  getUpperBoundOperands().size()};
2441 }
2442 
2443 bool AffineForOp::matchingBoundOperandList() {
2444  auto lbMap = getLowerBoundMap();
2445  auto ubMap = getUpperBoundMap();
2446  if (lbMap.getNumDims() != ubMap.getNumDims() ||
2447  lbMap.getNumSymbols() != ubMap.getNumSymbols())
2448  return false;
2449 
2450  unsigned numOperands = lbMap.getNumInputs();
2451  for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2452  // Compare Value 's.
2453  if (getOperand(i) != getOperand(numOperands + i))
2454  return false;
2455  }
2456  return true;
2457 }
2458 
2459 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
2460 
2461 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2462  return SmallVector<Value>{getInductionVar()};
2463 }
2464 
2465 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2466  if (!hasConstantLowerBound())
2467  return std::nullopt;
2468  OpBuilder b(getContext());
2470  OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2471 }
2472 
2473 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2474  OpBuilder b(getContext());
2476  OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
2477 }
2478 
2479 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2480  if (!hasConstantUpperBound())
2481  return {};
2482  OpBuilder b(getContext());
2484  OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2485 }
2486 
2487 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2488  RewriterBase &rewriter, ValueRange newInitOperands,
2489  bool replaceInitOperandUsesInLoop,
2490  const NewYieldValuesFn &newYieldValuesFn) {
2491  // Create a new loop before the existing one, with the extra operands.
2492  OpBuilder::InsertionGuard g(rewriter);
2493  rewriter.setInsertionPoint(getOperation());
2494  auto inits = llvm::to_vector(getInits());
2495  inits.append(newInitOperands.begin(), newInitOperands.end());
2496  AffineForOp newLoop = rewriter.create<AffineForOp>(
2497  getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2498  getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
2499 
2500  // Generate the new yield values and append them to the scf.yield operation.
2501  auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2502  ArrayRef<BlockArgument> newIterArgs =
2503  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2504  {
2505  OpBuilder::InsertionGuard g(rewriter);
2506  rewriter.setInsertionPoint(yieldOp);
2507  SmallVector<Value> newYieldedValues =
2508  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2509  assert(newInitOperands.size() == newYieldedValues.size() &&
2510  "expected as many new yield values as new iter operands");
2511  rewriter.modifyOpInPlace(yieldOp, [&]() {
2512  yieldOp.getOperandsMutable().append(newYieldedValues);
2513  });
2514  }
2515 
2516  // Move the loop body to the new op.
2517  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2518  newLoop.getBody()->getArguments().take_front(
2519  getBody()->getNumArguments()));
2520 
2521  if (replaceInitOperandUsesInLoop) {
2522  // Replace all uses of `newInitOperands` with the corresponding basic block
2523  // arguments.
2524  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2525  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
2526  [&](OpOperand &use) {
2527  Operation *user = use.getOwner();
2528  return newLoop->isProperAncestor(user);
2529  });
2530  }
2531  }
2532 
2533  // Replace the old loop.
2534  rewriter.replaceOp(getOperation(),
2535  newLoop->getResults().take_front(getNumResults()));
2536  return cast<LoopLikeOpInterface>(newLoop.getOperation());
2537 }
2538 
2539 Speculation::Speculatability AffineForOp::getSpeculatability() {
2540  // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2541  // Start and End.
2542  //
2543  // For Step != 1, the loop may not terminate. We can add more smarts here if
2544  // needed.
2545  return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable
2547 }
2548 
2549 /// Returns true if the provided value is the induction variable of a
2550 /// AffineForOp.
2552  return getForInductionVarOwner(val) != AffineForOp();
2553 }
2554 
2556  return getAffineParallelInductionVarOwner(val) != nullptr;
2557 }
2558 
2561 }
2562 
2564  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2565  if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2566  return AffineForOp();
2567  if (auto forOp =
2568  ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2569  // Check to make sure `val` is the induction variable, not an iter_arg.
2570  return forOp.getInductionVar() == val ? forOp : AffineForOp();
2571  return AffineForOp();
2572 }
2573 
2575  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2576  if (!ivArg || !ivArg.getOwner())
2577  return nullptr;
2578  Operation *containingOp = ivArg.getOwner()->getParentOp();
2579  auto parallelOp = dyn_cast<AffineParallelOp>(containingOp);
2580  if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2581  return parallelOp;
2582  return nullptr;
2583 }
2584 
2585 /// Extracts the induction variables from a list of AffineForOps and returns
2586 /// them.
2588  SmallVectorImpl<Value> *ivs) {
2589  ivs->reserve(forInsts.size());
2590  for (auto forInst : forInsts)
2591  ivs->push_back(forInst.getInductionVar());
2592 }
2593 
2596  ivs.reserve(affineOps.size());
2597  for (Operation *op : affineOps) {
2598  // Add constraints from forOp's bounds.
2599  if (auto forOp = dyn_cast<AffineForOp>(op))
2600  ivs.push_back(forOp.getInductionVar());
2601  else if (auto parallelOp = dyn_cast<AffineParallelOp>(op))
2602  for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2603  ivs.push_back(parallelOp.getBody()->getArgument(i));
2604  }
2605 }
2606 
2607 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2608 /// operations.
2609 template <typename BoundListTy, typename LoopCreatorTy>
2611  OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2612  ArrayRef<int64_t> steps,
2613  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2614  LoopCreatorTy &&loopCreatorFn) {
2615  assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2616  assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2617 
2618  // If there are no loops to be constructed, construct the body anyway.
2619  OpBuilder::InsertionGuard guard(builder);
2620  if (lbs.empty()) {
2621  if (bodyBuilderFn)
2622  bodyBuilderFn(builder, loc, ValueRange());
2623  return;
2624  }
2625 
2626  // Create the loops iteratively and store the induction variables.
2628  ivs.reserve(lbs.size());
2629  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2630  // Callback for creating the loop body, always creates the terminator.
2631  auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
2632  ValueRange iterArgs) {
2633  ivs.push_back(iv);
2634  // In the innermost loop, call the body builder.
2635  if (i == e - 1 && bodyBuilderFn) {
2636  OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
2637  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2638  }
2639  nestedBuilder.create<AffineYieldOp>(nestedLoc);
2640  };
2641 
2642  // Delegate actual loop creation to the callback in order to dispatch
2643  // between constant- and variable-bound loops.
2644  auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2645  builder.setInsertionPointToStart(loop.getBody());
2646  }
2647 }
2648 
2649 /// Creates an affine loop from the bounds known to be constants.
2650 static AffineForOp
2652  int64_t ub, int64_t step,
2653  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2654  return builder.create<AffineForOp>(loc, lb, ub, step,
2655  /*iterArgs=*/std::nullopt, bodyBuilderFn);
2656 }
2657 
2658 /// Creates an affine loop from the bounds that may or may not be constants.
2659 static AffineForOp
2661  int64_t step,
2662  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2663  std::optional<int64_t> lbConst = getConstantIntValue(lb);
2664  std::optional<int64_t> ubConst = getConstantIntValue(ub);
2665  if (lbConst && ubConst)
2666  return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
2667  ubConst.value(), step, bodyBuilderFn);
2668  return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
2669  builder.getDimIdentityMap(), step,
2670  /*iterArgs=*/std::nullopt, bodyBuilderFn);
2671 }
2672 
2674  OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
2676  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2677  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2679 }
2680 
2682  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
2683  ArrayRef<int64_t> steps,
2684  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2685  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2687 }
2688 
2689 //===----------------------------------------------------------------------===//
2690 // AffineIfOp
2691 //===----------------------------------------------------------------------===//
2692 
2693 namespace {
2694 /// Remove else blocks that have nothing other than a zero value yield.
2695 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
2697 
2698  LogicalResult matchAndRewrite(AffineIfOp ifOp,
2699  PatternRewriter &rewriter) const override {
2700  if (ifOp.getElseRegion().empty() ||
2701  !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2702  return failure();
2703 
2704  rewriter.startOpModification(ifOp);
2705  rewriter.eraseBlock(ifOp.getElseBlock());
2706  rewriter.finalizeOpModification(ifOp);
2707  return success();
2708  }
2709 };
2710 
2711 /// Removes affine.if cond if the condition is always true or false in certain
2712 /// trivial cases. Promotes the then/else block in the parent operation block.
2713 struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
2715 
2716  LogicalResult matchAndRewrite(AffineIfOp op,
2717  PatternRewriter &rewriter) const override {
2718 
2719  auto isTriviallyFalse = [](IntegerSet iSet) {
2720  return iSet.isEmptyIntegerSet();
2721  };
2722 
2723  auto isTriviallyTrue = [](IntegerSet iSet) {
2724  return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2725  iSet.getConstraint(0) == 0);
2726  };
2727 
2728  IntegerSet affineIfConditions = op.getIntegerSet();
2729  Block *blockToMove;
2730  if (isTriviallyFalse(affineIfConditions)) {
2731  // The absence, or equivalently, the emptiness of the else region need not
2732  // be checked when affine.if is returning results because if an affine.if
2733  // operation is returning results, it always has a non-empty else region.
2734  if (op.getNumResults() == 0 && !op.hasElse()) {
2735  // If the else region is absent, or equivalently, empty, remove the
2736  // affine.if operation (which is not returning any results).
2737  rewriter.eraseOp(op);
2738  return success();
2739  }
2740  blockToMove = op.getElseBlock();
2741  } else if (isTriviallyTrue(affineIfConditions)) {
2742  blockToMove = op.getThenBlock();
2743  } else {
2744  return failure();
2745  }
2746  Operation *blockToMoveTerminator = blockToMove->getTerminator();
2747  // Promote the "blockToMove" block to the parent operation block between the
2748  // prologue and epilogue of "op".
2749  rewriter.inlineBlockBefore(blockToMove, op);
2750  // Replace the "op" operation with the operands of the
2751  // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2752  // the affine.yield operation present in the "blockToMove" block. It has no
2753  // operands when affine.if is not returning results and therefore, in that
2754  // case, replaceOp just erases "op". When affine.if is not returning
2755  // results, the affine.yield operation can be omitted. It gets inserted
2756  // implicitly.
2757  rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
2758  // Erase the "blockToMoveTerminator" operation since it is now in the parent
2759  // operation block, which already has its own terminator.
2760  rewriter.eraseOp(blockToMoveTerminator);
2761  return success();
2762  }
2763 };
2764 } // namespace
2765 
2766 /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
2767 /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
2768 void AffineIfOp::getSuccessorRegions(
2770  // If the predecessor is an AffineIfOp, then branching into both `then` and
2771  // `else` region is valid.
2772  if (point.isParent()) {
2773  regions.reserve(2);
2774  regions.push_back(
2775  RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
2776  // If the "else" region is empty, branch bach into parent.
2777  if (getElseRegion().empty()) {
2778  regions.push_back(getResults());
2779  } else {
2780  regions.push_back(
2781  RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
2782  }
2783  return;
2784  }
2785 
2786  // If the predecessor is the `else`/`then` region, then branching into parent
2787  // op is valid.
2788  regions.push_back(RegionSuccessor(getResults()));
2789 }
2790 
2791 LogicalResult AffineIfOp::verify() {
2792  // Verify that we have a condition attribute.
2793  // FIXME: This should be specified in the arguments list in ODS.
2794  auto conditionAttr =
2795  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2796  if (!conditionAttr)
2797  return emitOpError("requires an integer set attribute named 'condition'");
2798 
2799  // Verify that there are enough operands for the condition.
2800  IntegerSet condition = conditionAttr.getValue();
2801  if (getNumOperands() != condition.getNumInputs())
2802  return emitOpError("operand count and condition integer set dimension and "
2803  "symbol count must match");
2804 
2805  // Verify that the operands are valid dimension/symbols.
2806  if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
2807  condition.getNumDims())))
2808  return failure();
2809 
2810  return success();
2811 }
2812 
2813 ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
2814  // Parse the condition attribute set.
2815  IntegerSetAttr conditionAttr;
2816  unsigned numDims;
2817  if (parser.parseAttribute(conditionAttr,
2818  AffineIfOp::getConditionAttrStrName(),
2819  result.attributes) ||
2820  parseDimAndSymbolList(parser, result.operands, numDims))
2821  return failure();
2822 
2823  // Verify the condition operands.
2824  auto set = conditionAttr.getValue();
2825  if (set.getNumDims() != numDims)
2826  return parser.emitError(
2827  parser.getNameLoc(),
2828  "dim operand count and integer set dim count must match");
2829  if (numDims + set.getNumSymbols() != result.operands.size())
2830  return parser.emitError(
2831  parser.getNameLoc(),
2832  "symbol operand count and integer set symbol count must match");
2833 
2834  if (parser.parseOptionalArrowTypeList(result.types))
2835  return failure();
2836 
2837  // Create the regions for 'then' and 'else'. The latter must be created even
2838  // if it remains empty for the validity of the operation.
2839  result.regions.reserve(2);
2840  Region *thenRegion = result.addRegion();
2841  Region *elseRegion = result.addRegion();
2842 
2843  // Parse the 'then' region.
2844  if (parser.parseRegion(*thenRegion, {}, {}))
2845  return failure();
2846  AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
2847  result.location);
2848 
2849  // If we find an 'else' keyword then parse the 'else' region.
2850  if (!parser.parseOptionalKeyword("else")) {
2851  if (parser.parseRegion(*elseRegion, {}, {}))
2852  return failure();
2853  AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
2854  result.location);
2855  }
2856 
2857  // Parse the optional attribute list.
2858  if (parser.parseOptionalAttrDict(result.attributes))
2859  return failure();
2860 
2861  return success();
2862 }
2863 
2865  auto conditionAttr =
2866  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2867  p << " " << conditionAttr;
2868  printDimAndSymbolList(operand_begin(), operand_end(),
2869  conditionAttr.getValue().getNumDims(), p);
2870  p.printOptionalArrowTypeList(getResultTypes());
2871  p << ' ';
2872  p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
2873  /*printBlockTerminators=*/getNumResults());
2874 
2875  // Print the 'else' regions if it has any blocks.
2876  auto &elseRegion = this->getElseRegion();
2877  if (!elseRegion.empty()) {
2878  p << " else ";
2879  p.printRegion(elseRegion,
2880  /*printEntryBlockArgs=*/false,
2881  /*printBlockTerminators=*/getNumResults());
2882  }
2883 
2884  // Print the attribute list.
2885  p.printOptionalAttrDict((*this)->getAttrs(),
2886  /*elidedAttrs=*/getConditionAttrStrName());
2887 }
2888 
2889 IntegerSet AffineIfOp::getIntegerSet() {
2890  return (*this)
2891  ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2892  .getValue();
2893 }
2894 
2895 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
2896  (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
2897 }
2898 
2899 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
2900  setIntegerSet(set);
2901  (*this)->setOperands(operands);
2902 }
2903 
2904 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2905  TypeRange resultTypes, IntegerSet set, ValueRange args,
2906  bool withElseRegion) {
2907  assert(resultTypes.empty() || withElseRegion);
2908  OpBuilder::InsertionGuard guard(builder);
2909 
2910  result.addTypes(resultTypes);
2911  result.addOperands(args);
2912  result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
2913 
2914  Region *thenRegion = result.addRegion();
2915  builder.createBlock(thenRegion);
2916  if (resultTypes.empty())
2917  AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
2918 
2919  Region *elseRegion = result.addRegion();
2920  if (withElseRegion) {
2921  builder.createBlock(elseRegion);
2922  if (resultTypes.empty())
2923  AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
2924  }
2925 }
2926 
2927 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2928  IntegerSet set, ValueRange args, bool withElseRegion) {
2929  AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
2930  withElseRegion);
2931 }
2932 
2933 /// Compose any affine.apply ops feeding into `operands` of the integer set
2934 /// `set` by composing the maps of such affine.apply ops with the integer
2935 /// set constraints.
2937  SmallVectorImpl<Value> &operands) {
2938  // We will simply reuse the API of the map composition by viewing the LHSs of
2939  // the equalities and inequalities of `set` as the affine exprs of an affine
2940  // map. Convert to equivalent map, compose, and convert back to set.
2941  auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(),
2942  set.getConstraints(), set.getContext());
2943  // Check if any composition is possible.
2944  if (llvm::none_of(operands,
2945  [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
2946  return;
2947 
2948  composeAffineMapAndOperands(&map, &operands);
2949  set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
2950  set.getEqFlags());
2951 }
2952 
2953 /// Canonicalize an affine if op's conditional (integer set + operands).
2954 LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2955  auto set = getIntegerSet();
2956  SmallVector<Value, 4> operands(getOperands());
2957  composeSetAndOperands(set, operands);
2958  canonicalizeSetAndOperands(&set, &operands);
2959 
2960  // Check if the canonicalization or composition led to any change.
2961  if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
2962  return failure();
2963 
2964  setConditional(set, operands);
2965  return success();
2966 }
2967 
2968 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2969  MLIRContext *context) {
2970  results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
2971 }
2972 
2973 //===----------------------------------------------------------------------===//
2974 // AffineLoadOp
2975 //===----------------------------------------------------------------------===//
2976 
2977 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2978  AffineMap map, ValueRange operands) {
2979  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2980  result.addOperands(operands);
2981  if (map)
2982  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2983  auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
2984  result.types.push_back(memrefType.getElementType());
2985 }
2986 
2987 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2988  Value memref, AffineMap map, ValueRange mapOperands) {
2989  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2990  result.addOperands(memref);
2991  result.addOperands(mapOperands);
2992  auto memrefType = llvm::cast<MemRefType>(memref.getType());
2993  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2994  result.types.push_back(memrefType.getElementType());
2995 }
2996 
2997 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2998  Value memref, ValueRange indices) {
2999  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3000  int64_t rank = memrefType.getRank();
3001  // Create identity map for memrefs with at least one dimension or () -> ()
3002  // for zero-dimensional memrefs.
3003  auto map =
3004  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3005  build(builder, result, memref, map, indices);
3006 }
3007 
3008 ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
3009  auto &builder = parser.getBuilder();
3010  auto indexTy = builder.getIndexType();
3011 
3012  MemRefType type;
3013  OpAsmParser::UnresolvedOperand memrefInfo;
3014  AffineMapAttr mapAttr;
3016  return failure(
3017  parser.parseOperand(memrefInfo) ||
3018  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3019  AffineLoadOp::getMapAttrStrName(),
3020  result.attributes) ||
3021  parser.parseOptionalAttrDict(result.attributes) ||
3022  parser.parseColonType(type) ||
3023  parser.resolveOperand(memrefInfo, type, result.operands) ||
3024  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3025  parser.addTypeToList(type.getElementType(), result.types));
3026 }
3027 
3029  p << " " << getMemRef() << '[';
3030  if (AffineMapAttr mapAttr =
3031  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3032  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3033  p << ']';
3034  p.printOptionalAttrDict((*this)->getAttrs(),
3035  /*elidedAttrs=*/{getMapAttrStrName()});
3036  p << " : " << getMemRefType();
3037 }
3038 
3039 /// Verify common indexing invariants of affine.load, affine.store,
3040 /// affine.vector_load and affine.vector_store.
3041 static LogicalResult
3042 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
3043  Operation::operand_range mapOperands,
3044  MemRefType memrefType, unsigned numIndexOperands) {
3045  AffineMap map = mapAttr.getValue();
3046  if (map.getNumResults() != memrefType.getRank())
3047  return op->emitOpError("affine map num results must equal memref rank");
3048  if (map.getNumInputs() != numIndexOperands)
3049  return op->emitOpError("expects as many subscripts as affine map inputs");
3050 
3051  Region *scope = getAffineScope(op);
3052  for (auto idx : mapOperands) {
3053  if (!idx.getType().isIndex())
3054  return op->emitOpError("index to load must have 'index' type");
3055  if (!isValidAffineIndexOperand(idx, scope))
3056  return op->emitOpError(
3057  "index must be a valid dimension or symbol identifier");
3058  }
3059 
3060  return success();
3061 }
3062 
3063 LogicalResult AffineLoadOp::verify() {
3064  auto memrefType = getMemRefType();
3065  if (getType() != memrefType.getElementType())
3066  return emitOpError("result type must match element type of memref");
3067 
3068  if (failed(verifyMemoryOpIndexing(
3069  getOperation(),
3070  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3071  getMapOperands(), memrefType,
3072  /*numIndexOperands=*/getNumOperands() - 1)))
3073  return failure();
3074 
3075  return success();
3076 }
3077 
3078 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3079  MLIRContext *context) {
3080  results.add<SimplifyAffineOp<AffineLoadOp>>(context);
3081 }
3082 
3083 OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3084  /// load(memrefcast) -> load
3085  if (succeeded(memref::foldMemRefCast(*this)))
3086  return getResult();
3087 
3088  // Fold load from a global constant memref.
3089  auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3090  if (!getGlobalOp)
3091  return {};
3092  // Get to the memref.global defining the symbol.
3093  auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3094  if (!symbolTableOp)
3095  return {};
3096  auto global = dyn_cast_or_null<memref::GlobalOp>(
3097  SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
3098  if (!global)
3099  return {};
3100 
3101  // Check if the global memref is a constant.
3102  auto cstAttr =
3103  llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3104  if (!cstAttr)
3105  return {};
3106  // If it's a splat constant, we can fold irrespective of indices.
3107  if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3108  return splatAttr.getSplatValue<Attribute>();
3109  // Otherwise, we can fold only if we know the indices.
3110  if (!getAffineMap().isConstant())
3111  return {};
3112  auto indices = llvm::to_vector<4>(
3113  llvm::map_range(getAffineMap().getConstantResults(),
3114  [](int64_t v) -> uint64_t { return v; }));
3115  return cstAttr.getValues<Attribute>()[indices];
3116 }
3117 
3118 //===----------------------------------------------------------------------===//
3119 // AffineStoreOp
3120 //===----------------------------------------------------------------------===//
3121 
3122 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3123  Value valueToStore, Value memref, AffineMap map,
3124  ValueRange mapOperands) {
3125  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3126  result.addOperands(valueToStore);
3127  result.addOperands(memref);
3128  result.addOperands(mapOperands);
3129  result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3130 }
3131 
3132 // Use identity map.
3133 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3134  Value valueToStore, Value memref,
3135  ValueRange indices) {
3136  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3137  int64_t rank = memrefType.getRank();
3138  // Create identity map for memrefs with at least one dimension or () -> ()
3139  // for zero-dimensional memrefs.
3140  auto map =
3141  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3142  build(builder, result, valueToStore, memref, map, indices);
3143 }
3144 
3145 ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
3146  auto indexTy = parser.getBuilder().getIndexType();
3147 
3148  MemRefType type;
3149  OpAsmParser::UnresolvedOperand storeValueInfo;
3150  OpAsmParser::UnresolvedOperand memrefInfo;
3151  AffineMapAttr mapAttr;
3153  return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3154  parser.parseOperand(memrefInfo) ||
3155  parser.parseAffineMapOfSSAIds(
3156  mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3157  result.attributes) ||
3158  parser.parseOptionalAttrDict(result.attributes) ||
3159  parser.parseColonType(type) ||
3160  parser.resolveOperand(storeValueInfo, type.getElementType(),
3161  result.operands) ||
3162  parser.resolveOperand(memrefInfo, type, result.operands) ||
3163  parser.resolveOperands(mapOperands, indexTy, result.operands));
3164 }
3165 
3167  p << " " << getValueToStore();
3168  p << ", " << getMemRef() << '[';
3169  if (AffineMapAttr mapAttr =
3170  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3171  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3172  p << ']';
3173  p.printOptionalAttrDict((*this)->getAttrs(),
3174  /*elidedAttrs=*/{getMapAttrStrName()});
3175  p << " : " << getMemRefType();
3176 }
3177 
3178 LogicalResult AffineStoreOp::verify() {
3179  // The value to store must have the same type as memref element type.
3180  auto memrefType = getMemRefType();
3181  if (getValueToStore().getType() != memrefType.getElementType())
3182  return emitOpError(
3183  "value to store must have the same type as memref element type");
3184 
3185  if (failed(verifyMemoryOpIndexing(
3186  getOperation(),
3187  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3188  getMapOperands(), memrefType,
3189  /*numIndexOperands=*/getNumOperands() - 2)))
3190  return failure();
3191 
3192  return success();
3193 }
3194 
3195 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3196  MLIRContext *context) {
3197  results.add<SimplifyAffineOp<AffineStoreOp>>(context);
3198 }
3199 
3200 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3201  SmallVectorImpl<OpFoldResult> &results) {
3202  /// store(memrefcast) -> store
3203  return memref::foldMemRefCast(*this, getValueToStore());
3204 }
3205 
3206 //===----------------------------------------------------------------------===//
3207 // AffineMinMaxOpBase
3208 //===----------------------------------------------------------------------===//
3209 
3210 template <typename T>
3211 static LogicalResult verifyAffineMinMaxOp(T op) {
3212  // Verify that operand count matches affine map dimension and symbol count.
3213  if (op.getNumOperands() !=
3214  op.getMap().getNumDims() + op.getMap().getNumSymbols())
3215  return op.emitOpError(
3216  "operand count and affine map dimension and symbol count must match");
3217 
3218  if (op.getMap().getNumResults() == 0)
3219  return op.emitOpError("affine map expect at least one result");
3220  return success();
3221 }
3222 
3223 template <typename T>
3224 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3225  p << ' ' << op->getAttr(T::getMapAttrStrName());
3226  auto operands = op.getOperands();
3227  unsigned numDims = op.getMap().getNumDims();
3228  p << '(' << operands.take_front(numDims) << ')';
3229 
3230  if (operands.size() != numDims)
3231  p << '[' << operands.drop_front(numDims) << ']';
3233  /*elidedAttrs=*/{T::getMapAttrStrName()});
3234 }
3235 
3236 template <typename T>
3237 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
3238  OperationState &result) {
3239  auto &builder = parser.getBuilder();
3240  auto indexType = builder.getIndexType();
3243  AffineMapAttr mapAttr;
3244  return failure(
3245  parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3246  result.attributes) ||
3247  parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) ||
3248  parser.parseOperandList(symInfos,
3250  parser.parseOptionalAttrDict(result.attributes) ||
3251  parser.resolveOperands(dimInfos, indexType, result.operands) ||
3252  parser.resolveOperands(symInfos, indexType, result.operands) ||
3253  parser.addTypeToList(indexType, result.types));
3254 }
3255 
3256 /// Fold an affine min or max operation with the given operands. The operand
3257 /// list may contain nulls, which are interpreted as the operand not being a
3258 /// constant.
3259 template <typename T>
3261  static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3262  "expected affine min or max op");
3263 
3264  // Fold the affine map.
3265  // TODO: Fold more cases:
3266  // min(some_affine, some_affine + constant, ...), etc.
3267  SmallVector<int64_t, 2> results;
3268  auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3269 
3270  if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3271  return op.getOperand(0);
3272 
3273  // If some of the map results are not constant, try changing the map in-place.
3274  if (results.empty()) {
3275  // If the map is the same, report that folding did not happen.
3276  if (foldedMap == op.getMap())
3277  return {};
3278  op->setAttr("map", AffineMapAttr::get(foldedMap));
3279  return op.getResult();
3280  }
3281 
3282  // Otherwise, completely fold the op into a constant.
3283  auto resultIt = std::is_same<T, AffineMinOp>::value
3284  ? llvm::min_element(results)
3285  : llvm::max_element(results);
3286  if (resultIt == results.end())
3287  return {};
3288  return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3289 }
3290 
3291 /// Remove duplicated expressions in affine min/max ops.
3292 template <typename T>
3295 
3296  LogicalResult matchAndRewrite(T affineOp,
3297  PatternRewriter &rewriter) const override {
3298  AffineMap oldMap = affineOp.getAffineMap();
3299 
3300  SmallVector<AffineExpr, 4> newExprs;
3301  for (AffineExpr expr : oldMap.getResults()) {
3302  // This is a linear scan over newExprs, but it should be fine given that
3303  // we typically just have a few expressions per op.
3304  if (!llvm::is_contained(newExprs, expr))
3305  newExprs.push_back(expr);
3306  }
3307 
3308  if (newExprs.size() == oldMap.getNumResults())
3309  return failure();
3310 
3311  auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
3312  newExprs, rewriter.getContext());
3313  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3314 
3315  return success();
3316  }
3317 };
3318 
3319 /// Merge an affine min/max op to its consumers if its consumer is also an
3320 /// affine min/max op.
3321 ///
3322 /// This pattern requires the producer affine min/max op is bound to a
3323 /// dimension/symbol that is used as a standalone expression in the consumer
3324 /// affine op's map.
3325 ///
3326 /// For example, a pattern like the following:
3327 ///
3328 /// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3329 /// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3330 ///
3331 /// Can be turned into:
3332 ///
3333 /// %1 = affine.min affine_map<
3334 /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3335 template <typename T>
3338 
3339  LogicalResult matchAndRewrite(T affineOp,
3340  PatternRewriter &rewriter) const override {
3341  AffineMap oldMap = affineOp.getAffineMap();
3342  ValueRange dimOperands =
3343  affineOp.getMapOperands().take_front(oldMap.getNumDims());
3344  ValueRange symOperands =
3345  affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3346 
3347  auto newDimOperands = llvm::to_vector<8>(dimOperands);
3348  auto newSymOperands = llvm::to_vector<8>(symOperands);
3349  SmallVector<AffineExpr, 4> newExprs;
3350  SmallVector<T, 4> producerOps;
3351 
3352  // Go over each expression to see whether it's a single dimension/symbol
3353  // with the corresponding operand which is the result of another affine
3354  // min/max op. If So it can be merged into this affine op.
3355  for (AffineExpr expr : oldMap.getResults()) {
3356  if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3357  Value symValue = symOperands[symExpr.getPosition()];
3358  if (auto producerOp = symValue.getDefiningOp<T>()) {
3359  producerOps.push_back(producerOp);
3360  continue;
3361  }
3362  } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3363  Value dimValue = dimOperands[dimExpr.getPosition()];
3364  if (auto producerOp = dimValue.getDefiningOp<T>()) {
3365  producerOps.push_back(producerOp);
3366  continue;
3367  }
3368  }
3369  // For the above cases we will remove the expression by merging the
3370  // producer affine min/max's affine expressions. Otherwise we need to
3371  // keep the existing expression.
3372  newExprs.push_back(expr);
3373  }
3374 
3375  if (producerOps.empty())
3376  return failure();
3377 
3378  unsigned numUsedDims = oldMap.getNumDims();
3379  unsigned numUsedSyms = oldMap.getNumSymbols();
3380 
3381  // Now go over all producer affine ops and merge their expressions.
3382  for (T producerOp : producerOps) {
3383  AffineMap producerMap = producerOp.getAffineMap();
3384  unsigned numProducerDims = producerMap.getNumDims();
3385  unsigned numProducerSyms = producerMap.getNumSymbols();
3386 
3387  // Collect all dimension/symbol values.
3388  ValueRange dimValues =
3389  producerOp.getMapOperands().take_front(numProducerDims);
3390  ValueRange symValues =
3391  producerOp.getMapOperands().take_back(numProducerSyms);
3392  newDimOperands.append(dimValues.begin(), dimValues.end());
3393  newSymOperands.append(symValues.begin(), symValues.end());
3394 
3395  // For expressions we need to shift to avoid overlap.
3396  for (AffineExpr expr : producerMap.getResults()) {
3397  newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3398  .shiftSymbols(numProducerSyms, numUsedSyms));
3399  }
3400 
3401  numUsedDims += numProducerDims;
3402  numUsedSyms += numProducerSyms;
3403  }
3404 
3405  auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
3406  rewriter.getContext());
3407  auto newOperands =
3408  llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3409  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3410 
3411  return success();
3412  }
3413 };
3414 
3415 /// Canonicalize the result expression order of an affine map and return success
3416 /// if the order changed.
3417 ///
3418 /// The function flattens the map's affine expressions to coefficient arrays and
3419 /// sorts them in lexicographic order. A coefficient array contains a multiplier
3420 /// for every dimension/symbol and a constant term. The canonicalization fails
3421 /// if a result expression is not pure or if the flattening requires local
3422 /// variables that, unlike dimensions and symbols, have no global order.
3423 static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
3424  SmallVector<SmallVector<int64_t>> flattenedExprs;
3425  for (const AffineExpr &resultExpr : map.getResults()) {
3426  // Fail if the expression is not pure.
3427  if (!resultExpr.isPureAffine())
3428  return failure();
3429 
3430  SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3431  auto flattenResult = flattener.walkPostOrder(resultExpr);
3432  if (failed(flattenResult))
3433  return failure();
3434 
3435  // Fail if the flattened expression has local variables.
3436  if (flattener.operandExprStack.back().size() !=
3437  map.getNumDims() + map.getNumSymbols() + 1)
3438  return failure();
3439 
3440  flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3441  flattener.operandExprStack.back().end());
3442  }
3443 
3444  // Fail if sorting is not necessary.
3445  if (llvm::is_sorted(flattenedExprs))
3446  return failure();
3447 
3448  // Reorder the result expressions according to their flattened form.
3449  SmallVector<unsigned> resultPermutation =
3450  llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults()));
3451  llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3452  return flattenedExprs[lhs] < flattenedExprs[rhs];
3453  });
3454  SmallVector<AffineExpr> newExprs;
3455  for (unsigned idx : resultPermutation)
3456  newExprs.push_back(map.getResult(idx));
3457 
3458  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs,
3459  map.getContext());
3460  return success();
3461 }
3462 
3463 /// Canonicalize the affine map result expression order of an affine min/max
3464 /// operation.
3465 ///
3466 /// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3467 /// expressions and replaces the operation if the order changed.
3468 ///
3469 /// For example, the following operation:
3470 ///
3471 /// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3472 ///
3473 /// Turns into:
3474 ///
3475 /// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3476 template <typename T>
3479 
3480  LogicalResult matchAndRewrite(T affineOp,
3481  PatternRewriter &rewriter) const override {
3482  AffineMap map = affineOp.getAffineMap();
3483  if (failed(canonicalizeMapExprAndTermOrder(map)))
3484  return failure();
3485  rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3486  return success();
3487  }
3488 };
3489 
3490 template <typename T>
3493 
3494  LogicalResult matchAndRewrite(T affineOp,
3495  PatternRewriter &rewriter) const override {
3496  if (affineOp.getMap().getNumResults() != 1)
3497  return failure();
3498  rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3499  affineOp.getOperands());
3500  return success();
3501  }
3502 };
3503 
3504 //===----------------------------------------------------------------------===//
3505 // AffineMinOp
3506 //===----------------------------------------------------------------------===//
3507 //
3508 // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3509 //
3510 
3511 OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3512  return foldMinMaxOp(*this, adaptor.getOperands());
3513 }
3514 
3515 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3516  MLIRContext *context) {
3519  MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3521  context);
3522 }
3523 
3524 LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); }
3525 
3526 ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3527  return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3528 }
3529 
3530 void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3531 
3532 //===----------------------------------------------------------------------===//
3533 // AffineMaxOp
3534 //===----------------------------------------------------------------------===//
3535 //
3536 // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3537 //
3538 
3539 OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3540  return foldMinMaxOp(*this, adaptor.getOperands());
3541 }
3542 
3543 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3544  MLIRContext *context) {
3547  MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3549  context);
3550 }
3551 
3552 LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); }
3553 
3554 ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3555  return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3556 }
3557 
3558 void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3559 
3560 //===----------------------------------------------------------------------===//
3561 // AffinePrefetchOp
3562 //===----------------------------------------------------------------------===//
3563 
3564 //
3565 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3566 //
3567 ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3568  OperationState &result) {
3569  auto &builder = parser.getBuilder();
3570  auto indexTy = builder.getIndexType();
3571 
3572  MemRefType type;
3573  OpAsmParser::UnresolvedOperand memrefInfo;
3574  IntegerAttr hintInfo;
3575  auto i32Type = parser.getBuilder().getIntegerType(32);
3576  StringRef readOrWrite, cacheType;
3577 
3578  AffineMapAttr mapAttr;
3580  if (parser.parseOperand(memrefInfo) ||
3581  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3582  AffinePrefetchOp::getMapAttrStrName(),
3583  result.attributes) ||
3584  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3585  parser.parseComma() || parser.parseKeyword("locality") ||
3586  parser.parseLess() ||
3587  parser.parseAttribute(hintInfo, i32Type,
3588  AffinePrefetchOp::getLocalityHintAttrStrName(),
3589  result.attributes) ||
3590  parser.parseGreater() || parser.parseComma() ||
3591  parser.parseKeyword(&cacheType) ||
3592  parser.parseOptionalAttrDict(result.attributes) ||
3593  parser.parseColonType(type) ||
3594  parser.resolveOperand(memrefInfo, type, result.operands) ||
3595  parser.resolveOperands(mapOperands, indexTy, result.operands))
3596  return failure();
3597 
3598  if (readOrWrite != "read" && readOrWrite != "write")
3599  return parser.emitError(parser.getNameLoc(),
3600  "rw specifier has to be 'read' or 'write'");
3601  result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3602  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
3603 
3604  if (cacheType != "data" && cacheType != "instr")
3605  return parser.emitError(parser.getNameLoc(),
3606  "cache type has to be 'data' or 'instr'");
3607 
3608  result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3609  parser.getBuilder().getBoolAttr(cacheType == "data"));
3610 
3611  return success();
3612 }
3613 
3615  p << " " << getMemref() << '[';
3616  AffineMapAttr mapAttr =
3617  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3618  if (mapAttr)
3619  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3620  p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3621  << "locality<" << getLocalityHint() << ">, "
3622  << (getIsDataCache() ? "data" : "instr");
3624  (*this)->getAttrs(),
3625  /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3626  getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3627  p << " : " << getMemRefType();
3628 }
3629 
3630 LogicalResult AffinePrefetchOp::verify() {
3631  auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3632  if (mapAttr) {
3633  AffineMap map = mapAttr.getValue();
3634  if (map.getNumResults() != getMemRefType().getRank())
3635  return emitOpError("affine.prefetch affine map num results must equal"
3636  " memref rank");
3637  if (map.getNumInputs() + 1 != getNumOperands())
3638  return emitOpError("too few operands");
3639  } else {
3640  if (getNumOperands() != 1)
3641  return emitOpError("too few operands");
3642  }
3643 
3644  Region *scope = getAffineScope(*this);
3645  for (auto idx : getMapOperands()) {
3646  if (!isValidAffineIndexOperand(idx, scope))
3647  return emitOpError(
3648  "index must be a valid dimension or symbol identifier");
3649  }
3650  return success();
3651 }
3652 
3653 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3654  MLIRContext *context) {
3655  // prefetch(memrefcast) -> prefetch
3656  results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3657 }
3658 
3659 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3660  SmallVectorImpl<OpFoldResult> &results) {
3661  /// prefetch(memrefcast) -> prefetch
3662  return memref::foldMemRefCast(*this);
3663 }
3664 
3665 //===----------------------------------------------------------------------===//
3666 // AffineParallelOp
3667 //===----------------------------------------------------------------------===//
3668 
3669 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3670  TypeRange resultTypes,
3671  ArrayRef<arith::AtomicRMWKind> reductions,
3672  ArrayRef<int64_t> ranges) {
3673  SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
3674  auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3675  return builder.getConstantAffineMap(value);
3676  }));
3677  SmallVector<int64_t> steps(ranges.size(), 1);
3678  build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
3679  /*ubArgs=*/{}, steps);
3680 }
3681 
3682 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3683  TypeRange resultTypes,
3684  ArrayRef<arith::AtomicRMWKind> reductions,
3685  ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
3686  ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
3687  ArrayRef<int64_t> steps) {
3688  assert(llvm::all_of(lbMaps,
3689  [lbMaps](AffineMap m) {
3690  return m.getNumDims() == lbMaps[0].getNumDims() &&
3691  m.getNumSymbols() == lbMaps[0].getNumSymbols();
3692  }) &&
3693  "expected all lower bounds maps to have the same number of dimensions "
3694  "and symbols");
3695  assert(llvm::all_of(ubMaps,
3696  [ubMaps](AffineMap m) {
3697  return m.getNumDims() == ubMaps[0].getNumDims() &&
3698  m.getNumSymbols() == ubMaps[0].getNumSymbols();
3699  }) &&
3700  "expected all upper bounds maps to have the same number of dimensions "
3701  "and symbols");
3702  assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3703  "expected lower bound maps to have as many inputs as lower bound "
3704  "operands");
3705  assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3706  "expected upper bound maps to have as many inputs as upper bound "
3707  "operands");
3708 
3709  OpBuilder::InsertionGuard guard(builder);
3710  result.addTypes(resultTypes);
3711 
3712  // Convert the reductions to integer attributes.
3713  SmallVector<Attribute, 4> reductionAttrs;
3714  for (arith::AtomicRMWKind reduction : reductions)
3715  reductionAttrs.push_back(
3716  builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
3717  result.addAttribute(getReductionsAttrStrName(),
3718  builder.getArrayAttr(reductionAttrs));
3719 
3720  // Concatenates maps defined in the same input space (same dimensions and
3721  // symbols), assumes there is at least one map.
3722  auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
3723  SmallVectorImpl<int32_t> &groups) {
3724  if (maps.empty())
3725  return AffineMap::get(builder.getContext());
3727  groups.reserve(groups.size() + maps.size());
3728  exprs.reserve(maps.size());
3729  for (AffineMap m : maps) {
3730  llvm::append_range(exprs, m.getResults());
3731  groups.push_back(m.getNumResults());
3732  }
3733  return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3734  maps[0].getContext());
3735  };
3736 
3737  // Set up the bounds.
3738  SmallVector<int32_t> lbGroups, ubGroups;
3739  AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3740  AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3741  result.addAttribute(getLowerBoundsMapAttrStrName(),
3742  AffineMapAttr::get(lbMap));
3743  result.addAttribute(getLowerBoundsGroupsAttrStrName(),
3744  builder.getI32TensorAttr(lbGroups));
3745  result.addAttribute(getUpperBoundsMapAttrStrName(),
3746  AffineMapAttr::get(ubMap));
3747  result.addAttribute(getUpperBoundsGroupsAttrStrName(),
3748  builder.getI32TensorAttr(ubGroups));
3749  result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
3750  result.addOperands(lbArgs);
3751  result.addOperands(ubArgs);
3752 
3753  // Create a region and a block for the body.
3754  auto *bodyRegion = result.addRegion();
3755  Block *body = builder.createBlock(bodyRegion);
3756 
3757  // Add all the block arguments.
3758  for (unsigned i = 0, e = steps.size(); i < e; ++i)
3759  body->addArgument(IndexType::get(builder.getContext()), result.location);
3760  if (resultTypes.empty())
3761  ensureTerminator(*bodyRegion, builder, result.location);
3762 }
3763 
3764 SmallVector<Region *> AffineParallelOp::getLoopRegions() {
3765  return {&getRegion()};
3766 }
3767 
3768 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
3769 
3770 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3771  return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3772 }
3773 
3774 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3775  return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3776 }
3777 
3778 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
3779  auto values = getLowerBoundsGroups().getValues<int32_t>();
3780  unsigned start = 0;
3781  for (unsigned i = 0; i < pos; ++i)
3782  start += values[i];
3783  return getLowerBoundsMap().getSliceMap(start, values[pos]);
3784 }
3785 
3786 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
3787  auto values = getUpperBoundsGroups().getValues<int32_t>();
3788  unsigned start = 0;
3789  for (unsigned i = 0; i < pos; ++i)
3790  start += values[i];
3791  return getUpperBoundsMap().getSliceMap(start, values[pos]);
3792 }
3793 
3794 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
3795  return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3796 }
3797 
3798 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
3799  return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3800 }
3801 
3802 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3803  if (hasMinMaxBounds())
3804  return std::nullopt;
3805 
3806  // Try to convert all the ranges to constant expressions.
3808  AffineValueMap rangesValueMap;
3809  AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3810  &rangesValueMap);
3811  out.reserve(rangesValueMap.getNumResults());
3812  for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
3813  auto expr = rangesValueMap.getResult(i);
3814  auto cst = dyn_cast<AffineConstantExpr>(expr);
3815  if (!cst)
3816  return std::nullopt;
3817  out.push_back(cst.getValue());
3818  }
3819  return out;
3820 }
3821 
3822 Block *AffineParallelOp::getBody() { return &getRegion().front(); }
3823 
3824 OpBuilder AffineParallelOp::getBodyBuilder() {
3825  return OpBuilder(getBody(), std::prev(getBody()->end()));
3826 }
3827 
3828 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
3829  assert(lbOperands.size() == map.getNumInputs() &&
3830  "operands to map must match number of inputs");
3831 
3832  auto ubOperands = getUpperBoundsOperands();
3833 
3834  SmallVector<Value, 4> newOperands(lbOperands);
3835  newOperands.append(ubOperands.begin(), ubOperands.end());
3836  (*this)->setOperands(newOperands);
3837 
3838  setLowerBoundsMapAttr(AffineMapAttr::get(map));
3839 }
3840 
3841 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
3842  assert(ubOperands.size() == map.getNumInputs() &&
3843  "operands to map must match number of inputs");
3844 
3845  SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
3846  newOperands.append(ubOperands.begin(), ubOperands.end());
3847  (*this)->setOperands(newOperands);
3848 
3849  setUpperBoundsMapAttr(AffineMapAttr::get(map));
3850 }
3851 
3852 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
3853  setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3854 }
3855 
3856 // check whether resultType match op or not in affine.parallel
3857 static bool isResultTypeMatchAtomicRMWKind(Type resultType,
3858  arith::AtomicRMWKind op) {
3859  switch (op) {
3860  case arith::AtomicRMWKind::addf:
3861  return isa<FloatType>(resultType);
3862  case arith::AtomicRMWKind::addi:
3863  return isa<IntegerType>(resultType);
3864  case arith::AtomicRMWKind::assign:
3865  return true;
3866  case arith::AtomicRMWKind::mulf:
3867  return isa<FloatType>(resultType);
3868  case arith::AtomicRMWKind::muli:
3869  return isa<IntegerType>(resultType);
3870  case arith::AtomicRMWKind::maximumf:
3871  return isa<FloatType>(resultType);
3872  case arith::AtomicRMWKind::minimumf:
3873  return isa<FloatType>(resultType);
3874  case arith::AtomicRMWKind::maxs: {
3875  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3876  return intType && intType.isSigned();
3877  }
3878  case arith::AtomicRMWKind::mins: {
3879  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3880  return intType && intType.isSigned();
3881  }
3882  case arith::AtomicRMWKind::maxu: {
3883  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3884  return intType && intType.isUnsigned();
3885  }
3886  case arith::AtomicRMWKind::minu: {
3887  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3888  return intType && intType.isUnsigned();
3889  }
3890  case arith::AtomicRMWKind::ori:
3891  return isa<IntegerType>(resultType);
3892  case arith::AtomicRMWKind::andi:
3893  return isa<IntegerType>(resultType);
3894  default:
3895  return false;
3896  }
3897 }
3898 
3899 LogicalResult AffineParallelOp::verify() {
3900  auto numDims = getNumDims();
3901  if (getLowerBoundsGroups().getNumElements() != numDims ||
3902  getUpperBoundsGroups().getNumElements() != numDims ||
3903  getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3904  return emitOpError() << "the number of region arguments ("
3905  << getBody()->getNumArguments()
3906  << ") and the number of map groups for lower ("
3907  << getLowerBoundsGroups().getNumElements()
3908  << ") and upper bound ("
3909  << getUpperBoundsGroups().getNumElements()
3910  << "), and the number of steps (" << getSteps().size()
3911  << ") must all match";
3912  }
3913 
3914  unsigned expectedNumLBResults = 0;
3915  for (APInt v : getLowerBoundsGroups())
3916  expectedNumLBResults += v.getZExtValue();
3917  if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
3918  return emitOpError() << "expected lower bounds map to have "
3919  << expectedNumLBResults << " results";
3920  unsigned expectedNumUBResults = 0;
3921  for (APInt v : getUpperBoundsGroups())
3922  expectedNumUBResults += v.getZExtValue();
3923  if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
3924  return emitOpError() << "expected upper bounds map to have "
3925  << expectedNumUBResults << " results";
3926 
3927  if (getReductions().size() != getNumResults())
3928  return emitOpError("a reduction must be specified for each output");
3929 
3930  // Verify reduction ops are all valid and each result type matches reduction
3931  // ops
3932  for (auto it : llvm::enumerate((getReductions()))) {
3933  Attribute attr = it.value();
3934  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
3935  if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
3936  return emitOpError("invalid reduction attribute");
3937  auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
3938  if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
3939  return emitOpError("result type cannot match reduction attribute");
3940  }
3941 
3942  // Verify that the bound operands are valid dimension/symbols.
3943  /// Lower bounds.
3944  if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
3945  getLowerBoundsMap().getNumDims())))
3946  return failure();
3947  /// Upper bounds.
3948  if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
3949  getUpperBoundsMap().getNumDims())))
3950  return failure();
3951  return success();
3952 }
3953 
3954 LogicalResult AffineValueMap::canonicalize() {
3955  SmallVector<Value, 4> newOperands{operands};
3956  auto newMap = getAffineMap();
3957  composeAffineMapAndOperands(&newMap, &newOperands);
3958  if (newMap == getAffineMap() && newOperands == operands)
3959  return failure();
3960  reset(newMap, newOperands);
3961  return success();
3962 }
3963 
3964 /// Canonicalize the bounds of the given loop.
3965 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
3966  AffineValueMap lb = op.getLowerBoundsValueMap();
3967  bool lbCanonicalized = succeeded(lb.canonicalize());
3968 
3969  AffineValueMap ub = op.getUpperBoundsValueMap();
3970  bool ubCanonicalized = succeeded(ub.canonicalize());
3971 
3972  // Any canonicalization change always leads to updated map(s).
3973  if (!lbCanonicalized && !ubCanonicalized)
3974  return failure();
3975 
3976  if (lbCanonicalized)
3977  op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
3978  if (ubCanonicalized)
3979  op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
3980 
3981  return success();
3982 }
3983 
3984 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
3985  SmallVectorImpl<OpFoldResult> &results) {
3986  return canonicalizeLoopBounds(*this);
3987 }
3988 
3989 /// Prints a lower(upper) bound of an affine parallel loop with max(min)
3990 /// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
3991 /// identifies which of the those expressions form max/min groups. `operands`
3992 /// are the SSA values of dimensions and symbols and `keyword` is either "min"
3993 /// or "max".
3994 static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
3995  DenseIntElementsAttr group, ValueRange operands,
3996  StringRef keyword) {
3997  AffineMap map = mapAttr.getValue();
3998  unsigned numDims = map.getNumDims();
3999  ValueRange dimOperands = operands.take_front(numDims);
4000  ValueRange symOperands = operands.drop_front(numDims);
4001  unsigned start = 0;
4002  for (llvm::APInt groupSize : group) {
4003  if (start != 0)
4004  p << ", ";
4005 
4006  unsigned size = groupSize.getZExtValue();
4007  if (size == 1) {
4008  p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
4009  ++start;
4010  } else {
4011  p << keyword << '(';
4012  AffineMap submap = map.getSliceMap(start, size);
4013  p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
4014  p << ')';
4015  start += size;
4016  }
4017  }
4018 }
4019 
4021  p << " (" << getBody()->getArguments() << ") = (";
4022  printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
4023  getLowerBoundsOperands(), "max");
4024  p << ") to (";
4025  printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
4026  getUpperBoundsOperands(), "min");
4027  p << ')';
4028  SmallVector<int64_t, 8> steps = getSteps();
4029  bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4030  if (!elideSteps) {
4031  p << " step (";
4032  llvm::interleaveComma(steps, p);
4033  p << ')';
4034  }
4035  if (getNumResults()) {
4036  p << " reduce (";
4037  llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4038  arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4039  llvm::cast<IntegerAttr>(attr).getInt());
4040  p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4041  });
4042  p << ") -> (" << getResultTypes() << ")";
4043  }
4044 
4045  p << ' ';
4046  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4047  /*printBlockTerminators=*/getNumResults());
4049  (*this)->getAttrs(),
4050  /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4051  AffineParallelOp::getLowerBoundsMapAttrStrName(),
4052  AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4053  AffineParallelOp::getUpperBoundsMapAttrStrName(),
4054  AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4055  AffineParallelOp::getStepsAttrStrName()});
4056 }
4057 
4058 /// Given a list of lists of parsed operands, populates `uniqueOperands` with
4059 /// unique operands. Also populates `replacements with affine expressions of
4060 /// `kind` that can be used to update affine maps previously accepting a
4061 /// `operands` to accept `uniqueOperands` instead.
4063  OpAsmParser &parser,
4065  SmallVectorImpl<Value> &uniqueOperands,
4066  SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
4067  assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
4068  "expected operands to be dim or symbol expression");
4069 
4070  Type indexType = parser.getBuilder().getIndexType();
4071  for (const auto &list : operands) {
4072  SmallVector<Value> valueOperands;
4073  if (parser.resolveOperands(list, indexType, valueOperands))
4074  return failure();
4075  for (Value operand : valueOperands) {
4076  unsigned pos = std::distance(uniqueOperands.begin(),
4077  llvm::find(uniqueOperands, operand));
4078  if (pos == uniqueOperands.size())
4079  uniqueOperands.push_back(operand);
4080  replacements.push_back(
4081  kind == AffineExprKind::DimId
4082  ? getAffineDimExpr(pos, parser.getContext())
4083  : getAffineSymbolExpr(pos, parser.getContext()));
4084  }
4085  }
4086  return success();
4087 }
4088 
4089 namespace {
4090 enum class MinMaxKind { Min, Max };
4091 } // namespace
4092 
4093 /// Parses an affine map that can contain a min/max for groups of its results,
4094 /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4095 /// `result` attributes with the map (flat list of expressions) and the grouping
4096 /// (list of integers that specify how many expressions to put into each
4097 /// min/max) attributes. Deduplicates repeated operands.
4098 ///
4099 /// parallel-bound ::= `(` parallel-group-list `)`
4100 /// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
4101 /// parallel-group ::= simple-group | min-max-group
4102 /// simple-group ::= expr-of-ssa-ids
4103 /// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4104 /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4105 ///
4106 /// Examples:
4107 /// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4108 /// (%0, max(%1 - 2 * %2))
4109 static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
4110  OperationState &result,
4111  MinMaxKind kind) {
4112  // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4113  // see: https://reviews.llvm.org/D134227#3821753
4114  const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4115 
4116  StringRef mapName = kind == MinMaxKind::Min
4117  ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4118  : AffineParallelOp::getLowerBoundsMapAttrStrName();
4119  StringRef groupsName =
4120  kind == MinMaxKind::Min
4121  ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4122  : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4123 
4124  if (failed(parser.parseLParen()))
4125  return failure();
4126 
4127  if (succeeded(parser.parseOptionalRParen())) {
4128  result.addAttribute(
4129  mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4130  result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
4131  return success();
4132  }
4133 
4134  SmallVector<AffineExpr> flatExprs;
4137  SmallVector<int32_t> numMapsPerGroup;
4139  auto parseOperands = [&]() {
4140  if (succeeded(parser.parseOptionalKeyword(
4141  kind == MinMaxKind::Min ? "min" : "max"))) {
4142  mapOperands.clear();
4143  AffineMapAttr map;
4144  if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName,
4145  result.attributes,
4147  return failure();
4148  result.attributes.erase(tmpAttrStrName);
4149  llvm::append_range(flatExprs, map.getValue().getResults());
4150  auto operandsRef = llvm::ArrayRef(mapOperands);
4151  auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4153  auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4155  flatDimOperands.append(map.getValue().getNumResults(), dims);
4156  flatSymOperands.append(map.getValue().getNumResults(), syms);
4157  numMapsPerGroup.push_back(map.getValue().getNumResults());
4158  } else {
4159  if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
4160  flatSymOperands.emplace_back(),
4161  flatExprs.emplace_back())))
4162  return failure();
4163  numMapsPerGroup.push_back(1);
4164  }
4165  return success();
4166  };
4167  if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
4168  return failure();
4169 
4170  unsigned totalNumDims = 0;
4171  unsigned totalNumSyms = 0;
4172  for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4173  unsigned numDims = flatDimOperands[i].size();
4174  unsigned numSyms = flatSymOperands[i].size();
4175  flatExprs[i] = flatExprs[i]
4176  .shiftDims(numDims, totalNumDims)
4177  .shiftSymbols(numSyms, totalNumSyms);
4178  totalNumDims += numDims;
4179  totalNumSyms += numSyms;
4180  }
4181 
4182  // Deduplicate map operands.
4183  SmallVector<Value> dimOperands, symOperands;
4184  SmallVector<AffineExpr> dimRplacements, symRepacements;
4185  if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
4186  dimRplacements, AffineExprKind::DimId) ||
4187  deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
4188  symRepacements, AffineExprKind::SymbolId))
4189  return failure();
4190 
4191  result.operands.append(dimOperands.begin(), dimOperands.end());
4192  result.operands.append(symOperands.begin(), symOperands.end());
4193 
4194  Builder &builder = parser.getBuilder();
4195  auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4196  parser.getContext());
4197  flatMap = flatMap.replaceDimsAndSymbols(
4198  dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4199 
4200  result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4201  result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
4202  return success();
4203 }
4204 
4205 //
4206 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4207 // `to` parallel-bound steps? region attr-dict?
4208 // steps ::= `steps` `(` integer-literals `)`
4209 //
4210 ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4211  OperationState &result) {
4212  auto &builder = parser.getBuilder();
4213  auto indexType = builder.getIndexType();
4216  parser.parseEqual() ||
4217  parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
4218  parser.parseKeyword("to") ||
4219  parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
4220  return failure();
4221 
4222  AffineMapAttr stepsMapAttr;
4223  NamedAttrList stepsAttrs;
4225  if (failed(parser.parseOptionalKeyword("step"))) {
4226  SmallVector<int64_t, 4> steps(ivs.size(), 1);
4227  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4228  builder.getI64ArrayAttr(steps));
4229  } else {
4230  if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
4231  AffineParallelOp::getStepsAttrStrName(),
4232  stepsAttrs,
4234  return failure();
4235 
4236  // Convert steps from an AffineMap into an I64ArrayAttr.
4238  auto stepsMap = stepsMapAttr.getValue();
4239  for (const auto &result : stepsMap.getResults()) {
4240  auto constExpr = dyn_cast<AffineConstantExpr>(result);
4241  if (!constExpr)
4242  return parser.emitError(parser.getNameLoc(),
4243  "steps must be constant integers");
4244  steps.push_back(constExpr.getValue());
4245  }
4246  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4247  builder.getI64ArrayAttr(steps));
4248  }
4249 
4250  // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4251  // quoted strings are a member of the enum AtomicRMWKind.
4252  SmallVector<Attribute, 4> reductions;
4253  if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4254  if (parser.parseLParen())
4255  return failure();
4256  auto parseAttributes = [&]() -> ParseResult {
4257  // Parse a single quoted string via the attribute parsing, and then
4258  // verify it is a member of the enum and convert to it's integer
4259  // representation.
4260  StringAttr attrVal;
4261  NamedAttrList attrStorage;
4262  auto loc = parser.getCurrentLocation();
4263  if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
4264  attrStorage))
4265  return failure();
4266  std::optional<arith::AtomicRMWKind> reduction =
4267  arith::symbolizeAtomicRMWKind(attrVal.getValue());
4268  if (!reduction)
4269  return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4270  reductions.push_back(
4271  builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4272  // While we keep getting commas, keep parsing.
4273  return success();
4274  };
4275  if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
4276  return failure();
4277  }
4278  result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4279  builder.getArrayAttr(reductions));
4280 
4281  // Parse return types of reductions (if any)
4282  if (parser.parseOptionalArrowTypeList(result.types))
4283  return failure();
4284 
4285  // Now parse the body.
4286  Region *body = result.addRegion();
4287  for (auto &iv : ivs)
4288  iv.type = indexType;
4289  if (parser.parseRegion(*body, ivs) ||
4290  parser.parseOptionalAttrDict(result.attributes))
4291  return failure();
4292 
4293  // Add a terminator if none was parsed.
4294  AffineParallelOp::ensureTerminator(*body, builder, result.location);
4295  return success();
4296 }
4297 
4298 //===----------------------------------------------------------------------===//
4299 // AffineYieldOp
4300 //===----------------------------------------------------------------------===//
4301 
4302 LogicalResult AffineYieldOp::verify() {
4303  auto *parentOp = (*this)->getParentOp();
4304  auto results = parentOp->getResults();
4305  auto operands = getOperands();
4306 
4307  if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4308  return emitOpError() << "only terminates affine.if/for/parallel regions";
4309  if (parentOp->getNumResults() != getNumOperands())
4310  return emitOpError() << "parent of yield must have same number of "
4311  "results as the yield operands";
4312  for (auto it : llvm::zip(results, operands)) {
4313  if (std::get<0>(it).getType() != std::get<1>(it).getType())
4314  return emitOpError() << "types mismatch between yield op and its parent";
4315  }
4316 
4317  return success();
4318 }
4319 
4320 //===----------------------------------------------------------------------===//
4321 // AffineVectorLoadOp
4322 //===----------------------------------------------------------------------===//
4323 
4324 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4325  VectorType resultType, AffineMap map,
4326  ValueRange operands) {
4327  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4328  result.addOperands(operands);
4329  if (map)
4330  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4331  result.types.push_back(resultType);
4332 }
4333 
4334 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4335  VectorType resultType, Value memref,
4336  AffineMap map, ValueRange mapOperands) {
4337  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4338  result.addOperands(memref);
4339  result.addOperands(mapOperands);
4340  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4341  result.types.push_back(resultType);
4342 }
4343 
4344 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4345  VectorType resultType, Value memref,
4346  ValueRange indices) {
4347  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4348  int64_t rank = memrefType.getRank();
4349  // Create identity map for memrefs with at least one dimension or () -> ()
4350  // for zero-dimensional memrefs.
4351  auto map =
4352  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4353  build(builder, result, resultType, memref, map, indices);
4354 }
4355 
4356 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4357  MLIRContext *context) {
4358  results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4359 }
4360 
4361 ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4362  OperationState &result) {
4363  auto &builder = parser.getBuilder();
4364  auto indexTy = builder.getIndexType();
4365 
4366  MemRefType memrefType;
4367  VectorType resultType;
4368  OpAsmParser::UnresolvedOperand memrefInfo;
4369  AffineMapAttr mapAttr;
4371  return failure(
4372  parser.parseOperand(memrefInfo) ||
4373  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4374  AffineVectorLoadOp::getMapAttrStrName(),
4375  result.attributes) ||
4376  parser.parseOptionalAttrDict(result.attributes) ||
4377  parser.parseColonType(memrefType) || parser.parseComma() ||
4378  parser.parseType(resultType) ||
4379  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4380  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
4381  parser.addTypeToList(resultType, result.types));
4382 }
4383 
4385  p << " " << getMemRef() << '[';
4386  if (AffineMapAttr mapAttr =
4387  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4388  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4389  p << ']';
4390  p.printOptionalAttrDict((*this)->getAttrs(),
4391  /*elidedAttrs=*/{getMapAttrStrName()});
4392  p << " : " << getMemRefType() << ", " << getType();
4393 }
4394 
4395 /// Verify common invariants of affine.vector_load and affine.vector_store.
4396 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4397  VectorType vectorType) {
4398  // Check that memref and vector element types match.
4399  if (memrefType.getElementType() != vectorType.getElementType())
4400  return op->emitOpError(
4401  "requires memref and vector types of the same elemental type");
4402  return success();
4403 }
4404 
4405 LogicalResult AffineVectorLoadOp::verify() {
4406  MemRefType memrefType = getMemRefType();
4407  if (failed(verifyMemoryOpIndexing(
4408  getOperation(),
4409  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4410  getMapOperands(), memrefType,
4411  /*numIndexOperands=*/getNumOperands() - 1)))
4412  return failure();
4413 
4414  if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
4415  return failure();
4416 
4417  return success();
4418 }
4419 
4420 //===----------------------------------------------------------------------===//
4421 // AffineVectorStoreOp
4422 //===----------------------------------------------------------------------===//
4423 
4424 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4425  Value valueToStore, Value memref, AffineMap map,
4426  ValueRange mapOperands) {
4427  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4428  result.addOperands(valueToStore);
4429  result.addOperands(memref);
4430  result.addOperands(mapOperands);
4431  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4432 }
4433 
4434 // Use identity map.
4435 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4436  Value valueToStore, Value memref,
4437  ValueRange indices) {
4438  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4439  int64_t rank = memrefType.getRank();
4440  // Create identity map for memrefs with at least one dimension or () -> ()
4441  // for zero-dimensional memrefs.
4442  auto map =
4443  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4444  build(builder, result, valueToStore, memref, map, indices);
4445 }
4446 void AffineVectorStoreOp::getCanonicalizationPatterns(
4447  RewritePatternSet &results, MLIRContext *context) {
4448  results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4449 }
4450 
4451 ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4452  OperationState &result) {
4453  auto indexTy = parser.getBuilder().getIndexType();
4454 
4455  MemRefType memrefType;
4456  VectorType resultType;
4457  OpAsmParser::UnresolvedOperand storeValueInfo;
4458  OpAsmParser::UnresolvedOperand memrefInfo;
4459  AffineMapAttr mapAttr;
4461  return failure(
4462  parser.parseOperand(storeValueInfo) || parser.parseComma() ||
4463  parser.parseOperand(memrefInfo) ||
4464  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4465  AffineVectorStoreOp::getMapAttrStrName(),
4466  result.attributes) ||
4467  parser.parseOptionalAttrDict(result.attributes) ||
4468  parser.parseColonType(memrefType) || parser.parseComma() ||
4469  parser.parseType(resultType) ||
4470  parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
4471  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4472  parser.resolveOperands(mapOperands, indexTy, result.operands));
4473 }
4474 
4476  p << " " << getValueToStore();
4477  p << ", " << getMemRef() << '[';
4478  if (AffineMapAttr mapAttr =
4479  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4480  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4481  p << ']';
4482  p.printOptionalAttrDict((*this)->getAttrs(),
4483  /*elidedAttrs=*/{getMapAttrStrName()});
4484  p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4485 }
4486 
4487 LogicalResult AffineVectorStoreOp::verify() {
4488  MemRefType memrefType = getMemRefType();
4489  if (failed(verifyMemoryOpIndexing(
4490  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4491  getMapOperands(), memrefType,
4492  /*numIndexOperands=*/getNumOperands() - 2)))
4493  return failure();
4494 
4495  if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4496  return failure();
4497 
4498  return success();
4499 }
4500 
4501 //===----------------------------------------------------------------------===//
4502 // DelinearizeIndexOp
4503 //===----------------------------------------------------------------------===//
4504 
4505 LogicalResult AffineDelinearizeIndexOp::inferReturnTypes(
4506  MLIRContext *context, std::optional<::mlir::Location> location,
4507  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
4508  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
4509  AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
4510  regions);
4511  inferredReturnTypes.assign(adaptor.getBasis().size(),
4512  IndexType::get(context));
4513  return success();
4514 }
4515 
4516 void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result,
4517  Value linearIndex,
4518  ArrayRef<OpFoldResult> basis) {
4519  result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType()));
4520  result.addOperands(linearIndex);
4521  SmallVector<Value> basisValues =
4522  llvm::map_to_vector(basis, [&](OpFoldResult ofr) -> Value {
4523  std::optional<int64_t> staticDim = getConstantIntValue(ofr);
4524  if (staticDim.has_value())
4525  return builder.create<arith::ConstantIndexOp>(result.location,
4526  *staticDim);
4527  return llvm::dyn_cast_if_present<Value>(ofr);
4528  });
4529  result.addOperands(basisValues);
4530 }
4531 
4532 LogicalResult AffineDelinearizeIndexOp::verify() {
4533  if (getBasis().empty())
4534  return emitOpError("basis should not be empty");
4535  if (getNumResults() != getBasis().size())
4536  return emitOpError("should return an index for each basis element");
4537  return success();
4538 }
4539 
4540 namespace {
4541 
4542 // Drops delinearization indices that correspond to unit-extent basis
4543 struct DropUnitExtentBasis
4544  : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4546 
4547  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4548  PatternRewriter &rewriter) const override {
4549  SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
4550  std::optional<Value> zero = std::nullopt;
4551  Location loc = delinearizeOp->getLoc();
4552  auto getZero = [&]() -> Value {
4553  if (!zero)
4554  zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
4555  return zero.value();
4556  };
4557 
4558  // Replace all indices corresponding to unit-extent basis with 0.
4559  // Remaining basis can be used to get a new `affine.delinearize_index` op.
4560  SmallVector<Value> newOperands;
4561  for (auto [index, basis] : llvm::enumerate(delinearizeOp.getBasis())) {
4562  if (matchPattern(basis, m_One()))
4563  replacements[index] = getZero();
4564  else
4565  newOperands.push_back(basis);
4566  }
4567 
4568  if (newOperands.size() == delinearizeOp.getBasis().size())
4569  return failure();
4570 
4571  if (!newOperands.empty()) {
4572  auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
4573  loc, delinearizeOp.getLinearIndex(), newOperands);
4574  int newIndex = 0;
4575  // Map back the new delinearized indices to the values they replace.
4576  for (auto &replacement : replacements) {
4577  if (replacement)
4578  continue;
4579  replacement = newDelinearizeOp->getResult(newIndex++);
4580  }
4581  }
4582 
4583  rewriter.replaceOp(delinearizeOp, replacements);
4584  return success();
4585  }
4586 };
4587 
4588 /// Drop delinearization pattern related to loops in the following way
4589 ///
4590 /// ```
4591 /// <loop>(%iv) = (%c0) to (%ub) step (%c1) {
4592 /// %0 = affine.delinearize_index %iv into (%ub) : index
4593 /// <some_use>(%0)
4594 /// }
4595 /// ```
4596 ///
4597 /// can be canonicalized to
4598 ///
4599 /// ```
4600 /// <loop>(%iv) = (%c0) to (%ub) step (%c1) {
4601 /// <some_use>(%iv)
4602 /// }
4603 /// ```
4604 struct DropDelinearizeOfSingleLoop
4605  : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4607 
4608  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4609  PatternRewriter &rewriter) const override {
4610  auto basis = delinearizeOp.getBasis();
4611  if (basis.size() != 1)
4612  return failure();
4613 
4614  // Check that the `linear_index` is an induction variable.
4615  auto inductionVar = dyn_cast<BlockArgument>(delinearizeOp.getLinearIndex());
4616  if (!inductionVar)
4617  return failure();
4618 
4619  // Check that the parent is a `LoopLikeOpInterface`.
4620  auto loopLikeOp = dyn_cast<LoopLikeOpInterface>(
4621  inductionVar.getParentRegion()->getParentOp());
4622  if (!loopLikeOp)
4623  return failure();
4624 
4625  // Check that loop is unit-rank and that the `linear_index` is the induction
4626  // variable.
4627  auto inductionVars = loopLikeOp.getLoopInductionVars();
4628  if (!inductionVars || inductionVars->size() != 1 ||
4629  inductionVars->front() != inductionVar) {
4630  return rewriter.notifyMatchFailure(
4631  delinearizeOp, "`linear_index` is not loop induction variable");
4632  }
4633 
4634  // Check that the upper-bound is the basis.
4635  auto upperBounds = loopLikeOp.getLoopUpperBounds();
4636  if (!upperBounds || upperBounds->size() != 1 ||
4637  upperBounds->front() != getAsOpFoldResult(basis.front())) {
4638  return rewriter.notifyMatchFailure(delinearizeOp,
4639  "`basis` is not upper bound");
4640  }
4641 
4642  // Check that the lower bound is zero.
4643  auto lowerBounds = loopLikeOp.getLoopLowerBounds();
4644  if (!lowerBounds || lowerBounds->size() != 1 ||
4645  !isZeroIndex(lowerBounds->front())) {
4646  return rewriter.notifyMatchFailure(delinearizeOp,
4647  "loop lower bound is not zero");
4648  }
4649 
4650  // Check that the step is one.
4651  auto steps = loopLikeOp.getLoopSteps();
4652  if (!steps || steps->size() != 1 || !isConstantIntValue(steps->front(), 1))
4653  return rewriter.notifyMatchFailure(delinearizeOp, "loop step is not one");
4654 
4655  rewriter.replaceOp(delinearizeOp, inductionVar);
4656  return success();
4657  }
4658 };
4659 
4660 } // namespace
4661 
4662 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
4663  RewritePatternSet &patterns, MLIRContext *context) {
4664  patterns.insert<DropDelinearizeOfSingleLoop, DropUnitExtentBasis>(context);
4665 }
4666 
4667 //===----------------------------------------------------------------------===//
4668 // TableGen'd op method definitions
4669 //===----------------------------------------------------------------------===//
4670 
4671 #define GET_OP_CLASSES
4672 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
Definition: AffineOps.cpp:2651
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
Definition: AffineOps.cpp:2371
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
Definition: AffineOps.cpp:1163
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
Definition: AffineOps.cpp:3224
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
Definition: AffineOps.cpp:3857
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
Definition: AffineOps.cpp:58
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
Definition: AffineOps.cpp:3994
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
Definition: AffineOps.cpp:999
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
Definition: AffineOps.cpp:3260
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
Definition: AffineOps.cpp:2229
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
Definition: AffineOps.cpp:785
static bool isValidAffineIndexOperand(Value value, Region *region)
Definition: AffineOps.cpp:464
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1357
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
Definition: AffineOps.cpp:1070
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:720
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
Definition: AffineOps.cpp:1918
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:712
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
Definition: AffineOps.cpp:2936
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
Definition: AffineOps.cpp:1022
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1314
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
Definition: AffineOps.cpp:4396
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
Definition: AffineOps.cpp:888
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
Definition: AffineOps.cpp:3237
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
Definition: AffineOps.cpp:327
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
Definition: AffineOps.cpp:660
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
Definition: AffineOps.cpp:4109
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
Definition: AffineOps.cpp:2660
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
Definition: AffineOps.cpp:469
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
Definition: AffineOps.cpp:622
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1250
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
Definition: AffineOps.cpp:2610
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
Definition: AffineOps.cpp:2183
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
Definition: AffineOps.cpp:501
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1265
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
Definition: AffineOps.cpp:346
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr &quotientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
Definition: AffineOps.cpp:688
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
Definition: AffineOps.cpp:4062
static LogicalResult verifyAffineMinMaxOp(T op)
Definition: AffineOps.cpp:3211
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
Definition: AffineOps.cpp:2093
static LogicalResult verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
Definition: AffineOps.cpp:3042
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
Definition: AffineOps.cpp:3423
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:76
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:81
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:917
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:35
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
Definition: AffineExpr.cpp:243
MLIRContext * getContext() const
Definition: AffineExpr.cpp:33
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:662
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
Definition: AffineMap.h:221
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:267
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
Definition: AffineMap.h:228
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineMap.cpp:500
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
Definition: AffineMap.h:280
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:515
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
Definition: AffineMap.cpp:128
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:654
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
Definition: AffineMap.cpp:434
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:203
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
AffineMap getDimIdentityMap()
Definition: Builders.cpp:423
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:427
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
Definition: Builders.cpp:219
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
NoneType getNoneType()
Definition: Builders.cpp:128
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
Definition: Builders.cpp:416
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:418
MLIRContext * getContext() const
Definition: Builders.h:55
AffineMap getSymbolIdentityMap()
Definition: Builders.cpp:436
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:321
IndexType getIndexType()
Definition: Builders.cpp:95
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition: IntegerSet.h:44
unsigned getNumDims() const
Definition: IntegerSet.cpp:15
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
Definition: IntegerSet.cpp:57
unsigned getNumInputs() const
Definition: IntegerSet.cpp:17
ArrayRef< AffineExpr > getConstraints() const
Definition: IntegerSet.cpp:41
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
Definition: IntegerSet.cpp:51
unsigned getNumSymbols() const
Definition: IntegerSet.cpp:16
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:453
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:328
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:450
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
operand_range::iterator operand_iterator
Definition: Operation.h:367
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:931
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:64
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:41
AffineBound represents a lower or upper bound in the for operation.
Definition: AffineOps.h:511
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
Definition: AffineOps.h:94
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
Definition: AffineOps.h:303
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
Definition: AffineOps.cpp:3954
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
unsigned getNumResults() const
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
Definition: AffineOps.cpp:2673
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
Definition: AffineOps.cpp:1132
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
Definition: AffineOps.cpp:2587
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
Definition: AffineOps.cpp:276
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1239
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
Definition: AffineOps.cpp:2559
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: AffineOps.cpp:2563
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1142
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1433
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1305
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
Definition: AffineOps.cpp:2551
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1298
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Definition: AffineOps.cpp:246
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
Definition: AffineOps.cpp:1438
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
Definition: AffineOps.cpp:2594
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
Definition: AffineOps.cpp:390
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Definition: AffineOps.cpp:2574
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
Definition: AffineOps.cpp:261
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
Definition: AffineOps.cpp:479
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
Definition: AffineOps.cpp:2555
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
Definition: AffineOps.cpp:1259
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:773
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
Definition: AffineMap.cpp:783
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
AffineExprKind
Definition: AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:70
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_One()
Matches a constant scalar / vector splat / tensor splat integer one.
Definition: Matchers.h:473
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:641
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
Definition: AffineMap.cpp:745
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:627
Canonicalize the affine map result expression order of an affine min/max operation.
Definition: AffineOps.cpp:3477
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3480
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3494
Remove duplicated expressions in affine min/max ops.
Definition: AffineOps.cpp:3293
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3296
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
Definition: AffineOps.cpp:3336
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3339
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:293
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.