MLIR  19.0.0git
AffineOps.cpp
Go to the documentation of this file.
1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/ScopeExit.h"
23 #include "llvm/ADT/SmallBitVector.h"
24 #include "llvm/ADT/SmallVectorExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/MathExtras.h"
28 #include <numeric>
29 #include <optional>
30 
31 using namespace mlir;
32 using namespace mlir::affine;
33 
34 using llvm::divideCeilSigned;
35 using llvm::divideFloorSigned;
36 using llvm::mod;
37 
38 #define DEBUG_TYPE "affine-ops"
39 
40 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
41 
42 /// A utility function to check if a value is defined at the top level of
43 /// `region` or is an argument of `region`. A value of index type defined at the
44 /// top level of a `AffineScope` region is always a valid symbol for all
45 /// uses in that region.
47  if (auto arg = llvm::dyn_cast<BlockArgument>(value))
48  return arg.getParentRegion() == region;
49  return value.getDefiningOp()->getParentRegion() == region;
50 }
51 
52 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
53 /// region remains legal if the operation that uses it is inlined into `dest`
54 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
55 /// `isValidSymbol`, depending on the value being required to remain a valid
56 /// dimension or symbol.
57 static bool
59  const IRMapping &mapping,
60  function_ref<bool(Value, Region *)> legalityCheck) {
61  // If the value is a valid dimension for any other reason than being
62  // a top-level value, it will remain valid: constants get inlined
63  // with the function, transitive affine applies also get inlined and
64  // will be checked themselves, etc.
65  if (!isTopLevelValue(value, src))
66  return true;
67 
68  // If it's a top-level value because it's a block operand, i.e. a
69  // function argument, check whether the value replacing it after
70  // inlining is a valid dimension in the new region.
71  if (llvm::isa<BlockArgument>(value))
72  return legalityCheck(mapping.lookup(value), dest);
73 
74  // If it's a top-level value because it's defined in the region,
75  // it can only be inlined if the defining op is a constant or a
76  // `dim`, which can appear anywhere and be valid, since the defining
77  // op won't be top-level anymore after inlining.
78  Attribute operandCst;
79  bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp());
80  return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
81  isDimLikeOp;
82 }
83 
84 /// Checks if all values known to be legal affine dimensions or symbols in `src`
85 /// remain so if their respective users are inlined into `dest`.
86 static bool
88  const IRMapping &mapping,
89  function_ref<bool(Value, Region *)> legalityCheck) {
90  return llvm::all_of(values, [&](Value v) {
91  return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
92  });
93 }
94 
95 /// Checks if an affine read or write operation remains legal after inlining
96 /// from `src` to `dest`.
97 template <typename OpTy>
98 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
99  const IRMapping &mapping) {
100  static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
101  AffineWriteOpInterface>::value,
102  "only ops with affine read/write interface are supported");
103 
104  AffineMap map = op.getAffineMap();
105  ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
106  ValueRange symbolOperands =
107  op.getMapOperands().take_back(map.getNumSymbols());
109  dimOperands, src, dest, mapping,
110  static_cast<bool (*)(Value, Region *)>(isValidDim)))
111  return false;
113  symbolOperands, src, dest, mapping,
114  static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
115  return false;
116  return true;
117 }
118 
119 /// Checks if an affine apply operation remains legal after inlining from `src`
120 /// to `dest`.
121 // Use "unused attribute" marker to silence clang-tidy warning stemming from
122 // the inability to see through "llvm::TypeSwitch".
123 template <>
124 bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op,
125  Region *src, Region *dest,
126  const IRMapping &mapping) {
127  // If it's a valid dimension, we need to check that it remains so.
128  if (isValidDim(op.getResult(), src))
130  op.getMapOperands(), src, dest, mapping,
131  static_cast<bool (*)(Value, Region *)>(isValidDim));
132 
133  // Otherwise it must be a valid symbol, check that it remains so.
135  op.getMapOperands(), src, dest, mapping,
136  static_cast<bool (*)(Value, Region *)>(isValidSymbol));
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // AffineDialect Interfaces
141 //===----------------------------------------------------------------------===//
142 
143 namespace {
144 /// This class defines the interface for handling inlining with affine
145 /// operations.
146 struct AffineInlinerInterface : public DialectInlinerInterface {
148 
149  //===--------------------------------------------------------------------===//
150  // Analysis Hooks
151  //===--------------------------------------------------------------------===//
152 
153  /// Returns true if the given region 'src' can be inlined into the region
154  /// 'dest' that is attached to an operation registered to the current dialect.
155  /// 'wouldBeCloned' is set if the region is cloned into its new location
156  /// rather than moved, indicating there may be other users.
157  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
158  IRMapping &valueMapping) const final {
159  // We can inline into affine loops and conditionals if this doesn't break
160  // affine value categorization rules.
161  Operation *destOp = dest->getParentOp();
162  if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
163  return false;
164 
165  // Multi-block regions cannot be inlined into affine constructs, all of
166  // which require single-block regions.
167  if (!llvm::hasSingleElement(*src))
168  return false;
169 
170  // Side-effecting operations that the affine dialect cannot understand
171  // should not be inlined.
172  Block &srcBlock = src->front();
173  for (Operation &op : srcBlock) {
174  // Ops with no side effects are fine,
175  if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
176  if (iface.hasNoEffect())
177  continue;
178  }
179 
180  // Assuming the inlined region is valid, we only need to check if the
181  // inlining would change it.
182  bool remainsValid =
184  .Case<AffineApplyOp, AffineReadOpInterface,
185  AffineWriteOpInterface>([&](auto op) {
186  return remainsLegalAfterInline(op, src, dest, valueMapping);
187  })
188  .Default([](Operation *) {
189  // Conservatively disallow inlining ops we cannot reason about.
190  return false;
191  });
192 
193  if (!remainsValid)
194  return false;
195  }
196 
197  return true;
198  }
199 
200  /// Returns true if the given operation 'op', that is registered to this
201  /// dialect, can be inlined into the given region, false otherwise.
202  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
203  IRMapping &valueMapping) const final {
204  // Always allow inlining affine operations into a region that is marked as
205  // affine scope, or into affine loops and conditionals. There are some edge
206  // cases when inlining *into* affine structures, but that is handled in the
207  // other 'isLegalToInline' hook above.
208  Operation *parentOp = region->getParentOp();
209  return parentOp->hasTrait<OpTrait::AffineScope>() ||
210  isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
211  }
212 
213  /// Affine regions should be analyzed recursively.
214  bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
215 };
216 } // namespace
217 
218 //===----------------------------------------------------------------------===//
219 // AffineDialect
220 //===----------------------------------------------------------------------===//
221 
222 void AffineDialect::initialize() {
223  addOperations<AffineDmaStartOp, AffineDmaWaitOp,
224 #define GET_OP_LIST
225 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
226  >();
227  addInterfaces<AffineInlinerInterface>();
228  declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
229  AffineMinOp>();
230 }
231 
232 /// Materialize a single constant operation from a given attribute value with
233 /// the desired resultant type.
235  Attribute value, Type type,
236  Location loc) {
237  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
238  return builder.create<ub::PoisonOp>(loc, type, poison);
239  return arith::ConstantOp::materialize(builder, value, type, loc);
240 }
241 
242 /// A utility function to check if a value is defined at the top level of an
243 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
244 /// conservatively assume it is not top-level. A value of index type defined at
245 /// the top level is always a valid symbol.
247  if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
248  // The block owning the argument may be unlinked, e.g. when the surrounding
249  // region has not yet been attached to an Op, at which point the parent Op
250  // is null.
251  Operation *parentOp = arg.getOwner()->getParentOp();
252  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
253  }
254  // The defining Op may live in an unlinked block so its parent Op may be null.
255  Operation *parentOp = value.getDefiningOp()->getParentOp();
256  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
257 }
258 
259 /// Returns the closest region enclosing `op` that is held by an operation with
260 /// trait `AffineScope`; `nullptr` if there is no such region.
262  auto *curOp = op;
263  while (auto *parentOp = curOp->getParentOp()) {
264  if (parentOp->hasTrait<OpTrait::AffineScope>())
265  return curOp->getParentRegion();
266  curOp = parentOp;
267  }
268  return nullptr;
269 }
270 
271 // A Value can be used as a dimension id iff it meets one of the following
272 // conditions:
273 // *) It is valid as a symbol.
274 // *) It is an induction variable.
275 // *) It is the result of affine apply operation with dimension id arguments.
277  // The value must be an index type.
278  if (!value.getType().isIndex())
279  return false;
280 
281  if (auto *defOp = value.getDefiningOp())
282  return isValidDim(value, getAffineScope(defOp));
283 
284  // This value has to be a block argument for an op that has the
285  // `AffineScope` trait or for an affine.for or affine.parallel.
286  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
287  return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
288  isa<AffineForOp, AffineParallelOp>(parentOp));
289 }
290 
291 // Value can be used as a dimension id iff it meets one of the following
292 // conditions:
293 // *) It is valid as a symbol.
294 // *) It is an induction variable.
295 // *) It is the result of an affine apply operation with dimension id operands.
296 bool mlir::affine::isValidDim(Value value, Region *region) {
297  // The value must be an index type.
298  if (!value.getType().isIndex())
299  return false;
300 
301  // All valid symbols are okay.
302  if (isValidSymbol(value, region))
303  return true;
304 
305  auto *op = value.getDefiningOp();
306  if (!op) {
307  // This value has to be a block argument for an affine.for or an
308  // affine.parallel.
309  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
310  return isa<AffineForOp, AffineParallelOp>(parentOp);
311  }
312 
313  // Affine apply operation is ok if all of its operands are ok.
314  if (auto applyOp = dyn_cast<AffineApplyOp>(op))
315  return applyOp.isValidDim(region);
316  // The dim op is okay if its operand memref/tensor is defined at the top
317  // level.
318  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
319  return isTopLevelValue(dimOp.getShapedValue());
320  return false;
321 }
322 
323 /// Returns true if the 'index' dimension of the `memref` defined by
324 /// `memrefDefOp` is a statically shaped one or defined using a valid symbol
325 /// for `region`.
326 template <typename AnyMemRefDefOp>
327 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
328  Region *region) {
329  MemRefType memRefType = memrefDefOp.getType();
330 
331  // Dimension index is out of bounds.
332  if (index >= memRefType.getRank()) {
333  return false;
334  }
335 
336  // Statically shaped.
337  if (!memRefType.isDynamicDim(index))
338  return true;
339  // Get the position of the dimension among dynamic dimensions;
340  unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
341  return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
342  region);
343 }
344 
345 /// Returns true if the result of the dim op is a valid symbol for `region`.
346 static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
347  // The dim op is okay if its source is defined at the top level.
348  if (isTopLevelValue(dimOp.getShapedValue()))
349  return true;
350 
351  // Conservatively handle remaining BlockArguments as non-valid symbols.
352  // E.g. scf.for iterArgs.
353  if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
354  return false;
355 
356  // The dim op is also okay if its operand memref is a view/subview whose
357  // corresponding size is a valid symbol.
358  std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
359 
360  // Be conservative if we can't understand the dimension.
361  if (!index.has_value())
362  return false;
363 
364  // Skip over all memref.cast ops (if any).
365  Operation *op = dimOp.getShapedValue().getDefiningOp();
366  while (auto castOp = dyn_cast<memref::CastOp>(op)) {
367  // Bail on unranked memrefs.
368  if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
369  return false;
370  op = castOp.getSource().getDefiningOp();
371  if (!op)
372  return false;
373  }
374 
375  int64_t i = index.value();
377  .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
378  [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
379  .Default([](Operation *) { return false; });
380 }
381 
382 // A value can be used as a symbol (at all its use sites) iff it meets one of
383 // the following conditions:
384 // *) It is a constant.
385 // *) Its defining op or block arg appearance is immediately enclosed by an op
386 // with `AffineScope` trait.
387 // *) It is the result of an affine.apply operation with symbol operands.
388 // *) It is a result of the dim op on a memref whose corresponding size is a
389 // valid symbol.
391  if (!value)
392  return false;
393 
394  // The value must be an index type.
395  if (!value.getType().isIndex())
396  return false;
397 
398  // Check that the value is a top level value.
399  if (isTopLevelValue(value))
400  return true;
401 
402  if (auto *defOp = value.getDefiningOp())
403  return isValidSymbol(value, getAffineScope(defOp));
404 
405  return false;
406 }
407 
408 /// A value can be used as a symbol for `region` iff it meets one of the
409 /// following conditions:
410 /// *) It is a constant.
411 /// *) It is the result of an affine apply operation with symbol arguments.
412 /// *) It is a result of the dim op on a memref whose corresponding size is
413 /// a valid symbol.
414 /// *) It is defined at the top level of 'region' or is its argument.
415 /// *) It dominates `region`'s parent op.
416 /// If `region` is null, conservatively assume the symbol definition scope does
417 /// not exist and only accept the values that would be symbols regardless of
418 /// the surrounding region structure, i.e. the first three cases above.
420  // The value must be an index type.
421  if (!value.getType().isIndex())
422  return false;
423 
424  // A top-level value is a valid symbol.
425  if (region && ::isTopLevelValue(value, region))
426  return true;
427 
428  auto *defOp = value.getDefiningOp();
429  if (!defOp) {
430  // A block argument that is not a top-level value is a valid symbol if it
431  // dominates region's parent op.
432  Operation *regionOp = region ? region->getParentOp() : nullptr;
433  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
434  if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
435  return isValidSymbol(value, parentOpRegion);
436  return false;
437  }
438 
439  // Constant operation is ok.
440  Attribute operandCst;
441  if (matchPattern(defOp, m_Constant(&operandCst)))
442  return true;
443 
444  // Affine apply operation is ok if all of its operands are ok.
445  if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
446  return applyOp.isValidSymbol(region);
447 
448  // Dim op results could be valid symbols at any level.
449  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
450  return isDimOpValidSymbol(dimOp, region);
451 
452  // Check for values dominating `region`'s parent op.
453  Operation *regionOp = region ? region->getParentOp() : nullptr;
454  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
455  if (auto *parentRegion = region->getParentOp()->getParentRegion())
456  return isValidSymbol(value, parentRegion);
457 
458  return false;
459 }
460 
461 // Returns true if 'value' is a valid index to an affine operation (e.g.
462 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
463 // `region` provides the polyhedral symbol scope. Returns false otherwise.
464 static bool isValidAffineIndexOperand(Value value, Region *region) {
465  return isValidDim(value, region) || isValidSymbol(value, region);
466 }
467 
468 /// Prints dimension and symbol list.
471  unsigned numDims, OpAsmPrinter &printer) {
472  OperandRange operands(begin, end);
473  printer << '(' << operands.take_front(numDims) << ')';
474  if (operands.size() > numDims)
475  printer << '[' << operands.drop_front(numDims) << ']';
476 }
477 
478 /// Parses dimension and symbol list and returns true if parsing failed.
480  OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
482  if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
483  return failure();
484  // Store number of dimensions for validation by caller.
485  numDims = opInfos.size();
486 
487  // Parse the optional symbol operands.
488  auto indexTy = parser.getBuilder().getIndexType();
489  return failure(parser.parseOperandList(
491  parser.resolveOperands(opInfos, indexTy, operands));
492 }
493 
494 /// Utility function to verify that a set of operands are valid dimension and
495 /// symbol identifiers. The operands should be laid out such that the dimension
496 /// operands are before the symbol operands. This function returns failure if
497 /// there was an invalid operand. An operation is provided to emit any necessary
498 /// errors.
499 template <typename OpTy>
500 static LogicalResult
502  unsigned numDims) {
503  unsigned opIt = 0;
504  for (auto operand : operands) {
505  if (opIt++ < numDims) {
506  if (!isValidDim(operand, getAffineScope(op)))
507  return op.emitOpError("operand cannot be used as a dimension id");
508  } else if (!isValidSymbol(operand, getAffineScope(op))) {
509  return op.emitOpError("operand cannot be used as a symbol");
510  }
511  }
512  return success();
513 }
514 
515 //===----------------------------------------------------------------------===//
516 // AffineApplyOp
517 //===----------------------------------------------------------------------===//
518 
519 AffineValueMap AffineApplyOp::getAffineValueMap() {
520  return AffineValueMap(getAffineMap(), getOperands(), getResult());
521 }
522 
523 ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
524  auto &builder = parser.getBuilder();
525  auto indexTy = builder.getIndexType();
526 
527  AffineMapAttr mapAttr;
528  unsigned numDims;
529  if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
530  parseDimAndSymbolList(parser, result.operands, numDims) ||
531  parser.parseOptionalAttrDict(result.attributes))
532  return failure();
533  auto map = mapAttr.getValue();
534 
535  if (map.getNumDims() != numDims ||
536  numDims + map.getNumSymbols() != result.operands.size()) {
537  return parser.emitError(parser.getNameLoc(),
538  "dimension or symbol index mismatch");
539  }
540 
541  result.types.append(map.getNumResults(), indexTy);
542  return success();
543 }
544 
546  p << " " << getMapAttr();
547  printDimAndSymbolList(operand_begin(), operand_end(),
548  getAffineMap().getNumDims(), p);
549  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
550 }
551 
552 LogicalResult AffineApplyOp::verify() {
553  // Check input and output dimensions match.
554  AffineMap affineMap = getMap();
555 
556  // Verify that operand count matches affine map dimension and symbol count.
557  if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
558  return emitOpError(
559  "operand count and affine map dimension and symbol count must match");
560 
561  // Verify that the map only produces one result.
562  if (affineMap.getNumResults() != 1)
563  return emitOpError("mapping must produce one value");
564 
565  return success();
566 }
567 
568 // The result of the affine apply operation can be used as a dimension id if all
569 // its operands are valid dimension ids.
571  return llvm::all_of(getOperands(),
572  [](Value op) { return affine::isValidDim(op); });
573 }
574 
575 // The result of the affine apply operation can be used as a dimension id if all
576 // its operands are valid dimension ids with the parent operation of `region`
577 // defining the polyhedral scope for symbols.
578 bool AffineApplyOp::isValidDim(Region *region) {
579  return llvm::all_of(getOperands(),
580  [&](Value op) { return ::isValidDim(op, region); });
581 }
582 
583 // The result of the affine apply operation can be used as a symbol if all its
584 // operands are symbols.
586  return llvm::all_of(getOperands(),
587  [](Value op) { return affine::isValidSymbol(op); });
588 }
589 
590 // The result of the affine apply operation can be used as a symbol in `region`
591 // if all its operands are symbols in `region`.
592 bool AffineApplyOp::isValidSymbol(Region *region) {
593  return llvm::all_of(getOperands(), [&](Value operand) {
594  return affine::isValidSymbol(operand, region);
595  });
596 }
597 
598 OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
599  auto map = getAffineMap();
600 
601  // Fold dims and symbols to existing values.
602  auto expr = map.getResult(0);
603  if (auto dim = dyn_cast<AffineDimExpr>(expr))
604  return getOperand(dim.getPosition());
605  if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
606  return getOperand(map.getNumDims() + sym.getPosition());
607 
608  // Otherwise, default to folding the map.
610  bool hasPoison = false;
611  auto foldResult =
612  map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
613  if (hasPoison)
615  if (failed(foldResult))
616  return {};
617  return result[0];
618 }
619 
620 /// Returns the largest known divisor of `e`. Exploits information from the
621 /// values in `operands`.
622 static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
623  // This method isn't aware of `operands`.
624  int64_t div = e.getLargestKnownDivisor();
625 
626  // We now make use of operands for the case `e` is a dim expression.
627  // TODO: More powerful simplification would have to modify
628  // getLargestKnownDivisor to take `operands` and exploit that information as
629  // well for dim/sym expressions, but in that case, getLargestKnownDivisor
630  // can't be part of the IR library but of the `Analysis` library. The IR
631  // library can only really depend on simple O(1) checks.
632  auto dimExpr = dyn_cast<AffineDimExpr>(e);
633  // If it's not a dim expr, `div` is the best we have.
634  if (!dimExpr)
635  return div;
636 
637  // We simply exploit information from loop IVs.
638  // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
639  // desired simplifications are expected to be part of other
640  // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
641  // LoopAnalysis library.
642  Value operand = operands[dimExpr.getPosition()];
643  int64_t operandDivisor = 1;
644  // TODO: With the right accessors, this can be extended to
645  // LoopLikeOpInterface.
646  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
647  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
648  operandDivisor = forOp.getStepAsInt();
649  } else {
650  uint64_t lbLargestKnownDivisor =
651  forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
652  operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
653  }
654  }
655  return operandDivisor;
656 }
657 
658 /// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
659 /// being an affine dim expression or a constant.
661  int64_t k) {
662  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
663  int64_t constVal = constExpr.getValue();
664  return constVal >= 0 && constVal < k;
665  }
666  auto dimExpr = dyn_cast<AffineDimExpr>(e);
667  if (!dimExpr)
668  return false;
669  Value operand = operands[dimExpr.getPosition()];
670  // TODO: With the right accessors, this can be extended to
671  // LoopLikeOpInterface.
672  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
673  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
674  forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
675  return true;
676  }
677  }
678 
679  // We don't consider other cases like `operand` being defined by a constant or
680  // an affine.apply op since such cases will already be handled by other
681  // patterns and propagation of loop IVs or constant would happen.
682  return false;
683 }
684 
685 /// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
686 /// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
687 /// expression is in that form.
688 static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
689  AffineExpr &quotientTimesDiv, AffineExpr &rem) {
690  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
691  if (!bin || bin.getKind() != AffineExprKind::Add)
692  return false;
693 
694  AffineExpr llhs = bin.getLHS();
695  AffineExpr rlhs = bin.getRHS();
696  div = getLargestKnownDivisor(llhs, operands);
697  if (isNonNegativeBoundedBy(rlhs, operands, div)) {
698  quotientTimesDiv = llhs;
699  rem = rlhs;
700  return true;
701  }
702  div = getLargestKnownDivisor(rlhs, operands);
703  if (isNonNegativeBoundedBy(llhs, operands, div)) {
704  quotientTimesDiv = rlhs;
705  rem = llhs;
706  return true;
707  }
708  return false;
709 }
710 
711 /// Gets the constant lower bound on an `iv`.
712 static std::optional<int64_t> getLowerBound(Value iv) {
713  AffineForOp forOp = getForInductionVarOwner(iv);
714  if (forOp && forOp.hasConstantLowerBound())
715  return forOp.getConstantLowerBound();
716  return std::nullopt;
717 }
718 
719 /// Gets the constant upper bound on an affine.for `iv`.
720 static std::optional<int64_t> getUpperBound(Value iv) {
721  AffineForOp forOp = getForInductionVarOwner(iv);
722  if (!forOp || !forOp.hasConstantUpperBound())
723  return std::nullopt;
724 
725  // If its lower bound is also known, we can get a more precise bound
726  // whenever the step is not one.
727  if (forOp.hasConstantLowerBound()) {
728  return forOp.getConstantUpperBound() - 1 -
729  (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
730  forOp.getStepAsInt();
731  }
732  return forOp.getConstantUpperBound() - 1;
733 }
734 
735 /// Determine a constant upper bound for `expr` if one exists while exploiting
736 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
737 /// is guaranteed to be less than or equal to it.
738 static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
739  unsigned numSymbols,
740  ArrayRef<Value> operands) {
741  // Get the constant lower or upper bounds on the operands.
742  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
743  constLowerBounds.reserve(operands.size());
744  constUpperBounds.reserve(operands.size());
745  for (Value operand : operands) {
746  constLowerBounds.push_back(getLowerBound(operand));
747  constUpperBounds.push_back(getUpperBound(operand));
748  }
749 
750  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
751  return constExpr.getValue();
752 
753  return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
754  constUpperBounds,
755  /*isUpper=*/true);
756 }
757 
758 /// Determine a constant lower bound for `expr` if one exists while exploiting
759 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
760 /// is guaranteed to be less than or equal to it.
761 static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
762  unsigned numSymbols,
763  ArrayRef<Value> operands) {
764  // Get the constant lower or upper bounds on the operands.
765  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
766  constLowerBounds.reserve(operands.size());
767  constUpperBounds.reserve(operands.size());
768  for (Value operand : operands) {
769  constLowerBounds.push_back(getLowerBound(operand));
770  constUpperBounds.push_back(getUpperBound(operand));
771  }
772 
773  std::optional<int64_t> lowerBound;
774  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
775  lowerBound = constExpr.getValue();
776  } else {
777  lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
778  constLowerBounds, constUpperBounds,
779  /*isUpper=*/false);
780  }
781  return lowerBound;
782 }
783 
784 /// Simplify `expr` while exploiting information from the values in `operands`.
785 static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
786  unsigned numSymbols,
787  ArrayRef<Value> operands) {
788  // We do this only for certain floordiv/mod expressions.
789  auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
790  if (!binExpr)
791  return;
792 
793  // Simplify the child expressions first.
794  AffineExpr lhs = binExpr.getLHS();
795  AffineExpr rhs = binExpr.getRHS();
796  simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
797  simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
798  expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
799 
800  binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
801  if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
802  expr.getKind() != AffineExprKind::CeilDiv &&
803  expr.getKind() != AffineExprKind::Mod)) {
804  return;
805  }
806 
807  // The `lhs` and `rhs` may be different post construction of simplified expr.
808  lhs = binExpr.getLHS();
809  rhs = binExpr.getRHS();
810  auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
811  if (!rhsConst)
812  return;
813 
814  int64_t rhsConstVal = rhsConst.getValue();
815  // Undefined exprsessions aren't touched; IR can still be valid with them.
816  if (rhsConstVal <= 0)
817  return;
818 
819  // Exploit constant lower/upper bounds to simplify a floordiv or mod.
820  MLIRContext *context = expr.getContext();
821  std::optional<int64_t> lhsLbConst =
822  getLowerBound(lhs, numDims, numSymbols, operands);
823  std::optional<int64_t> lhsUbConst =
824  getUpperBound(lhs, numDims, numSymbols, operands);
825  if (lhsLbConst && lhsUbConst) {
826  int64_t lhsLbConstVal = *lhsLbConst;
827  int64_t lhsUbConstVal = *lhsUbConst;
828  // lhs floordiv c is a single value lhs is bounded in a range `c` that has
829  // the same quotient.
830  if (binExpr.getKind() == AffineExprKind::FloorDiv &&
831  divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
832  divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
833  expr = getAffineConstantExpr(
834  divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
835  return;
836  }
837  // lhs ceildiv c is a single value if the entire range has the same ceil
838  // quotient.
839  if (binExpr.getKind() == AffineExprKind::CeilDiv &&
840  divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
841  divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
842  expr = getAffineConstantExpr(divideCeilSigned(lhsLbConstVal, rhsConstVal),
843  context);
844  return;
845  }
846  // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
847  if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
848  lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
849  expr = lhs;
850  return;
851  }
852  }
853 
854  // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
855  // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
856  // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
857  // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
858  AffineExpr quotientTimesDiv, rem;
859  int64_t divisor;
860  if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
861  if (rhsConstVal % divisor == 0 &&
862  binExpr.getKind() == AffineExprKind::FloorDiv) {
863  expr = quotientTimesDiv.floorDiv(rhsConst);
864  } else if (divisor % rhsConstVal == 0 &&
865  binExpr.getKind() == AffineExprKind::Mod) {
866  expr = rem % rhsConst;
867  }
868  return;
869  }
870 
871  // Handle the simple case when the LHS expression can be either upper
872  // bounded or is a known multiple of RHS constant.
873  // lhs floordiv c -> 0 if 0 <= lhs < c,
874  // lhs mod c -> 0 if lhs % c = 0.
875  if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) &&
876  binExpr.getKind() == AffineExprKind::FloorDiv) ||
877  (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 &&
878  binExpr.getKind() == AffineExprKind::Mod)) {
879  expr = getAffineConstantExpr(0, expr.getContext());
880  }
881 }
882 
883 /// Simplify the expressions in `map` while making use of lower or upper bounds
884 /// of its operands. If `isMax` is true, the map is to be treated as a max of
885 /// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
886 /// d1) can be simplified to (8) if the operands are respectively lower bounded
887 /// by 2 and 0 (the second expression can't be lower than 8).
889  ArrayRef<Value> operands,
890  bool isMax) {
891  // Can't simplify.
892  if (operands.empty())
893  return;
894 
895  // Get the upper or lower bound on an affine.for op IV using its range.
896  // Get the constant lower or upper bounds on the operands.
897  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
898  constLowerBounds.reserve(operands.size());
899  constUpperBounds.reserve(operands.size());
900  for (Value operand : operands) {
901  constLowerBounds.push_back(getLowerBound(operand));
902  constUpperBounds.push_back(getUpperBound(operand));
903  }
904 
905  // We will compute the lower and upper bounds on each of the expressions
906  // Then, we will check (depending on max or min) as to whether a specific
907  // bound is redundant by checking if its highest (in case of max) and its
908  // lowest (in the case of min) value is already lower than (or higher than)
909  // the lower bound (or upper bound in the case of min) of another bound.
910  SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
911  lowerBounds.reserve(map.getNumResults());
912  upperBounds.reserve(map.getNumResults());
913  for (AffineExpr e : map.getResults()) {
914  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
915  lowerBounds.push_back(constExpr.getValue());
916  upperBounds.push_back(constExpr.getValue());
917  } else {
918  lowerBounds.push_back(
920  constLowerBounds, constUpperBounds,
921  /*isUpper=*/false));
922  upperBounds.push_back(
924  constLowerBounds, constUpperBounds,
925  /*isUpper=*/true));
926  }
927  }
928 
929  // Collect expressions that are not redundant.
930  SmallVector<AffineExpr, 4> irredundantExprs;
931  for (auto exprEn : llvm::enumerate(map.getResults())) {
932  AffineExpr e = exprEn.value();
933  unsigned i = exprEn.index();
934  // Some expressions can be turned into constants.
935  if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
936  e = getAffineConstantExpr(*lowerBounds[i], e.getContext());
937 
938  // Check if the expression is redundant.
939  if (isMax) {
940  if (!upperBounds[i]) {
941  irredundantExprs.push_back(e);
942  continue;
943  }
944  // If there exists another expression such that its lower bound is greater
945  // than this expression's upper bound, it's redundant.
946  if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
947  auto otherLowerBound = en.value();
948  unsigned pos = en.index();
949  if (pos == i || !otherLowerBound)
950  return false;
951  if (*otherLowerBound > *upperBounds[i])
952  return true;
953  if (*otherLowerBound < *upperBounds[i])
954  return false;
955  // Equality case. When both expressions are considered redundant, we
956  // don't want to get both of them. We keep the one that appears
957  // first.
958  if (upperBounds[pos] && lowerBounds[i] &&
959  lowerBounds[i] == upperBounds[i] &&
960  otherLowerBound == *upperBounds[pos] && i < pos)
961  return false;
962  return true;
963  }))
964  irredundantExprs.push_back(e);
965  } else {
966  if (!lowerBounds[i]) {
967  irredundantExprs.push_back(e);
968  continue;
969  }
970  // Likewise for the `min` case. Use the complement of the condition above.
971  if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
972  auto otherUpperBound = en.value();
973  unsigned pos = en.index();
974  if (pos == i || !otherUpperBound)
975  return false;
976  if (*otherUpperBound < *lowerBounds[i])
977  return true;
978  if (*otherUpperBound > *lowerBounds[i])
979  return false;
980  if (lowerBounds[pos] && upperBounds[i] &&
981  lowerBounds[i] == upperBounds[i] &&
982  otherUpperBound == lowerBounds[pos] && i < pos)
983  return false;
984  return true;
985  }))
986  irredundantExprs.push_back(e);
987  }
988  }
989 
990  // Create the map without the redundant expressions.
991  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs,
992  map.getContext());
993 }
994 
995 /// Simplify the map while exploiting information on the values in `operands`.
996 // Use "unused attribute" marker to silence warning stemming from the inability
997 // to see through the template expansion.
998 static void LLVM_ATTRIBUTE_UNUSED
1000  assert(map.getNumInputs() == operands.size() && "invalid operands for map");
1001  SmallVector<AffineExpr> newResults;
1002  newResults.reserve(map.getNumResults());
1003  for (AffineExpr expr : map.getResults()) {
1005  operands);
1006  newResults.push_back(expr);
1007  }
1008  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
1009  map.getContext());
1010 }
1011 
1012 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
1013 /// defining AffineApplyOp expression and operands.
1014 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
1015 /// When `dimOrSymbolPosition >= dims.size()`,
1016 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
1017 /// Mutate `map`,`dims` and `syms` in place as follows:
1018 /// 1. `dims` and `syms` are only appended to.
1019 /// 2. `map` dim and symbols are gradually shifted to higher positions.
1020 /// 3. Old `dim` and `sym` entries are replaced by nullptr
1021 /// This avoids the need for any bookkeeping.
1022 static LogicalResult replaceDimOrSym(AffineMap *map,
1023  unsigned dimOrSymbolPosition,
1024  SmallVectorImpl<Value> &dims,
1025  SmallVectorImpl<Value> &syms) {
1026  MLIRContext *ctx = map->getContext();
1027  bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1028  unsigned pos = isDimReplacement ? dimOrSymbolPosition
1029  : dimOrSymbolPosition - dims.size();
1030  Value &v = isDimReplacement ? dims[pos] : syms[pos];
1031  if (!v)
1032  return failure();
1033 
1034  auto affineApply = v.getDefiningOp<AffineApplyOp>();
1035  if (!affineApply)
1036  return failure();
1037 
1038  // At this point we will perform a replacement of `v`, set the entry in `dim`
1039  // or `sym` to nullptr immediately.
1040  v = nullptr;
1041 
1042  // Compute the map, dims and symbols coming from the AffineApplyOp.
1043  AffineMap composeMap = affineApply.getAffineMap();
1044  assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1045  SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1046  affineApply.getMapOperands().end());
1047  // Canonicalize the map to promote dims to symbols when possible. This is to
1048  // avoid generating invalid maps.
1049  canonicalizeMapAndOperands(&composeMap, &composeOperands);
1050  AffineExpr replacementExpr =
1051  composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
1052  ValueRange composeDims =
1053  ArrayRef<Value>(composeOperands).take_front(composeMap.getNumDims());
1054  ValueRange composeSyms =
1055  ArrayRef<Value>(composeOperands).take_back(composeMap.getNumSymbols());
1056  AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
1057  : getAffineSymbolExpr(pos, ctx);
1058 
1059  // Append the dims and symbols where relevant and perform the replacement.
1060  dims.append(composeDims.begin(), composeDims.end());
1061  syms.append(composeSyms.begin(), composeSyms.end());
1062  *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());
1063 
1064  return success();
1065 }
1066 
1067 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1068 /// iteratively. Perform canonicalization of map and operands as well as
1069 /// AffineMap simplification. `map` and `operands` are mutated in place.
1071  SmallVectorImpl<Value> *operands) {
1072  if (map->getNumResults() == 0) {
1073  canonicalizeMapAndOperands(map, operands);
1074  *map = simplifyAffineMap(*map);
1075  return;
1076  }
1077 
1078  MLIRContext *ctx = map->getContext();
1079  SmallVector<Value, 4> dims(operands->begin(),
1080  operands->begin() + map->getNumDims());
1081  SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1082  operands->end());
1083 
1084  // Iterate over dims and symbols coming from AffineApplyOp and replace until
1085  // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1086  // and `syms` can only increase by construction.
1087  // The implementation uses a `while` loop to support the case of symbols
1088  // that may be constructed from dims ;this may be overkill.
1089  while (true) {
1090  bool changed = false;
1091  for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1092  if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
1093  break;
1094  if (!changed)
1095  break;
1096  }
1097 
1098  // Clear operands so we can fill them anew.
1099  operands->clear();
1100 
1101  // At this point we may have introduced null operands, prune them out before
1102  // canonicalizing map and operands.
1103  unsigned nDims = 0, nSyms = 0;
1104  SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1105  dimReplacements.reserve(dims.size());
1106  symReplacements.reserve(syms.size());
1107  for (auto *container : {&dims, &syms}) {
1108  bool isDim = (container == &dims);
1109  auto &repls = isDim ? dimReplacements : symReplacements;
1110  for (const auto &en : llvm::enumerate(*container)) {
1111  Value v = en.value();
1112  if (!v) {
1113  assert(isDim ? !map->isFunctionOfDim(en.index())
1114  : !map->isFunctionOfSymbol(en.index()) &&
1115  "map is function of unexpected expr@pos");
1116  repls.push_back(getAffineConstantExpr(0, ctx));
1117  continue;
1118  }
1119  repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
1120  : getAffineSymbolExpr(nSyms++, ctx));
1121  operands->push_back(v);
1122  }
1123  }
1124  *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
1125  nSyms);
1126 
1127  // Canonicalize and simplify before returning.
1128  canonicalizeMapAndOperands(map, operands);
1129  *map = simplifyAffineMap(*map);
1130 }
1131 
1133  AffineMap *map, SmallVectorImpl<Value> *operands) {
1134  while (llvm::any_of(*operands, [](Value v) {
1135  return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
1136  })) {
1137  composeAffineMapAndOperands(map, operands);
1138  }
1139 }
1140 
1141 AffineApplyOp
1143  ArrayRef<OpFoldResult> operands) {
1144  SmallVector<Value> valueOperands;
1145  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1146  composeAffineMapAndOperands(&map, &valueOperands);
1147  assert(map);
1148  return b.create<AffineApplyOp>(loc, map, valueOperands);
1149 }
1150 
1151 AffineApplyOp
1153  ArrayRef<OpFoldResult> operands) {
1154  return makeComposedAffineApply(
1155  b, loc,
1157  .front(),
1158  operands);
1159 }
1160 
1161 /// Composes the given affine map with the given list of operands, pulling in
1162 /// the maps from any affine.apply operations that supply the operands.
1164  SmallVectorImpl<Value> &operands) {
1165  // Compose and canonicalize each expression in the map individually because
1166  // composition only applies to single-result maps, collecting potentially
1167  // duplicate operands in a single list with shifted dimensions and symbols.
1168  SmallVector<Value> dims, symbols;
1170  for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
1171  SmallVector<Value> submapOperands(operands.begin(), operands.end());
1172  AffineMap submap = map.getSubMap({i});
1173  fullyComposeAffineMapAndOperands(&submap, &submapOperands);
1174  canonicalizeMapAndOperands(&submap, &submapOperands);
1175  unsigned numNewDims = submap.getNumDims();
1176  submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
1177  llvm::append_range(dims,
1178  ArrayRef<Value>(submapOperands).take_front(numNewDims));
1179  llvm::append_range(symbols,
1180  ArrayRef<Value>(submapOperands).drop_front(numNewDims));
1181  exprs.push_back(submap.getResult(0));
1182  }
1183 
1184  // Canonicalize the map created from composed expressions to deduplicate the
1185  // dimension and symbol operands.
1186  operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1187  map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
1188  canonicalizeMapAndOperands(&map, &operands);
1189 }
1190 
1193  AffineMap map,
1194  ArrayRef<OpFoldResult> operands) {
1195  assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1196 
1197  // Create new builder without a listener, so that no notification is
1198  // triggered if the op is folded.
1199  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1200  // workaround is no longer needed.
1201  OpBuilder newBuilder(b.getContext());
1203 
1204  // Create op.
1205  AffineApplyOp applyOp =
1206  makeComposedAffineApply(newBuilder, loc, map, operands);
1207 
1208  // Get constant operands.
1209  SmallVector<Attribute> constOperands(applyOp->getNumOperands());
1210  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1211  matchPattern(applyOp->getOperand(i), m_Constant(&constOperands[i]));
1212 
1213  // Try to fold the operation.
1214  SmallVector<OpFoldResult> foldResults;
1215  if (failed(applyOp->fold(constOperands, foldResults)) ||
1216  foldResults.empty()) {
1217  if (OpBuilder::Listener *listener = b.getListener())
1218  listener->notifyOperationInserted(applyOp, /*previous=*/{});
1219  return applyOp.getResult();
1220  }
1221 
1222  applyOp->erase();
1223  assert(foldResults.size() == 1 && "expected 1 folded result");
1224  return foldResults.front();
1225 }
1226 
1229  AffineExpr expr,
1230  ArrayRef<OpFoldResult> operands) {
1232  b, loc,
1234  .front(),
1235  operands);
1236 }
1237 
1240  OpBuilder &b, Location loc, AffineMap map,
1241  ArrayRef<OpFoldResult> operands) {
1242  return llvm::map_to_vector(llvm::seq<unsigned>(0, map.getNumResults()),
1243  [&](unsigned i) {
1244  return makeComposedFoldedAffineApply(
1245  b, loc, map.getSubMap({i}), operands);
1246  });
1247 }
1248 
1249 template <typename OpTy>
1251  ArrayRef<OpFoldResult> operands) {
1252  SmallVector<Value> valueOperands;
1253  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1254  composeMultiResultAffineMap(map, valueOperands);
1255  return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands);
1256 }
1257 
1258 AffineMinOp
1260  ArrayRef<OpFoldResult> operands) {
1261  return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1262 }
1263 
1264 template <typename OpTy>
1266  AffineMap map,
1267  ArrayRef<OpFoldResult> operands) {
1268  // Create new builder without a listener, so that no notification is
1269  // triggered if the op is folded.
1270  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1271  // workaround is no longer needed.
1272  OpBuilder newBuilder(b.getContext());
1274 
1275  // Create op.
1276  auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1277 
1278  // Get constant operands.
1279  SmallVector<Attribute> constOperands(minMaxOp->getNumOperands());
1280  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1281  matchPattern(minMaxOp->getOperand(i), m_Constant(&constOperands[i]));
1282 
1283  // Try to fold the operation.
1284  SmallVector<OpFoldResult> foldResults;
1285  if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1286  foldResults.empty()) {
1287  if (OpBuilder::Listener *listener = b.getListener())
1288  listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
1289  return minMaxOp.getResult();
1290  }
1291 
1292  minMaxOp->erase();
1293  assert(foldResults.size() == 1 && "expected 1 folded result");
1294  return foldResults.front();
1295 }
1296 
1299  AffineMap map,
1300  ArrayRef<OpFoldResult> operands) {
1301  return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1302 }
1303 
1306  AffineMap map,
1307  ArrayRef<OpFoldResult> operands) {
1308  return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1309 }
1310 
1311 // A symbol may appear as a dim in affine.apply operations. This function
1312 // canonicalizes dims that are valid symbols into actual symbols.
1313 template <class MapOrSet>
1314 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1315  SmallVectorImpl<Value> *operands) {
1316  if (!mapOrSet || operands->empty())
1317  return;
1318 
1319  assert(mapOrSet->getNumInputs() == operands->size() &&
1320  "map/set inputs must match number of operands");
1321 
1322  auto *context = mapOrSet->getContext();
1323  SmallVector<Value, 8> resultOperands;
1324  resultOperands.reserve(operands->size());
1325  SmallVector<Value, 8> remappedSymbols;
1326  remappedSymbols.reserve(operands->size());
1327  unsigned nextDim = 0;
1328  unsigned nextSym = 0;
1329  unsigned oldNumSyms = mapOrSet->getNumSymbols();
1330  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1331  for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1332  if (i < mapOrSet->getNumDims()) {
1333  if (isValidSymbol((*operands)[i])) {
1334  // This is a valid symbol that appears as a dim, canonicalize it.
1335  dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1336  remappedSymbols.push_back((*operands)[i]);
1337  } else {
1338  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1339  resultOperands.push_back((*operands)[i]);
1340  }
1341  } else {
1342  resultOperands.push_back((*operands)[i]);
1343  }
1344  }
1345 
1346  resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1347  *operands = resultOperands;
1348  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1349  oldNumSyms + nextSym);
1350 
1351  assert(mapOrSet->getNumInputs() == operands->size() &&
1352  "map/set inputs must match number of operands");
1353 }
1354 
1355 // Works for either an affine map or an integer set.
1356 template <class MapOrSet>
1357 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1358  SmallVectorImpl<Value> *operands) {
1359  static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1360  "Argument must be either of AffineMap or IntegerSet type");
1361 
1362  if (!mapOrSet || operands->empty())
1363  return;
1364 
1365  assert(mapOrSet->getNumInputs() == operands->size() &&
1366  "map/set inputs must match number of operands");
1367 
1368  canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1369 
1370  // Check to see what dims are used.
1371  llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1372  llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1373  mapOrSet->walkExprs([&](AffineExpr expr) {
1374  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1375  usedDims[dimExpr.getPosition()] = true;
1376  else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1377  usedSyms[symExpr.getPosition()] = true;
1378  });
1379 
1380  auto *context = mapOrSet->getContext();
1381 
1382  SmallVector<Value, 8> resultOperands;
1383  resultOperands.reserve(operands->size());
1384 
1385  llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1386  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1387  unsigned nextDim = 0;
1388  for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1389  if (usedDims[i]) {
1390  // Remap dim positions for duplicate operands.
1391  auto it = seenDims.find((*operands)[i]);
1392  if (it == seenDims.end()) {
1393  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1394  resultOperands.push_back((*operands)[i]);
1395  seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1396  } else {
1397  dimRemapping[i] = it->second;
1398  }
1399  }
1400  }
1401  llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1402  SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1403  unsigned nextSym = 0;
1404  for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1405  if (!usedSyms[i])
1406  continue;
1407  // Handle constant operands (only needed for symbolic operands since
1408  // constant operands in dimensional positions would have already been
1409  // promoted to symbolic positions above).
1410  IntegerAttr operandCst;
1411  if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1412  m_Constant(&operandCst))) {
1413  symRemapping[i] =
1414  getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1415  continue;
1416  }
1417  // Remap symbol positions for duplicate operands.
1418  auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1419  if (it == seenSymbols.end()) {
1420  symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1421  resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1422  seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1423  symRemapping[i]));
1424  } else {
1425  symRemapping[i] = it->second;
1426  }
1427  }
1428  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1429  nextDim, nextSym);
1430  *operands = resultOperands;
1431 }
1432 
1434  AffineMap *map, SmallVectorImpl<Value> *operands) {
1435  canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1436 }
1437 
1439  IntegerSet *set, SmallVectorImpl<Value> *operands) {
1440  canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1441 }
1442 
1443 namespace {
1444 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1445 /// maps that supply results into them.
1446 ///
1447 template <typename AffineOpTy>
1448 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1450 
1451  /// Replace the affine op with another instance of it with the supplied
1452  /// map and mapOperands.
1453  void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1454  AffineMap map, ArrayRef<Value> mapOperands) const;
1455 
1456  LogicalResult matchAndRewrite(AffineOpTy affineOp,
1457  PatternRewriter &rewriter) const override {
1458  static_assert(
1459  llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1460  AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1461  AffineVectorStoreOp, AffineVectorLoadOp>::value,
1462  "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1463  "expected");
1464  auto map = affineOp.getAffineMap();
1465  AffineMap oldMap = map;
1466  auto oldOperands = affineOp.getMapOperands();
1467  SmallVector<Value, 8> resultOperands(oldOperands);
1468  composeAffineMapAndOperands(&map, &resultOperands);
1469  canonicalizeMapAndOperands(&map, &resultOperands);
1470  simplifyMapWithOperands(map, resultOperands);
1471  if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1472  resultOperands.begin()))
1473  return failure();
1474 
1475  replaceAffineOp(rewriter, affineOp, map, resultOperands);
1476  return success();
1477  }
1478 };
1479 
1480 // Specialize the template to account for the different build signatures for
1481 // affine load, store, and apply ops.
1482 template <>
1483 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1484  PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1485  ArrayRef<Value> mapOperands) const {
1486  rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1487  mapOperands);
1488 }
1489 template <>
1490 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1491  PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1492  ArrayRef<Value> mapOperands) const {
1493  rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1494  prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1495  prefetch.getLocalityHint(), prefetch.getIsDataCache());
1496 }
1497 template <>
1498 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1499  PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1500  ArrayRef<Value> mapOperands) const {
1501  rewriter.replaceOpWithNewOp<AffineStoreOp>(
1502  store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1503 }
1504 template <>
1505 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1506  PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1507  ArrayRef<Value> mapOperands) const {
1508  rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1509  vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1510  mapOperands);
1511 }
1512 template <>
1513 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1514  PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1515  ArrayRef<Value> mapOperands) const {
1516  rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1517  vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1518  mapOperands);
1519 }
1520 
1521 // Generic version for ops that don't have extra operands.
1522 template <typename AffineOpTy>
1523 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1524  PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1525  ArrayRef<Value> mapOperands) const {
1526  rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1527 }
1528 } // namespace
1529 
1530 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1531  MLIRContext *context) {
1532  results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1533 }
1534 
1535 //===----------------------------------------------------------------------===//
1536 // AffineDmaStartOp
1537 //===----------------------------------------------------------------------===//
1538 
1539 // TODO: Check that map operands are loop IVs or symbols.
1540 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1541  Value srcMemRef, AffineMap srcMap,
1542  ValueRange srcIndices, Value destMemRef,
1543  AffineMap dstMap, ValueRange destIndices,
1544  Value tagMemRef, AffineMap tagMap,
1545  ValueRange tagIndices, Value numElements,
1546  Value stride, Value elementsPerStride) {
1547  result.addOperands(srcMemRef);
1548  result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1549  result.addOperands(srcIndices);
1550  result.addOperands(destMemRef);
1551  result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1552  result.addOperands(destIndices);
1553  result.addOperands(tagMemRef);
1554  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1555  result.addOperands(tagIndices);
1556  result.addOperands(numElements);
1557  if (stride) {
1558  result.addOperands({stride, elementsPerStride});
1559  }
1560 }
1561 
1563  p << " " << getSrcMemRef() << '[';
1564  p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1565  p << "], " << getDstMemRef() << '[';
1566  p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1567  p << "], " << getTagMemRef() << '[';
1568  p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1569  p << "], " << getNumElements();
1570  if (isStrided()) {
1571  p << ", " << getStride();
1572  p << ", " << getNumElementsPerStride();
1573  }
1574  p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1575  << getTagMemRefType();
1576 }
1577 
1578 // Parse AffineDmaStartOp.
1579 // Ex:
1580 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1581 // %stride, %num_elt_per_stride
1582 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1583 //
1585  OperationState &result) {
1586  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1587  AffineMapAttr srcMapAttr;
1589  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1590  AffineMapAttr dstMapAttr;
1592  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1593  AffineMapAttr tagMapAttr;
1595  OpAsmParser::UnresolvedOperand numElementsInfo;
1597 
1598  SmallVector<Type, 3> types;
1599  auto indexType = parser.getBuilder().getIndexType();
1600 
1601  // Parse and resolve the following list of operands:
1602  // *) dst memref followed by its affine maps operands (in square brackets).
1603  // *) src memref followed by its affine map operands (in square brackets).
1604  // *) tag memref followed by its affine map operands (in square brackets).
1605  // *) number of elements transferred by DMA operation.
1606  if (parser.parseOperand(srcMemRefInfo) ||
1607  parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1608  getSrcMapAttrStrName(),
1609  result.attributes) ||
1610  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1611  parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1612  getDstMapAttrStrName(),
1613  result.attributes) ||
1614  parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1615  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1616  getTagMapAttrStrName(),
1617  result.attributes) ||
1618  parser.parseComma() || parser.parseOperand(numElementsInfo))
1619  return failure();
1620 
1621  // Parse optional stride and elements per stride.
1622  if (parser.parseTrailingOperandList(strideInfo))
1623  return failure();
1624 
1625  if (!strideInfo.empty() && strideInfo.size() != 2) {
1626  return parser.emitError(parser.getNameLoc(),
1627  "expected two stride related operands");
1628  }
1629  bool isStrided = strideInfo.size() == 2;
1630 
1631  if (parser.parseColonTypeList(types))
1632  return failure();
1633 
1634  if (types.size() != 3)
1635  return parser.emitError(parser.getNameLoc(), "expected three types");
1636 
1637  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1638  parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1639  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1640  parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1641  parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1642  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1643  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1644  return failure();
1645 
1646  if (isStrided) {
1647  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1648  return failure();
1649  }
1650 
1651  // Check that src/dst/tag operand counts match their map.numInputs.
1652  if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1653  dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1654  tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1655  return parser.emitError(parser.getNameLoc(),
1656  "memref operand count not equal to map.numInputs");
1657  return success();
1658 }
1659 
1660 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1661  if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
1662  return emitOpError("expected DMA source to be of memref type");
1663  if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
1664  return emitOpError("expected DMA destination to be of memref type");
1665  if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
1666  return emitOpError("expected DMA tag to be of memref type");
1667 
1668  unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1669  getDstMap().getNumInputs() +
1670  getTagMap().getNumInputs();
1671  if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1672  getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1673  return emitOpError("incorrect number of operands");
1674  }
1675 
1676  Region *scope = getAffineScope(*this);
1677  for (auto idx : getSrcIndices()) {
1678  if (!idx.getType().isIndex())
1679  return emitOpError("src index to dma_start must have 'index' type");
1680  if (!isValidAffineIndexOperand(idx, scope))
1681  return emitOpError(
1682  "src index must be a valid dimension or symbol identifier");
1683  }
1684  for (auto idx : getDstIndices()) {
1685  if (!idx.getType().isIndex())
1686  return emitOpError("dst index to dma_start must have 'index' type");
1687  if (!isValidAffineIndexOperand(idx, scope))
1688  return emitOpError(
1689  "dst index must be a valid dimension or symbol identifier");
1690  }
1691  for (auto idx : getTagIndices()) {
1692  if (!idx.getType().isIndex())
1693  return emitOpError("tag index to dma_start must have 'index' type");
1694  if (!isValidAffineIndexOperand(idx, scope))
1695  return emitOpError(
1696  "tag index must be a valid dimension or symbol identifier");
1697  }
1698  return success();
1699 }
1700 
1701 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1702  SmallVectorImpl<OpFoldResult> &results) {
1703  /// dma_start(memrefcast) -> dma_start
1704  return memref::foldMemRefCast(*this);
1705 }
1706 
1707 void AffineDmaStartOp::getEffects(
1709  &effects) {
1710  effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
1712  effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
1714  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
1716 }
1717 
1718 //===----------------------------------------------------------------------===//
1719 // AffineDmaWaitOp
1720 //===----------------------------------------------------------------------===//
1721 
1722 // TODO: Check that map operands are loop IVs or symbols.
1723 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1724  Value tagMemRef, AffineMap tagMap,
1725  ValueRange tagIndices, Value numElements) {
1726  result.addOperands(tagMemRef);
1727  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1728  result.addOperands(tagIndices);
1729  result.addOperands(numElements);
1730 }
1731 
1733  p << " " << getTagMemRef() << '[';
1734  SmallVector<Value, 2> operands(getTagIndices());
1735  p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1736  p << "], ";
1738  p << " : " << getTagMemRef().getType();
1739 }
1740 
1741 // Parse AffineDmaWaitOp.
1742 // Eg:
1743 // affine.dma_wait %tag[%index], %num_elements
1744 // : memref<1 x i32, (d0) -> (d0), 4>
1745 //
1747  OperationState &result) {
1748  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1749  AffineMapAttr tagMapAttr;
1751  Type type;
1752  auto indexType = parser.getBuilder().getIndexType();
1753  OpAsmParser::UnresolvedOperand numElementsInfo;
1754 
1755  // Parse tag memref, its map operands, and dma size.
1756  if (parser.parseOperand(tagMemRefInfo) ||
1757  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1758  getTagMapAttrStrName(),
1759  result.attributes) ||
1760  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1761  parser.parseColonType(type) ||
1762  parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1763  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1764  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1765  return failure();
1766 
1767  if (!llvm::isa<MemRefType>(type))
1768  return parser.emitError(parser.getNameLoc(),
1769  "expected tag to be of memref type");
1770 
1771  if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1772  return parser.emitError(parser.getNameLoc(),
1773  "tag memref operand count != to map.numInputs");
1774  return success();
1775 }
1776 
1777 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1778  if (!llvm::isa<MemRefType>(getOperand(0).getType()))
1779  return emitOpError("expected DMA tag to be of memref type");
1780  Region *scope = getAffineScope(*this);
1781  for (auto idx : getTagIndices()) {
1782  if (!idx.getType().isIndex())
1783  return emitOpError("index to dma_wait must have 'index' type");
1784  if (!isValidAffineIndexOperand(idx, scope))
1785  return emitOpError(
1786  "index must be a valid dimension or symbol identifier");
1787  }
1788  return success();
1789 }
1790 
1791 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1792  SmallVectorImpl<OpFoldResult> &results) {
1793  /// dma_wait(memrefcast) -> dma_wait
1794  return memref::foldMemRefCast(*this);
1795 }
1796 
1797 void AffineDmaWaitOp::getEffects(
1799  &effects) {
1800  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
1802 }
1803 
1804 //===----------------------------------------------------------------------===//
1805 // AffineForOp
1806 //===----------------------------------------------------------------------===//
1807 
1808 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1809 /// bodyBuilder are empty/null, we include default terminator op.
1810 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1811  ValueRange lbOperands, AffineMap lbMap,
1812  ValueRange ubOperands, AffineMap ubMap, int64_t step,
1813  ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1814  assert(((!lbMap && lbOperands.empty()) ||
1815  lbOperands.size() == lbMap.getNumInputs()) &&
1816  "lower bound operand count does not match the affine map");
1817  assert(((!ubMap && ubOperands.empty()) ||
1818  ubOperands.size() == ubMap.getNumInputs()) &&
1819  "upper bound operand count does not match the affine map");
1820  assert(step > 0 && "step has to be a positive integer constant");
1821 
1822  OpBuilder::InsertionGuard guard(builder);
1823 
1824  // Set variadic segment sizes.
1825  result.addAttribute(
1826  getOperandSegmentSizeAttr(),
1827  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()),
1828  static_cast<int32_t>(ubOperands.size()),
1829  static_cast<int32_t>(iterArgs.size())}));
1830 
1831  for (Value val : iterArgs)
1832  result.addTypes(val.getType());
1833 
1834  // Add an attribute for the step.
1835  result.addAttribute(getStepAttrName(result.name),
1836  builder.getIntegerAttr(builder.getIndexType(), step));
1837 
1838  // Add the lower bound.
1839  result.addAttribute(getLowerBoundMapAttrName(result.name),
1840  AffineMapAttr::get(lbMap));
1841  result.addOperands(lbOperands);
1842 
1843  // Add the upper bound.
1844  result.addAttribute(getUpperBoundMapAttrName(result.name),
1845  AffineMapAttr::get(ubMap));
1846  result.addOperands(ubOperands);
1847 
1848  result.addOperands(iterArgs);
1849  // Create a region and a block for the body. The argument of the region is
1850  // the loop induction variable.
1851  Region *bodyRegion = result.addRegion();
1852  Block *bodyBlock = builder.createBlock(bodyRegion);
1853  Value inductionVar =
1854  bodyBlock->addArgument(builder.getIndexType(), result.location);
1855  for (Value val : iterArgs)
1856  bodyBlock->addArgument(val.getType(), val.getLoc());
1857 
1858  // Create the default terminator if the builder is not provided and if the
1859  // iteration arguments are not provided. Otherwise, leave this to the caller
1860  // because we don't know which values to return from the loop.
1861  if (iterArgs.empty() && !bodyBuilder) {
1862  ensureTerminator(*bodyRegion, builder, result.location);
1863  } else if (bodyBuilder) {
1864  OpBuilder::InsertionGuard guard(builder);
1865  builder.setInsertionPointToStart(bodyBlock);
1866  bodyBuilder(builder, result.location, inductionVar,
1867  bodyBlock->getArguments().drop_front());
1868  }
1869 }
1870 
1871 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1872  int64_t ub, int64_t step, ValueRange iterArgs,
1873  BodyBuilderFn bodyBuilder) {
1874  auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1875  auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1876  return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1877  bodyBuilder);
1878 }
1879 
1880 LogicalResult AffineForOp::verifyRegions() {
1881  // Check that the body defines as single block argument for the induction
1882  // variable.
1883  auto *body = getBody();
1884  if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1885  return emitOpError("expected body to have a single index argument for the "
1886  "induction variable");
1887 
1888  // Verify that the bound operands are valid dimension/symbols.
1889  /// Lower bound.
1890  if (getLowerBoundMap().getNumInputs() > 0)
1892  getLowerBoundMap().getNumDims())))
1893  return failure();
1894  /// Upper bound.
1895  if (getUpperBoundMap().getNumInputs() > 0)
1897  getUpperBoundMap().getNumDims())))
1898  return failure();
1899 
1900  unsigned opNumResults = getNumResults();
1901  if (opNumResults == 0)
1902  return success();
1903 
1904  // If ForOp defines values, check that the number and types of the defined
1905  // values match ForOp initial iter operands and backedge basic block
1906  // arguments.
1907  if (getNumIterOperands() != opNumResults)
1908  return emitOpError(
1909  "mismatch between the number of loop-carried values and results");
1910  if (getNumRegionIterArgs() != opNumResults)
1911  return emitOpError(
1912  "mismatch between the number of basic block args and results");
1913 
1914  return success();
1915 }
1916 
1917 /// Parse a for operation loop bounds.
1918 static ParseResult parseBound(bool isLower, OperationState &result,
1919  OpAsmParser &p) {
1920  // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1921  // the map has multiple results.
1922  bool failedToParsedMinMax =
1923  failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1924 
1925  auto &builder = p.getBuilder();
1926  auto boundAttrStrName =
1927  isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)
1928  : AffineForOp::getUpperBoundMapAttrName(result.name);
1929 
1930  // Parse ssa-id as identity map.
1932  if (p.parseOperandList(boundOpInfos))
1933  return failure();
1934 
1935  if (!boundOpInfos.empty()) {
1936  // Check that only one operand was parsed.
1937  if (boundOpInfos.size() > 1)
1938  return p.emitError(p.getNameLoc(),
1939  "expected only one loop bound operand");
1940 
1941  // TODO: improve error message when SSA value is not of index type.
1942  // Currently it is 'use of value ... expects different type than prior uses'
1943  if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1944  result.operands))
1945  return failure();
1946 
1947  // Create an identity map using symbol id. This representation is optimized
1948  // for storage. Analysis passes may expand it into a multi-dimensional map
1949  // if desired.
1950  AffineMap map = builder.getSymbolIdentityMap();
1951  result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
1952  return success();
1953  }
1954 
1955  // Get the attribute location.
1956  SMLoc attrLoc = p.getCurrentLocation();
1957 
1958  Attribute boundAttr;
1959  if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
1960  result.attributes))
1961  return failure();
1962 
1963  // Parse full form - affine map followed by dim and symbol list.
1964  if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
1965  unsigned currentNumOperands = result.operands.size();
1966  unsigned numDims;
1967  if (parseDimAndSymbolList(p, result.operands, numDims))
1968  return failure();
1969 
1970  auto map = affineMapAttr.getValue();
1971  if (map.getNumDims() != numDims)
1972  return p.emitError(
1973  p.getNameLoc(),
1974  "dim operand count and affine map dim count must match");
1975 
1976  unsigned numDimAndSymbolOperands =
1977  result.operands.size() - currentNumOperands;
1978  if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1979  return p.emitError(
1980  p.getNameLoc(),
1981  "symbol operand count and affine map symbol count must match");
1982 
1983  // If the map has multiple results, make sure that we parsed the min/max
1984  // prefix.
1985  if (map.getNumResults() > 1 && failedToParsedMinMax) {
1986  if (isLower) {
1987  return p.emitError(attrLoc, "lower loop bound affine map with "
1988  "multiple results requires 'max' prefix");
1989  }
1990  return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1991  "results requires 'min' prefix");
1992  }
1993  return success();
1994  }
1995 
1996  // Parse custom assembly form.
1997  if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
1998  result.attributes.pop_back();
1999  result.addAttribute(
2000  boundAttrStrName,
2001  AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
2002  return success();
2003  }
2004 
2005  return p.emitError(
2006  p.getNameLoc(),
2007  "expected valid affine map representation for loop bounds");
2008 }
2009 
2010 ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
2011  auto &builder = parser.getBuilder();
2012  OpAsmParser::Argument inductionVariable;
2013  inductionVariable.type = builder.getIndexType();
2014  // Parse the induction variable followed by '='.
2015  if (parser.parseArgument(inductionVariable) || parser.parseEqual())
2016  return failure();
2017 
2018  // Parse loop bounds.
2019  int64_t numOperands = result.operands.size();
2020  if (parseBound(/*isLower=*/true, result, parser))
2021  return failure();
2022  int64_t numLbOperands = result.operands.size() - numOperands;
2023  if (parser.parseKeyword("to", " between bounds"))
2024  return failure();
2025  numOperands = result.operands.size();
2026  if (parseBound(/*isLower=*/false, result, parser))
2027  return failure();
2028  int64_t numUbOperands = result.operands.size() - numOperands;
2029 
2030  // Parse the optional loop step, we default to 1 if one is not present.
2031  if (parser.parseOptionalKeyword("step")) {
2032  result.addAttribute(
2033  getStepAttrName(result.name),
2034  builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
2035  } else {
2036  SMLoc stepLoc = parser.getCurrentLocation();
2037  IntegerAttr stepAttr;
2038  if (parser.parseAttribute(stepAttr, builder.getIndexType(),
2039  getStepAttrName(result.name).data(),
2040  result.attributes))
2041  return failure();
2042 
2043  if (stepAttr.getValue().isNegative())
2044  return parser.emitError(
2045  stepLoc,
2046  "expected step to be representable as a positive signed integer");
2047  }
2048 
2049  // Parse the optional initial iteration arguments.
2052 
2053  // Induction variable.
2054  regionArgs.push_back(inductionVariable);
2055 
2056  if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
2057  // Parse assignment list and results type list.
2058  if (parser.parseAssignmentList(regionArgs, operands) ||
2059  parser.parseArrowTypeList(result.types))
2060  return failure();
2061  // Resolve input operands.
2062  for (auto argOperandType :
2063  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2064  Type type = std::get<2>(argOperandType);
2065  std::get<0>(argOperandType).type = type;
2066  if (parser.resolveOperand(std::get<1>(argOperandType), type,
2067  result.operands))
2068  return failure();
2069  }
2070  }
2071 
2072  result.addAttribute(
2073  getOperandSegmentSizeAttr(),
2074  builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands),
2075  static_cast<int32_t>(numUbOperands),
2076  static_cast<int32_t>(operands.size())}));
2077 
2078  // Parse the body region.
2079  Region *body = result.addRegion();
2080  if (regionArgs.size() != result.types.size() + 1)
2081  return parser.emitError(
2082  parser.getNameLoc(),
2083  "mismatch between the number of loop-carried values and results");
2084  if (parser.parseRegion(*body, regionArgs))
2085  return failure();
2086 
2087  AffineForOp::ensureTerminator(*body, builder, result.location);
2088 
2089  // Parse the optional attribute list.
2090  return parser.parseOptionalAttrDict(result.attributes);
2091 }
2092 
2093 static void printBound(AffineMapAttr boundMap,
2094  Operation::operand_range boundOperands,
2095  const char *prefix, OpAsmPrinter &p) {
2096  AffineMap map = boundMap.getValue();
2097 
2098  // Check if this bound should be printed using custom assembly form.
2099  // The decision to restrict printing custom assembly form to trivial cases
2100  // comes from the will to roundtrip MLIR binary -> text -> binary in a
2101  // lossless way.
2102  // Therefore, custom assembly form parsing and printing is only supported for
2103  // zero-operand constant maps and single symbol operand identity maps.
2104  if (map.getNumResults() == 1) {
2105  AffineExpr expr = map.getResult(0);
2106 
2107  // Print constant bound.
2108  if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2109  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2110  p << constExpr.getValue();
2111  return;
2112  }
2113  }
2114 
2115  // Print bound that consists of a single SSA symbol if the map is over a
2116  // single symbol.
2117  if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2118  if (dyn_cast<AffineSymbolExpr>(expr)) {
2119  p.printOperand(*boundOperands.begin());
2120  return;
2121  }
2122  }
2123  } else {
2124  // Map has multiple results. Print 'min' or 'max' prefix.
2125  p << prefix << ' ';
2126  }
2127 
2128  // Print the map and its operands.
2129  p << boundMap;
2130  printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
2131  map.getNumDims(), p);
2132 }
2133 
2134 unsigned AffineForOp::getNumIterOperands() {
2135  AffineMap lbMap = getLowerBoundMapAttr().getValue();
2136  AffineMap ubMap = getUpperBoundMapAttr().getValue();
2137 
2138  return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2139 }
2140 
2141 std::optional<MutableArrayRef<OpOperand>>
2142 AffineForOp::getYieldedValuesMutable() {
2143  return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2144 }
2145 
2147  p << ' ';
2148  p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
2149  /*omitType=*/true);
2150  p << " = ";
2151  printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
2152  p << " to ";
2153  printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
2154 
2155  if (getStepAsInt() != 1)
2156  p << " step " << getStepAsInt();
2157 
2158  bool printBlockTerminators = false;
2159  if (getNumIterOperands() > 0) {
2160  p << " iter_args(";
2161  auto regionArgs = getRegionIterArgs();
2162  auto operands = getInits();
2163 
2164  llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2165  p << std::get<0>(it) << " = " << std::get<1>(it);
2166  });
2167  p << ") -> (" << getResultTypes() << ")";
2168  printBlockTerminators = true;
2169  }
2170 
2171  p << ' ';
2172  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2173  printBlockTerminators);
2175  (*this)->getAttrs(),
2176  /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()),
2177  getUpperBoundMapAttrName(getOperation()->getName()),
2178  getStepAttrName(getOperation()->getName()),
2179  getOperandSegmentSizeAttr()});
2180 }
2181 
2182 /// Fold the constant bounds of a loop.
2183 static LogicalResult foldLoopBounds(AffineForOp forOp) {
2184  auto foldLowerOrUpperBound = [&forOp](bool lower) {
2185  // Check to see if each of the operands is the result of a constant. If
2186  // so, get the value. If not, ignore it.
2187  SmallVector<Attribute, 8> operandConstants;
2188  auto boundOperands =
2189  lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2190  for (auto operand : boundOperands) {
2191  Attribute operandCst;
2192  matchPattern(operand, m_Constant(&operandCst));
2193  operandConstants.push_back(operandCst);
2194  }
2195 
2196  AffineMap boundMap =
2197  lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2198  assert(boundMap.getNumResults() >= 1 &&
2199  "bound maps should have at least one result");
2200  SmallVector<Attribute, 4> foldedResults;
2201  if (failed(boundMap.constantFold(operandConstants, foldedResults)))
2202  return failure();
2203 
2204  // Compute the max or min as applicable over the results.
2205  assert(!foldedResults.empty() && "bounds should have at least one result");
2206  auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2207  for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2208  auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2209  maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2210  : llvm::APIntOps::smin(maxOrMin, foldedResult);
2211  }
2212  lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2213  : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2214  return success();
2215  };
2216 
2217  // Try to fold the lower bound.
2218  bool folded = false;
2219  if (!forOp.hasConstantLowerBound())
2220  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
2221 
2222  // Try to fold the upper bound.
2223  if (!forOp.hasConstantUpperBound())
2224  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
2225  return success(folded);
2226 }
2227 
2228 /// Canonicalize the bounds of the given loop.
2229 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2230  SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2231  SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2232 
2233  auto lbMap = forOp.getLowerBoundMap();
2234  auto ubMap = forOp.getUpperBoundMap();
2235  auto prevLbMap = lbMap;
2236  auto prevUbMap = ubMap;
2237 
2238  composeAffineMapAndOperands(&lbMap, &lbOperands);
2239  canonicalizeMapAndOperands(&lbMap, &lbOperands);
2240  simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
2241  simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
2242  lbMap = removeDuplicateExprs(lbMap);
2243 
2244  composeAffineMapAndOperands(&ubMap, &ubOperands);
2245  canonicalizeMapAndOperands(&ubMap, &ubOperands);
2246  ubMap = removeDuplicateExprs(ubMap);
2247 
2248  // Any canonicalization change always leads to updated map(s).
2249  if (lbMap == prevLbMap && ubMap == prevUbMap)
2250  return failure();
2251 
2252  if (lbMap != prevLbMap)
2253  forOp.setLowerBound(lbOperands, lbMap);
2254  if (ubMap != prevUbMap)
2255  forOp.setUpperBound(ubOperands, ubMap);
2256  return success();
2257 }
2258 
2259 namespace {
2260 /// Returns constant trip count in trivial cases.
2261 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2262  int64_t step = forOp.getStepAsInt();
2263  if (!forOp.hasConstantBounds() || step <= 0)
2264  return std::nullopt;
2265  int64_t lb = forOp.getConstantLowerBound();
2266  int64_t ub = forOp.getConstantUpperBound();
2267  return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2268 }
2269 
2270 /// This is a pattern to fold trivially empty loop bodies.
2271 /// TODO: This should be moved into the folding hook.
2272 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
2274 
2275  LogicalResult matchAndRewrite(AffineForOp forOp,
2276  PatternRewriter &rewriter) const override {
2277  // Check that the body only contains a yield.
2278  if (!llvm::hasSingleElement(*forOp.getBody()))
2279  return failure();
2280  if (forOp.getNumResults() == 0)
2281  return success();
2282  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2283  if (tripCount && *tripCount == 0) {
2284  // The initial values of the iteration arguments would be the op's
2285  // results.
2286  rewriter.replaceOp(forOp, forOp.getInits());
2287  return success();
2288  }
2289  SmallVector<Value, 4> replacements;
2290  auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2291  auto iterArgs = forOp.getRegionIterArgs();
2292  bool hasValDefinedOutsideLoop = false;
2293  bool iterArgsNotInOrder = false;
2294  for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2295  Value val = yieldOp.getOperand(i);
2296  auto *iterArgIt = llvm::find(iterArgs, val);
2297  if (iterArgIt == iterArgs.end()) {
2298  // `val` is defined outside of the loop.
2299  assert(forOp.isDefinedOutsideOfLoop(val) &&
2300  "must be defined outside of the loop");
2301  hasValDefinedOutsideLoop = true;
2302  replacements.push_back(val);
2303  } else {
2304  unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2305  if (pos != i)
2306  iterArgsNotInOrder = true;
2307  replacements.push_back(forOp.getInits()[pos]);
2308  }
2309  }
2310  // Bail out when the trip count is unknown and the loop returns any value
2311  // defined outside of the loop or any iterArg out of order.
2312  if (!tripCount.has_value() &&
2313  (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2314  return failure();
2315  // Bail out when the loop iterates more than once and it returns any iterArg
2316  // out of order.
2317  if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2318  return failure();
2319  rewriter.replaceOp(forOp, replacements);
2320  return success();
2321  }
2322 };
2323 } // namespace
2324 
2325 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
2326  MLIRContext *context) {
2327  results.add<AffineForEmptyLoopFolder>(context);
2328 }
2329 
2330 OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2331  assert((point.isParent() || point == getRegion()) && "invalid region point");
2332 
2333  // The initial operands map to the loop arguments after the induction
2334  // variable or are forwarded to the results when the trip count is zero.
2335  return getInits();
2336 }
2337 
2338 void AffineForOp::getSuccessorRegions(
2340  assert((point.isParent() || point == getRegion()) && "expected loop region");
2341  // The loop may typically branch back to its body or to the parent operation.
2342  // If the predecessor is the parent op and the trip count is known to be at
2343  // least one, branch into the body using the iterator arguments. And in cases
2344  // we know the trip count is zero, it can only branch back to its parent.
2345  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2346  if (point.isParent() && tripCount.has_value()) {
2347  if (tripCount.value() > 0) {
2348  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2349  return;
2350  }
2351  if (tripCount.value() == 0) {
2352  regions.push_back(RegionSuccessor(getResults()));
2353  return;
2354  }
2355  }
2356 
2357  // From the loop body, if the trip count is one, we can only branch back to
2358  // the parent.
2359  if (!point.isParent() && tripCount && *tripCount == 1) {
2360  regions.push_back(RegionSuccessor(getResults()));
2361  return;
2362  }
2363 
2364  // In all other cases, the loop may branch back to itself or the parent
2365  // operation.
2366  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2367  regions.push_back(RegionSuccessor(getResults()));
2368 }
2369 
2370 /// Returns true if the affine.for has zero iterations in trivial cases.
2371 static bool hasTrivialZeroTripCount(AffineForOp op) {
2372  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
2373  return tripCount && *tripCount == 0;
2374 }
2375 
2376 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2377  SmallVectorImpl<OpFoldResult> &results) {
2378  bool folded = succeeded(foldLoopBounds(*this));
2379  folded |= succeeded(canonicalizeLoopBounds(*this));
2380  if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
2381  // The initial values of the loop-carried variables (iter_args) are the
2382  // results of the op. But this must be avoided for an affine.for op that
2383  // does not return any results. Since ops that do not return results cannot
2384  // be folded away, we would enter an infinite loop of folds on the same
2385  // affine.for op.
2386  results.assign(getInits().begin(), getInits().end());
2387  folded = true;
2388  }
2389  return success(folded);
2390 }
2391 
2393  return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
2394 }
2395 
2397  return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap());
2398 }
2399 
2400 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2401  assert(lbOperands.size() == map.getNumInputs());
2402  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2403  getLowerBoundOperandsMutable().assign(lbOperands);
2404  setLowerBoundMap(map);
2405 }
2406 
2407 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2408  assert(ubOperands.size() == map.getNumInputs());
2409  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2410  getUpperBoundOperandsMutable().assign(ubOperands);
2411  setUpperBoundMap(map);
2412 }
2413 
2414 bool AffineForOp::hasConstantLowerBound() {
2415  return getLowerBoundMap().isSingleConstant();
2416 }
2417 
2418 bool AffineForOp::hasConstantUpperBound() {
2419  return getUpperBoundMap().isSingleConstant();
2420 }
2421 
2422 int64_t AffineForOp::getConstantLowerBound() {
2423  return getLowerBoundMap().getSingleConstantResult();
2424 }
2425 
2426 int64_t AffineForOp::getConstantUpperBound() {
2427  return getUpperBoundMap().getSingleConstantResult();
2428 }
2429 
2430 void AffineForOp::setConstantLowerBound(int64_t value) {
2431  setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2432 }
2433 
2434 void AffineForOp::setConstantUpperBound(int64_t value) {
2435  setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2436 }
2437 
2438 AffineForOp::operand_range AffineForOp::getControlOperands() {
2439  return {operand_begin(), operand_begin() + getLowerBoundOperands().size() +
2440  getUpperBoundOperands().size()};
2441 }
2442 
2443 bool AffineForOp::matchingBoundOperandList() {
2444  auto lbMap = getLowerBoundMap();
2445  auto ubMap = getUpperBoundMap();
2446  if (lbMap.getNumDims() != ubMap.getNumDims() ||
2447  lbMap.getNumSymbols() != ubMap.getNumSymbols())
2448  return false;
2449 
2450  unsigned numOperands = lbMap.getNumInputs();
2451  for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2452  // Compare Value 's.
2453  if (getOperand(i) != getOperand(numOperands + i))
2454  return false;
2455  }
2456  return true;
2457 }
2458 
2459 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
2460 
2461 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2462  return SmallVector<Value>{getInductionVar()};
2463 }
2464 
2465 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2466  if (!hasConstantLowerBound())
2467  return std::nullopt;
2468  OpBuilder b(getContext());
2470  OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2471 }
2472 
2473 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2474  OpBuilder b(getContext());
2476  OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
2477 }
2478 
2479 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2480  if (!hasConstantUpperBound())
2481  return {};
2482  OpBuilder b(getContext());
2484  OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2485 }
2486 
2487 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2488  RewriterBase &rewriter, ValueRange newInitOperands,
2489  bool replaceInitOperandUsesInLoop,
2490  const NewYieldValuesFn &newYieldValuesFn) {
2491  // Create a new loop before the existing one, with the extra operands.
2492  OpBuilder::InsertionGuard g(rewriter);
2493  rewriter.setInsertionPoint(getOperation());
2494  auto inits = llvm::to_vector(getInits());
2495  inits.append(newInitOperands.begin(), newInitOperands.end());
2496  AffineForOp newLoop = rewriter.create<AffineForOp>(
2497  getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2498  getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
2499 
2500  // Generate the new yield values and append them to the scf.yield operation.
2501  auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2502  ArrayRef<BlockArgument> newIterArgs =
2503  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2504  {
2505  OpBuilder::InsertionGuard g(rewriter);
2506  rewriter.setInsertionPoint(yieldOp);
2507  SmallVector<Value> newYieldedValues =
2508  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2509  assert(newInitOperands.size() == newYieldedValues.size() &&
2510  "expected as many new yield values as new iter operands");
2511  rewriter.modifyOpInPlace(yieldOp, [&]() {
2512  yieldOp.getOperandsMutable().append(newYieldedValues);
2513  });
2514  }
2515 
2516  // Move the loop body to the new op.
2517  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2518  newLoop.getBody()->getArguments().take_front(
2519  getBody()->getNumArguments()));
2520 
2521  if (replaceInitOperandUsesInLoop) {
2522  // Replace all uses of `newInitOperands` with the corresponding basic block
2523  // arguments.
2524  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2525  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
2526  [&](OpOperand &use) {
2527  Operation *user = use.getOwner();
2528  return newLoop->isProperAncestor(user);
2529  });
2530  }
2531  }
2532 
2533  // Replace the old loop.
2534  rewriter.replaceOp(getOperation(),
2535  newLoop->getResults().take_front(getNumResults()));
2536  return cast<LoopLikeOpInterface>(newLoop.getOperation());
2537 }
2538 
2539 Speculation::Speculatability AffineForOp::getSpeculatability() {
2540  // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2541  // Start and End.
2542  //
2543  // For Step != 1, the loop may not terminate. We can add more smarts here if
2544  // needed.
2545  return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable
2547 }
2548 
2549 /// Returns true if the provided value is the induction variable of a
2550 /// AffineForOp.
2552  return getForInductionVarOwner(val) != AffineForOp();
2553 }
2554 
2556  return getAffineParallelInductionVarOwner(val) != nullptr;
2557 }
2558 
2561 }
2562 
2564  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2565  if (!ivArg || !ivArg.getOwner())
2566  return AffineForOp();
2567  auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
2568  if (auto forOp = dyn_cast<AffineForOp>(containingInst))
2569  // Check to make sure `val` is the induction variable, not an iter_arg.
2570  return forOp.getInductionVar() == val ? forOp : AffineForOp();
2571  return AffineForOp();
2572 }
2573 
2575  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2576  if (!ivArg || !ivArg.getOwner())
2577  return nullptr;
2578  Operation *containingOp = ivArg.getOwner()->getParentOp();
2579  auto parallelOp = dyn_cast<AffineParallelOp>(containingOp);
2580  if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2581  return parallelOp;
2582  return nullptr;
2583 }
2584 
2585 /// Extracts the induction variables from a list of AffineForOps and returns
2586 /// them.
2588  SmallVectorImpl<Value> *ivs) {
2589  ivs->reserve(forInsts.size());
2590  for (auto forInst : forInsts)
2591  ivs->push_back(forInst.getInductionVar());
2592 }
2593 
2596  ivs.reserve(affineOps.size());
2597  for (Operation *op : affineOps) {
2598  // Add constraints from forOp's bounds.
2599  if (auto forOp = dyn_cast<AffineForOp>(op))
2600  ivs.push_back(forOp.getInductionVar());
2601  else if (auto parallelOp = dyn_cast<AffineParallelOp>(op))
2602  for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2603  ivs.push_back(parallelOp.getBody()->getArgument(i));
2604  }
2605 }
2606 
2607 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2608 /// operations.
2609 template <typename BoundListTy, typename LoopCreatorTy>
2611  OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2612  ArrayRef<int64_t> steps,
2613  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2614  LoopCreatorTy &&loopCreatorFn) {
2615  assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2616  assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2617 
2618  // If there are no loops to be constructed, construct the body anyway.
2619  OpBuilder::InsertionGuard guard(builder);
2620  if (lbs.empty()) {
2621  if (bodyBuilderFn)
2622  bodyBuilderFn(builder, loc, ValueRange());
2623  return;
2624  }
2625 
2626  // Create the loops iteratively and store the induction variables.
2628  ivs.reserve(lbs.size());
2629  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2630  // Callback for creating the loop body, always creates the terminator.
2631  auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
2632  ValueRange iterArgs) {
2633  ivs.push_back(iv);
2634  // In the innermost loop, call the body builder.
2635  if (i == e - 1 && bodyBuilderFn) {
2636  OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
2637  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2638  }
2639  nestedBuilder.create<AffineYieldOp>(nestedLoc);
2640  };
2641 
2642  // Delegate actual loop creation to the callback in order to dispatch
2643  // between constant- and variable-bound loops.
2644  auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2645  builder.setInsertionPointToStart(loop.getBody());
2646  }
2647 }
2648 
2649 /// Creates an affine loop from the bounds known to be constants.
2650 static AffineForOp
2652  int64_t ub, int64_t step,
2653  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2654  return builder.create<AffineForOp>(loc, lb, ub, step,
2655  /*iterArgs=*/std::nullopt, bodyBuilderFn);
2656 }
2657 
2658 /// Creates an affine loop from the bounds that may or may not be constants.
2659 static AffineForOp
2661  int64_t step,
2662  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2663  std::optional<int64_t> lbConst = getConstantIntValue(lb);
2664  std::optional<int64_t> ubConst = getConstantIntValue(ub);
2665  if (lbConst && ubConst)
2666  return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
2667  ubConst.value(), step, bodyBuilderFn);
2668  return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
2669  builder.getDimIdentityMap(), step,
2670  /*iterArgs=*/std::nullopt, bodyBuilderFn);
2671 }
2672 
2674  OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
2676  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2677  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2679 }
2680 
2682  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
2683  ArrayRef<int64_t> steps,
2684  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2685  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2687 }
2688 
2689 //===----------------------------------------------------------------------===//
2690 // AffineIfOp
2691 //===----------------------------------------------------------------------===//
2692 
2693 namespace {
2694 /// Remove else blocks that have nothing other than a zero value yield.
2695 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
2697 
2698  LogicalResult matchAndRewrite(AffineIfOp ifOp,
2699  PatternRewriter &rewriter) const override {
2700  if (ifOp.getElseRegion().empty() ||
2701  !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2702  return failure();
2703 
2704  rewriter.startOpModification(ifOp);
2705  rewriter.eraseBlock(ifOp.getElseBlock());
2706  rewriter.finalizeOpModification(ifOp);
2707  return success();
2708  }
2709 };
2710 
2711 /// Removes affine.if cond if the condition is always true or false in certain
2712 /// trivial cases. Promotes the then/else block in the parent operation block.
2713 struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
2715 
2716  LogicalResult matchAndRewrite(AffineIfOp op,
2717  PatternRewriter &rewriter) const override {
2718 
2719  auto isTriviallyFalse = [](IntegerSet iSet) {
2720  return iSet.isEmptyIntegerSet();
2721  };
2722 
2723  auto isTriviallyTrue = [](IntegerSet iSet) {
2724  return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2725  iSet.getConstraint(0) == 0);
2726  };
2727 
2728  IntegerSet affineIfConditions = op.getIntegerSet();
2729  Block *blockToMove;
2730  if (isTriviallyFalse(affineIfConditions)) {
2731  // The absence, or equivalently, the emptiness of the else region need not
2732  // be checked when affine.if is returning results because if an affine.if
2733  // operation is returning results, it always has a non-empty else region.
2734  if (op.getNumResults() == 0 && !op.hasElse()) {
2735  // If the else region is absent, or equivalently, empty, remove the
2736  // affine.if operation (which is not returning any results).
2737  rewriter.eraseOp(op);
2738  return success();
2739  }
2740  blockToMove = op.getElseBlock();
2741  } else if (isTriviallyTrue(affineIfConditions)) {
2742  blockToMove = op.getThenBlock();
2743  } else {
2744  return failure();
2745  }
2746  Operation *blockToMoveTerminator = blockToMove->getTerminator();
2747  // Promote the "blockToMove" block to the parent operation block between the
2748  // prologue and epilogue of "op".
2749  rewriter.inlineBlockBefore(blockToMove, op);
2750  // Replace the "op" operation with the operands of the
2751  // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2752  // the affine.yield operation present in the "blockToMove" block. It has no
2753  // operands when affine.if is not returning results and therefore, in that
2754  // case, replaceOp just erases "op". When affine.if is not returning
2755  // results, the affine.yield operation can be omitted. It gets inserted
2756  // implicitly.
2757  rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
2758  // Erase the "blockToMoveTerminator" operation since it is now in the parent
2759  // operation block, which already has its own terminator.
2760  rewriter.eraseOp(blockToMoveTerminator);
2761  return success();
2762  }
2763 };
2764 } // namespace
2765 
2766 /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
2767 /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
2768 void AffineIfOp::getSuccessorRegions(
2770  // If the predecessor is an AffineIfOp, then branching into both `then` and
2771  // `else` region is valid.
2772  if (point.isParent()) {
2773  regions.reserve(2);
2774  regions.push_back(
2775  RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
2776  // If the "else" region is empty, branch bach into parent.
2777  if (getElseRegion().empty()) {
2778  regions.push_back(getResults());
2779  } else {
2780  regions.push_back(
2781  RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
2782  }
2783  return;
2784  }
2785 
2786  // If the predecessor is the `else`/`then` region, then branching into parent
2787  // op is valid.
2788  regions.push_back(RegionSuccessor(getResults()));
2789 }
2790 
2791 LogicalResult AffineIfOp::verify() {
2792  // Verify that we have a condition attribute.
2793  // FIXME: This should be specified in the arguments list in ODS.
2794  auto conditionAttr =
2795  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2796  if (!conditionAttr)
2797  return emitOpError("requires an integer set attribute named 'condition'");
2798 
2799  // Verify that there are enough operands for the condition.
2800  IntegerSet condition = conditionAttr.getValue();
2801  if (getNumOperands() != condition.getNumInputs())
2802  return emitOpError("operand count and condition integer set dimension and "
2803  "symbol count must match");
2804 
2805  // Verify that the operands are valid dimension/symbols.
2806  if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
2807  condition.getNumDims())))
2808  return failure();
2809 
2810  return success();
2811 }
2812 
2813 ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
2814  // Parse the condition attribute set.
2815  IntegerSetAttr conditionAttr;
2816  unsigned numDims;
2817  if (parser.parseAttribute(conditionAttr,
2818  AffineIfOp::getConditionAttrStrName(),
2819  result.attributes) ||
2820  parseDimAndSymbolList(parser, result.operands, numDims))
2821  return failure();
2822 
2823  // Verify the condition operands.
2824  auto set = conditionAttr.getValue();
2825  if (set.getNumDims() != numDims)
2826  return parser.emitError(
2827  parser.getNameLoc(),
2828  "dim operand count and integer set dim count must match");
2829  if (numDims + set.getNumSymbols() != result.operands.size())
2830  return parser.emitError(
2831  parser.getNameLoc(),
2832  "symbol operand count and integer set symbol count must match");
2833 
2834  if (parser.parseOptionalArrowTypeList(result.types))
2835  return failure();
2836 
2837  // Create the regions for 'then' and 'else'. The latter must be created even
2838  // if it remains empty for the validity of the operation.
2839  result.regions.reserve(2);
2840  Region *thenRegion = result.addRegion();
2841  Region *elseRegion = result.addRegion();
2842 
2843  // Parse the 'then' region.
2844  if (parser.parseRegion(*thenRegion, {}, {}))
2845  return failure();
2846  AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
2847  result.location);
2848 
2849  // If we find an 'else' keyword then parse the 'else' region.
2850  if (!parser.parseOptionalKeyword("else")) {
2851  if (parser.parseRegion(*elseRegion, {}, {}))
2852  return failure();
2853  AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
2854  result.location);
2855  }
2856 
2857  // Parse the optional attribute list.
2858  if (parser.parseOptionalAttrDict(result.attributes))
2859  return failure();
2860 
2861  return success();
2862 }
2863 
2865  auto conditionAttr =
2866  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2867  p << " " << conditionAttr;
2868  printDimAndSymbolList(operand_begin(), operand_end(),
2869  conditionAttr.getValue().getNumDims(), p);
2870  p.printOptionalArrowTypeList(getResultTypes());
2871  p << ' ';
2872  p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
2873  /*printBlockTerminators=*/getNumResults());
2874 
2875  // Print the 'else' regions if it has any blocks.
2876  auto &elseRegion = this->getElseRegion();
2877  if (!elseRegion.empty()) {
2878  p << " else ";
2879  p.printRegion(elseRegion,
2880  /*printEntryBlockArgs=*/false,
2881  /*printBlockTerminators=*/getNumResults());
2882  }
2883 
2884  // Print the attribute list.
2885  p.printOptionalAttrDict((*this)->getAttrs(),
2886  /*elidedAttrs=*/getConditionAttrStrName());
2887 }
2888 
2889 IntegerSet AffineIfOp::getIntegerSet() {
2890  return (*this)
2891  ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2892  .getValue();
2893 }
2894 
2895 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
2896  (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
2897 }
2898 
2899 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
2900  setIntegerSet(set);
2901  (*this)->setOperands(operands);
2902 }
2903 
2904 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2905  TypeRange resultTypes, IntegerSet set, ValueRange args,
2906  bool withElseRegion) {
2907  assert(resultTypes.empty() || withElseRegion);
2908  OpBuilder::InsertionGuard guard(builder);
2909 
2910  result.addTypes(resultTypes);
2911  result.addOperands(args);
2912  result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
2913 
2914  Region *thenRegion = result.addRegion();
2915  builder.createBlock(thenRegion);
2916  if (resultTypes.empty())
2917  AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
2918 
2919  Region *elseRegion = result.addRegion();
2920  if (withElseRegion) {
2921  builder.createBlock(elseRegion);
2922  if (resultTypes.empty())
2923  AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
2924  }
2925 }
2926 
2927 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2928  IntegerSet set, ValueRange args, bool withElseRegion) {
2929  AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
2930  withElseRegion);
2931 }
2932 
2933 /// Compose any affine.apply ops feeding into `operands` of the integer set
2934 /// `set` by composing the maps of such affine.apply ops with the integer
2935 /// set constraints.
2937  SmallVectorImpl<Value> &operands) {
2938  // We will simply reuse the API of the map composition by viewing the LHSs of
2939  // the equalities and inequalities of `set` as the affine exprs of an affine
2940  // map. Convert to equivalent map, compose, and convert back to set.
2941  auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(),
2942  set.getConstraints(), set.getContext());
2943  // Check if any composition is possible.
2944  if (llvm::none_of(operands,
2945  [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
2946  return;
2947 
2948  composeAffineMapAndOperands(&map, &operands);
2949  set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
2950  set.getEqFlags());
2951 }
2952 
2953 /// Canonicalize an affine if op's conditional (integer set + operands).
2954 LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2955  auto set = getIntegerSet();
2956  SmallVector<Value, 4> operands(getOperands());
2957  composeSetAndOperands(set, operands);
2958  canonicalizeSetAndOperands(&set, &operands);
2959 
2960  // Check if the canonicalization or composition led to any change.
2961  if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
2962  return failure();
2963 
2964  setConditional(set, operands);
2965  return success();
2966 }
2967 
2968 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2969  MLIRContext *context) {
2970  results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
2971 }
2972 
2973 //===----------------------------------------------------------------------===//
2974 // AffineLoadOp
2975 //===----------------------------------------------------------------------===//
2976 
2977 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2978  AffineMap map, ValueRange operands) {
2979  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2980  result.addOperands(operands);
2981  if (map)
2982  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2983  auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
2984  result.types.push_back(memrefType.getElementType());
2985 }
2986 
2987 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2988  Value memref, AffineMap map, ValueRange mapOperands) {
2989  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2990  result.addOperands(memref);
2991  result.addOperands(mapOperands);
2992  auto memrefType = llvm::cast<MemRefType>(memref.getType());
2993  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2994  result.types.push_back(memrefType.getElementType());
2995 }
2996 
2997 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2998  Value memref, ValueRange indices) {
2999  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3000  int64_t rank = memrefType.getRank();
3001  // Create identity map for memrefs with at least one dimension or () -> ()
3002  // for zero-dimensional memrefs.
3003  auto map =
3004  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3005  build(builder, result, memref, map, indices);
3006 }
3007 
3008 ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
3009  auto &builder = parser.getBuilder();
3010  auto indexTy = builder.getIndexType();
3011 
3012  MemRefType type;
3013  OpAsmParser::UnresolvedOperand memrefInfo;
3014  AffineMapAttr mapAttr;
3016  return failure(
3017  parser.parseOperand(memrefInfo) ||
3018  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3019  AffineLoadOp::getMapAttrStrName(),
3020  result.attributes) ||
3021  parser.parseOptionalAttrDict(result.attributes) ||
3022  parser.parseColonType(type) ||
3023  parser.resolveOperand(memrefInfo, type, result.operands) ||
3024  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3025  parser.addTypeToList(type.getElementType(), result.types));
3026 }
3027 
3029  p << " " << getMemRef() << '[';
3030  if (AffineMapAttr mapAttr =
3031  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3032  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3033  p << ']';
3034  p.printOptionalAttrDict((*this)->getAttrs(),
3035  /*elidedAttrs=*/{getMapAttrStrName()});
3036  p << " : " << getMemRefType();
3037 }
3038 
3039 /// Verify common indexing invariants of affine.load, affine.store,
3040 /// affine.vector_load and affine.vector_store.
3041 static LogicalResult
3042 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
3043  Operation::operand_range mapOperands,
3044  MemRefType memrefType, unsigned numIndexOperands) {
3045  AffineMap map = mapAttr.getValue();
3046  if (map.getNumResults() != memrefType.getRank())
3047  return op->emitOpError("affine map num results must equal memref rank");
3048  if (map.getNumInputs() != numIndexOperands)
3049  return op->emitOpError("expects as many subscripts as affine map inputs");
3050 
3051  Region *scope = getAffineScope(op);
3052  for (auto idx : mapOperands) {
3053  if (!idx.getType().isIndex())
3054  return op->emitOpError("index to load must have 'index' type");
3055  if (!isValidAffineIndexOperand(idx, scope))
3056  return op->emitOpError(
3057  "index must be a valid dimension or symbol identifier");
3058  }
3059 
3060  return success();
3061 }
3062 
3063 LogicalResult AffineLoadOp::verify() {
3064  auto memrefType = getMemRefType();
3065  if (getType() != memrefType.getElementType())
3066  return emitOpError("result type must match element type of memref");
3067 
3068  if (failed(verifyMemoryOpIndexing(
3069  getOperation(),
3070  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3071  getMapOperands(), memrefType,
3072  /*numIndexOperands=*/getNumOperands() - 1)))
3073  return failure();
3074 
3075  return success();
3076 }
3077 
3078 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3079  MLIRContext *context) {
3080  results.add<SimplifyAffineOp<AffineLoadOp>>(context);
3081 }
3082 
3083 OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3084  /// load(memrefcast) -> load
3085  if (succeeded(memref::foldMemRefCast(*this)))
3086  return getResult();
3087 
3088  // Fold load from a global constant memref.
3089  auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3090  if (!getGlobalOp)
3091  return {};
3092  // Get to the memref.global defining the symbol.
3093  auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3094  if (!symbolTableOp)
3095  return {};
3096  auto global = dyn_cast_or_null<memref::GlobalOp>(
3097  SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
3098  if (!global)
3099  return {};
3100 
3101  // Check if the global memref is a constant.
3102  auto cstAttr =
3103  llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3104  if (!cstAttr)
3105  return {};
3106  // If it's a splat constant, we can fold irrespective of indices.
3107  if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3108  return splatAttr.getSplatValue<Attribute>();
3109  // Otherwise, we can fold only if we know the indices.
3110  if (!getAffineMap().isConstant())
3111  return {};
3112  auto indices = llvm::to_vector<4>(
3113  llvm::map_range(getAffineMap().getConstantResults(),
3114  [](int64_t v) -> uint64_t { return v; }));
3115  return cstAttr.getValues<Attribute>()[indices];
3116 }
3117 
3118 //===----------------------------------------------------------------------===//
3119 // AffineStoreOp
3120 //===----------------------------------------------------------------------===//
3121 
3122 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3123  Value valueToStore, Value memref, AffineMap map,
3124  ValueRange mapOperands) {
3125  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3126  result.addOperands(valueToStore);
3127  result.addOperands(memref);
3128  result.addOperands(mapOperands);
3129  result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3130 }
3131 
3132 // Use identity map.
3133 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3134  Value valueToStore, Value memref,
3135  ValueRange indices) {
3136  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3137  int64_t rank = memrefType.getRank();
3138  // Create identity map for memrefs with at least one dimension or () -> ()
3139  // for zero-dimensional memrefs.
3140  auto map =
3141  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3142  build(builder, result, valueToStore, memref, map, indices);
3143 }
3144 
3145 ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
3146  auto indexTy = parser.getBuilder().getIndexType();
3147 
3148  MemRefType type;
3149  OpAsmParser::UnresolvedOperand storeValueInfo;
3150  OpAsmParser::UnresolvedOperand memrefInfo;
3151  AffineMapAttr mapAttr;
3153  return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3154  parser.parseOperand(memrefInfo) ||
3155  parser.parseAffineMapOfSSAIds(
3156  mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3157  result.attributes) ||
3158  parser.parseOptionalAttrDict(result.attributes) ||
3159  parser.parseColonType(type) ||
3160  parser.resolveOperand(storeValueInfo, type.getElementType(),
3161  result.operands) ||
3162  parser.resolveOperand(memrefInfo, type, result.operands) ||
3163  parser.resolveOperands(mapOperands, indexTy, result.operands));
3164 }
3165 
3167  p << " " << getValueToStore();
3168  p << ", " << getMemRef() << '[';
3169  if (AffineMapAttr mapAttr =
3170  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3171  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3172  p << ']';
3173  p.printOptionalAttrDict((*this)->getAttrs(),
3174  /*elidedAttrs=*/{getMapAttrStrName()});
3175  p << " : " << getMemRefType();
3176 }
3177 
3178 LogicalResult AffineStoreOp::verify() {
3179  // The value to store must have the same type as memref element type.
3180  auto memrefType = getMemRefType();
3181  if (getValueToStore().getType() != memrefType.getElementType())
3182  return emitOpError(
3183  "value to store must have the same type as memref element type");
3184 
3185  if (failed(verifyMemoryOpIndexing(
3186  getOperation(),
3187  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3188  getMapOperands(), memrefType,
3189  /*numIndexOperands=*/getNumOperands() - 2)))
3190  return failure();
3191 
3192  return success();
3193 }
3194 
3195 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3196  MLIRContext *context) {
3197  results.add<SimplifyAffineOp<AffineStoreOp>>(context);
3198 }
3199 
3200 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3201  SmallVectorImpl<OpFoldResult> &results) {
3202  /// store(memrefcast) -> store
3203  return memref::foldMemRefCast(*this, getValueToStore());
3204 }
3205 
3206 //===----------------------------------------------------------------------===//
3207 // AffineMinMaxOpBase
3208 //===----------------------------------------------------------------------===//
3209 
3210 template <typename T>
3211 static LogicalResult verifyAffineMinMaxOp(T op) {
3212  // Verify that operand count matches affine map dimension and symbol count.
3213  if (op.getNumOperands() !=
3214  op.getMap().getNumDims() + op.getMap().getNumSymbols())
3215  return op.emitOpError(
3216  "operand count and affine map dimension and symbol count must match");
3217  return success();
3218 }
3219 
3220 template <typename T>
3221 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3222  p << ' ' << op->getAttr(T::getMapAttrStrName());
3223  auto operands = op.getOperands();
3224  unsigned numDims = op.getMap().getNumDims();
3225  p << '(' << operands.take_front(numDims) << ')';
3226 
3227  if (operands.size() != numDims)
3228  p << '[' << operands.drop_front(numDims) << ']';
3230  /*elidedAttrs=*/{T::getMapAttrStrName()});
3231 }
3232 
3233 template <typename T>
3234 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
3235  OperationState &result) {
3236  auto &builder = parser.getBuilder();
3237  auto indexType = builder.getIndexType();
3240  AffineMapAttr mapAttr;
3241  return failure(
3242  parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3243  result.attributes) ||
3244  parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) ||
3245  parser.parseOperandList(symInfos,
3247  parser.parseOptionalAttrDict(result.attributes) ||
3248  parser.resolveOperands(dimInfos, indexType, result.operands) ||
3249  parser.resolveOperands(symInfos, indexType, result.operands) ||
3250  parser.addTypeToList(indexType, result.types));
3251 }
3252 
3253 /// Fold an affine min or max operation with the given operands. The operand
3254 /// list may contain nulls, which are interpreted as the operand not being a
3255 /// constant.
3256 template <typename T>
3258  static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3259  "expected affine min or max op");
3260 
3261  // Fold the affine map.
3262  // TODO: Fold more cases:
3263  // min(some_affine, some_affine + constant, ...), etc.
3264  SmallVector<int64_t, 2> results;
3265  auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3266 
3267  if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3268  return op.getOperand(0);
3269 
3270  // If some of the map results are not constant, try changing the map in-place.
3271  if (results.empty()) {
3272  // If the map is the same, report that folding did not happen.
3273  if (foldedMap == op.getMap())
3274  return {};
3275  op->setAttr("map", AffineMapAttr::get(foldedMap));
3276  return op.getResult();
3277  }
3278 
3279  // Otherwise, completely fold the op into a constant.
3280  auto resultIt = std::is_same<T, AffineMinOp>::value
3281  ? llvm::min_element(results)
3282  : llvm::max_element(results);
3283  if (resultIt == results.end())
3284  return {};
3285  return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3286 }
3287 
3288 /// Remove duplicated expressions in affine min/max ops.
3289 template <typename T>
3292 
3293  LogicalResult matchAndRewrite(T affineOp,
3294  PatternRewriter &rewriter) const override {
3295  AffineMap oldMap = affineOp.getAffineMap();
3296 
3297  SmallVector<AffineExpr, 4> newExprs;
3298  for (AffineExpr expr : oldMap.getResults()) {
3299  // This is a linear scan over newExprs, but it should be fine given that
3300  // we typically just have a few expressions per op.
3301  if (!llvm::is_contained(newExprs, expr))
3302  newExprs.push_back(expr);
3303  }
3304 
3305  if (newExprs.size() == oldMap.getNumResults())
3306  return failure();
3307 
3308  auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
3309  newExprs, rewriter.getContext());
3310  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3311 
3312  return success();
3313  }
3314 };
3315 
3316 /// Merge an affine min/max op to its consumers if its consumer is also an
3317 /// affine min/max op.
3318 ///
3319 /// This pattern requires the producer affine min/max op is bound to a
3320 /// dimension/symbol that is used as a standalone expression in the consumer
3321 /// affine op's map.
3322 ///
3323 /// For example, a pattern like the following:
3324 ///
3325 /// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3326 /// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3327 ///
3328 /// Can be turned into:
3329 ///
3330 /// %1 = affine.min affine_map<
3331 /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3332 template <typename T>
3335 
3336  LogicalResult matchAndRewrite(T affineOp,
3337  PatternRewriter &rewriter) const override {
3338  AffineMap oldMap = affineOp.getAffineMap();
3339  ValueRange dimOperands =
3340  affineOp.getMapOperands().take_front(oldMap.getNumDims());
3341  ValueRange symOperands =
3342  affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3343 
3344  auto newDimOperands = llvm::to_vector<8>(dimOperands);
3345  auto newSymOperands = llvm::to_vector<8>(symOperands);
3346  SmallVector<AffineExpr, 4> newExprs;
3347  SmallVector<T, 4> producerOps;
3348 
3349  // Go over each expression to see whether it's a single dimension/symbol
3350  // with the corresponding operand which is the result of another affine
3351  // min/max op. If So it can be merged into this affine op.
3352  for (AffineExpr expr : oldMap.getResults()) {
3353  if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3354  Value symValue = symOperands[symExpr.getPosition()];
3355  if (auto producerOp = symValue.getDefiningOp<T>()) {
3356  producerOps.push_back(producerOp);
3357  continue;
3358  }
3359  } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3360  Value dimValue = dimOperands[dimExpr.getPosition()];
3361  if (auto producerOp = dimValue.getDefiningOp<T>()) {
3362  producerOps.push_back(producerOp);
3363  continue;
3364  }
3365  }
3366  // For the above cases we will remove the expression by merging the
3367  // producer affine min/max's affine expressions. Otherwise we need to
3368  // keep the existing expression.
3369  newExprs.push_back(expr);
3370  }
3371 
3372  if (producerOps.empty())
3373  return failure();
3374 
3375  unsigned numUsedDims = oldMap.getNumDims();
3376  unsigned numUsedSyms = oldMap.getNumSymbols();
3377 
3378  // Now go over all producer affine ops and merge their expressions.
3379  for (T producerOp : producerOps) {
3380  AffineMap producerMap = producerOp.getAffineMap();
3381  unsigned numProducerDims = producerMap.getNumDims();
3382  unsigned numProducerSyms = producerMap.getNumSymbols();
3383 
3384  // Collect all dimension/symbol values.
3385  ValueRange dimValues =
3386  producerOp.getMapOperands().take_front(numProducerDims);
3387  ValueRange symValues =
3388  producerOp.getMapOperands().take_back(numProducerSyms);
3389  newDimOperands.append(dimValues.begin(), dimValues.end());
3390  newSymOperands.append(symValues.begin(), symValues.end());
3391 
3392  // For expressions we need to shift to avoid overlap.
3393  for (AffineExpr expr : producerMap.getResults()) {
3394  newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3395  .shiftSymbols(numProducerSyms, numUsedSyms));
3396  }
3397 
3398  numUsedDims += numProducerDims;
3399  numUsedSyms += numProducerSyms;
3400  }
3401 
3402  auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
3403  rewriter.getContext());
3404  auto newOperands =
3405  llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3406  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3407 
3408  return success();
3409  }
3410 };
3411 
3412 /// Canonicalize the result expression order of an affine map and return success
3413 /// if the order changed.
3414 ///
3415 /// The function flattens the map's affine expressions to coefficient arrays and
3416 /// sorts them in lexicographic order. A coefficient array contains a multiplier
3417 /// for every dimension/symbol and a constant term. The canonicalization fails
3418 /// if a result expression is not pure or if the flattening requires local
3419 /// variables that, unlike dimensions and symbols, have no global order.
3420 static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
3421  SmallVector<SmallVector<int64_t>> flattenedExprs;
3422  for (const AffineExpr &resultExpr : map.getResults()) {
3423  // Fail if the expression is not pure.
3424  if (!resultExpr.isPureAffine())
3425  return failure();
3426 
3427  SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3428  auto flattenResult = flattener.walkPostOrder(resultExpr);
3429  if (failed(flattenResult))
3430  return failure();
3431 
3432  // Fail if the flattened expression has local variables.
3433  if (flattener.operandExprStack.back().size() !=
3434  map.getNumDims() + map.getNumSymbols() + 1)
3435  return failure();
3436 
3437  flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3438  flattener.operandExprStack.back().end());
3439  }
3440 
3441  // Fail if sorting is not necessary.
3442  if (llvm::is_sorted(flattenedExprs))
3443  return failure();
3444 
3445  // Reorder the result expressions according to their flattened form.
3446  SmallVector<unsigned> resultPermutation =
3447  llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults()));
3448  llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3449  return flattenedExprs[lhs] < flattenedExprs[rhs];
3450  });
3451  SmallVector<AffineExpr> newExprs;
3452  for (unsigned idx : resultPermutation)
3453  newExprs.push_back(map.getResult(idx));
3454 
3455  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs,
3456  map.getContext());
3457  return success();
3458 }
3459 
3460 /// Canonicalize the affine map result expression order of an affine min/max
3461 /// operation.
3462 ///
3463 /// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3464 /// expressions and replaces the operation if the order changed.
3465 ///
3466 /// For example, the following operation:
3467 ///
3468 /// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3469 ///
3470 /// Turns into:
3471 ///
3472 /// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3473 template <typename T>
3476 
3477  LogicalResult matchAndRewrite(T affineOp,
3478  PatternRewriter &rewriter) const override {
3479  AffineMap map = affineOp.getAffineMap();
3480  if (failed(canonicalizeMapExprAndTermOrder(map)))
3481  return failure();
3482  rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3483  return success();
3484  }
3485 };
3486 
3487 template <typename T>
3490 
3491  LogicalResult matchAndRewrite(T affineOp,
3492  PatternRewriter &rewriter) const override {
3493  if (affineOp.getMap().getNumResults() != 1)
3494  return failure();
3495  rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3496  affineOp.getOperands());
3497  return success();
3498  }
3499 };
3500 
3501 //===----------------------------------------------------------------------===//
3502 // AffineMinOp
3503 //===----------------------------------------------------------------------===//
3504 //
3505 // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3506 //
3507 
3508 OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3509  return foldMinMaxOp(*this, adaptor.getOperands());
3510 }
3511 
3512 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3513  MLIRContext *context) {
3516  MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3518  context);
3519 }
3520 
3521 LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); }
3522 
3523 ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3524  return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3525 }
3526 
3527 void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3528 
3529 //===----------------------------------------------------------------------===//
3530 // AffineMaxOp
3531 //===----------------------------------------------------------------------===//
3532 //
3533 // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3534 //
3535 
3536 OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3537  return foldMinMaxOp(*this, adaptor.getOperands());
3538 }
3539 
3540 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3541  MLIRContext *context) {
3544  MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3546  context);
3547 }
3548 
3549 LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); }
3550 
3551 ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3552  return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3553 }
3554 
3555 void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3556 
3557 //===----------------------------------------------------------------------===//
3558 // AffinePrefetchOp
3559 //===----------------------------------------------------------------------===//
3560 
3561 //
3562 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3563 //
3564 ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3565  OperationState &result) {
3566  auto &builder = parser.getBuilder();
3567  auto indexTy = builder.getIndexType();
3568 
3569  MemRefType type;
3570  OpAsmParser::UnresolvedOperand memrefInfo;
3571  IntegerAttr hintInfo;
3572  auto i32Type = parser.getBuilder().getIntegerType(32);
3573  StringRef readOrWrite, cacheType;
3574 
3575  AffineMapAttr mapAttr;
3577  if (parser.parseOperand(memrefInfo) ||
3578  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3579  AffinePrefetchOp::getMapAttrStrName(),
3580  result.attributes) ||
3581  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3582  parser.parseComma() || parser.parseKeyword("locality") ||
3583  parser.parseLess() ||
3584  parser.parseAttribute(hintInfo, i32Type,
3585  AffinePrefetchOp::getLocalityHintAttrStrName(),
3586  result.attributes) ||
3587  parser.parseGreater() || parser.parseComma() ||
3588  parser.parseKeyword(&cacheType) ||
3589  parser.parseOptionalAttrDict(result.attributes) ||
3590  parser.parseColonType(type) ||
3591  parser.resolveOperand(memrefInfo, type, result.operands) ||
3592  parser.resolveOperands(mapOperands, indexTy, result.operands))
3593  return failure();
3594 
3595  if (readOrWrite != "read" && readOrWrite != "write")
3596  return parser.emitError(parser.getNameLoc(),
3597  "rw specifier has to be 'read' or 'write'");
3598  result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3599  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
3600 
3601  if (cacheType != "data" && cacheType != "instr")
3602  return parser.emitError(parser.getNameLoc(),
3603  "cache type has to be 'data' or 'instr'");
3604 
3605  result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3606  parser.getBuilder().getBoolAttr(cacheType == "data"));
3607 
3608  return success();
3609 }
3610 
3612  p << " " << getMemref() << '[';
3613  AffineMapAttr mapAttr =
3614  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3615  if (mapAttr)
3616  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3617  p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3618  << "locality<" << getLocalityHint() << ">, "
3619  << (getIsDataCache() ? "data" : "instr");
3621  (*this)->getAttrs(),
3622  /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3623  getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3624  p << " : " << getMemRefType();
3625 }
3626 
3627 LogicalResult AffinePrefetchOp::verify() {
3628  auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3629  if (mapAttr) {
3630  AffineMap map = mapAttr.getValue();
3631  if (map.getNumResults() != getMemRefType().getRank())
3632  return emitOpError("affine.prefetch affine map num results must equal"
3633  " memref rank");
3634  if (map.getNumInputs() + 1 != getNumOperands())
3635  return emitOpError("too few operands");
3636  } else {
3637  if (getNumOperands() != 1)
3638  return emitOpError("too few operands");
3639  }
3640 
3641  Region *scope = getAffineScope(*this);
3642  for (auto idx : getMapOperands()) {
3643  if (!isValidAffineIndexOperand(idx, scope))
3644  return emitOpError(
3645  "index must be a valid dimension or symbol identifier");
3646  }
3647  return success();
3648 }
3649 
3650 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3651  MLIRContext *context) {
3652  // prefetch(memrefcast) -> prefetch
3653  results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3654 }
3655 
3656 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3657  SmallVectorImpl<OpFoldResult> &results) {
3658  /// prefetch(memrefcast) -> prefetch
3659  return memref::foldMemRefCast(*this);
3660 }
3661 
3662 //===----------------------------------------------------------------------===//
3663 // AffineParallelOp
3664 //===----------------------------------------------------------------------===//
3665 
3666 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3667  TypeRange resultTypes,
3668  ArrayRef<arith::AtomicRMWKind> reductions,
3669  ArrayRef<int64_t> ranges) {
3670  SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
3671  auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3672  return builder.getConstantAffineMap(value);
3673  }));
3674  SmallVector<int64_t> steps(ranges.size(), 1);
3675  build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
3676  /*ubArgs=*/{}, steps);
3677 }
3678 
3679 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3680  TypeRange resultTypes,
3681  ArrayRef<arith::AtomicRMWKind> reductions,
3682  ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
3683  ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
3684  ArrayRef<int64_t> steps) {
3685  assert(llvm::all_of(lbMaps,
3686  [lbMaps](AffineMap m) {
3687  return m.getNumDims() == lbMaps[0].getNumDims() &&
3688  m.getNumSymbols() == lbMaps[0].getNumSymbols();
3689  }) &&
3690  "expected all lower bounds maps to have the same number of dimensions "
3691  "and symbols");
3692  assert(llvm::all_of(ubMaps,
3693  [ubMaps](AffineMap m) {
3694  return m.getNumDims() == ubMaps[0].getNumDims() &&
3695  m.getNumSymbols() == ubMaps[0].getNumSymbols();
3696  }) &&
3697  "expected all upper bounds maps to have the same number of dimensions "
3698  "and symbols");
3699  assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3700  "expected lower bound maps to have as many inputs as lower bound "
3701  "operands");
3702  assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3703  "expected upper bound maps to have as many inputs as upper bound "
3704  "operands");
3705 
3706  OpBuilder::InsertionGuard guard(builder);
3707  result.addTypes(resultTypes);
3708 
3709  // Convert the reductions to integer attributes.
3710  SmallVector<Attribute, 4> reductionAttrs;
3711  for (arith::AtomicRMWKind reduction : reductions)
3712  reductionAttrs.push_back(
3713  builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
3714  result.addAttribute(getReductionsAttrStrName(),
3715  builder.getArrayAttr(reductionAttrs));
3716 
3717  // Concatenates maps defined in the same input space (same dimensions and
3718  // symbols), assumes there is at least one map.
3719  auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
3720  SmallVectorImpl<int32_t> &groups) {
3721  if (maps.empty())
3722  return AffineMap::get(builder.getContext());
3724  groups.reserve(groups.size() + maps.size());
3725  exprs.reserve(maps.size());
3726  for (AffineMap m : maps) {
3727  llvm::append_range(exprs, m.getResults());
3728  groups.push_back(m.getNumResults());
3729  }
3730  return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3731  maps[0].getContext());
3732  };
3733 
3734  // Set up the bounds.
3735  SmallVector<int32_t> lbGroups, ubGroups;
3736  AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3737  AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3738  result.addAttribute(getLowerBoundsMapAttrStrName(),
3739  AffineMapAttr::get(lbMap));
3740  result.addAttribute(getLowerBoundsGroupsAttrStrName(),
3741  builder.getI32TensorAttr(lbGroups));
3742  result.addAttribute(getUpperBoundsMapAttrStrName(),
3743  AffineMapAttr::get(ubMap));
3744  result.addAttribute(getUpperBoundsGroupsAttrStrName(),
3745  builder.getI32TensorAttr(ubGroups));
3746  result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
3747  result.addOperands(lbArgs);
3748  result.addOperands(ubArgs);
3749 
3750  // Create a region and a block for the body.
3751  auto *bodyRegion = result.addRegion();
3752  Block *body = builder.createBlock(bodyRegion);
3753 
3754  // Add all the block arguments.
3755  for (unsigned i = 0, e = steps.size(); i < e; ++i)
3756  body->addArgument(IndexType::get(builder.getContext()), result.location);
3757  if (resultTypes.empty())
3758  ensureTerminator(*bodyRegion, builder, result.location);
3759 }
3760 
3761 SmallVector<Region *> AffineParallelOp::getLoopRegions() {
3762  return {&getRegion()};
3763 }
3764 
3765 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
3766 
3767 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3768  return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3769 }
3770 
3771 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3772  return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3773 }
3774 
3775 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
3776  auto values = getLowerBoundsGroups().getValues<int32_t>();
3777  unsigned start = 0;
3778  for (unsigned i = 0; i < pos; ++i)
3779  start += values[i];
3780  return getLowerBoundsMap().getSliceMap(start, values[pos]);
3781 }
3782 
3783 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
3784  auto values = getUpperBoundsGroups().getValues<int32_t>();
3785  unsigned start = 0;
3786  for (unsigned i = 0; i < pos; ++i)
3787  start += values[i];
3788  return getUpperBoundsMap().getSliceMap(start, values[pos]);
3789 }
3790 
3791 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
3792  return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3793 }
3794 
3795 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
3796  return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3797 }
3798 
3799 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3800  if (hasMinMaxBounds())
3801  return std::nullopt;
3802 
3803  // Try to convert all the ranges to constant expressions.
3805  AffineValueMap rangesValueMap;
3806  AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3807  &rangesValueMap);
3808  out.reserve(rangesValueMap.getNumResults());
3809  for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
3810  auto expr = rangesValueMap.getResult(i);
3811  auto cst = dyn_cast<AffineConstantExpr>(expr);
3812  if (!cst)
3813  return std::nullopt;
3814  out.push_back(cst.getValue());
3815  }
3816  return out;
3817 }
3818 
3819 Block *AffineParallelOp::getBody() { return &getRegion().front(); }
3820 
3821 OpBuilder AffineParallelOp::getBodyBuilder() {
3822  return OpBuilder(getBody(), std::prev(getBody()->end()));
3823 }
3824 
3825 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
3826  assert(lbOperands.size() == map.getNumInputs() &&
3827  "operands to map must match number of inputs");
3828 
3829  auto ubOperands = getUpperBoundsOperands();
3830 
3831  SmallVector<Value, 4> newOperands(lbOperands);
3832  newOperands.append(ubOperands.begin(), ubOperands.end());
3833  (*this)->setOperands(newOperands);
3834 
3835  setLowerBoundsMapAttr(AffineMapAttr::get(map));
3836 }
3837 
3838 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
3839  assert(ubOperands.size() == map.getNumInputs() &&
3840  "operands to map must match number of inputs");
3841 
3842  SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
3843  newOperands.append(ubOperands.begin(), ubOperands.end());
3844  (*this)->setOperands(newOperands);
3845 
3846  setUpperBoundsMapAttr(AffineMapAttr::get(map));
3847 }
3848 
3849 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
3850  setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3851 }
3852 
3853 // check whether resultType match op or not in affine.parallel
3854 static bool isResultTypeMatchAtomicRMWKind(Type resultType,
3855  arith::AtomicRMWKind op) {
3856  switch (op) {
3857  case arith::AtomicRMWKind::addf:
3858  return isa<FloatType>(resultType);
3859  case arith::AtomicRMWKind::addi:
3860  return isa<IntegerType>(resultType);
3861  case arith::AtomicRMWKind::assign:
3862  return true;
3863  case arith::AtomicRMWKind::mulf:
3864  return isa<FloatType>(resultType);
3865  case arith::AtomicRMWKind::muli:
3866  return isa<IntegerType>(resultType);
3867  case arith::AtomicRMWKind::maximumf:
3868  return isa<FloatType>(resultType);
3869  case arith::AtomicRMWKind::minimumf:
3870  return isa<FloatType>(resultType);
3871  case arith::AtomicRMWKind::maxs: {
3872  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3873  return intType && intType.isSigned();
3874  }
3875  case arith::AtomicRMWKind::mins: {
3876  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3877  return intType && intType.isSigned();
3878  }
3879  case arith::AtomicRMWKind::maxu: {
3880  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3881  return intType && intType.isUnsigned();
3882  }
3883  case arith::AtomicRMWKind::minu: {
3884  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3885  return intType && intType.isUnsigned();
3886  }
3887  case arith::AtomicRMWKind::ori:
3888  return isa<IntegerType>(resultType);
3889  case arith::AtomicRMWKind::andi:
3890  return isa<IntegerType>(resultType);
3891  default:
3892  return false;
3893  }
3894 }
3895 
3896 LogicalResult AffineParallelOp::verify() {
3897  auto numDims = getNumDims();
3898  if (getLowerBoundsGroups().getNumElements() != numDims ||
3899  getUpperBoundsGroups().getNumElements() != numDims ||
3900  getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3901  return emitOpError() << "the number of region arguments ("
3902  << getBody()->getNumArguments()
3903  << ") and the number of map groups for lower ("
3904  << getLowerBoundsGroups().getNumElements()
3905  << ") and upper bound ("
3906  << getUpperBoundsGroups().getNumElements()
3907  << "), and the number of steps (" << getSteps().size()
3908  << ") must all match";
3909  }
3910 
3911  unsigned expectedNumLBResults = 0;
3912  for (APInt v : getLowerBoundsGroups())
3913  expectedNumLBResults += v.getZExtValue();
3914  if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
3915  return emitOpError() << "expected lower bounds map to have "
3916  << expectedNumLBResults << " results";
3917  unsigned expectedNumUBResults = 0;
3918  for (APInt v : getUpperBoundsGroups())
3919  expectedNumUBResults += v.getZExtValue();
3920  if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
3921  return emitOpError() << "expected upper bounds map to have "
3922  << expectedNumUBResults << " results";
3923 
3924  if (getReductions().size() != getNumResults())
3925  return emitOpError("a reduction must be specified for each output");
3926 
3927  // Verify reduction ops are all valid and each result type matches reduction
3928  // ops
3929  for (auto it : llvm::enumerate((getReductions()))) {
3930  Attribute attr = it.value();
3931  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
3932  if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
3933  return emitOpError("invalid reduction attribute");
3934  auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
3935  if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
3936  return emitOpError("result type cannot match reduction attribute");
3937  }
3938 
3939  // Verify that the bound operands are valid dimension/symbols.
3940  /// Lower bounds.
3941  if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
3942  getLowerBoundsMap().getNumDims())))
3943  return failure();
3944  /// Upper bounds.
3945  if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
3946  getUpperBoundsMap().getNumDims())))
3947  return failure();
3948  return success();
3949 }
3950 
3951 LogicalResult AffineValueMap::canonicalize() {
3952  SmallVector<Value, 4> newOperands{operands};
3953  auto newMap = getAffineMap();
3954  composeAffineMapAndOperands(&newMap, &newOperands);
3955  if (newMap == getAffineMap() && newOperands == operands)
3956  return failure();
3957  reset(newMap, newOperands);
3958  return success();
3959 }
3960 
3961 /// Canonicalize the bounds of the given loop.
3962 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
3963  AffineValueMap lb = op.getLowerBoundsValueMap();
3964  bool lbCanonicalized = succeeded(lb.canonicalize());
3965 
3966  AffineValueMap ub = op.getUpperBoundsValueMap();
3967  bool ubCanonicalized = succeeded(ub.canonicalize());
3968 
3969  // Any canonicalization change always leads to updated map(s).
3970  if (!lbCanonicalized && !ubCanonicalized)
3971  return failure();
3972 
3973  if (lbCanonicalized)
3974  op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
3975  if (ubCanonicalized)
3976  op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
3977 
3978  return success();
3979 }
3980 
3981 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
3982  SmallVectorImpl<OpFoldResult> &results) {
3983  return canonicalizeLoopBounds(*this);
3984 }
3985 
3986 /// Prints a lower(upper) bound of an affine parallel loop with max(min)
3987 /// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
3988 /// identifies which of the those expressions form max/min groups. `operands`
3989 /// are the SSA values of dimensions and symbols and `keyword` is either "min"
3990 /// or "max".
3991 static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
3992  DenseIntElementsAttr group, ValueRange operands,
3993  StringRef keyword) {
3994  AffineMap map = mapAttr.getValue();
3995  unsigned numDims = map.getNumDims();
3996  ValueRange dimOperands = operands.take_front(numDims);
3997  ValueRange symOperands = operands.drop_front(numDims);
3998  unsigned start = 0;
3999  for (llvm::APInt groupSize : group) {
4000  if (start != 0)
4001  p << ", ";
4002 
4003  unsigned size = groupSize.getZExtValue();
4004  if (size == 1) {
4005  p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
4006  ++start;
4007  } else {
4008  p << keyword << '(';
4009  AffineMap submap = map.getSliceMap(start, size);
4010  p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
4011  p << ')';
4012  start += size;
4013  }
4014  }
4015 }
4016 
4018  p << " (" << getBody()->getArguments() << ") = (";
4019  printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
4020  getLowerBoundsOperands(), "max");
4021  p << ") to (";
4022  printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
4023  getUpperBoundsOperands(), "min");
4024  p << ')';
4025  SmallVector<int64_t, 8> steps = getSteps();
4026  bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4027  if (!elideSteps) {
4028  p << " step (";
4029  llvm::interleaveComma(steps, p);
4030  p << ')';
4031  }
4032  if (getNumResults()) {
4033  p << " reduce (";
4034  llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4035  arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4036  llvm::cast<IntegerAttr>(attr).getInt());
4037  p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4038  });
4039  p << ") -> (" << getResultTypes() << ")";
4040  }
4041 
4042  p << ' ';
4043  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4044  /*printBlockTerminators=*/getNumResults());
4046  (*this)->getAttrs(),
4047  /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4048  AffineParallelOp::getLowerBoundsMapAttrStrName(),
4049  AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4050  AffineParallelOp::getUpperBoundsMapAttrStrName(),
4051  AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4052  AffineParallelOp::getStepsAttrStrName()});
4053 }
4054 
4055 /// Given a list of lists of parsed operands, populates `uniqueOperands` with
4056 /// unique operands. Also populates `replacements with affine expressions of
4057 /// `kind` that can be used to update affine maps previously accepting a
4058 /// `operands` to accept `uniqueOperands` instead.
4060  OpAsmParser &parser,
4062  SmallVectorImpl<Value> &uniqueOperands,
4063  SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
4064  assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
4065  "expected operands to be dim or symbol expression");
4066 
4067  Type indexType = parser.getBuilder().getIndexType();
4068  for (const auto &list : operands) {
4069  SmallVector<Value> valueOperands;
4070  if (parser.resolveOperands(list, indexType, valueOperands))
4071  return failure();
4072  for (Value operand : valueOperands) {
4073  unsigned pos = std::distance(uniqueOperands.begin(),
4074  llvm::find(uniqueOperands, operand));
4075  if (pos == uniqueOperands.size())
4076  uniqueOperands.push_back(operand);
4077  replacements.push_back(
4078  kind == AffineExprKind::DimId
4079  ? getAffineDimExpr(pos, parser.getContext())
4080  : getAffineSymbolExpr(pos, parser.getContext()));
4081  }
4082  }
4083  return success();
4084 }
4085 
4086 namespace {
4087 enum class MinMaxKind { Min, Max };
4088 } // namespace
4089 
4090 /// Parses an affine map that can contain a min/max for groups of its results,
4091 /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4092 /// `result` attributes with the map (flat list of expressions) and the grouping
4093 /// (list of integers that specify how many expressions to put into each
4094 /// min/max) attributes. Deduplicates repeated operands.
4095 ///
4096 /// parallel-bound ::= `(` parallel-group-list `)`
4097 /// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
4098 /// parallel-group ::= simple-group | min-max-group
4099 /// simple-group ::= expr-of-ssa-ids
4100 /// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4101 /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4102 ///
4103 /// Examples:
4104 /// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4105 /// (%0, max(%1 - 2 * %2))
4106 static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
4107  OperationState &result,
4108  MinMaxKind kind) {
4109  // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4110  // see: https://reviews.llvm.org/D134227#3821753
4111  const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4112 
4113  StringRef mapName = kind == MinMaxKind::Min
4114  ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4115  : AffineParallelOp::getLowerBoundsMapAttrStrName();
4116  StringRef groupsName =
4117  kind == MinMaxKind::Min
4118  ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4119  : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4120 
4121  if (failed(parser.parseLParen()))
4122  return failure();
4123 
4124  if (succeeded(parser.parseOptionalRParen())) {
4125  result.addAttribute(
4126  mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4127  result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
4128  return success();
4129  }
4130 
4131  SmallVector<AffineExpr> flatExprs;
4134  SmallVector<int32_t> numMapsPerGroup;
4136  auto parseOperands = [&]() {
4137  if (succeeded(parser.parseOptionalKeyword(
4138  kind == MinMaxKind::Min ? "min" : "max"))) {
4139  mapOperands.clear();
4140  AffineMapAttr map;
4141  if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName,
4142  result.attributes,
4144  return failure();
4145  result.attributes.erase(tmpAttrStrName);
4146  llvm::append_range(flatExprs, map.getValue().getResults());
4147  auto operandsRef = llvm::ArrayRef(mapOperands);
4148  auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4149  SmallVector<OpAsmParser::UnresolvedOperand> dims(dimsRef.begin(),
4150  dimsRef.end());
4151  auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4152  SmallVector<OpAsmParser::UnresolvedOperand> syms(symsRef.begin(),
4153  symsRef.end());
4154  flatDimOperands.append(map.getValue().getNumResults(), dims);
4155  flatSymOperands.append(map.getValue().getNumResults(), syms);
4156  numMapsPerGroup.push_back(map.getValue().getNumResults());
4157  } else {
4158  if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
4159  flatSymOperands.emplace_back(),
4160  flatExprs.emplace_back())))
4161  return failure();
4162  numMapsPerGroup.push_back(1);
4163  }
4164  return success();
4165  };
4166  if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
4167  return failure();
4168 
4169  unsigned totalNumDims = 0;
4170  unsigned totalNumSyms = 0;
4171  for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4172  unsigned numDims = flatDimOperands[i].size();
4173  unsigned numSyms = flatSymOperands[i].size();
4174  flatExprs[i] = flatExprs[i]
4175  .shiftDims(numDims, totalNumDims)
4176  .shiftSymbols(numSyms, totalNumSyms);
4177  totalNumDims += numDims;
4178  totalNumSyms += numSyms;
4179  }
4180 
4181  // Deduplicate map operands.
4182  SmallVector<Value> dimOperands, symOperands;
4183  SmallVector<AffineExpr> dimRplacements, symRepacements;
4184  if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
4185  dimRplacements, AffineExprKind::DimId) ||
4186  deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
4187  symRepacements, AffineExprKind::SymbolId))
4188  return failure();
4189 
4190  result.operands.append(dimOperands.begin(), dimOperands.end());
4191  result.operands.append(symOperands.begin(), symOperands.end());
4192 
4193  Builder &builder = parser.getBuilder();
4194  auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4195  parser.getContext());
4196  flatMap = flatMap.replaceDimsAndSymbols(
4197  dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4198 
4199  result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4200  result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
4201  return success();
4202 }
4203 
4204 //
4205 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4206 // `to` parallel-bound steps? region attr-dict?
4207 // steps ::= `steps` `(` integer-literals `)`
4208 //
4209 ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4210  OperationState &result) {
4211  auto &builder = parser.getBuilder();
4212  auto indexType = builder.getIndexType();
4215  parser.parseEqual() ||
4216  parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
4217  parser.parseKeyword("to") ||
4218  parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
4219  return failure();
4220 
4221  AffineMapAttr stepsMapAttr;
4222  NamedAttrList stepsAttrs;
4224  if (failed(parser.parseOptionalKeyword("step"))) {
4225  SmallVector<int64_t, 4> steps(ivs.size(), 1);
4226  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4227  builder.getI64ArrayAttr(steps));
4228  } else {
4229  if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
4230  AffineParallelOp::getStepsAttrStrName(),
4231  stepsAttrs,
4233  return failure();
4234 
4235  // Convert steps from an AffineMap into an I64ArrayAttr.
4237  auto stepsMap = stepsMapAttr.getValue();
4238  for (const auto &result : stepsMap.getResults()) {
4239  auto constExpr = dyn_cast<AffineConstantExpr>(result);
4240  if (!constExpr)
4241  return parser.emitError(parser.getNameLoc(),
4242  "steps must be constant integers");
4243  steps.push_back(constExpr.getValue());
4244  }
4245  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4246  builder.getI64ArrayAttr(steps));
4247  }
4248 
4249  // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4250  // quoted strings are a member of the enum AtomicRMWKind.
4251  SmallVector<Attribute, 4> reductions;
4252  if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4253  if (parser.parseLParen())
4254  return failure();
4255  auto parseAttributes = [&]() -> ParseResult {
4256  // Parse a single quoted string via the attribute parsing, and then
4257  // verify it is a member of the enum and convert to it's integer
4258  // representation.
4259  StringAttr attrVal;
4260  NamedAttrList attrStorage;
4261  auto loc = parser.getCurrentLocation();
4262  if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
4263  attrStorage))
4264  return failure();
4265  std::optional<arith::AtomicRMWKind> reduction =
4266  arith::symbolizeAtomicRMWKind(attrVal.getValue());
4267  if (!reduction)
4268  return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4269  reductions.push_back(
4270  builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4271  // While we keep getting commas, keep parsing.
4272  return success();
4273  };
4274  if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
4275  return failure();
4276  }
4277  result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4278  builder.getArrayAttr(reductions));
4279 
4280  // Parse return types of reductions (if any)
4281  if (parser.parseOptionalArrowTypeList(result.types))
4282  return failure();
4283 
4284  // Now parse the body.
4285  Region *body = result.addRegion();
4286  for (auto &iv : ivs)
4287  iv.type = indexType;
4288  if (parser.parseRegion(*body, ivs) ||
4289  parser.parseOptionalAttrDict(result.attributes))
4290  return failure();
4291 
4292  // Add a terminator if none was parsed.
4293  AffineParallelOp::ensureTerminator(*body, builder, result.location);
4294  return success();
4295 }
4296 
4297 //===----------------------------------------------------------------------===//
4298 // AffineYieldOp
4299 //===----------------------------------------------------------------------===//
4300 
4301 LogicalResult AffineYieldOp::verify() {
4302  auto *parentOp = (*this)->getParentOp();
4303  auto results = parentOp->getResults();
4304  auto operands = getOperands();
4305 
4306  if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4307  return emitOpError() << "only terminates affine.if/for/parallel regions";
4308  if (parentOp->getNumResults() != getNumOperands())
4309  return emitOpError() << "parent of yield must have same number of "
4310  "results as the yield operands";
4311  for (auto it : llvm::zip(results, operands)) {
4312  if (std::get<0>(it).getType() != std::get<1>(it).getType())
4313  return emitOpError() << "types mismatch between yield op and its parent";
4314  }
4315 
4316  return success();
4317 }
4318 
4319 //===----------------------------------------------------------------------===//
4320 // AffineVectorLoadOp
4321 //===----------------------------------------------------------------------===//
4322 
4323 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4324  VectorType resultType, AffineMap map,
4325  ValueRange operands) {
4326  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4327  result.addOperands(operands);
4328  if (map)
4329  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4330  result.types.push_back(resultType);
4331 }
4332 
4333 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4334  VectorType resultType, Value memref,
4335  AffineMap map, ValueRange mapOperands) {
4336  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4337  result.addOperands(memref);
4338  result.addOperands(mapOperands);
4339  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4340  result.types.push_back(resultType);
4341 }
4342 
4343 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4344  VectorType resultType, Value memref,
4345  ValueRange indices) {
4346  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4347  int64_t rank = memrefType.getRank();
4348  // Create identity map for memrefs with at least one dimension or () -> ()
4349  // for zero-dimensional memrefs.
4350  auto map =
4351  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4352  build(builder, result, resultType, memref, map, indices);
4353 }
4354 
4355 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4356  MLIRContext *context) {
4357  results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4358 }
4359 
4360 ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4361  OperationState &result) {
4362  auto &builder = parser.getBuilder();
4363  auto indexTy = builder.getIndexType();
4364 
4365  MemRefType memrefType;
4366  VectorType resultType;
4367  OpAsmParser::UnresolvedOperand memrefInfo;
4368  AffineMapAttr mapAttr;
4370  return failure(
4371  parser.parseOperand(memrefInfo) ||
4372  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4373  AffineVectorLoadOp::getMapAttrStrName(),
4374  result.attributes) ||
4375  parser.parseOptionalAttrDict(result.attributes) ||
4376  parser.parseColonType(memrefType) || parser.parseComma() ||
4377  parser.parseType(resultType) ||
4378  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4379  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
4380  parser.addTypeToList(resultType, result.types));
4381 }
4382 
4384  p << " " << getMemRef() << '[';
4385  if (AffineMapAttr mapAttr =
4386  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4387  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4388  p << ']';
4389  p.printOptionalAttrDict((*this)->getAttrs(),
4390  /*elidedAttrs=*/{getMapAttrStrName()});
4391  p << " : " << getMemRefType() << ", " << getType();
4392 }
4393 
4394 /// Verify common invariants of affine.vector_load and affine.vector_store.
4395 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4396  VectorType vectorType) {
4397  // Check that memref and vector element types match.
4398  if (memrefType.getElementType() != vectorType.getElementType())
4399  return op->emitOpError(
4400  "requires memref and vector types of the same elemental type");
4401  return success();
4402 }
4403 
4404 LogicalResult AffineVectorLoadOp::verify() {
4405  MemRefType memrefType = getMemRefType();
4406  if (failed(verifyMemoryOpIndexing(
4407  getOperation(),
4408  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4409  getMapOperands(), memrefType,
4410  /*numIndexOperands=*/getNumOperands() - 1)))
4411  return failure();
4412 
4413  if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
4414  return failure();
4415 
4416  return success();
4417 }
4418 
4419 //===----------------------------------------------------------------------===//
4420 // AffineVectorStoreOp
4421 //===----------------------------------------------------------------------===//
4422 
4423 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4424  Value valueToStore, Value memref, AffineMap map,
4425  ValueRange mapOperands) {
4426  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4427  result.addOperands(valueToStore);
4428  result.addOperands(memref);
4429  result.addOperands(mapOperands);
4430  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4431 }
4432 
4433 // Use identity map.
4434 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4435  Value valueToStore, Value memref,
4436  ValueRange indices) {
4437  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4438  int64_t rank = memrefType.getRank();
4439  // Create identity map for memrefs with at least one dimension or () -> ()
4440  // for zero-dimensional memrefs.
4441  auto map =
4442  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4443  build(builder, result, valueToStore, memref, map, indices);
4444 }
4445 void AffineVectorStoreOp::getCanonicalizationPatterns(
4446  RewritePatternSet &results, MLIRContext *context) {
4447  results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4448 }
4449 
4450 ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4451  OperationState &result) {
4452  auto indexTy = parser.getBuilder().getIndexType();
4453 
4454  MemRefType memrefType;
4455  VectorType resultType;
4456  OpAsmParser::UnresolvedOperand storeValueInfo;
4457  OpAsmParser::UnresolvedOperand memrefInfo;
4458  AffineMapAttr mapAttr;
4460  return failure(
4461  parser.parseOperand(storeValueInfo) || parser.parseComma() ||
4462  parser.parseOperand(memrefInfo) ||
4463  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4464  AffineVectorStoreOp::getMapAttrStrName(),
4465  result.attributes) ||
4466  parser.parseOptionalAttrDict(result.attributes) ||
4467  parser.parseColonType(memrefType) || parser.parseComma() ||
4468  parser.parseType(resultType) ||
4469  parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
4470  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4471  parser.resolveOperands(mapOperands, indexTy, result.operands));
4472 }
4473 
4475  p << " " << getValueToStore();
4476  p << ", " << getMemRef() << '[';
4477  if (AffineMapAttr mapAttr =
4478  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4479  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4480  p << ']';
4481  p.printOptionalAttrDict((*this)->getAttrs(),
4482  /*elidedAttrs=*/{getMapAttrStrName()});
4483  p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4484 }
4485 
4486 LogicalResult AffineVectorStoreOp::verify() {
4487  MemRefType memrefType = getMemRefType();
4488  if (failed(verifyMemoryOpIndexing(
4489  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4490  getMapOperands(), memrefType,
4491  /*numIndexOperands=*/getNumOperands() - 2)))
4492  return failure();
4493 
4494  if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4495  return failure();
4496 
4497  return success();
4498 }
4499 
4500 //===----------------------------------------------------------------------===//
4501 // DelinearizeIndexOp
4502 //===----------------------------------------------------------------------===//
4503 
4504 LogicalResult AffineDelinearizeIndexOp::inferReturnTypes(
4505  MLIRContext *context, std::optional<::mlir::Location> location,
4506  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
4507  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
4508  AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
4509  regions);
4510  inferredReturnTypes.assign(adaptor.getBasis().size(),
4511  IndexType::get(context));
4512  return success();
4513 }
4514 
4515 void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result,
4516  Value linearIndex,
4517  ArrayRef<OpFoldResult> basis) {
4518  result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType()));
4519  result.addOperands(linearIndex);
4520  SmallVector<Value> basisValues =
4521  llvm::map_to_vector(basis, [&](OpFoldResult ofr) -> Value {
4522  std::optional<int64_t> staticDim = getConstantIntValue(ofr);
4523  if (staticDim.has_value())
4524  return builder.create<arith::ConstantIndexOp>(result.location,
4525  *staticDim);
4526  return llvm::dyn_cast_if_present<Value>(ofr);
4527  });
4528  result.addOperands(basisValues);
4529 }
4530 
4531 LogicalResult AffineDelinearizeIndexOp::verify() {
4532  if (getBasis().empty())
4533  return emitOpError("basis should not be empty");
4534  if (getNumResults() != getBasis().size())
4535  return emitOpError("should return an index for each basis element");
4536  return success();
4537 }
4538 
4539 //===----------------------------------------------------------------------===//
4540 // TableGen'd op method definitions
4541 //===----------------------------------------------------------------------===//
4542 
4543 #define GET_OP_CLASSES
4544 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
Definition: AffineOps.cpp:2651
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
Definition: AffineOps.cpp:2371
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
Definition: AffineOps.cpp:1163
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
Definition: AffineOps.cpp:3221
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
Definition: AffineOps.cpp:3854
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
Definition: AffineOps.cpp:58
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
Definition: AffineOps.cpp:3991
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
Definition: AffineOps.cpp:999
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
Definition: AffineOps.cpp:3257
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
Definition: AffineOps.cpp:2229
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
Definition: AffineOps.cpp:785
static bool isValidAffineIndexOperand(Value value, Region *region)
Definition: AffineOps.cpp:464
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1357
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
Definition: AffineOps.cpp:1070
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:720
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
Definition: AffineOps.cpp:1918
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:712
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
Definition: AffineOps.cpp:2936
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
Definition: AffineOps.cpp:1022
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1314
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
Definition: AffineOps.cpp:4395
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
Definition: AffineOps.cpp:888
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
Definition: AffineOps.cpp:3234
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
Definition: AffineOps.cpp:327
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
Definition: AffineOps.cpp:660
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
Definition: AffineOps.cpp:4106
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
Definition: AffineOps.cpp:2660
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
Definition: AffineOps.cpp:469
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
Definition: AffineOps.cpp:622
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1250
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
Definition: AffineOps.cpp:2610
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
Definition: AffineOps.cpp:2183
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
Definition: AffineOps.cpp:501
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1265
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
Definition: AffineOps.cpp:346
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr &quotientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
Definition: AffineOps.cpp:688
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
Definition: AffineOps.cpp:4059
static LogicalResult verifyAffineMinMaxOp(T op)
Definition: AffineOps.cpp:3211
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
Definition: AffineOps.cpp:2093
static LogicalResult verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
Definition: AffineOps.cpp:3042
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
Definition: AffineOps.cpp:3420
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:76
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:81
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1545
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:901
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:34
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
Definition: AffineExpr.cpp:242
MLIRContext * getContext() const
Definition: AffineExpr.cpp:32
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:626
MLIRContext * getContext() const
Definition: AffineMap.cpp:330
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
Definition: AffineMap.h:213
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:259
unsigned getNumSymbols() const
Definition: AffineMap.cpp:385
unsigned getNumDims() const
Definition: AffineMap.cpp:381
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:394
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
Definition: AffineMap.h:220
unsigned getNumResults() const
Definition: AffineMap.cpp:389
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineMap.cpp:487
unsigned getNumInputs() const
Definition: AffineMap.cpp:390
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
Definition: AffineMap.h:272
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:398
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:502
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
Definition: AffineMap.cpp:128
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:618
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
Definition: AffineMap.cpp:421
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:299
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
AffineMap getDimIdentityMap()
Definition: Builders.cpp:390
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:394
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
Definition: Builders.cpp:195
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
NoneType getNoneType()
Definition: Builders.cpp:104
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
Definition: Builders.cpp:383
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:385
MLIRContext * getContext() const
Definition: Builders.h:55
AffineMap getSymbolIdentityMap()
Definition: Builders.cpp:403
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
IndexType getIndexType()
Definition: Builders.cpp:71
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition: IntegerSet.h:44
unsigned getNumDims() const
Definition: IntegerSet.cpp:15
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
Definition: IntegerSet.cpp:57
unsigned getNumInputs() const
Definition: IntegerSet.cpp:17
ArrayRef< AffineExpr > getConstraints() const
Definition: IntegerSet.cpp:41
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
Definition: IntegerSet.cpp:51
unsigned getNumSymbols() const
Definition: IntegerSet.cpp:16
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:447
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:322
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:444
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
operand_range::iterator operand_iterator
Definition: Operation.h:367
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
AffineBound represents a lower or upper bound in the for operation.
Definition: AffineOps.h:511
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
Definition: AffineOps.h:94
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
Definition: AffineOps.h:303
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
Definition: AffineOps.cpp:3951
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
unsigned getNumResults() const
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
Definition: AffineOps.cpp:2673
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
Definition: AffineOps.cpp:1132
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
Definition: AffineOps.cpp:2587
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
Definition: AffineOps.cpp:276
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1239
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
Definition: AffineOps.cpp:2559
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: AffineOps.cpp:2563
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1142
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1433
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1305
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
Definition: AffineOps.cpp:2551
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1298
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Definition: AffineOps.cpp:246
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
Definition: AffineOps.cpp:1438
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
Definition: AffineOps.cpp:2594
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
Definition: AffineOps.cpp:390
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Definition: AffineOps.cpp:2574
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
Definition: AffineOps.cpp:261
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
Definition: AffineOps.cpp:479
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
Definition: AffineOps.cpp:2555
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
Definition: AffineOps.cpp:1259
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:737
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
Definition: AffineMap.cpp:747
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
AffineExprKind
Definition: AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:69
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:630
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:606
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
Definition: AffineMap.cpp:709
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:616
Canonicalize the affine map result expression order of an affine min/max operation.
Definition: AffineOps.cpp:3474
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3477
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3491
Remove duplicated expressions in affine min/max ops.
Definition: AffineOps.cpp:3290
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3293
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
Definition: AffineOps.cpp:3333
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3336
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:287
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.