MLIR  19.0.0git
AffineOps.cpp
Go to the documentation of this file.
1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
23 #include "llvm/ADT/ScopeExit.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 #include "llvm/ADT/SmallVectorExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/Debug.h"
28 #include <numeric>
29 #include <optional>
30 
31 using namespace mlir;
32 using namespace mlir::affine;
33 
34 #define DEBUG_TYPE "affine-ops"
35 
36 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
37 
38 /// A utility function to check if a value is defined at the top level of
39 /// `region` or is an argument of `region`. A value of index type defined at the
40 /// top level of a `AffineScope` region is always a valid symbol for all
41 /// uses in that region.
43  if (auto arg = llvm::dyn_cast<BlockArgument>(value))
44  return arg.getParentRegion() == region;
45  return value.getDefiningOp()->getParentRegion() == region;
46 }
47 
48 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
49 /// region remains legal if the operation that uses it is inlined into `dest`
50 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
51 /// `isValidSymbol`, depending on the value being required to remain a valid
52 /// dimension or symbol.
53 static bool
55  const IRMapping &mapping,
56  function_ref<bool(Value, Region *)> legalityCheck) {
57  // If the value is a valid dimension for any other reason than being
58  // a top-level value, it will remain valid: constants get inlined
59  // with the function, transitive affine applies also get inlined and
60  // will be checked themselves, etc.
61  if (!isTopLevelValue(value, src))
62  return true;
63 
64  // If it's a top-level value because it's a block operand, i.e. a
65  // function argument, check whether the value replacing it after
66  // inlining is a valid dimension in the new region.
67  if (llvm::isa<BlockArgument>(value))
68  return legalityCheck(mapping.lookup(value), dest);
69 
70  // If it's a top-level value because it's defined in the region,
71  // it can only be inlined if the defining op is a constant or a
72  // `dim`, which can appear anywhere and be valid, since the defining
73  // op won't be top-level anymore after inlining.
74  Attribute operandCst;
75  bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp());
76  return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
77  isDimLikeOp;
78 }
79 
80 /// Checks if all values known to be legal affine dimensions or symbols in `src`
81 /// remain so if their respective users are inlined into `dest`.
82 static bool
84  const IRMapping &mapping,
85  function_ref<bool(Value, Region *)> legalityCheck) {
86  return llvm::all_of(values, [&](Value v) {
87  return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
88  });
89 }
90 
91 /// Checks if an affine read or write operation remains legal after inlining
92 /// from `src` to `dest`.
93 template <typename OpTy>
94 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
95  const IRMapping &mapping) {
96  static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
97  AffineWriteOpInterface>::value,
98  "only ops with affine read/write interface are supported");
99 
100  AffineMap map = op.getAffineMap();
101  ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
102  ValueRange symbolOperands =
103  op.getMapOperands().take_back(map.getNumSymbols());
105  dimOperands, src, dest, mapping,
106  static_cast<bool (*)(Value, Region *)>(isValidDim)))
107  return false;
109  symbolOperands, src, dest, mapping,
110  static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
111  return false;
112  return true;
113 }
114 
115 /// Checks if an affine apply operation remains legal after inlining from `src`
116 /// to `dest`.
117 // Use "unused attribute" marker to silence clang-tidy warning stemming from
118 // the inability to see through "llvm::TypeSwitch".
119 template <>
120 bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op,
121  Region *src, Region *dest,
122  const IRMapping &mapping) {
123  // If it's a valid dimension, we need to check that it remains so.
124  if (isValidDim(op.getResult(), src))
126  op.getMapOperands(), src, dest, mapping,
127  static_cast<bool (*)(Value, Region *)>(isValidDim));
128 
129  // Otherwise it must be a valid symbol, check that it remains so.
131  op.getMapOperands(), src, dest, mapping,
132  static_cast<bool (*)(Value, Region *)>(isValidSymbol));
133 }
134 
135 //===----------------------------------------------------------------------===//
136 // AffineDialect Interfaces
137 //===----------------------------------------------------------------------===//
138 
139 namespace {
140 /// This class defines the interface for handling inlining with affine
141 /// operations.
142 struct AffineInlinerInterface : public DialectInlinerInterface {
144 
145  //===--------------------------------------------------------------------===//
146  // Analysis Hooks
147  //===--------------------------------------------------------------------===//
148 
149  /// Returns true if the given region 'src' can be inlined into the region
150  /// 'dest' that is attached to an operation registered to the current dialect.
151  /// 'wouldBeCloned' is set if the region is cloned into its new location
152  /// rather than moved, indicating there may be other users.
153  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
154  IRMapping &valueMapping) const final {
155  // We can inline into affine loops and conditionals if this doesn't break
156  // affine value categorization rules.
157  Operation *destOp = dest->getParentOp();
158  if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
159  return false;
160 
161  // Multi-block regions cannot be inlined into affine constructs, all of
162  // which require single-block regions.
163  if (!llvm::hasSingleElement(*src))
164  return false;
165 
166  // Side-effecting operations that the affine dialect cannot understand
167  // should not be inlined.
168  Block &srcBlock = src->front();
169  for (Operation &op : srcBlock) {
170  // Ops with no side effects are fine,
171  if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
172  if (iface.hasNoEffect())
173  continue;
174  }
175 
176  // Assuming the inlined region is valid, we only need to check if the
177  // inlining would change it.
178  bool remainsValid =
180  .Case<AffineApplyOp, AffineReadOpInterface,
181  AffineWriteOpInterface>([&](auto op) {
182  return remainsLegalAfterInline(op, src, dest, valueMapping);
183  })
184  .Default([](Operation *) {
185  // Conservatively disallow inlining ops we cannot reason about.
186  return false;
187  });
188 
189  if (!remainsValid)
190  return false;
191  }
192 
193  return true;
194  }
195 
196  /// Returns true if the given operation 'op', that is registered to this
197  /// dialect, can be inlined into the given region, false otherwise.
198  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
199  IRMapping &valueMapping) const final {
200  // Always allow inlining affine operations into a region that is marked as
201  // affine scope, or into affine loops and conditionals. There are some edge
202  // cases when inlining *into* affine structures, but that is handled in the
203  // other 'isLegalToInline' hook above.
204  Operation *parentOp = region->getParentOp();
205  return parentOp->hasTrait<OpTrait::AffineScope>() ||
206  isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
207  }
208 
209  /// Affine regions should be analyzed recursively.
210  bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
211 };
212 } // namespace
213 
214 //===----------------------------------------------------------------------===//
215 // AffineDialect
216 //===----------------------------------------------------------------------===//
217 
218 void AffineDialect::initialize() {
219  addOperations<AffineDmaStartOp, AffineDmaWaitOp,
220 #define GET_OP_LIST
221 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
222  >();
223  addInterfaces<AffineInlinerInterface>();
224  declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
225  AffineMinOp>();
226 }
227 
228 /// Materialize a single constant operation from a given attribute value with
229 /// the desired resultant type.
231  Attribute value, Type type,
232  Location loc) {
233  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
234  return builder.create<ub::PoisonOp>(loc, type, poison);
235  return arith::ConstantOp::materialize(builder, value, type, loc);
236 }
237 
238 /// A utility function to check if a value is defined at the top level of an
239 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
240 /// conservatively assume it is not top-level. A value of index type defined at
241 /// the top level is always a valid symbol.
243  if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
244  // The block owning the argument may be unlinked, e.g. when the surrounding
245  // region has not yet been attached to an Op, at which point the parent Op
246  // is null.
247  Operation *parentOp = arg.getOwner()->getParentOp();
248  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
249  }
250  // The defining Op may live in an unlinked block so its parent Op may be null.
251  Operation *parentOp = value.getDefiningOp()->getParentOp();
252  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
253 }
254 
255 /// Returns the closest region enclosing `op` that is held by an operation with
256 /// trait `AffineScope`; `nullptr` if there is no such region.
258  auto *curOp = op;
259  while (auto *parentOp = curOp->getParentOp()) {
260  if (parentOp->hasTrait<OpTrait::AffineScope>())
261  return curOp->getParentRegion();
262  curOp = parentOp;
263  }
264  return nullptr;
265 }
266 
267 // A Value can be used as a dimension id iff it meets one of the following
268 // conditions:
269 // *) It is valid as a symbol.
270 // *) It is an induction variable.
271 // *) It is the result of affine apply operation with dimension id arguments.
273  // The value must be an index type.
274  if (!value.getType().isIndex())
275  return false;
276 
277  if (auto *defOp = value.getDefiningOp())
278  return isValidDim(value, getAffineScope(defOp));
279 
280  // This value has to be a block argument for an op that has the
281  // `AffineScope` trait or for an affine.for or affine.parallel.
282  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
283  return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
284  isa<AffineForOp, AffineParallelOp>(parentOp));
285 }
286 
287 // Value can be used as a dimension id iff it meets one of the following
288 // conditions:
289 // *) It is valid as a symbol.
290 // *) It is an induction variable.
291 // *) It is the result of an affine apply operation with dimension id operands.
292 bool mlir::affine::isValidDim(Value value, Region *region) {
293  // The value must be an index type.
294  if (!value.getType().isIndex())
295  return false;
296 
297  // All valid symbols are okay.
298  if (isValidSymbol(value, region))
299  return true;
300 
301  auto *op = value.getDefiningOp();
302  if (!op) {
303  // This value has to be a block argument for an affine.for or an
304  // affine.parallel.
305  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
306  return isa<AffineForOp, AffineParallelOp>(parentOp);
307  }
308 
309  // Affine apply operation is ok if all of its operands are ok.
310  if (auto applyOp = dyn_cast<AffineApplyOp>(op))
311  return applyOp.isValidDim(region);
312  // The dim op is okay if its operand memref/tensor is defined at the top
313  // level.
314  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
315  return isTopLevelValue(dimOp.getShapedValue());
316  return false;
317 }
318 
319 /// Returns true if the 'index' dimension of the `memref` defined by
320 /// `memrefDefOp` is a statically shaped one or defined using a valid symbol
321 /// for `region`.
322 template <typename AnyMemRefDefOp>
323 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
324  Region *region) {
325  MemRefType memRefType = memrefDefOp.getType();
326 
327  // Dimension index is out of bounds.
328  if (index >= memRefType.getRank()) {
329  return false;
330  }
331 
332  // Statically shaped.
333  if (!memRefType.isDynamicDim(index))
334  return true;
335  // Get the position of the dimension among dynamic dimensions;
336  unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
337  return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
338  region);
339 }
340 
341 /// Returns true if the result of the dim op is a valid symbol for `region`.
342 static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
343  // The dim op is okay if its source is defined at the top level.
344  if (isTopLevelValue(dimOp.getShapedValue()))
345  return true;
346 
347  // Conservatively handle remaining BlockArguments as non-valid symbols.
348  // E.g. scf.for iterArgs.
349  if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
350  return false;
351 
352  // The dim op is also okay if its operand memref is a view/subview whose
353  // corresponding size is a valid symbol.
354  std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
355 
356  // Be conservative if we can't understand the dimension.
357  if (!index.has_value())
358  return false;
359 
360  // Skip over all memref.cast ops (if any).
361  Operation *op = dimOp.getShapedValue().getDefiningOp();
362  while (auto castOp = dyn_cast<memref::CastOp>(op)) {
363  // Bail on unranked memrefs.
364  if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
365  return false;
366  op = castOp.getSource().getDefiningOp();
367  if (!op)
368  return false;
369  }
370 
371  int64_t i = index.value();
373  .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
374  [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
375  .Default([](Operation *) { return false; });
376 }
377 
378 // A value can be used as a symbol (at all its use sites) iff it meets one of
379 // the following conditions:
380 // *) It is a constant.
381 // *) Its defining op or block arg appearance is immediately enclosed by an op
382 // with `AffineScope` trait.
383 // *) It is the result of an affine.apply operation with symbol operands.
384 // *) It is a result of the dim op on a memref whose corresponding size is a
385 // valid symbol.
387  if (!value)
388  return false;
389 
390  // The value must be an index type.
391  if (!value.getType().isIndex())
392  return false;
393 
394  // Check that the value is a top level value.
395  if (isTopLevelValue(value))
396  return true;
397 
398  if (auto *defOp = value.getDefiningOp())
399  return isValidSymbol(value, getAffineScope(defOp));
400 
401  return false;
402 }
403 
404 /// A value can be used as a symbol for `region` iff it meets one of the
405 /// following conditions:
406 /// *) It is a constant.
407 /// *) It is the result of an affine apply operation with symbol arguments.
408 /// *) It is a result of the dim op on a memref whose corresponding size is
409 /// a valid symbol.
410 /// *) It is defined at the top level of 'region' or is its argument.
411 /// *) It dominates `region`'s parent op.
412 /// If `region` is null, conservatively assume the symbol definition scope does
413 /// not exist and only accept the values that would be symbols regardless of
414 /// the surrounding region structure, i.e. the first three cases above.
416  // The value must be an index type.
417  if (!value.getType().isIndex())
418  return false;
419 
420  // A top-level value is a valid symbol.
421  if (region && ::isTopLevelValue(value, region))
422  return true;
423 
424  auto *defOp = value.getDefiningOp();
425  if (!defOp) {
426  // A block argument that is not a top-level value is a valid symbol if it
427  // dominates region's parent op.
428  Operation *regionOp = region ? region->getParentOp() : nullptr;
429  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
430  if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
431  return isValidSymbol(value, parentOpRegion);
432  return false;
433  }
434 
435  // Constant operation is ok.
436  Attribute operandCst;
437  if (matchPattern(defOp, m_Constant(&operandCst)))
438  return true;
439 
440  // Affine apply operation is ok if all of its operands are ok.
441  if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
442  return applyOp.isValidSymbol(region);
443 
444  // Dim op results could be valid symbols at any level.
445  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
446  return isDimOpValidSymbol(dimOp, region);
447 
448  // Check for values dominating `region`'s parent op.
449  Operation *regionOp = region ? region->getParentOp() : nullptr;
450  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
451  if (auto *parentRegion = region->getParentOp()->getParentRegion())
452  return isValidSymbol(value, parentRegion);
453 
454  return false;
455 }
456 
457 // Returns true if 'value' is a valid index to an affine operation (e.g.
458 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
459 // `region` provides the polyhedral symbol scope. Returns false otherwise.
460 static bool isValidAffineIndexOperand(Value value, Region *region) {
461  return isValidDim(value, region) || isValidSymbol(value, region);
462 }
463 
464 /// Prints dimension and symbol list.
467  unsigned numDims, OpAsmPrinter &printer) {
468  OperandRange operands(begin, end);
469  printer << '(' << operands.take_front(numDims) << ')';
470  if (operands.size() > numDims)
471  printer << '[' << operands.drop_front(numDims) << ']';
472 }
473 
474 /// Parses dimension and symbol list and returns true if parsing failed.
476  OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
478  if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
479  return failure();
480  // Store number of dimensions for validation by caller.
481  numDims = opInfos.size();
482 
483  // Parse the optional symbol operands.
484  auto indexTy = parser.getBuilder().getIndexType();
485  return failure(parser.parseOperandList(
487  parser.resolveOperands(opInfos, indexTy, operands));
488 }
489 
490 /// Utility function to verify that a set of operands are valid dimension and
491 /// symbol identifiers. The operands should be laid out such that the dimension
492 /// operands are before the symbol operands. This function returns failure if
493 /// there was an invalid operand. An operation is provided to emit any necessary
494 /// errors.
495 template <typename OpTy>
496 static LogicalResult
498  unsigned numDims) {
499  unsigned opIt = 0;
500  for (auto operand : operands) {
501  if (opIt++ < numDims) {
502  if (!isValidDim(operand, getAffineScope(op)))
503  return op.emitOpError("operand cannot be used as a dimension id");
504  } else if (!isValidSymbol(operand, getAffineScope(op))) {
505  return op.emitOpError("operand cannot be used as a symbol");
506  }
507  }
508  return success();
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // AffineApplyOp
513 //===----------------------------------------------------------------------===//
514 
515 AffineValueMap AffineApplyOp::getAffineValueMap() {
516  return AffineValueMap(getAffineMap(), getOperands(), getResult());
517 }
518 
520  auto &builder = parser.getBuilder();
521  auto indexTy = builder.getIndexType();
522 
523  AffineMapAttr mapAttr;
524  unsigned numDims;
525  if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
526  parseDimAndSymbolList(parser, result.operands, numDims) ||
527  parser.parseOptionalAttrDict(result.attributes))
528  return failure();
529  auto map = mapAttr.getValue();
530 
531  if (map.getNumDims() != numDims ||
532  numDims + map.getNumSymbols() != result.operands.size()) {
533  return parser.emitError(parser.getNameLoc(),
534  "dimension or symbol index mismatch");
535  }
536 
537  result.types.append(map.getNumResults(), indexTy);
538  return success();
539 }
540 
542  p << " " << getMapAttr();
543  printDimAndSymbolList(operand_begin(), operand_end(),
544  getAffineMap().getNumDims(), p);
545  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
546 }
547 
549  // Check input and output dimensions match.
550  AffineMap affineMap = getMap();
551 
552  // Verify that operand count matches affine map dimension and symbol count.
553  if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
554  return emitOpError(
555  "operand count and affine map dimension and symbol count must match");
556 
557  // Verify that the map only produces one result.
558  if (affineMap.getNumResults() != 1)
559  return emitOpError("mapping must produce one value");
560 
561  return success();
562 }
563 
564 // The result of the affine apply operation can be used as a dimension id if all
565 // its operands are valid dimension ids.
567  return llvm::all_of(getOperands(),
568  [](Value op) { return affine::isValidDim(op); });
569 }
570 
571 // The result of the affine apply operation can be used as a dimension id if all
572 // its operands are valid dimension ids with the parent operation of `region`
573 // defining the polyhedral scope for symbols.
574 bool AffineApplyOp::isValidDim(Region *region) {
575  return llvm::all_of(getOperands(),
576  [&](Value op) { return ::isValidDim(op, region); });
577 }
578 
579 // The result of the affine apply operation can be used as a symbol if all its
580 // operands are symbols.
582  return llvm::all_of(getOperands(),
583  [](Value op) { return affine::isValidSymbol(op); });
584 }
585 
586 // The result of the affine apply operation can be used as a symbol in `region`
587 // if all its operands are symbols in `region`.
588 bool AffineApplyOp::isValidSymbol(Region *region) {
589  return llvm::all_of(getOperands(), [&](Value operand) {
590  return affine::isValidSymbol(operand, region);
591  });
592 }
593 
594 OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
595  auto map = getAffineMap();
596 
597  // Fold dims and symbols to existing values.
598  auto expr = map.getResult(0);
599  if (auto dim = dyn_cast<AffineDimExpr>(expr))
600  return getOperand(dim.getPosition());
601  if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
602  return getOperand(map.getNumDims() + sym.getPosition());
603 
604  // Otherwise, default to folding the map.
606  bool hasPoison = false;
607  auto foldResult =
608  map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
609  if (hasPoison)
611  if (failed(foldResult))
612  return {};
613  return result[0];
614 }
615 
616 /// Returns the largest known divisor of `e`. Exploits information from the
617 /// values in `operands`.
618 static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
619  // This method isn't aware of `operands`.
620  int64_t div = e.getLargestKnownDivisor();
621 
622  // We now make use of operands for the case `e` is a dim expression.
623  // TODO: More powerful simplification would have to modify
624  // getLargestKnownDivisor to take `operands` and exploit that information as
625  // well for dim/sym expressions, but in that case, getLargestKnownDivisor
626  // can't be part of the IR library but of the `Analysis` library. The IR
627  // library can only really depend on simple O(1) checks.
628  auto dimExpr = dyn_cast<AffineDimExpr>(e);
629  // If it's not a dim expr, `div` is the best we have.
630  if (!dimExpr)
631  return div;
632 
633  // We simply exploit information from loop IVs.
634  // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
635  // desired simplifications are expected to be part of other
636  // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
637  // LoopAnalysis library.
638  Value operand = operands[dimExpr.getPosition()];
639  int64_t operandDivisor = 1;
640  // TODO: With the right accessors, this can be extended to
641  // LoopLikeOpInterface.
642  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
643  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
644  operandDivisor = forOp.getStepAsInt();
645  } else {
646  uint64_t lbLargestKnownDivisor =
647  forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
648  operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
649  }
650  }
651  return operandDivisor;
652 }
653 
654 /// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
655 /// being an affine dim expression or a constant.
657  int64_t k) {
658  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
659  int64_t constVal = constExpr.getValue();
660  return constVal >= 0 && constVal < k;
661  }
662  auto dimExpr = dyn_cast<AffineDimExpr>(e);
663  if (!dimExpr)
664  return false;
665  Value operand = operands[dimExpr.getPosition()];
666  // TODO: With the right accessors, this can be extended to
667  // LoopLikeOpInterface.
668  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
669  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
670  forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
671  return true;
672  }
673  }
674 
675  // We don't consider other cases like `operand` being defined by a constant or
676  // an affine.apply op since such cases will already be handled by other
677  // patterns and propagation of loop IVs or constant would happen.
678  return false;
679 }
680 
681 /// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
682 /// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
683 /// expression is in that form.
684 static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
685  AffineExpr &quotientTimesDiv, AffineExpr &rem) {
686  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
687  if (!bin || bin.getKind() != AffineExprKind::Add)
688  return false;
689 
690  AffineExpr llhs = bin.getLHS();
691  AffineExpr rlhs = bin.getRHS();
692  div = getLargestKnownDivisor(llhs, operands);
693  if (isNonNegativeBoundedBy(rlhs, operands, div)) {
694  quotientTimesDiv = llhs;
695  rem = rlhs;
696  return true;
697  }
698  div = getLargestKnownDivisor(rlhs, operands);
699  if (isNonNegativeBoundedBy(llhs, operands, div)) {
700  quotientTimesDiv = rlhs;
701  rem = llhs;
702  return true;
703  }
704  return false;
705 }
706 
707 /// Gets the constant lower bound on an `iv`.
708 static std::optional<int64_t> getLowerBound(Value iv) {
709  AffineForOp forOp = getForInductionVarOwner(iv);
710  if (forOp && forOp.hasConstantLowerBound())
711  return forOp.getConstantLowerBound();
712  return std::nullopt;
713 }
714 
715 /// Gets the constant upper bound on an affine.for `iv`.
716 static std::optional<int64_t> getUpperBound(Value iv) {
717  AffineForOp forOp = getForInductionVarOwner(iv);
718  if (!forOp || !forOp.hasConstantUpperBound())
719  return std::nullopt;
720 
721  // If its lower bound is also known, we can get a more precise bound
722  // whenever the step is not one.
723  if (forOp.hasConstantLowerBound()) {
724  return forOp.getConstantUpperBound() - 1 -
725  (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
726  forOp.getStepAsInt();
727  }
728  return forOp.getConstantUpperBound() - 1;
729 }
730 
731 /// Determine a constant upper bound for `expr` if one exists while exploiting
732 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
733 /// is guaranteed to be less than or equal to it.
734 static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
735  unsigned numSymbols,
736  ArrayRef<Value> operands) {
737  // Get the constant lower or upper bounds on the operands.
738  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
739  constLowerBounds.reserve(operands.size());
740  constUpperBounds.reserve(operands.size());
741  for (Value operand : operands) {
742  constLowerBounds.push_back(getLowerBound(operand));
743  constUpperBounds.push_back(getUpperBound(operand));
744  }
745 
746  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
747  return constExpr.getValue();
748 
749  return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
750  constUpperBounds,
751  /*isUpper=*/true);
752 }
753 
754 /// Determine a constant lower bound for `expr` if one exists while exploiting
755 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
756 /// is guaranteed to be less than or equal to it.
757 static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
758  unsigned numSymbols,
759  ArrayRef<Value> operands) {
760  // Get the constant lower or upper bounds on the operands.
761  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
762  constLowerBounds.reserve(operands.size());
763  constUpperBounds.reserve(operands.size());
764  for (Value operand : operands) {
765  constLowerBounds.push_back(getLowerBound(operand));
766  constUpperBounds.push_back(getUpperBound(operand));
767  }
768 
769  std::optional<int64_t> lowerBound;
770  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
771  lowerBound = constExpr.getValue();
772  } else {
773  lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
774  constLowerBounds, constUpperBounds,
775  /*isUpper=*/false);
776  }
777  return lowerBound;
778 }
779 
780 /// Simplify `expr` while exploiting information from the values in `operands`.
781 static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
782  unsigned numSymbols,
783  ArrayRef<Value> operands) {
784  // We do this only for certain floordiv/mod expressions.
785  auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
786  if (!binExpr)
787  return;
788 
789  // Simplify the child expressions first.
790  AffineExpr lhs = binExpr.getLHS();
791  AffineExpr rhs = binExpr.getRHS();
792  simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
793  simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
794  expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
795 
796  binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
797  if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
798  expr.getKind() != AffineExprKind::CeilDiv &&
799  expr.getKind() != AffineExprKind::Mod)) {
800  return;
801  }
802 
803  // The `lhs` and `rhs` may be different post construction of simplified expr.
804  lhs = binExpr.getLHS();
805  rhs = binExpr.getRHS();
806  auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
807  if (!rhsConst)
808  return;
809 
810  int64_t rhsConstVal = rhsConst.getValue();
811  // Undefined exprsessions aren't touched; IR can still be valid with them.
812  if (rhsConstVal <= 0)
813  return;
814 
815  // Exploit constant lower/upper bounds to simplify a floordiv or mod.
816  MLIRContext *context = expr.getContext();
817  std::optional<int64_t> lhsLbConst =
818  getLowerBound(lhs, numDims, numSymbols, operands);
819  std::optional<int64_t> lhsUbConst =
820  getUpperBound(lhs, numDims, numSymbols, operands);
821  if (lhsLbConst && lhsUbConst) {
822  int64_t lhsLbConstVal = *lhsLbConst;
823  int64_t lhsUbConstVal = *lhsUbConst;
824  // lhs floordiv c is a single value lhs is bounded in a range `c` that has
825  // the same quotient.
826  if (binExpr.getKind() == AffineExprKind::FloorDiv &&
827  floorDiv(lhsLbConstVal, rhsConstVal) ==
828  floorDiv(lhsUbConstVal, rhsConstVal)) {
829  expr =
830  getAffineConstantExpr(floorDiv(lhsLbConstVal, rhsConstVal), context);
831  return;
832  }
833  // lhs ceildiv c is a single value if the entire range has the same ceil
834  // quotient.
835  if (binExpr.getKind() == AffineExprKind::CeilDiv &&
836  ceilDiv(lhsLbConstVal, rhsConstVal) ==
837  ceilDiv(lhsUbConstVal, rhsConstVal)) {
838  expr =
839  getAffineConstantExpr(ceilDiv(lhsLbConstVal, rhsConstVal), context);
840  return;
841  }
842  // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
843  if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
844  lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
845  expr = lhs;
846  return;
847  }
848  }
849 
850  // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
851  // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
852  // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
853  // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
854  AffineExpr quotientTimesDiv, rem;
855  int64_t divisor;
856  if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
857  if (rhsConstVal % divisor == 0 &&
858  binExpr.getKind() == AffineExprKind::FloorDiv) {
859  expr = quotientTimesDiv.floorDiv(rhsConst);
860  } else if (divisor % rhsConstVal == 0 &&
861  binExpr.getKind() == AffineExprKind::Mod) {
862  expr = rem % rhsConst;
863  }
864  return;
865  }
866 
867  // Handle the simple case when the LHS expression can be either upper
868  // bounded or is a known multiple of RHS constant.
869  // lhs floordiv c -> 0 if 0 <= lhs < c,
870  // lhs mod c -> 0 if lhs % c = 0.
871  if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) &&
872  binExpr.getKind() == AffineExprKind::FloorDiv) ||
873  (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 &&
874  binExpr.getKind() == AffineExprKind::Mod)) {
875  expr = getAffineConstantExpr(0, expr.getContext());
876  }
877 }
878 
879 /// Simplify the expressions in `map` while making use of lower or upper bounds
880 /// of its operands. If `isMax` is true, the map is to be treated as a max of
881 /// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
882 /// d1) can be simplified to (8) if the operands are respectively lower bounded
883 /// by 2 and 0 (the second expression can't be lower than 8).
885  ArrayRef<Value> operands,
886  bool isMax) {
887  // Can't simplify.
888  if (operands.empty())
889  return;
890 
891  // Get the upper or lower bound on an affine.for op IV using its range.
892  // Get the constant lower or upper bounds on the operands.
893  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
894  constLowerBounds.reserve(operands.size());
895  constUpperBounds.reserve(operands.size());
896  for (Value operand : operands) {
897  constLowerBounds.push_back(getLowerBound(operand));
898  constUpperBounds.push_back(getUpperBound(operand));
899  }
900 
901  // We will compute the lower and upper bounds on each of the expressions
902  // Then, we will check (depending on max or min) as to whether a specific
903  // bound is redundant by checking if its highest (in case of max) and its
904  // lowest (in the case of min) value is already lower than (or higher than)
905  // the lower bound (or upper bound in the case of min) of another bound.
906  SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
907  lowerBounds.reserve(map.getNumResults());
908  upperBounds.reserve(map.getNumResults());
909  for (AffineExpr e : map.getResults()) {
910  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
911  lowerBounds.push_back(constExpr.getValue());
912  upperBounds.push_back(constExpr.getValue());
913  } else {
914  lowerBounds.push_back(
916  constLowerBounds, constUpperBounds,
917  /*isUpper=*/false));
918  upperBounds.push_back(
920  constLowerBounds, constUpperBounds,
921  /*isUpper=*/true));
922  }
923  }
924 
925  // Collect expressions that are not redundant.
926  SmallVector<AffineExpr, 4> irredundantExprs;
927  for (auto exprEn : llvm::enumerate(map.getResults())) {
928  AffineExpr e = exprEn.value();
929  unsigned i = exprEn.index();
930  // Some expressions can be turned into constants.
931  if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
932  e = getAffineConstantExpr(*lowerBounds[i], e.getContext());
933 
934  // Check if the expression is redundant.
935  if (isMax) {
936  if (!upperBounds[i]) {
937  irredundantExprs.push_back(e);
938  continue;
939  }
940  // If there exists another expression such that its lower bound is greater
941  // than this expression's upper bound, it's redundant.
942  if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
943  auto otherLowerBound = en.value();
944  unsigned pos = en.index();
945  if (pos == i || !otherLowerBound)
946  return false;
947  if (*otherLowerBound > *upperBounds[i])
948  return true;
949  if (*otherLowerBound < *upperBounds[i])
950  return false;
951  // Equality case. When both expressions are considered redundant, we
952  // don't want to get both of them. We keep the one that appears
953  // first.
954  if (upperBounds[pos] && lowerBounds[i] &&
955  lowerBounds[i] == upperBounds[i] &&
956  otherLowerBound == *upperBounds[pos] && i < pos)
957  return false;
958  return true;
959  }))
960  irredundantExprs.push_back(e);
961  } else {
962  if (!lowerBounds[i]) {
963  irredundantExprs.push_back(e);
964  continue;
965  }
966  // Likewise for the `min` case. Use the complement of the condition above.
967  if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
968  auto otherUpperBound = en.value();
969  unsigned pos = en.index();
970  if (pos == i || !otherUpperBound)
971  return false;
972  if (*otherUpperBound < *lowerBounds[i])
973  return true;
974  if (*otherUpperBound > *lowerBounds[i])
975  return false;
976  if (lowerBounds[pos] && upperBounds[i] &&
977  lowerBounds[i] == upperBounds[i] &&
978  otherUpperBound == lowerBounds[pos] && i < pos)
979  return false;
980  return true;
981  }))
982  irredundantExprs.push_back(e);
983  }
984  }
985 
986  // Create the map without the redundant expressions.
987  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs,
988  map.getContext());
989 }
990 
991 /// Simplify the map while exploiting information on the values in `operands`.
992 // Use "unused attribute" marker to silence warning stemming from the inability
993 // to see through the template expansion.
994 static void LLVM_ATTRIBUTE_UNUSED
996  assert(map.getNumInputs() == operands.size() && "invalid operands for map");
997  SmallVector<AffineExpr> newResults;
998  newResults.reserve(map.getNumResults());
999  for (AffineExpr expr : map.getResults()) {
1001  operands);
1002  newResults.push_back(expr);
1003  }
1004  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
1005  map.getContext());
1006 }
1007 
1008 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
1009 /// defining AffineApplyOp expression and operands.
1010 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
1011 /// When `dimOrSymbolPosition >= dims.size()`,
1012 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
1013 /// Mutate `map`,`dims` and `syms` in place as follows:
1014 /// 1. `dims` and `syms` are only appended to.
1015 /// 2. `map` dim and symbols are gradually shifted to higher positions.
1016 /// 3. Old `dim` and `sym` entries are replaced by nullptr
1017 /// This avoids the need for any bookkeeping.
1019  unsigned dimOrSymbolPosition,
1020  SmallVectorImpl<Value> &dims,
1021  SmallVectorImpl<Value> &syms) {
1022  MLIRContext *ctx = map->getContext();
1023  bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1024  unsigned pos = isDimReplacement ? dimOrSymbolPosition
1025  : dimOrSymbolPosition - dims.size();
1026  Value &v = isDimReplacement ? dims[pos] : syms[pos];
1027  if (!v)
1028  return failure();
1029 
1030  auto affineApply = v.getDefiningOp<AffineApplyOp>();
1031  if (!affineApply)
1032  return failure();
1033 
1034  // At this point we will perform a replacement of `v`, set the entry in `dim`
1035  // or `sym` to nullptr immediately.
1036  v = nullptr;
1037 
1038  // Compute the map, dims and symbols coming from the AffineApplyOp.
1039  AffineMap composeMap = affineApply.getAffineMap();
1040  assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1041  SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1042  affineApply.getMapOperands().end());
1043  // Canonicalize the map to promote dims to symbols when possible. This is to
1044  // avoid generating invalid maps.
1045  canonicalizeMapAndOperands(&composeMap, &composeOperands);
1046  AffineExpr replacementExpr =
1047  composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
1048  ValueRange composeDims =
1049  ArrayRef<Value>(composeOperands).take_front(composeMap.getNumDims());
1050  ValueRange composeSyms =
1051  ArrayRef<Value>(composeOperands).take_back(composeMap.getNumSymbols());
1052  AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
1053  : getAffineSymbolExpr(pos, ctx);
1054 
1055  // Append the dims and symbols where relevant and perform the replacement.
1056  dims.append(composeDims.begin(), composeDims.end());
1057  syms.append(composeSyms.begin(), composeSyms.end());
1058  *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());
1059 
1060  return success();
1061 }
1062 
1063 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1064 /// iteratively. Perform canonicalization of map and operands as well as
1065 /// AffineMap simplification. `map` and `operands` are mutated in place.
1067  SmallVectorImpl<Value> *operands) {
1068  if (map->getNumResults() == 0) {
1069  canonicalizeMapAndOperands(map, operands);
1070  *map = simplifyAffineMap(*map);
1071  return;
1072  }
1073 
1074  MLIRContext *ctx = map->getContext();
1075  SmallVector<Value, 4> dims(operands->begin(),
1076  operands->begin() + map->getNumDims());
1077  SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1078  operands->end());
1079 
1080  // Iterate over dims and symbols coming from AffineApplyOp and replace until
1081  // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1082  // and `syms` can only increase by construction.
1083  // The implementation uses a `while` loop to support the case of symbols
1084  // that may be constructed from dims ;this may be overkill.
1085  while (true) {
1086  bool changed = false;
1087  for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1088  if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
1089  break;
1090  if (!changed)
1091  break;
1092  }
1093 
1094  // Clear operands so we can fill them anew.
1095  operands->clear();
1096 
1097  // At this point we may have introduced null operands, prune them out before
1098  // canonicalizing map and operands.
1099  unsigned nDims = 0, nSyms = 0;
1100  SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1101  dimReplacements.reserve(dims.size());
1102  symReplacements.reserve(syms.size());
1103  for (auto *container : {&dims, &syms}) {
1104  bool isDim = (container == &dims);
1105  auto &repls = isDim ? dimReplacements : symReplacements;
1106  for (const auto &en : llvm::enumerate(*container)) {
1107  Value v = en.value();
1108  if (!v) {
1109  assert(isDim ? !map->isFunctionOfDim(en.index())
1110  : !map->isFunctionOfSymbol(en.index()) &&
1111  "map is function of unexpected expr@pos");
1112  repls.push_back(getAffineConstantExpr(0, ctx));
1113  continue;
1114  }
1115  repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
1116  : getAffineSymbolExpr(nSyms++, ctx));
1117  operands->push_back(v);
1118  }
1119  }
1120  *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
1121  nSyms);
1122 
1123  // Canonicalize and simplify before returning.
1124  canonicalizeMapAndOperands(map, operands);
1125  *map = simplifyAffineMap(*map);
1126 }
1127 
1129  AffineMap *map, SmallVectorImpl<Value> *operands) {
1130  while (llvm::any_of(*operands, [](Value v) {
1131  return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
1132  })) {
1133  composeAffineMapAndOperands(map, operands);
1134  }
1135 }
1136 
1137 AffineApplyOp
1139  ArrayRef<OpFoldResult> operands) {
1140  SmallVector<Value> valueOperands;
1141  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1142  composeAffineMapAndOperands(&map, &valueOperands);
1143  assert(map);
1144  return b.create<AffineApplyOp>(loc, map, valueOperands);
1145 }
1146 
1147 AffineApplyOp
1149  ArrayRef<OpFoldResult> operands) {
1150  return makeComposedAffineApply(
1151  b, loc,
1153  .front(),
1154  operands);
1155 }
1156 
1157 /// Composes the given affine map with the given list of operands, pulling in
1158 /// the maps from any affine.apply operations that supply the operands.
1160  SmallVectorImpl<Value> &operands) {
1161  // Compose and canonicalize each expression in the map individually because
1162  // composition only applies to single-result maps, collecting potentially
1163  // duplicate operands in a single list with shifted dimensions and symbols.
1164  SmallVector<Value> dims, symbols;
1166  for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
1167  SmallVector<Value> submapOperands(operands.begin(), operands.end());
1168  AffineMap submap = map.getSubMap({i});
1169  fullyComposeAffineMapAndOperands(&submap, &submapOperands);
1170  canonicalizeMapAndOperands(&submap, &submapOperands);
1171  unsigned numNewDims = submap.getNumDims();
1172  submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
1173  llvm::append_range(dims,
1174  ArrayRef<Value>(submapOperands).take_front(numNewDims));
1175  llvm::append_range(symbols,
1176  ArrayRef<Value>(submapOperands).drop_front(numNewDims));
1177  exprs.push_back(submap.getResult(0));
1178  }
1179 
1180  // Canonicalize the map created from composed expressions to deduplicate the
1181  // dimension and symbol operands.
1182  operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1183  map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
1184  canonicalizeMapAndOperands(&map, &operands);
1185 }
1186 
1189  AffineMap map,
1190  ArrayRef<OpFoldResult> operands) {
1191  assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1192 
1193  // Create new builder without a listener, so that no notification is
1194  // triggered if the op is folded.
1195  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1196  // workaround is no longer needed.
1197  OpBuilder newBuilder(b.getContext());
1199 
1200  // Create op.
1201  AffineApplyOp applyOp =
1202  makeComposedAffineApply(newBuilder, loc, map, operands);
1203 
1204  // Get constant operands.
1205  SmallVector<Attribute> constOperands(applyOp->getNumOperands());
1206  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1207  matchPattern(applyOp->getOperand(i), m_Constant(&constOperands[i]));
1208 
1209  // Try to fold the operation.
1210  SmallVector<OpFoldResult> foldResults;
1211  if (failed(applyOp->fold(constOperands, foldResults)) ||
1212  foldResults.empty()) {
1213  if (OpBuilder::Listener *listener = b.getListener())
1214  listener->notifyOperationInserted(applyOp, /*previous=*/{});
1215  return applyOp.getResult();
1216  }
1217 
1218  applyOp->erase();
1219  assert(foldResults.size() == 1 && "expected 1 folded result");
1220  return foldResults.front();
1221 }
1222 
1225  AffineExpr expr,
1226  ArrayRef<OpFoldResult> operands) {
1228  b, loc,
1230  .front(),
1231  operands);
1232 }
1233 
1236  OpBuilder &b, Location loc, AffineMap map,
1237  ArrayRef<OpFoldResult> operands) {
1238  return llvm::map_to_vector(llvm::seq<unsigned>(0, map.getNumResults()),
1239  [&](unsigned i) {
1240  return makeComposedFoldedAffineApply(
1241  b, loc, map.getSubMap({i}), operands);
1242  });
1243 }
1244 
1245 template <typename OpTy>
1247  ArrayRef<OpFoldResult> operands) {
1248  SmallVector<Value> valueOperands;
1249  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1250  composeMultiResultAffineMap(map, valueOperands);
1251  return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands);
1252 }
1253 
1254 AffineMinOp
1256  ArrayRef<OpFoldResult> operands) {
1257  return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1258 }
1259 
1260 template <typename OpTy>
1262  AffineMap map,
1263  ArrayRef<OpFoldResult> operands) {
1264  // Create new builder without a listener, so that no notification is
1265  // triggered if the op is folded.
1266  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1267  // workaround is no longer needed.
1268  OpBuilder newBuilder(b.getContext());
1270 
1271  // Create op.
1272  auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1273 
1274  // Get constant operands.
1275  SmallVector<Attribute> constOperands(minMaxOp->getNumOperands());
1276  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1277  matchPattern(minMaxOp->getOperand(i), m_Constant(&constOperands[i]));
1278 
1279  // Try to fold the operation.
1280  SmallVector<OpFoldResult> foldResults;
1281  if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1282  foldResults.empty()) {
1283  if (OpBuilder::Listener *listener = b.getListener())
1284  listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
1285  return minMaxOp.getResult();
1286  }
1287 
1288  minMaxOp->erase();
1289  assert(foldResults.size() == 1 && "expected 1 folded result");
1290  return foldResults.front();
1291 }
1292 
1295  AffineMap map,
1296  ArrayRef<OpFoldResult> operands) {
1297  return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1298 }
1299 
1302  AffineMap map,
1303  ArrayRef<OpFoldResult> operands) {
1304  return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1305 }
1306 
1307 // A symbol may appear as a dim in affine.apply operations. This function
1308 // canonicalizes dims that are valid symbols into actual symbols.
1309 template <class MapOrSet>
1310 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1311  SmallVectorImpl<Value> *operands) {
1312  if (!mapOrSet || operands->empty())
1313  return;
1314 
1315  assert(mapOrSet->getNumInputs() == operands->size() &&
1316  "map/set inputs must match number of operands");
1317 
1318  auto *context = mapOrSet->getContext();
1319  SmallVector<Value, 8> resultOperands;
1320  resultOperands.reserve(operands->size());
1321  SmallVector<Value, 8> remappedSymbols;
1322  remappedSymbols.reserve(operands->size());
1323  unsigned nextDim = 0;
1324  unsigned nextSym = 0;
1325  unsigned oldNumSyms = mapOrSet->getNumSymbols();
1326  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1327  for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1328  if (i < mapOrSet->getNumDims()) {
1329  if (isValidSymbol((*operands)[i])) {
1330  // This is a valid symbol that appears as a dim, canonicalize it.
1331  dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1332  remappedSymbols.push_back((*operands)[i]);
1333  } else {
1334  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1335  resultOperands.push_back((*operands)[i]);
1336  }
1337  } else {
1338  resultOperands.push_back((*operands)[i]);
1339  }
1340  }
1341 
1342  resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1343  *operands = resultOperands;
1344  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1345  oldNumSyms + nextSym);
1346 
1347  assert(mapOrSet->getNumInputs() == operands->size() &&
1348  "map/set inputs must match number of operands");
1349 }
1350 
1351 // Works for either an affine map or an integer set.
1352 template <class MapOrSet>
1353 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1354  SmallVectorImpl<Value> *operands) {
1355  static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1356  "Argument must be either of AffineMap or IntegerSet type");
1357 
1358  if (!mapOrSet || operands->empty())
1359  return;
1360 
1361  assert(mapOrSet->getNumInputs() == operands->size() &&
1362  "map/set inputs must match number of operands");
1363 
1364  canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1365 
1366  // Check to see what dims are used.
1367  llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1368  llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1369  mapOrSet->walkExprs([&](AffineExpr expr) {
1370  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1371  usedDims[dimExpr.getPosition()] = true;
1372  else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1373  usedSyms[symExpr.getPosition()] = true;
1374  });
1375 
1376  auto *context = mapOrSet->getContext();
1377 
1378  SmallVector<Value, 8> resultOperands;
1379  resultOperands.reserve(operands->size());
1380 
1381  llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1382  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1383  unsigned nextDim = 0;
1384  for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1385  if (usedDims[i]) {
1386  // Remap dim positions for duplicate operands.
1387  auto it = seenDims.find((*operands)[i]);
1388  if (it == seenDims.end()) {
1389  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1390  resultOperands.push_back((*operands)[i]);
1391  seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1392  } else {
1393  dimRemapping[i] = it->second;
1394  }
1395  }
1396  }
1397  llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1398  SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1399  unsigned nextSym = 0;
1400  for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1401  if (!usedSyms[i])
1402  continue;
1403  // Handle constant operands (only needed for symbolic operands since
1404  // constant operands in dimensional positions would have already been
1405  // promoted to symbolic positions above).
1406  IntegerAttr operandCst;
1407  if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1408  m_Constant(&operandCst))) {
1409  symRemapping[i] =
1410  getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1411  continue;
1412  }
1413  // Remap symbol positions for duplicate operands.
1414  auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1415  if (it == seenSymbols.end()) {
1416  symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1417  resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1418  seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1419  symRemapping[i]));
1420  } else {
1421  symRemapping[i] = it->second;
1422  }
1423  }
1424  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1425  nextDim, nextSym);
1426  *operands = resultOperands;
1427 }
1428 
1430  AffineMap *map, SmallVectorImpl<Value> *operands) {
1431  canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1432 }
1433 
1435  IntegerSet *set, SmallVectorImpl<Value> *operands) {
1436  canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1437 }
1438 
1439 namespace {
1440 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1441 /// maps that supply results into them.
1442 ///
1443 template <typename AffineOpTy>
1444 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1446 
1447  /// Replace the affine op with another instance of it with the supplied
1448  /// map and mapOperands.
1449  void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1450  AffineMap map, ArrayRef<Value> mapOperands) const;
1451 
1452  LogicalResult matchAndRewrite(AffineOpTy affineOp,
1453  PatternRewriter &rewriter) const override {
1454  static_assert(
1455  llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1456  AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1457  AffineVectorStoreOp, AffineVectorLoadOp>::value,
1458  "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1459  "expected");
1460  auto map = affineOp.getAffineMap();
1461  AffineMap oldMap = map;
1462  auto oldOperands = affineOp.getMapOperands();
1463  SmallVector<Value, 8> resultOperands(oldOperands);
1464  composeAffineMapAndOperands(&map, &resultOperands);
1465  canonicalizeMapAndOperands(&map, &resultOperands);
1466  simplifyMapWithOperands(map, resultOperands);
1467  if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1468  resultOperands.begin()))
1469  return failure();
1470 
1471  replaceAffineOp(rewriter, affineOp, map, resultOperands);
1472  return success();
1473  }
1474 };
1475 
1476 // Specialize the template to account for the different build signatures for
1477 // affine load, store, and apply ops.
1478 template <>
1479 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1480  PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1481  ArrayRef<Value> mapOperands) const {
1482  rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1483  mapOperands);
1484 }
1485 template <>
1486 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1487  PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1488  ArrayRef<Value> mapOperands) const {
1489  rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1490  prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1491  prefetch.getLocalityHint(), prefetch.getIsDataCache());
1492 }
1493 template <>
1494 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1495  PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1496  ArrayRef<Value> mapOperands) const {
1497  rewriter.replaceOpWithNewOp<AffineStoreOp>(
1498  store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1499 }
1500 template <>
1501 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1502  PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1503  ArrayRef<Value> mapOperands) const {
1504  rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1505  vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1506  mapOperands);
1507 }
1508 template <>
1509 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1510  PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1511  ArrayRef<Value> mapOperands) const {
1512  rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1513  vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1514  mapOperands);
1515 }
1516 
1517 // Generic version for ops that don't have extra operands.
1518 template <typename AffineOpTy>
1519 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1520  PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1521  ArrayRef<Value> mapOperands) const {
1522  rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1523 }
1524 } // namespace
1525 
1526 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1527  MLIRContext *context) {
1528  results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1529 }
1530 
1531 //===----------------------------------------------------------------------===//
1532 // AffineDmaStartOp
1533 //===----------------------------------------------------------------------===//
1534 
1535 // TODO: Check that map operands are loop IVs or symbols.
1536 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1537  Value srcMemRef, AffineMap srcMap,
1538  ValueRange srcIndices, Value destMemRef,
1539  AffineMap dstMap, ValueRange destIndices,
1540  Value tagMemRef, AffineMap tagMap,
1541  ValueRange tagIndices, Value numElements,
1542  Value stride, Value elementsPerStride) {
1543  result.addOperands(srcMemRef);
1544  result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1545  result.addOperands(srcIndices);
1546  result.addOperands(destMemRef);
1547  result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1548  result.addOperands(destIndices);
1549  result.addOperands(tagMemRef);
1550  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1551  result.addOperands(tagIndices);
1552  result.addOperands(numElements);
1553  if (stride) {
1554  result.addOperands({stride, elementsPerStride});
1555  }
1556 }
1557 
1559  p << " " << getSrcMemRef() << '[';
1560  p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1561  p << "], " << getDstMemRef() << '[';
1562  p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1563  p << "], " << getTagMemRef() << '[';
1564  p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1565  p << "], " << getNumElements();
1566  if (isStrided()) {
1567  p << ", " << getStride();
1568  p << ", " << getNumElementsPerStride();
1569  }
1570  p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1571  << getTagMemRefType();
1572 }
1573 
1574 // Parse AffineDmaStartOp.
1575 // Ex:
1576 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1577 // %stride, %num_elt_per_stride
1578 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1579 //
1581  OperationState &result) {
1582  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1583  AffineMapAttr srcMapAttr;
1585  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1586  AffineMapAttr dstMapAttr;
1588  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1589  AffineMapAttr tagMapAttr;
1591  OpAsmParser::UnresolvedOperand numElementsInfo;
1593 
1594  SmallVector<Type, 3> types;
1595  auto indexType = parser.getBuilder().getIndexType();
1596 
1597  // Parse and resolve the following list of operands:
1598  // *) dst memref followed by its affine maps operands (in square brackets).
1599  // *) src memref followed by its affine map operands (in square brackets).
1600  // *) tag memref followed by its affine map operands (in square brackets).
1601  // *) number of elements transferred by DMA operation.
1602  if (parser.parseOperand(srcMemRefInfo) ||
1603  parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1604  getSrcMapAttrStrName(),
1605  result.attributes) ||
1606  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1607  parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1608  getDstMapAttrStrName(),
1609  result.attributes) ||
1610  parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1611  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1612  getTagMapAttrStrName(),
1613  result.attributes) ||
1614  parser.parseComma() || parser.parseOperand(numElementsInfo))
1615  return failure();
1616 
1617  // Parse optional stride and elements per stride.
1618  if (parser.parseTrailingOperandList(strideInfo))
1619  return failure();
1620 
1621  if (!strideInfo.empty() && strideInfo.size() != 2) {
1622  return parser.emitError(parser.getNameLoc(),
1623  "expected two stride related operands");
1624  }
1625  bool isStrided = strideInfo.size() == 2;
1626 
1627  if (parser.parseColonTypeList(types))
1628  return failure();
1629 
1630  if (types.size() != 3)
1631  return parser.emitError(parser.getNameLoc(), "expected three types");
1632 
1633  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1634  parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1635  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1636  parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1637  parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1638  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1639  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1640  return failure();
1641 
1642  if (isStrided) {
1643  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1644  return failure();
1645  }
1646 
1647  // Check that src/dst/tag operand counts match their map.numInputs.
1648  if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1649  dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1650  tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1651  return parser.emitError(parser.getNameLoc(),
1652  "memref operand count not equal to map.numInputs");
1653  return success();
1654 }
1655 
1656 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1657  if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
1658  return emitOpError("expected DMA source to be of memref type");
1659  if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
1660  return emitOpError("expected DMA destination to be of memref type");
1661  if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
1662  return emitOpError("expected DMA tag to be of memref type");
1663 
1664  unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1665  getDstMap().getNumInputs() +
1666  getTagMap().getNumInputs();
1667  if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1668  getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1669  return emitOpError("incorrect number of operands");
1670  }
1671 
1672  Region *scope = getAffineScope(*this);
1673  for (auto idx : getSrcIndices()) {
1674  if (!idx.getType().isIndex())
1675  return emitOpError("src index to dma_start must have 'index' type");
1676  if (!isValidAffineIndexOperand(idx, scope))
1677  return emitOpError(
1678  "src index must be a valid dimension or symbol identifier");
1679  }
1680  for (auto idx : getDstIndices()) {
1681  if (!idx.getType().isIndex())
1682  return emitOpError("dst index to dma_start must have 'index' type");
1683  if (!isValidAffineIndexOperand(idx, scope))
1684  return emitOpError(
1685  "dst index must be a valid dimension or symbol identifier");
1686  }
1687  for (auto idx : getTagIndices()) {
1688  if (!idx.getType().isIndex())
1689  return emitOpError("tag index to dma_start must have 'index' type");
1690  if (!isValidAffineIndexOperand(idx, scope))
1691  return emitOpError(
1692  "tag index must be a valid dimension or symbol identifier");
1693  }
1694  return success();
1695 }
1696 
1697 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1698  SmallVectorImpl<OpFoldResult> &results) {
1699  /// dma_start(memrefcast) -> dma_start
1700  return memref::foldMemRefCast(*this);
1701 }
1702 
1703 void AffineDmaStartOp::getEffects(
1705  &effects) {
1706  effects.emplace_back(MemoryEffects::Read::get(), getSrcMemRef(),
1708  effects.emplace_back(MemoryEffects::Write::get(), getDstMemRef(),
1710  effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
1712 }
1713 
1714 //===----------------------------------------------------------------------===//
1715 // AffineDmaWaitOp
1716 //===----------------------------------------------------------------------===//
1717 
1718 // TODO: Check that map operands are loop IVs or symbols.
1719 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1720  Value tagMemRef, AffineMap tagMap,
1721  ValueRange tagIndices, Value numElements) {
1722  result.addOperands(tagMemRef);
1723  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1724  result.addOperands(tagIndices);
1725  result.addOperands(numElements);
1726 }
1727 
1729  p << " " << getTagMemRef() << '[';
1730  SmallVector<Value, 2> operands(getTagIndices());
1731  p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1732  p << "], ";
1734  p << " : " << getTagMemRef().getType();
1735 }
1736 
1737 // Parse AffineDmaWaitOp.
1738 // Eg:
1739 // affine.dma_wait %tag[%index], %num_elements
1740 // : memref<1 x i32, (d0) -> (d0), 4>
1741 //
1743  OperationState &result) {
1744  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1745  AffineMapAttr tagMapAttr;
1747  Type type;
1748  auto indexType = parser.getBuilder().getIndexType();
1749  OpAsmParser::UnresolvedOperand numElementsInfo;
1750 
1751  // Parse tag memref, its map operands, and dma size.
1752  if (parser.parseOperand(tagMemRefInfo) ||
1753  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1754  getTagMapAttrStrName(),
1755  result.attributes) ||
1756  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1757  parser.parseColonType(type) ||
1758  parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1759  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1760  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1761  return failure();
1762 
1763  if (!llvm::isa<MemRefType>(type))
1764  return parser.emitError(parser.getNameLoc(),
1765  "expected tag to be of memref type");
1766 
1767  if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1768  return parser.emitError(parser.getNameLoc(),
1769  "tag memref operand count != to map.numInputs");
1770  return success();
1771 }
1772 
1773 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1774  if (!llvm::isa<MemRefType>(getOperand(0).getType()))
1775  return emitOpError("expected DMA tag to be of memref type");
1776  Region *scope = getAffineScope(*this);
1777  for (auto idx : getTagIndices()) {
1778  if (!idx.getType().isIndex())
1779  return emitOpError("index to dma_wait must have 'index' type");
1780  if (!isValidAffineIndexOperand(idx, scope))
1781  return emitOpError(
1782  "index must be a valid dimension or symbol identifier");
1783  }
1784  return success();
1785 }
1786 
1787 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1788  SmallVectorImpl<OpFoldResult> &results) {
1789  /// dma_wait(memrefcast) -> dma_wait
1790  return memref::foldMemRefCast(*this);
1791 }
1792 
1793 void AffineDmaWaitOp::getEffects(
1795  &effects) {
1796  effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
1798 }
1799 
1800 //===----------------------------------------------------------------------===//
1801 // AffineForOp
1802 //===----------------------------------------------------------------------===//
1803 
1804 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1805 /// bodyBuilder are empty/null, we include default terminator op.
1806 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1807  ValueRange lbOperands, AffineMap lbMap,
1808  ValueRange ubOperands, AffineMap ubMap, int64_t step,
1809  ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1810  assert(((!lbMap && lbOperands.empty()) ||
1811  lbOperands.size() == lbMap.getNumInputs()) &&
1812  "lower bound operand count does not match the affine map");
1813  assert(((!ubMap && ubOperands.empty()) ||
1814  ubOperands.size() == ubMap.getNumInputs()) &&
1815  "upper bound operand count does not match the affine map");
1816  assert(step > 0 && "step has to be a positive integer constant");
1817 
1818  OpBuilder::InsertionGuard guard(builder);
1819 
1820  // Set variadic segment sizes.
1821  result.addAttribute(
1822  getOperandSegmentSizeAttr(),
1823  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()),
1824  static_cast<int32_t>(ubOperands.size()),
1825  static_cast<int32_t>(iterArgs.size())}));
1826 
1827  for (Value val : iterArgs)
1828  result.addTypes(val.getType());
1829 
1830  // Add an attribute for the step.
1831  result.addAttribute(getStepAttrName(result.name),
1832  builder.getIntegerAttr(builder.getIndexType(), step));
1833 
1834  // Add the lower bound.
1835  result.addAttribute(getLowerBoundMapAttrName(result.name),
1836  AffineMapAttr::get(lbMap));
1837  result.addOperands(lbOperands);
1838 
1839  // Add the upper bound.
1840  result.addAttribute(getUpperBoundMapAttrName(result.name),
1841  AffineMapAttr::get(ubMap));
1842  result.addOperands(ubOperands);
1843 
1844  result.addOperands(iterArgs);
1845  // Create a region and a block for the body. The argument of the region is
1846  // the loop induction variable.
1847  Region *bodyRegion = result.addRegion();
1848  Block *bodyBlock = builder.createBlock(bodyRegion);
1849  Value inductionVar =
1850  bodyBlock->addArgument(builder.getIndexType(), result.location);
1851  for (Value val : iterArgs)
1852  bodyBlock->addArgument(val.getType(), val.getLoc());
1853 
1854  // Create the default terminator if the builder is not provided and if the
1855  // iteration arguments are not provided. Otherwise, leave this to the caller
1856  // because we don't know which values to return from the loop.
1857  if (iterArgs.empty() && !bodyBuilder) {
1858  ensureTerminator(*bodyRegion, builder, result.location);
1859  } else if (bodyBuilder) {
1860  OpBuilder::InsertionGuard guard(builder);
1861  builder.setInsertionPointToStart(bodyBlock);
1862  bodyBuilder(builder, result.location, inductionVar,
1863  bodyBlock->getArguments().drop_front());
1864  }
1865 }
1866 
1867 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1868  int64_t ub, int64_t step, ValueRange iterArgs,
1869  BodyBuilderFn bodyBuilder) {
1870  auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1871  auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1872  return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1873  bodyBuilder);
1874 }
1875 
1876 LogicalResult AffineForOp::verifyRegions() {
1877  // Check that the body defines as single block argument for the induction
1878  // variable.
1879  auto *body = getBody();
1880  if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1881  return emitOpError("expected body to have a single index argument for the "
1882  "induction variable");
1883 
1884  // Verify that the bound operands are valid dimension/symbols.
1885  /// Lower bound.
1886  if (getLowerBoundMap().getNumInputs() > 0)
1888  getLowerBoundMap().getNumDims())))
1889  return failure();
1890  /// Upper bound.
1891  if (getUpperBoundMap().getNumInputs() > 0)
1893  getUpperBoundMap().getNumDims())))
1894  return failure();
1895 
1896  unsigned opNumResults = getNumResults();
1897  if (opNumResults == 0)
1898  return success();
1899 
1900  // If ForOp defines values, check that the number and types of the defined
1901  // values match ForOp initial iter operands and backedge basic block
1902  // arguments.
1903  if (getNumIterOperands() != opNumResults)
1904  return emitOpError(
1905  "mismatch between the number of loop-carried values and results");
1906  if (getNumRegionIterArgs() != opNumResults)
1907  return emitOpError(
1908  "mismatch between the number of basic block args and results");
1909 
1910  return success();
1911 }
1912 
1913 /// Parse a for operation loop bounds.
1914 static ParseResult parseBound(bool isLower, OperationState &result,
1915  OpAsmParser &p) {
1916  // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1917  // the map has multiple results.
1918  bool failedToParsedMinMax =
1919  failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1920 
1921  auto &builder = p.getBuilder();
1922  auto boundAttrStrName =
1923  isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)
1924  : AffineForOp::getUpperBoundMapAttrName(result.name);
1925 
1926  // Parse ssa-id as identity map.
1928  if (p.parseOperandList(boundOpInfos))
1929  return failure();
1930 
1931  if (!boundOpInfos.empty()) {
1932  // Check that only one operand was parsed.
1933  if (boundOpInfos.size() > 1)
1934  return p.emitError(p.getNameLoc(),
1935  "expected only one loop bound operand");
1936 
1937  // TODO: improve error message when SSA value is not of index type.
1938  // Currently it is 'use of value ... expects different type than prior uses'
1939  if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1940  result.operands))
1941  return failure();
1942 
1943  // Create an identity map using symbol id. This representation is optimized
1944  // for storage. Analysis passes may expand it into a multi-dimensional map
1945  // if desired.
1946  AffineMap map = builder.getSymbolIdentityMap();
1947  result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
1948  return success();
1949  }
1950 
1951  // Get the attribute location.
1952  SMLoc attrLoc = p.getCurrentLocation();
1953 
1954  Attribute boundAttr;
1955  if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
1956  result.attributes))
1957  return failure();
1958 
1959  // Parse full form - affine map followed by dim and symbol list.
1960  if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
1961  unsigned currentNumOperands = result.operands.size();
1962  unsigned numDims;
1963  if (parseDimAndSymbolList(p, result.operands, numDims))
1964  return failure();
1965 
1966  auto map = affineMapAttr.getValue();
1967  if (map.getNumDims() != numDims)
1968  return p.emitError(
1969  p.getNameLoc(),
1970  "dim operand count and affine map dim count must match");
1971 
1972  unsigned numDimAndSymbolOperands =
1973  result.operands.size() - currentNumOperands;
1974  if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1975  return p.emitError(
1976  p.getNameLoc(),
1977  "symbol operand count and affine map symbol count must match");
1978 
1979  // If the map has multiple results, make sure that we parsed the min/max
1980  // prefix.
1981  if (map.getNumResults() > 1 && failedToParsedMinMax) {
1982  if (isLower) {
1983  return p.emitError(attrLoc, "lower loop bound affine map with "
1984  "multiple results requires 'max' prefix");
1985  }
1986  return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1987  "results requires 'min' prefix");
1988  }
1989  return success();
1990  }
1991 
1992  // Parse custom assembly form.
1993  if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
1994  result.attributes.pop_back();
1995  result.addAttribute(
1996  boundAttrStrName,
1997  AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
1998  return success();
1999  }
2000 
2001  return p.emitError(
2002  p.getNameLoc(),
2003  "expected valid affine map representation for loop bounds");
2004 }
2005 
2007  auto &builder = parser.getBuilder();
2008  OpAsmParser::Argument inductionVariable;
2009  inductionVariable.type = builder.getIndexType();
2010  // Parse the induction variable followed by '='.
2011  if (parser.parseArgument(inductionVariable) || parser.parseEqual())
2012  return failure();
2013 
2014  // Parse loop bounds.
2015  int64_t numOperands = result.operands.size();
2016  if (parseBound(/*isLower=*/true, result, parser))
2017  return failure();
2018  int64_t numLbOperands = result.operands.size() - numOperands;
2019  if (parser.parseKeyword("to", " between bounds"))
2020  return failure();
2021  numOperands = result.operands.size();
2022  if (parseBound(/*isLower=*/false, result, parser))
2023  return failure();
2024  int64_t numUbOperands = result.operands.size() - numOperands;
2025 
2026  // Parse the optional loop step, we default to 1 if one is not present.
2027  if (parser.parseOptionalKeyword("step")) {
2028  result.addAttribute(
2029  getStepAttrName(result.name),
2030  builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
2031  } else {
2032  SMLoc stepLoc = parser.getCurrentLocation();
2033  IntegerAttr stepAttr;
2034  if (parser.parseAttribute(stepAttr, builder.getIndexType(),
2035  getStepAttrName(result.name).data(),
2036  result.attributes))
2037  return failure();
2038 
2039  if (stepAttr.getValue().isNegative())
2040  return parser.emitError(
2041  stepLoc,
2042  "expected step to be representable as a positive signed integer");
2043  }
2044 
2045  // Parse the optional initial iteration arguments.
2048 
2049  // Induction variable.
2050  regionArgs.push_back(inductionVariable);
2051 
2052  if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
2053  // Parse assignment list and results type list.
2054  if (parser.parseAssignmentList(regionArgs, operands) ||
2055  parser.parseArrowTypeList(result.types))
2056  return failure();
2057  // Resolve input operands.
2058  for (auto argOperandType :
2059  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2060  Type type = std::get<2>(argOperandType);
2061  std::get<0>(argOperandType).type = type;
2062  if (parser.resolveOperand(std::get<1>(argOperandType), type,
2063  result.operands))
2064  return failure();
2065  }
2066  }
2067 
2068  result.addAttribute(
2069  getOperandSegmentSizeAttr(),
2070  builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands),
2071  static_cast<int32_t>(numUbOperands),
2072  static_cast<int32_t>(operands.size())}));
2073 
2074  // Parse the body region.
2075  Region *body = result.addRegion();
2076  if (regionArgs.size() != result.types.size() + 1)
2077  return parser.emitError(
2078  parser.getNameLoc(),
2079  "mismatch between the number of loop-carried values and results");
2080  if (parser.parseRegion(*body, regionArgs))
2081  return failure();
2082 
2083  AffineForOp::ensureTerminator(*body, builder, result.location);
2084 
2085  // Parse the optional attribute list.
2086  return parser.parseOptionalAttrDict(result.attributes);
2087 }
2088 
2089 static void printBound(AffineMapAttr boundMap,
2090  Operation::operand_range boundOperands,
2091  const char *prefix, OpAsmPrinter &p) {
2092  AffineMap map = boundMap.getValue();
2093 
2094  // Check if this bound should be printed using custom assembly form.
2095  // The decision to restrict printing custom assembly form to trivial cases
2096  // comes from the will to roundtrip MLIR binary -> text -> binary in a
2097  // lossless way.
2098  // Therefore, custom assembly form parsing and printing is only supported for
2099  // zero-operand constant maps and single symbol operand identity maps.
2100  if (map.getNumResults() == 1) {
2101  AffineExpr expr = map.getResult(0);
2102 
2103  // Print constant bound.
2104  if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2105  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2106  p << constExpr.getValue();
2107  return;
2108  }
2109  }
2110 
2111  // Print bound that consists of a single SSA symbol if the map is over a
2112  // single symbol.
2113  if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2114  if (dyn_cast<AffineSymbolExpr>(expr)) {
2115  p.printOperand(*boundOperands.begin());
2116  return;
2117  }
2118  }
2119  } else {
2120  // Map has multiple results. Print 'min' or 'max' prefix.
2121  p << prefix << ' ';
2122  }
2123 
2124  // Print the map and its operands.
2125  p << boundMap;
2126  printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
2127  map.getNumDims(), p);
2128 }
2129 
2130 unsigned AffineForOp::getNumIterOperands() {
2131  AffineMap lbMap = getLowerBoundMapAttr().getValue();
2132  AffineMap ubMap = getUpperBoundMapAttr().getValue();
2133 
2134  return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2135 }
2136 
2137 std::optional<MutableArrayRef<OpOperand>>
2138 AffineForOp::getYieldedValuesMutable() {
2139  return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2140 }
2141 
2143  p << ' ';
2144  p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
2145  /*omitType=*/true);
2146  p << " = ";
2147  printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
2148  p << " to ";
2149  printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
2150 
2151  if (getStepAsInt() != 1)
2152  p << " step " << getStepAsInt();
2153 
2154  bool printBlockTerminators = false;
2155  if (getNumIterOperands() > 0) {
2156  p << " iter_args(";
2157  auto regionArgs = getRegionIterArgs();
2158  auto operands = getInits();
2159 
2160  llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2161  p << std::get<0>(it) << " = " << std::get<1>(it);
2162  });
2163  p << ") -> (" << getResultTypes() << ")";
2164  printBlockTerminators = true;
2165  }
2166 
2167  p << ' ';
2168  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2169  printBlockTerminators);
2171  (*this)->getAttrs(),
2172  /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()),
2173  getUpperBoundMapAttrName(getOperation()->getName()),
2174  getStepAttrName(getOperation()->getName()),
2175  getOperandSegmentSizeAttr()});
2176 }
2177 
2178 /// Fold the constant bounds of a loop.
2179 static LogicalResult foldLoopBounds(AffineForOp forOp) {
2180  auto foldLowerOrUpperBound = [&forOp](bool lower) {
2181  // Check to see if each of the operands is the result of a constant. If
2182  // so, get the value. If not, ignore it.
2183  SmallVector<Attribute, 8> operandConstants;
2184  auto boundOperands =
2185  lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2186  for (auto operand : boundOperands) {
2187  Attribute operandCst;
2188  matchPattern(operand, m_Constant(&operandCst));
2189  operandConstants.push_back(operandCst);
2190  }
2191 
2192  AffineMap boundMap =
2193  lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2194  assert(boundMap.getNumResults() >= 1 &&
2195  "bound maps should have at least one result");
2196  SmallVector<Attribute, 4> foldedResults;
2197  if (failed(boundMap.constantFold(operandConstants, foldedResults)))
2198  return failure();
2199 
2200  // Compute the max or min as applicable over the results.
2201  assert(!foldedResults.empty() && "bounds should have at least one result");
2202  auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2203  for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2204  auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2205  maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2206  : llvm::APIntOps::smin(maxOrMin, foldedResult);
2207  }
2208  lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2209  : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2210  return success();
2211  };
2212 
2213  // Try to fold the lower bound.
2214  bool folded = false;
2215  if (!forOp.hasConstantLowerBound())
2216  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
2217 
2218  // Try to fold the upper bound.
2219  if (!forOp.hasConstantUpperBound())
2220  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
2221  return success(folded);
2222 }
2223 
2224 /// Canonicalize the bounds of the given loop.
2225 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2226  SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2227  SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2228 
2229  auto lbMap = forOp.getLowerBoundMap();
2230  auto ubMap = forOp.getUpperBoundMap();
2231  auto prevLbMap = lbMap;
2232  auto prevUbMap = ubMap;
2233 
2234  composeAffineMapAndOperands(&lbMap, &lbOperands);
2235  canonicalizeMapAndOperands(&lbMap, &lbOperands);
2236  simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
2237  simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
2238  lbMap = removeDuplicateExprs(lbMap);
2239 
2240  composeAffineMapAndOperands(&ubMap, &ubOperands);
2241  canonicalizeMapAndOperands(&ubMap, &ubOperands);
2242  ubMap = removeDuplicateExprs(ubMap);
2243 
2244  // Any canonicalization change always leads to updated map(s).
2245  if (lbMap == prevLbMap && ubMap == prevUbMap)
2246  return failure();
2247 
2248  if (lbMap != prevLbMap)
2249  forOp.setLowerBound(lbOperands, lbMap);
2250  if (ubMap != prevUbMap)
2251  forOp.setUpperBound(ubOperands, ubMap);
2252  return success();
2253 }
2254 
2255 namespace {
2256 /// Returns constant trip count in trivial cases.
2257 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2258  int64_t step = forOp.getStepAsInt();
2259  if (!forOp.hasConstantBounds() || step <= 0)
2260  return std::nullopt;
2261  int64_t lb = forOp.getConstantLowerBound();
2262  int64_t ub = forOp.getConstantUpperBound();
2263  return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2264 }
2265 
2266 /// This is a pattern to fold trivially empty loop bodies.
2267 /// TODO: This should be moved into the folding hook.
2268 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
2270 
2271  LogicalResult matchAndRewrite(AffineForOp forOp,
2272  PatternRewriter &rewriter) const override {
2273  // Check that the body only contains a yield.
2274  if (!llvm::hasSingleElement(*forOp.getBody()))
2275  return failure();
2276  if (forOp.getNumResults() == 0)
2277  return success();
2278  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2279  if (tripCount && *tripCount == 0) {
2280  // The initial values of the iteration arguments would be the op's
2281  // results.
2282  rewriter.replaceOp(forOp, forOp.getInits());
2283  return success();
2284  }
2285  SmallVector<Value, 4> replacements;
2286  auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2287  auto iterArgs = forOp.getRegionIterArgs();
2288  bool hasValDefinedOutsideLoop = false;
2289  bool iterArgsNotInOrder = false;
2290  for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2291  Value val = yieldOp.getOperand(i);
2292  auto *iterArgIt = llvm::find(iterArgs, val);
2293  if (iterArgIt == iterArgs.end()) {
2294  // `val` is defined outside of the loop.
2295  assert(forOp.isDefinedOutsideOfLoop(val) &&
2296  "must be defined outside of the loop");
2297  hasValDefinedOutsideLoop = true;
2298  replacements.push_back(val);
2299  } else {
2300  unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2301  if (pos != i)
2302  iterArgsNotInOrder = true;
2303  replacements.push_back(forOp.getInits()[pos]);
2304  }
2305  }
2306  // Bail out when the trip count is unknown and the loop returns any value
2307  // defined outside of the loop or any iterArg out of order.
2308  if (!tripCount.has_value() &&
2309  (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2310  return failure();
2311  // Bail out when the loop iterates more than once and it returns any iterArg
2312  // out of order.
2313  if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2314  return failure();
2315  rewriter.replaceOp(forOp, replacements);
2316  return success();
2317  }
2318 };
2319 } // namespace
2320 
2321 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
2322  MLIRContext *context) {
2323  results.add<AffineForEmptyLoopFolder>(context);
2324 }
2325 
2326 OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2327  assert((point.isParent() || point == getRegion()) && "invalid region point");
2328 
2329  // The initial operands map to the loop arguments after the induction
2330  // variable or are forwarded to the results when the trip count is zero.
2331  return getInits();
2332 }
2333 
2334 void AffineForOp::getSuccessorRegions(
2336  assert((point.isParent() || point == getRegion()) && "expected loop region");
2337  // The loop may typically branch back to its body or to the parent operation.
2338  // If the predecessor is the parent op and the trip count is known to be at
2339  // least one, branch into the body using the iterator arguments. And in cases
2340  // we know the trip count is zero, it can only branch back to its parent.
2341  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2342  if (point.isParent() && tripCount.has_value()) {
2343  if (tripCount.value() > 0) {
2344  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2345  return;
2346  }
2347  if (tripCount.value() == 0) {
2348  regions.push_back(RegionSuccessor(getResults()));
2349  return;
2350  }
2351  }
2352 
2353  // From the loop body, if the trip count is one, we can only branch back to
2354  // the parent.
2355  if (!point.isParent() && tripCount && *tripCount == 1) {
2356  regions.push_back(RegionSuccessor(getResults()));
2357  return;
2358  }
2359 
2360  // In all other cases, the loop may branch back to itself or the parent
2361  // operation.
2362  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2363  regions.push_back(RegionSuccessor(getResults()));
2364 }
2365 
2366 /// Returns true if the affine.for has zero iterations in trivial cases.
2367 static bool hasTrivialZeroTripCount(AffineForOp op) {
2368  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
2369  return tripCount && *tripCount == 0;
2370 }
2371 
2372 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2373  SmallVectorImpl<OpFoldResult> &results) {
2374  bool folded = succeeded(foldLoopBounds(*this));
2375  folded |= succeeded(canonicalizeLoopBounds(*this));
2376  if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
2377  // The initial values of the loop-carried variables (iter_args) are the
2378  // results of the op. But this must be avoided for an affine.for op that
2379  // does not return any results. Since ops that do not return results cannot
2380  // be folded away, we would enter an infinite loop of folds on the same
2381  // affine.for op.
2382  results.assign(getInits().begin(), getInits().end());
2383  folded = true;
2384  }
2385  return success(folded);
2386 }
2387 
2389  return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
2390 }
2391 
2393  return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap());
2394 }
2395 
2396 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2397  assert(lbOperands.size() == map.getNumInputs());
2398  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2399  getLowerBoundOperandsMutable().assign(lbOperands);
2400  setLowerBoundMap(map);
2401 }
2402 
2403 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2404  assert(ubOperands.size() == map.getNumInputs());
2405  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2406  getUpperBoundOperandsMutable().assign(ubOperands);
2407  setUpperBoundMap(map);
2408 }
2409 
2410 bool AffineForOp::hasConstantLowerBound() {
2411  return getLowerBoundMap().isSingleConstant();
2412 }
2413 
2414 bool AffineForOp::hasConstantUpperBound() {
2415  return getUpperBoundMap().isSingleConstant();
2416 }
2417 
2418 int64_t AffineForOp::getConstantLowerBound() {
2419  return getLowerBoundMap().getSingleConstantResult();
2420 }
2421 
2422 int64_t AffineForOp::getConstantUpperBound() {
2423  return getUpperBoundMap().getSingleConstantResult();
2424 }
2425 
2426 void AffineForOp::setConstantLowerBound(int64_t value) {
2427  setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2428 }
2429 
2430 void AffineForOp::setConstantUpperBound(int64_t value) {
2431  setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2432 }
2433 
2434 AffineForOp::operand_range AffineForOp::getControlOperands() {
2435  return {operand_begin(), operand_begin() + getLowerBoundOperands().size() +
2436  getUpperBoundOperands().size()};
2437 }
2438 
2439 bool AffineForOp::matchingBoundOperandList() {
2440  auto lbMap = getLowerBoundMap();
2441  auto ubMap = getUpperBoundMap();
2442  if (lbMap.getNumDims() != ubMap.getNumDims() ||
2443  lbMap.getNumSymbols() != ubMap.getNumSymbols())
2444  return false;
2445 
2446  unsigned numOperands = lbMap.getNumInputs();
2447  for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2448  // Compare Value 's.
2449  if (getOperand(i) != getOperand(numOperands + i))
2450  return false;
2451  }
2452  return true;
2453 }
2454 
2455 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
2456 
2457 std::optional<Value> AffineForOp::getSingleInductionVar() {
2458  return getInductionVar();
2459 }
2460 
2461 std::optional<OpFoldResult> AffineForOp::getSingleLowerBound() {
2462  if (!hasConstantLowerBound())
2463  return std::nullopt;
2464  OpBuilder b(getContext());
2465  return OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()));
2466 }
2467 
2468 std::optional<OpFoldResult> AffineForOp::getSingleStep() {
2469  OpBuilder b(getContext());
2470  return OpFoldResult(b.getI64IntegerAttr(getStepAsInt()));
2471 }
2472 
2473 std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
2474  if (!hasConstantUpperBound())
2475  return std::nullopt;
2476  OpBuilder b(getContext());
2477  return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
2478 }
2479 
2480 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2481  RewriterBase &rewriter, ValueRange newInitOperands,
2482  bool replaceInitOperandUsesInLoop,
2483  const NewYieldValuesFn &newYieldValuesFn) {
2484  // Create a new loop before the existing one, with the extra operands.
2485  OpBuilder::InsertionGuard g(rewriter);
2486  rewriter.setInsertionPoint(getOperation());
2487  auto inits = llvm::to_vector(getInits());
2488  inits.append(newInitOperands.begin(), newInitOperands.end());
2489  AffineForOp newLoop = rewriter.create<AffineForOp>(
2490  getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2491  getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
2492 
2493  // Generate the new yield values and append them to the scf.yield operation.
2494  auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2495  ArrayRef<BlockArgument> newIterArgs =
2496  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2497  {
2498  OpBuilder::InsertionGuard g(rewriter);
2499  rewriter.setInsertionPoint(yieldOp);
2500  SmallVector<Value> newYieldedValues =
2501  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2502  assert(newInitOperands.size() == newYieldedValues.size() &&
2503  "expected as many new yield values as new iter operands");
2504  rewriter.modifyOpInPlace(yieldOp, [&]() {
2505  yieldOp.getOperandsMutable().append(newYieldedValues);
2506  });
2507  }
2508 
2509  // Move the loop body to the new op.
2510  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2511  newLoop.getBody()->getArguments().take_front(
2512  getBody()->getNumArguments()));
2513 
2514  if (replaceInitOperandUsesInLoop) {
2515  // Replace all uses of `newInitOperands` with the corresponding basic block
2516  // arguments.
2517  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2518  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
2519  [&](OpOperand &use) {
2520  Operation *user = use.getOwner();
2521  return newLoop->isProperAncestor(user);
2522  });
2523  }
2524  }
2525 
2526  // Replace the old loop.
2527  rewriter.replaceOp(getOperation(),
2528  newLoop->getResults().take_front(getNumResults()));
2529  return cast<LoopLikeOpInterface>(newLoop.getOperation());
2530 }
2531 
2532 Speculation::Speculatability AffineForOp::getSpeculatability() {
2533  // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2534  // Start and End.
2535  //
2536  // For Step != 1, the loop may not terminate. We can add more smarts here if
2537  // needed.
2538  return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable
2540 }
2541 
2542 /// Returns true if the provided value is the induction variable of a
2543 /// AffineForOp.
2545  return getForInductionVarOwner(val) != AffineForOp();
2546 }
2547 
2549  return getAffineParallelInductionVarOwner(val) != nullptr;
2550 }
2551 
2554 }
2555 
2557  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2558  if (!ivArg || !ivArg.getOwner())
2559  return AffineForOp();
2560  auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
2561  if (auto forOp = dyn_cast<AffineForOp>(containingInst))
2562  // Check to make sure `val` is the induction variable, not an iter_arg.
2563  return forOp.getInductionVar() == val ? forOp : AffineForOp();
2564  return AffineForOp();
2565 }
2566 
2568  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2569  if (!ivArg || !ivArg.getOwner())
2570  return nullptr;
2571  Operation *containingOp = ivArg.getOwner()->getParentOp();
2572  auto parallelOp = dyn_cast<AffineParallelOp>(containingOp);
2573  if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2574  return parallelOp;
2575  return nullptr;
2576 }
2577 
2578 /// Extracts the induction variables from a list of AffineForOps and returns
2579 /// them.
2581  SmallVectorImpl<Value> *ivs) {
2582  ivs->reserve(forInsts.size());
2583  for (auto forInst : forInsts)
2584  ivs->push_back(forInst.getInductionVar());
2585 }
2586 
2589  ivs.reserve(affineOps.size());
2590  for (Operation *op : affineOps) {
2591  // Add constraints from forOp's bounds.
2592  if (auto forOp = dyn_cast<AffineForOp>(op))
2593  ivs.push_back(forOp.getInductionVar());
2594  else if (auto parallelOp = dyn_cast<AffineParallelOp>(op))
2595  for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2596  ivs.push_back(parallelOp.getBody()->getArgument(i));
2597  }
2598 }
2599 
2600 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2601 /// operations.
2602 template <typename BoundListTy, typename LoopCreatorTy>
2604  OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2605  ArrayRef<int64_t> steps,
2606  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2607  LoopCreatorTy &&loopCreatorFn) {
2608  assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2609  assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2610 
2611  // If there are no loops to be constructed, construct the body anyway.
2612  OpBuilder::InsertionGuard guard(builder);
2613  if (lbs.empty()) {
2614  if (bodyBuilderFn)
2615  bodyBuilderFn(builder, loc, ValueRange());
2616  return;
2617  }
2618 
2619  // Create the loops iteratively and store the induction variables.
2621  ivs.reserve(lbs.size());
2622  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2623  // Callback for creating the loop body, always creates the terminator.
2624  auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
2625  ValueRange iterArgs) {
2626  ivs.push_back(iv);
2627  // In the innermost loop, call the body builder.
2628  if (i == e - 1 && bodyBuilderFn) {
2629  OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
2630  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2631  }
2632  nestedBuilder.create<AffineYieldOp>(nestedLoc);
2633  };
2634 
2635  // Delegate actual loop creation to the callback in order to dispatch
2636  // between constant- and variable-bound loops.
2637  auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2638  builder.setInsertionPointToStart(loop.getBody());
2639  }
2640 }
2641 
2642 /// Creates an affine loop from the bounds known to be constants.
2643 static AffineForOp
2645  int64_t ub, int64_t step,
2646  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2647  return builder.create<AffineForOp>(loc, lb, ub, step,
2648  /*iterArgs=*/std::nullopt, bodyBuilderFn);
2649 }
2650 
2651 /// Creates an affine loop from the bounds that may or may not be constants.
2652 static AffineForOp
2654  int64_t step,
2655  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2656  std::optional<int64_t> lbConst = getConstantIntValue(lb);
2657  std::optional<int64_t> ubConst = getConstantIntValue(ub);
2658  if (lbConst && ubConst)
2659  return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
2660  ubConst.value(), step, bodyBuilderFn);
2661  return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
2662  builder.getDimIdentityMap(), step,
2663  /*iterArgs=*/std::nullopt, bodyBuilderFn);
2664 }
2665 
2667  OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
2669  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2670  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2672 }
2673 
2675  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
2676  ArrayRef<int64_t> steps,
2677  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2678  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2680 }
2681 
2682 //===----------------------------------------------------------------------===//
2683 // AffineIfOp
2684 //===----------------------------------------------------------------------===//
2685 
2686 namespace {
2687 /// Remove else blocks that have nothing other than a zero value yield.
2688 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
2690 
2691  LogicalResult matchAndRewrite(AffineIfOp ifOp,
2692  PatternRewriter &rewriter) const override {
2693  if (ifOp.getElseRegion().empty() ||
2694  !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2695  return failure();
2696 
2697  rewriter.startOpModification(ifOp);
2698  rewriter.eraseBlock(ifOp.getElseBlock());
2699  rewriter.finalizeOpModification(ifOp);
2700  return success();
2701  }
2702 };
2703 
2704 /// Removes affine.if cond if the condition is always true or false in certain
2705 /// trivial cases. Promotes the then/else block in the parent operation block.
2706 struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
2708 
2709  LogicalResult matchAndRewrite(AffineIfOp op,
2710  PatternRewriter &rewriter) const override {
2711 
2712  auto isTriviallyFalse = [](IntegerSet iSet) {
2713  return iSet.isEmptyIntegerSet();
2714  };
2715 
2716  auto isTriviallyTrue = [](IntegerSet iSet) {
2717  return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2718  iSet.getConstraint(0) == 0);
2719  };
2720 
2721  IntegerSet affineIfConditions = op.getIntegerSet();
2722  Block *blockToMove;
2723  if (isTriviallyFalse(affineIfConditions)) {
2724  // The absence, or equivalently, the emptiness of the else region need not
2725  // be checked when affine.if is returning results because if an affine.if
2726  // operation is returning results, it always has a non-empty else region.
2727  if (op.getNumResults() == 0 && !op.hasElse()) {
2728  // If the else region is absent, or equivalently, empty, remove the
2729  // affine.if operation (which is not returning any results).
2730  rewriter.eraseOp(op);
2731  return success();
2732  }
2733  blockToMove = op.getElseBlock();
2734  } else if (isTriviallyTrue(affineIfConditions)) {
2735  blockToMove = op.getThenBlock();
2736  } else {
2737  return failure();
2738  }
2739  Operation *blockToMoveTerminator = blockToMove->getTerminator();
2740  // Promote the "blockToMove" block to the parent operation block between the
2741  // prologue and epilogue of "op".
2742  rewriter.inlineBlockBefore(blockToMove, op);
2743  // Replace the "op" operation with the operands of the
2744  // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2745  // the affine.yield operation present in the "blockToMove" block. It has no
2746  // operands when affine.if is not returning results and therefore, in that
2747  // case, replaceOp just erases "op". When affine.if is not returning
2748  // results, the affine.yield operation can be omitted. It gets inserted
2749  // implicitly.
2750  rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
2751  // Erase the "blockToMoveTerminator" operation since it is now in the parent
2752  // operation block, which already has its own terminator.
2753  rewriter.eraseOp(blockToMoveTerminator);
2754  return success();
2755  }
2756 };
2757 } // namespace
2758 
2759 /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
2760 /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
2761 void AffineIfOp::getSuccessorRegions(
2763  // If the predecessor is an AffineIfOp, then branching into both `then` and
2764  // `else` region is valid.
2765  if (point.isParent()) {
2766  regions.reserve(2);
2767  regions.push_back(
2768  RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
2769  // If the "else" region is empty, branch bach into parent.
2770  if (getElseRegion().empty()) {
2771  regions.push_back(getResults());
2772  } else {
2773  regions.push_back(
2774  RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
2775  }
2776  return;
2777  }
2778 
2779  // If the predecessor is the `else`/`then` region, then branching into parent
2780  // op is valid.
2781  regions.push_back(RegionSuccessor(getResults()));
2782 }
2783 
2785  // Verify that we have a condition attribute.
2786  // FIXME: This should be specified in the arguments list in ODS.
2787  auto conditionAttr =
2788  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2789  if (!conditionAttr)
2790  return emitOpError("requires an integer set attribute named 'condition'");
2791 
2792  // Verify that there are enough operands for the condition.
2793  IntegerSet condition = conditionAttr.getValue();
2794  if (getNumOperands() != condition.getNumInputs())
2795  return emitOpError("operand count and condition integer set dimension and "
2796  "symbol count must match");
2797 
2798  // Verify that the operands are valid dimension/symbols.
2799  if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
2800  condition.getNumDims())))
2801  return failure();
2802 
2803  return success();
2804 }
2805 
2807  // Parse the condition attribute set.
2808  IntegerSetAttr conditionAttr;
2809  unsigned numDims;
2810  if (parser.parseAttribute(conditionAttr,
2811  AffineIfOp::getConditionAttrStrName(),
2812  result.attributes) ||
2813  parseDimAndSymbolList(parser, result.operands, numDims))
2814  return failure();
2815 
2816  // Verify the condition operands.
2817  auto set = conditionAttr.getValue();
2818  if (set.getNumDims() != numDims)
2819  return parser.emitError(
2820  parser.getNameLoc(),
2821  "dim operand count and integer set dim count must match");
2822  if (numDims + set.getNumSymbols() != result.operands.size())
2823  return parser.emitError(
2824  parser.getNameLoc(),
2825  "symbol operand count and integer set symbol count must match");
2826 
2827  if (parser.parseOptionalArrowTypeList(result.types))
2828  return failure();
2829 
2830  // Create the regions for 'then' and 'else'. The latter must be created even
2831  // if it remains empty for the validity of the operation.
2832  result.regions.reserve(2);
2833  Region *thenRegion = result.addRegion();
2834  Region *elseRegion = result.addRegion();
2835 
2836  // Parse the 'then' region.
2837  if (parser.parseRegion(*thenRegion, {}, {}))
2838  return failure();
2839  AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
2840  result.location);
2841 
2842  // If we find an 'else' keyword then parse the 'else' region.
2843  if (!parser.parseOptionalKeyword("else")) {
2844  if (parser.parseRegion(*elseRegion, {}, {}))
2845  return failure();
2846  AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
2847  result.location);
2848  }
2849 
2850  // Parse the optional attribute list.
2851  if (parser.parseOptionalAttrDict(result.attributes))
2852  return failure();
2853 
2854  return success();
2855 }
2856 
2858  auto conditionAttr =
2859  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2860  p << " " << conditionAttr;
2861  printDimAndSymbolList(operand_begin(), operand_end(),
2862  conditionAttr.getValue().getNumDims(), p);
2863  p.printOptionalArrowTypeList(getResultTypes());
2864  p << ' ';
2865  p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
2866  /*printBlockTerminators=*/getNumResults());
2867 
2868  // Print the 'else' regions if it has any blocks.
2869  auto &elseRegion = this->getElseRegion();
2870  if (!elseRegion.empty()) {
2871  p << " else ";
2872  p.printRegion(elseRegion,
2873  /*printEntryBlockArgs=*/false,
2874  /*printBlockTerminators=*/getNumResults());
2875  }
2876 
2877  // Print the attribute list.
2878  p.printOptionalAttrDict((*this)->getAttrs(),
2879  /*elidedAttrs=*/getConditionAttrStrName());
2880 }
2881 
2882 IntegerSet AffineIfOp::getIntegerSet() {
2883  return (*this)
2884  ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2885  .getValue();
2886 }
2887 
2888 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
2889  (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
2890 }
2891 
2892 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
2893  setIntegerSet(set);
2894  (*this)->setOperands(operands);
2895 }
2896 
2897 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2898  TypeRange resultTypes, IntegerSet set, ValueRange args,
2899  bool withElseRegion) {
2900  assert(resultTypes.empty() || withElseRegion);
2901  OpBuilder::InsertionGuard guard(builder);
2902 
2903  result.addTypes(resultTypes);
2904  result.addOperands(args);
2905  result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
2906 
2907  Region *thenRegion = result.addRegion();
2908  builder.createBlock(thenRegion);
2909  if (resultTypes.empty())
2910  AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
2911 
2912  Region *elseRegion = result.addRegion();
2913  if (withElseRegion) {
2914  builder.createBlock(elseRegion);
2915  if (resultTypes.empty())
2916  AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
2917  }
2918 }
2919 
2920 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2921  IntegerSet set, ValueRange args, bool withElseRegion) {
2922  AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
2923  withElseRegion);
2924 }
2925 
2926 /// Compose any affine.apply ops feeding into `operands` of the integer set
2927 /// `set` by composing the maps of such affine.apply ops with the integer
2928 /// set constraints.
2930  SmallVectorImpl<Value> &operands) {
2931  // We will simply reuse the API of the map composition by viewing the LHSs of
2932  // the equalities and inequalities of `set` as the affine exprs of an affine
2933  // map. Convert to equivalent map, compose, and convert back to set.
2934  auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(),
2935  set.getConstraints(), set.getContext());
2936  // Check if any composition is possible.
2937  if (llvm::none_of(operands,
2938  [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
2939  return;
2940 
2941  composeAffineMapAndOperands(&map, &operands);
2942  set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
2943  set.getEqFlags());
2944 }
2945 
2946 /// Canonicalize an affine if op's conditional (integer set + operands).
2947 LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2948  auto set = getIntegerSet();
2949  SmallVector<Value, 4> operands(getOperands());
2950  composeSetAndOperands(set, operands);
2951  canonicalizeSetAndOperands(&set, &operands);
2952 
2953  // Check if the canonicalization or composition led to any change.
2954  if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
2955  return failure();
2956 
2957  setConditional(set, operands);
2958  return success();
2959 }
2960 
2961 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2962  MLIRContext *context) {
2963  results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
2964 }
2965 
2966 //===----------------------------------------------------------------------===//
2967 // AffineLoadOp
2968 //===----------------------------------------------------------------------===//
2969 
2970 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2971  AffineMap map, ValueRange operands) {
2972  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2973  result.addOperands(operands);
2974  if (map)
2975  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2976  auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
2977  result.types.push_back(memrefType.getElementType());
2978 }
2979 
2980 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2981  Value memref, AffineMap map, ValueRange mapOperands) {
2982  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2983  result.addOperands(memref);
2984  result.addOperands(mapOperands);
2985  auto memrefType = llvm::cast<MemRefType>(memref.getType());
2986  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2987  result.types.push_back(memrefType.getElementType());
2988 }
2989 
2990 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2991  Value memref, ValueRange indices) {
2992  auto memrefType = llvm::cast<MemRefType>(memref.getType());
2993  int64_t rank = memrefType.getRank();
2994  // Create identity map for memrefs with at least one dimension or () -> ()
2995  // for zero-dimensional memrefs.
2996  auto map =
2997  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2998  build(builder, result, memref, map, indices);
2999 }
3000 
3002  auto &builder = parser.getBuilder();
3003  auto indexTy = builder.getIndexType();
3004 
3005  MemRefType type;
3006  OpAsmParser::UnresolvedOperand memrefInfo;
3007  AffineMapAttr mapAttr;
3009  return failure(
3010  parser.parseOperand(memrefInfo) ||
3011  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3012  AffineLoadOp::getMapAttrStrName(),
3013  result.attributes) ||
3014  parser.parseOptionalAttrDict(result.attributes) ||
3015  parser.parseColonType(type) ||
3016  parser.resolveOperand(memrefInfo, type, result.operands) ||
3017  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3018  parser.addTypeToList(type.getElementType(), result.types));
3019 }
3020 
3022  p << " " << getMemRef() << '[';
3023  if (AffineMapAttr mapAttr =
3024  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3025  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3026  p << ']';
3027  p.printOptionalAttrDict((*this)->getAttrs(),
3028  /*elidedAttrs=*/{getMapAttrStrName()});
3029  p << " : " << getMemRefType();
3030 }
3031 
3032 /// Verify common indexing invariants of affine.load, affine.store,
3033 /// affine.vector_load and affine.vector_store.
3034 static LogicalResult
3035 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
3036  Operation::operand_range mapOperands,
3037  MemRefType memrefType, unsigned numIndexOperands) {
3038  AffineMap map = mapAttr.getValue();
3039  if (map.getNumResults() != memrefType.getRank())
3040  return op->emitOpError("affine map num results must equal memref rank");
3041  if (map.getNumInputs() != numIndexOperands)
3042  return op->emitOpError("expects as many subscripts as affine map inputs");
3043 
3044  Region *scope = getAffineScope(op);
3045  for (auto idx : mapOperands) {
3046  if (!idx.getType().isIndex())
3047  return op->emitOpError("index to load must have 'index' type");
3048  if (!isValidAffineIndexOperand(idx, scope))
3049  return op->emitOpError(
3050  "index must be a valid dimension or symbol identifier");
3051  }
3052 
3053  return success();
3054 }
3055 
3057  auto memrefType = getMemRefType();
3058  if (getType() != memrefType.getElementType())
3059  return emitOpError("result type must match element type of memref");
3060 
3062  getOperation(),
3063  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3064  getMapOperands(), memrefType,
3065  /*numIndexOperands=*/getNumOperands() - 1)))
3066  return failure();
3067 
3068  return success();
3069 }
3070 
3071 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3072  MLIRContext *context) {
3073  results.add<SimplifyAffineOp<AffineLoadOp>>(context);
3074 }
3075 
3076 OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3077  /// load(memrefcast) -> load
3078  if (succeeded(memref::foldMemRefCast(*this)))
3079  return getResult();
3080 
3081  // Fold load from a global constant memref.
3082  auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3083  if (!getGlobalOp)
3084  return {};
3085  // Get to the memref.global defining the symbol.
3086  auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3087  if (!symbolTableOp)
3088  return {};
3089  auto global = dyn_cast_or_null<memref::GlobalOp>(
3090  SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
3091  if (!global)
3092  return {};
3093 
3094  // Check if the global memref is a constant.
3095  auto cstAttr =
3096  llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3097  if (!cstAttr)
3098  return {};
3099  // If it's a splat constant, we can fold irrespective of indices.
3100  if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3101  return splatAttr.getSplatValue<Attribute>();
3102  // Otherwise, we can fold only if we know the indices.
3103  if (!getAffineMap().isConstant())
3104  return {};
3105  auto indices = llvm::to_vector<4>(
3106  llvm::map_range(getAffineMap().getConstantResults(),
3107  [](int64_t v) -> uint64_t { return v; }));
3108  return cstAttr.getValues<Attribute>()[indices];
3109 }
3110 
3111 //===----------------------------------------------------------------------===//
3112 // AffineStoreOp
3113 //===----------------------------------------------------------------------===//
3114 
3115 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3116  Value valueToStore, Value memref, AffineMap map,
3117  ValueRange mapOperands) {
3118  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3119  result.addOperands(valueToStore);
3120  result.addOperands(memref);
3121  result.addOperands(mapOperands);
3122  result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3123 }
3124 
3125 // Use identity map.
3126 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3127  Value valueToStore, Value memref,
3128  ValueRange indices) {
3129  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3130  int64_t rank = memrefType.getRank();
3131  // Create identity map for memrefs with at least one dimension or () -> ()
3132  // for zero-dimensional memrefs.
3133  auto map =
3134  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3135  build(builder, result, valueToStore, memref, map, indices);
3136 }
3137 
3139  auto indexTy = parser.getBuilder().getIndexType();
3140 
3141  MemRefType type;
3142  OpAsmParser::UnresolvedOperand storeValueInfo;
3143  OpAsmParser::UnresolvedOperand memrefInfo;
3144  AffineMapAttr mapAttr;
3146  return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3147  parser.parseOperand(memrefInfo) ||
3148  parser.parseAffineMapOfSSAIds(
3149  mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3150  result.attributes) ||
3151  parser.parseOptionalAttrDict(result.attributes) ||
3152  parser.parseColonType(type) ||
3153  parser.resolveOperand(storeValueInfo, type.getElementType(),
3154  result.operands) ||
3155  parser.resolveOperand(memrefInfo, type, result.operands) ||
3156  parser.resolveOperands(mapOperands, indexTy, result.operands));
3157 }
3158 
3160  p << " " << getValueToStore();
3161  p << ", " << getMemRef() << '[';
3162  if (AffineMapAttr mapAttr =
3163  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3164  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3165  p << ']';
3166  p.printOptionalAttrDict((*this)->getAttrs(),
3167  /*elidedAttrs=*/{getMapAttrStrName()});
3168  p << " : " << getMemRefType();
3169 }
3170 
3172  // The value to store must have the same type as memref element type.
3173  auto memrefType = getMemRefType();
3174  if (getValueToStore().getType() != memrefType.getElementType())
3175  return emitOpError(
3176  "value to store must have the same type as memref element type");
3177 
3179  getOperation(),
3180  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3181  getMapOperands(), memrefType,
3182  /*numIndexOperands=*/getNumOperands() - 2)))
3183  return failure();
3184 
3185  return success();
3186 }
3187 
3188 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3189  MLIRContext *context) {
3190  results.add<SimplifyAffineOp<AffineStoreOp>>(context);
3191 }
3192 
3193 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3194  SmallVectorImpl<OpFoldResult> &results) {
3195  /// store(memrefcast) -> store
3196  return memref::foldMemRefCast(*this, getValueToStore());
3197 }
3198 
3199 //===----------------------------------------------------------------------===//
3200 // AffineMinMaxOpBase
3201 //===----------------------------------------------------------------------===//
3202 
3203 template <typename T>
3205  // Verify that operand count matches affine map dimension and symbol count.
3206  if (op.getNumOperands() !=
3207  op.getMap().getNumDims() + op.getMap().getNumSymbols())
3208  return op.emitOpError(
3209  "operand count and affine map dimension and symbol count must match");
3210  return success();
3211 }
3212 
3213 template <typename T>
3214 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3215  p << ' ' << op->getAttr(T::getMapAttrStrName());
3216  auto operands = op.getOperands();
3217  unsigned numDims = op.getMap().getNumDims();
3218  p << '(' << operands.take_front(numDims) << ')';
3219 
3220  if (operands.size() != numDims)
3221  p << '[' << operands.drop_front(numDims) << ']';
3223  /*elidedAttrs=*/{T::getMapAttrStrName()});
3224 }
3225 
3226 template <typename T>
3228  OperationState &result) {
3229  auto &builder = parser.getBuilder();
3230  auto indexType = builder.getIndexType();
3233  AffineMapAttr mapAttr;
3234  return failure(
3235  parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3236  result.attributes) ||
3237  parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) ||
3238  parser.parseOperandList(symInfos,
3240  parser.parseOptionalAttrDict(result.attributes) ||
3241  parser.resolveOperands(dimInfos, indexType, result.operands) ||
3242  parser.resolveOperands(symInfos, indexType, result.operands) ||
3243  parser.addTypeToList(indexType, result.types));
3244 }
3245 
3246 /// Fold an affine min or max operation with the given operands. The operand
3247 /// list may contain nulls, which are interpreted as the operand not being a
3248 /// constant.
3249 template <typename T>
3251  static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3252  "expected affine min or max op");
3253 
3254  // Fold the affine map.
3255  // TODO: Fold more cases:
3256  // min(some_affine, some_affine + constant, ...), etc.
3257  SmallVector<int64_t, 2> results;
3258  auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3259 
3260  if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3261  return op.getOperand(0);
3262 
3263  // If some of the map results are not constant, try changing the map in-place.
3264  if (results.empty()) {
3265  // If the map is the same, report that folding did not happen.
3266  if (foldedMap == op.getMap())
3267  return {};
3268  op->setAttr("map", AffineMapAttr::get(foldedMap));
3269  return op.getResult();
3270  }
3271 
3272  // Otherwise, completely fold the op into a constant.
3273  auto resultIt = std::is_same<T, AffineMinOp>::value
3274  ? llvm::min_element(results)
3275  : llvm::max_element(results);
3276  if (resultIt == results.end())
3277  return {};
3278  return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3279 }
3280 
3281 /// Remove duplicated expressions in affine min/max ops.
3282 template <typename T>
3285 
3287  PatternRewriter &rewriter) const override {
3288  AffineMap oldMap = affineOp.getAffineMap();
3289 
3290  SmallVector<AffineExpr, 4> newExprs;
3291  for (AffineExpr expr : oldMap.getResults()) {
3292  // This is a linear scan over newExprs, but it should be fine given that
3293  // we typically just have a few expressions per op.
3294  if (!llvm::is_contained(newExprs, expr))
3295  newExprs.push_back(expr);
3296  }
3297 
3298  if (newExprs.size() == oldMap.getNumResults())
3299  return failure();
3300 
3301  auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
3302  newExprs, rewriter.getContext());
3303  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3304 
3305  return success();
3306  }
3307 };
3308 
3309 /// Merge an affine min/max op to its consumers if its consumer is also an
3310 /// affine min/max op.
3311 ///
3312 /// This pattern requires the producer affine min/max op is bound to a
3313 /// dimension/symbol that is used as a standalone expression in the consumer
3314 /// affine op's map.
3315 ///
3316 /// For example, a pattern like the following:
3317 ///
3318 /// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3319 /// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3320 ///
3321 /// Can be turned into:
3322 ///
3323 /// %1 = affine.min affine_map<
3324 /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3325 template <typename T>
3328 
3330  PatternRewriter &rewriter) const override {
3331  AffineMap oldMap = affineOp.getAffineMap();
3332  ValueRange dimOperands =
3333  affineOp.getMapOperands().take_front(oldMap.getNumDims());
3334  ValueRange symOperands =
3335  affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3336 
3337  auto newDimOperands = llvm::to_vector<8>(dimOperands);
3338  auto newSymOperands = llvm::to_vector<8>(symOperands);
3339  SmallVector<AffineExpr, 4> newExprs;
3340  SmallVector<T, 4> producerOps;
3341 
3342  // Go over each expression to see whether it's a single dimension/symbol
3343  // with the corresponding operand which is the result of another affine
3344  // min/max op. If So it can be merged into this affine op.
3345  for (AffineExpr expr : oldMap.getResults()) {
3346  if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3347  Value symValue = symOperands[symExpr.getPosition()];
3348  if (auto producerOp = symValue.getDefiningOp<T>()) {
3349  producerOps.push_back(producerOp);
3350  continue;
3351  }
3352  } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3353  Value dimValue = dimOperands[dimExpr.getPosition()];
3354  if (auto producerOp = dimValue.getDefiningOp<T>()) {
3355  producerOps.push_back(producerOp);
3356  continue;
3357  }
3358  }
3359  // For the above cases we will remove the expression by merging the
3360  // producer affine min/max's affine expressions. Otherwise we need to
3361  // keep the existing expression.
3362  newExprs.push_back(expr);
3363  }
3364 
3365  if (producerOps.empty())
3366  return failure();
3367 
3368  unsigned numUsedDims = oldMap.getNumDims();
3369  unsigned numUsedSyms = oldMap.getNumSymbols();
3370 
3371  // Now go over all producer affine ops and merge their expressions.
3372  for (T producerOp : producerOps) {
3373  AffineMap producerMap = producerOp.getAffineMap();
3374  unsigned numProducerDims = producerMap.getNumDims();
3375  unsigned numProducerSyms = producerMap.getNumSymbols();
3376 
3377  // Collect all dimension/symbol values.
3378  ValueRange dimValues =
3379  producerOp.getMapOperands().take_front(numProducerDims);
3380  ValueRange symValues =
3381  producerOp.getMapOperands().take_back(numProducerSyms);
3382  newDimOperands.append(dimValues.begin(), dimValues.end());
3383  newSymOperands.append(symValues.begin(), symValues.end());
3384 
3385  // For expressions we need to shift to avoid overlap.
3386  for (AffineExpr expr : producerMap.getResults()) {
3387  newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3388  .shiftSymbols(numProducerSyms, numUsedSyms));
3389  }
3390 
3391  numUsedDims += numProducerDims;
3392  numUsedSyms += numProducerSyms;
3393  }
3394 
3395  auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
3396  rewriter.getContext());
3397  auto newOperands =
3398  llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3399  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3400 
3401  return success();
3402  }
3403 };
3404 
3405 /// Canonicalize the result expression order of an affine map and return success
3406 /// if the order changed.
3407 ///
3408 /// The function flattens the map's affine expressions to coefficient arrays and
3409 /// sorts them in lexicographic order. A coefficient array contains a multiplier
3410 /// for every dimension/symbol and a constant term. The canonicalization fails
3411 /// if a result expression is not pure or if the flattening requires local
3412 /// variables that, unlike dimensions and symbols, have no global order.
3414  SmallVector<SmallVector<int64_t>> flattenedExprs;
3415  for (const AffineExpr &resultExpr : map.getResults()) {
3416  // Fail if the expression is not pure.
3417  if (!resultExpr.isPureAffine())
3418  return failure();
3419 
3420  SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3421  auto flattenResult = flattener.walkPostOrder(resultExpr);
3422  if (failed(flattenResult))
3423  return failure();
3424 
3425  // Fail if the flattened expression has local variables.
3426  if (flattener.operandExprStack.back().size() !=
3427  map.getNumDims() + map.getNumSymbols() + 1)
3428  return failure();
3429 
3430  flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3431  flattener.operandExprStack.back().end());
3432  }
3433 
3434  // Fail if sorting is not necessary.
3435  if (llvm::is_sorted(flattenedExprs))
3436  return failure();
3437 
3438  // Reorder the result expressions according to their flattened form.
3439  SmallVector<unsigned> resultPermutation =
3440  llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults()));
3441  llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3442  return flattenedExprs[lhs] < flattenedExprs[rhs];
3443  });
3444  SmallVector<AffineExpr> newExprs;
3445  for (unsigned idx : resultPermutation)
3446  newExprs.push_back(map.getResult(idx));
3447 
3448  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs,
3449  map.getContext());
3450  return success();
3451 }
3452 
3453 /// Canonicalize the affine map result expression order of an affine min/max
3454 /// operation.
3455 ///
3456 /// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3457 /// expressions and replaces the operation if the order changed.
3458 ///
3459 /// For example, the following operation:
3460 ///
3461 /// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3462 ///
3463 /// Turns into:
3464 ///
3465 /// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3466 template <typename T>
3469 
3471  PatternRewriter &rewriter) const override {
3472  AffineMap map = affineOp.getAffineMap();
3474  return failure();
3475  rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3476  return success();
3477  }
3478 };
3479 
3480 template <typename T>
3483 
3485  PatternRewriter &rewriter) const override {
3486  if (affineOp.getMap().getNumResults() != 1)
3487  return failure();
3488  rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3489  affineOp.getOperands());
3490  return success();
3491  }
3492 };
3493 
3494 //===----------------------------------------------------------------------===//
3495 // AffineMinOp
3496 //===----------------------------------------------------------------------===//
3497 //
3498 // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3499 //
3500 
3501 OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3502  return foldMinMaxOp(*this, adaptor.getOperands());
3503 }
3504 
3505 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3506  MLIRContext *context) {
3509  MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3511  context);
3512 }
3513 
3515 
3517  return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3518 }
3519 
3520 void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3521 
3522 //===----------------------------------------------------------------------===//
3523 // AffineMaxOp
3524 //===----------------------------------------------------------------------===//
3525 //
3526 // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3527 //
3528 
3529 OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3530  return foldMinMaxOp(*this, adaptor.getOperands());
3531 }
3532 
3533 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3534  MLIRContext *context) {
3537  MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3539  context);
3540 }
3541 
3543 
3545  return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3546 }
3547 
3548 void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3549 
3550 //===----------------------------------------------------------------------===//
3551 // AffinePrefetchOp
3552 //===----------------------------------------------------------------------===//
3553 
3554 //
3555 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3556 //
3558  OperationState &result) {
3559  auto &builder = parser.getBuilder();
3560  auto indexTy = builder.getIndexType();
3561 
3562  MemRefType type;
3563  OpAsmParser::UnresolvedOperand memrefInfo;
3564  IntegerAttr hintInfo;
3565  auto i32Type = parser.getBuilder().getIntegerType(32);
3566  StringRef readOrWrite, cacheType;
3567 
3568  AffineMapAttr mapAttr;
3570  if (parser.parseOperand(memrefInfo) ||
3571  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3572  AffinePrefetchOp::getMapAttrStrName(),
3573  result.attributes) ||
3574  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3575  parser.parseComma() || parser.parseKeyword("locality") ||
3576  parser.parseLess() ||
3577  parser.parseAttribute(hintInfo, i32Type,
3578  AffinePrefetchOp::getLocalityHintAttrStrName(),
3579  result.attributes) ||
3580  parser.parseGreater() || parser.parseComma() ||
3581  parser.parseKeyword(&cacheType) ||
3582  parser.parseOptionalAttrDict(result.attributes) ||
3583  parser.parseColonType(type) ||
3584  parser.resolveOperand(memrefInfo, type, result.operands) ||
3585  parser.resolveOperands(mapOperands, indexTy, result.operands))
3586  return failure();
3587 
3588  if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
3589  return parser.emitError(parser.getNameLoc(),
3590  "rw specifier has to be 'read' or 'write'");
3591  result.addAttribute(
3592  AffinePrefetchOp::getIsWriteAttrStrName(),
3593  parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
3594 
3595  if (!cacheType.equals("data") && !cacheType.equals("instr"))
3596  return parser.emitError(parser.getNameLoc(),
3597  "cache type has to be 'data' or 'instr'");
3598 
3599  result.addAttribute(
3600  AffinePrefetchOp::getIsDataCacheAttrStrName(),
3601  parser.getBuilder().getBoolAttr(cacheType.equals("data")));
3602 
3603  return success();
3604 }
3605 
3607  p << " " << getMemref() << '[';
3608  AffineMapAttr mapAttr =
3609  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3610  if (mapAttr)
3611  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3612  p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3613  << "locality<" << getLocalityHint() << ">, "
3614  << (getIsDataCache() ? "data" : "instr");
3616  (*this)->getAttrs(),
3617  /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3618  getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3619  p << " : " << getMemRefType();
3620 }
3621 
3623  auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3624  if (mapAttr) {
3625  AffineMap map = mapAttr.getValue();
3626  if (map.getNumResults() != getMemRefType().getRank())
3627  return emitOpError("affine.prefetch affine map num results must equal"
3628  " memref rank");
3629  if (map.getNumInputs() + 1 != getNumOperands())
3630  return emitOpError("too few operands");
3631  } else {
3632  if (getNumOperands() != 1)
3633  return emitOpError("too few operands");
3634  }
3635 
3636  Region *scope = getAffineScope(*this);
3637  for (auto idx : getMapOperands()) {
3638  if (!isValidAffineIndexOperand(idx, scope))
3639  return emitOpError(
3640  "index must be a valid dimension or symbol identifier");
3641  }
3642  return success();
3643 }
3644 
3645 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3646  MLIRContext *context) {
3647  // prefetch(memrefcast) -> prefetch
3648  results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3649 }
3650 
3651 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3652  SmallVectorImpl<OpFoldResult> &results) {
3653  /// prefetch(memrefcast) -> prefetch
3654  return memref::foldMemRefCast(*this);
3655 }
3656 
3657 //===----------------------------------------------------------------------===//
3658 // AffineParallelOp
3659 //===----------------------------------------------------------------------===//
3660 
3661 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3662  TypeRange resultTypes,
3663  ArrayRef<arith::AtomicRMWKind> reductions,
3664  ArrayRef<int64_t> ranges) {
3665  SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
3666  auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3667  return builder.getConstantAffineMap(value);
3668  }));
3669  SmallVector<int64_t> steps(ranges.size(), 1);
3670  build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
3671  /*ubArgs=*/{}, steps);
3672 }
3673 
3674 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3675  TypeRange resultTypes,
3676  ArrayRef<arith::AtomicRMWKind> reductions,
3677  ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
3678  ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
3679  ArrayRef<int64_t> steps) {
3680  assert(llvm::all_of(lbMaps,
3681  [lbMaps](AffineMap m) {
3682  return m.getNumDims() == lbMaps[0].getNumDims() &&
3683  m.getNumSymbols() == lbMaps[0].getNumSymbols();
3684  }) &&
3685  "expected all lower bounds maps to have the same number of dimensions "
3686  "and symbols");
3687  assert(llvm::all_of(ubMaps,
3688  [ubMaps](AffineMap m) {
3689  return m.getNumDims() == ubMaps[0].getNumDims() &&
3690  m.getNumSymbols() == ubMaps[0].getNumSymbols();
3691  }) &&
3692  "expected all upper bounds maps to have the same number of dimensions "
3693  "and symbols");
3694  assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3695  "expected lower bound maps to have as many inputs as lower bound "
3696  "operands");
3697  assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3698  "expected upper bound maps to have as many inputs as upper bound "
3699  "operands");
3700 
3701  OpBuilder::InsertionGuard guard(builder);
3702  result.addTypes(resultTypes);
3703 
3704  // Convert the reductions to integer attributes.
3705  SmallVector<Attribute, 4> reductionAttrs;
3706  for (arith::AtomicRMWKind reduction : reductions)
3707  reductionAttrs.push_back(
3708  builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
3709  result.addAttribute(getReductionsAttrStrName(),
3710  builder.getArrayAttr(reductionAttrs));
3711 
3712  // Concatenates maps defined in the same input space (same dimensions and
3713  // symbols), assumes there is at least one map.
3714  auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
3715  SmallVectorImpl<int32_t> &groups) {
3716  if (maps.empty())
3717  return AffineMap::get(builder.getContext());
3719  groups.reserve(groups.size() + maps.size());
3720  exprs.reserve(maps.size());
3721  for (AffineMap m : maps) {
3722  llvm::append_range(exprs, m.getResults());
3723  groups.push_back(m.getNumResults());
3724  }
3725  return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3726  maps[0].getContext());
3727  };
3728 
3729  // Set up the bounds.
3730  SmallVector<int32_t> lbGroups, ubGroups;
3731  AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3732  AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3733  result.addAttribute(getLowerBoundsMapAttrStrName(),
3734  AffineMapAttr::get(lbMap));
3735  result.addAttribute(getLowerBoundsGroupsAttrStrName(),
3736  builder.getI32TensorAttr(lbGroups));
3737  result.addAttribute(getUpperBoundsMapAttrStrName(),
3738  AffineMapAttr::get(ubMap));
3739  result.addAttribute(getUpperBoundsGroupsAttrStrName(),
3740  builder.getI32TensorAttr(ubGroups));
3741  result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
3742  result.addOperands(lbArgs);
3743  result.addOperands(ubArgs);
3744 
3745  // Create a region and a block for the body.
3746  auto *bodyRegion = result.addRegion();
3747  Block *body = builder.createBlock(bodyRegion);
3748 
3749  // Add all the block arguments.
3750  for (unsigned i = 0, e = steps.size(); i < e; ++i)
3751  body->addArgument(IndexType::get(builder.getContext()), result.location);
3752  if (resultTypes.empty())
3753  ensureTerminator(*bodyRegion, builder, result.location);
3754 }
3755 
3756 SmallVector<Region *> AffineParallelOp::getLoopRegions() {
3757  return {&getRegion()};
3758 }
3759 
3760 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
3761 
3762 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3763  return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3764 }
3765 
3766 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3767  return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3768 }
3769 
3770 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
3771  auto values = getLowerBoundsGroups().getValues<int32_t>();
3772  unsigned start = 0;
3773  for (unsigned i = 0; i < pos; ++i)
3774  start += values[i];
3775  return getLowerBoundsMap().getSliceMap(start, values[pos]);
3776 }
3777 
3778 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
3779  auto values = getUpperBoundsGroups().getValues<int32_t>();
3780  unsigned start = 0;
3781  for (unsigned i = 0; i < pos; ++i)
3782  start += values[i];
3783  return getUpperBoundsMap().getSliceMap(start, values[pos]);
3784 }
3785 
3786 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
3787  return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3788 }
3789 
3790 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
3791  return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3792 }
3793 
3794 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3795  if (hasMinMaxBounds())
3796  return std::nullopt;
3797 
3798  // Try to convert all the ranges to constant expressions.
3800  AffineValueMap rangesValueMap;
3801  AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3802  &rangesValueMap);
3803  out.reserve(rangesValueMap.getNumResults());
3804  for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
3805  auto expr = rangesValueMap.getResult(i);
3806  auto cst = dyn_cast<AffineConstantExpr>(expr);
3807  if (!cst)
3808  return std::nullopt;
3809  out.push_back(cst.getValue());
3810  }
3811  return out;
3812 }
3813 
3814 Block *AffineParallelOp::getBody() { return &getRegion().front(); }
3815 
3816 OpBuilder AffineParallelOp::getBodyBuilder() {
3817  return OpBuilder(getBody(), std::prev(getBody()->end()));
3818 }
3819 
3820 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
3821  assert(lbOperands.size() == map.getNumInputs() &&
3822  "operands to map must match number of inputs");
3823 
3824  auto ubOperands = getUpperBoundsOperands();
3825 
3826  SmallVector<Value, 4> newOperands(lbOperands);
3827  newOperands.append(ubOperands.begin(), ubOperands.end());
3828  (*this)->setOperands(newOperands);
3829 
3830  setLowerBoundsMapAttr(AffineMapAttr::get(map));
3831 }
3832 
3833 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
3834  assert(ubOperands.size() == map.getNumInputs() &&
3835  "operands to map must match number of inputs");
3836 
3837  SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
3838  newOperands.append(ubOperands.begin(), ubOperands.end());
3839  (*this)->setOperands(newOperands);
3840 
3841  setUpperBoundsMapAttr(AffineMapAttr::get(map));
3842 }
3843 
3844 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
3845  setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3846 }
3847 
3848 // check whether resultType match op or not in affine.parallel
3849 static bool isResultTypeMatchAtomicRMWKind(Type resultType,
3850  arith::AtomicRMWKind op) {
3851  switch (op) {
3852  case arith::AtomicRMWKind::addf:
3853  return isa<FloatType>(resultType);
3854  case arith::AtomicRMWKind::addi:
3855  return isa<IntegerType>(resultType);
3856  case arith::AtomicRMWKind::assign:
3857  return true;
3858  case arith::AtomicRMWKind::mulf:
3859  return isa<FloatType>(resultType);
3860  case arith::AtomicRMWKind::muli:
3861  return isa<IntegerType>(resultType);
3862  case arith::AtomicRMWKind::maximumf:
3863  return isa<FloatType>(resultType);
3864  case arith::AtomicRMWKind::minimumf:
3865  return isa<FloatType>(resultType);
3866  case arith::AtomicRMWKind::maxs: {
3867  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3868  return intType && intType.isSigned();
3869  }
3870  case arith::AtomicRMWKind::mins: {
3871  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3872  return intType && intType.isSigned();
3873  }
3874  case arith::AtomicRMWKind::maxu: {
3875  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3876  return intType && intType.isUnsigned();
3877  }
3878  case arith::AtomicRMWKind::minu: {
3879  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3880  return intType && intType.isUnsigned();
3881  }
3882  case arith::AtomicRMWKind::ori:
3883  return isa<IntegerType>(resultType);
3884  case arith::AtomicRMWKind::andi:
3885  return isa<IntegerType>(resultType);
3886  default:
3887  return false;
3888  }
3889 }
3890 
3892  auto numDims = getNumDims();
3893  if (getLowerBoundsGroups().getNumElements() != numDims ||
3894  getUpperBoundsGroups().getNumElements() != numDims ||
3895  getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3896  return emitOpError() << "the number of region arguments ("
3897  << getBody()->getNumArguments()
3898  << ") and the number of map groups for lower ("
3899  << getLowerBoundsGroups().getNumElements()
3900  << ") and upper bound ("
3901  << getUpperBoundsGroups().getNumElements()
3902  << "), and the number of steps (" << getSteps().size()
3903  << ") must all match";
3904  }
3905 
3906  unsigned expectedNumLBResults = 0;
3907  for (APInt v : getLowerBoundsGroups())
3908  expectedNumLBResults += v.getZExtValue();
3909  if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
3910  return emitOpError() << "expected lower bounds map to have "
3911  << expectedNumLBResults << " results";
3912  unsigned expectedNumUBResults = 0;
3913  for (APInt v : getUpperBoundsGroups())
3914  expectedNumUBResults += v.getZExtValue();
3915  if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
3916  return emitOpError() << "expected upper bounds map to have "
3917  << expectedNumUBResults << " results";
3918 
3919  if (getReductions().size() != getNumResults())
3920  return emitOpError("a reduction must be specified for each output");
3921 
3922  // Verify reduction ops are all valid and each result type matches reduction
3923  // ops
3924  for (auto it : llvm::enumerate((getReductions()))) {
3925  Attribute attr = it.value();
3926  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
3927  if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
3928  return emitOpError("invalid reduction attribute");
3929  auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
3930  if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
3931  return emitOpError("result type cannot match reduction attribute");
3932  }
3933 
3934  // Verify that the bound operands are valid dimension/symbols.
3935  /// Lower bounds.
3936  if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
3937  getLowerBoundsMap().getNumDims())))
3938  return failure();
3939  /// Upper bounds.
3940  if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
3941  getUpperBoundsMap().getNumDims())))
3942  return failure();
3943  return success();
3944 }
3945 
3946 LogicalResult AffineValueMap::canonicalize() {
3947  SmallVector<Value, 4> newOperands{operands};
3948  auto newMap = getAffineMap();
3949  composeAffineMapAndOperands(&newMap, &newOperands);
3950  if (newMap == getAffineMap() && newOperands == operands)
3951  return failure();
3952  reset(newMap, newOperands);
3953  return success();
3954 }
3955 
3956 /// Canonicalize the bounds of the given loop.
3957 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
3958  AffineValueMap lb = op.getLowerBoundsValueMap();
3959  bool lbCanonicalized = succeeded(lb.canonicalize());
3960 
3961  AffineValueMap ub = op.getUpperBoundsValueMap();
3962  bool ubCanonicalized = succeeded(ub.canonicalize());
3963 
3964  // Any canonicalization change always leads to updated map(s).
3965  if (!lbCanonicalized && !ubCanonicalized)
3966  return failure();
3967 
3968  if (lbCanonicalized)
3969  op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
3970  if (ubCanonicalized)
3971  op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
3972 
3973  return success();
3974 }
3975 
3976 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
3977  SmallVectorImpl<OpFoldResult> &results) {
3978  return canonicalizeLoopBounds(*this);
3979 }
3980 
3981 /// Prints a lower(upper) bound of an affine parallel loop with max(min)
3982 /// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
3983 /// identifies which of the those expressions form max/min groups. `operands`
3984 /// are the SSA values of dimensions and symbols and `keyword` is either "min"
3985 /// or "max".
3986 static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
3987  DenseIntElementsAttr group, ValueRange operands,
3988  StringRef keyword) {
3989  AffineMap map = mapAttr.getValue();
3990  unsigned numDims = map.getNumDims();
3991  ValueRange dimOperands = operands.take_front(numDims);
3992  ValueRange symOperands = operands.drop_front(numDims);
3993  unsigned start = 0;
3994  for (llvm::APInt groupSize : group) {
3995  if (start != 0)
3996  p << ", ";
3997 
3998  unsigned size = groupSize.getZExtValue();
3999  if (size == 1) {
4000  p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
4001  ++start;
4002  } else {
4003  p << keyword << '(';
4004  AffineMap submap = map.getSliceMap(start, size);
4005  p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
4006  p << ')';
4007  start += size;
4008  }
4009  }
4010 }
4011 
4013  p << " (" << getBody()->getArguments() << ") = (";
4014  printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
4015  getLowerBoundsOperands(), "max");
4016  p << ") to (";
4017  printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
4018  getUpperBoundsOperands(), "min");
4019  p << ')';
4020  SmallVector<int64_t, 8> steps = getSteps();
4021  bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4022  if (!elideSteps) {
4023  p << " step (";
4024  llvm::interleaveComma(steps, p);
4025  p << ')';
4026  }
4027  if (getNumResults()) {
4028  p << " reduce (";
4029  llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4030  arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4031  llvm::cast<IntegerAttr>(attr).getInt());
4032  p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4033  });
4034  p << ") -> (" << getResultTypes() << ")";
4035  }
4036 
4037  p << ' ';
4038  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4039  /*printBlockTerminators=*/getNumResults());
4041  (*this)->getAttrs(),
4042  /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4043  AffineParallelOp::getLowerBoundsMapAttrStrName(),
4044  AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4045  AffineParallelOp::getUpperBoundsMapAttrStrName(),
4046  AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4047  AffineParallelOp::getStepsAttrStrName()});
4048 }
4049 
4050 /// Given a list of lists of parsed operands, populates `uniqueOperands` with
4051 /// unique operands. Also populates `replacements with affine expressions of
4052 /// `kind` that can be used to update affine maps previously accepting a
4053 /// `operands` to accept `uniqueOperands` instead.
4055  OpAsmParser &parser,
4057  SmallVectorImpl<Value> &uniqueOperands,
4058  SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
4059  assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
4060  "expected operands to be dim or symbol expression");
4061 
4062  Type indexType = parser.getBuilder().getIndexType();
4063  for (const auto &list : operands) {
4064  SmallVector<Value> valueOperands;
4065  if (parser.resolveOperands(list, indexType, valueOperands))
4066  return failure();
4067  for (Value operand : valueOperands) {
4068  unsigned pos = std::distance(uniqueOperands.begin(),
4069  llvm::find(uniqueOperands, operand));
4070  if (pos == uniqueOperands.size())
4071  uniqueOperands.push_back(operand);
4072  replacements.push_back(
4073  kind == AffineExprKind::DimId
4074  ? getAffineDimExpr(pos, parser.getContext())
4075  : getAffineSymbolExpr(pos, parser.getContext()));
4076  }
4077  }
4078  return success();
4079 }
4080 
4081 namespace {
4082 enum class MinMaxKind { Min, Max };
4083 } // namespace
4084 
4085 /// Parses an affine map that can contain a min/max for groups of its results,
4086 /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4087 /// `result` attributes with the map (flat list of expressions) and the grouping
4088 /// (list of integers that specify how many expressions to put into each
4089 /// min/max) attributes. Deduplicates repeated operands.
4090 ///
4091 /// parallel-bound ::= `(` parallel-group-list `)`
4092 /// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
4093 /// parallel-group ::= simple-group | min-max-group
4094 /// simple-group ::= expr-of-ssa-ids
4095 /// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4096 /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4097 ///
4098 /// Examples:
4099 /// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4100 /// (%0, max(%1 - 2 * %2))
4102  OperationState &result,
4103  MinMaxKind kind) {
4104  // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4105  // see: https://reviews.llvm.org/D134227#3821753
4106  const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4107 
4108  StringRef mapName = kind == MinMaxKind::Min
4109  ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4110  : AffineParallelOp::getLowerBoundsMapAttrStrName();
4111  StringRef groupsName =
4112  kind == MinMaxKind::Min
4113  ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4114  : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4115 
4116  if (failed(parser.parseLParen()))
4117  return failure();
4118 
4119  if (succeeded(parser.parseOptionalRParen())) {
4120  result.addAttribute(
4121  mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4122  result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
4123  return success();
4124  }
4125 
4126  SmallVector<AffineExpr> flatExprs;
4129  SmallVector<int32_t> numMapsPerGroup;
4131  auto parseOperands = [&]() {
4132  if (succeeded(parser.parseOptionalKeyword(
4133  kind == MinMaxKind::Min ? "min" : "max"))) {
4134  mapOperands.clear();
4135  AffineMapAttr map;
4136  if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName,
4137  result.attributes,
4139  return failure();
4140  result.attributes.erase(tmpAttrStrName);
4141  llvm::append_range(flatExprs, map.getValue().getResults());
4142  auto operandsRef = llvm::ArrayRef(mapOperands);
4143  auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4144  SmallVector<OpAsmParser::UnresolvedOperand> dims(dimsRef.begin(),
4145  dimsRef.end());
4146  auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4147  SmallVector<OpAsmParser::UnresolvedOperand> syms(symsRef.begin(),
4148  symsRef.end());
4149  flatDimOperands.append(map.getValue().getNumResults(), dims);
4150  flatSymOperands.append(map.getValue().getNumResults(), syms);
4151  numMapsPerGroup.push_back(map.getValue().getNumResults());
4152  } else {
4153  if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
4154  flatSymOperands.emplace_back(),
4155  flatExprs.emplace_back())))
4156  return failure();
4157  numMapsPerGroup.push_back(1);
4158  }
4159  return success();
4160  };
4161  if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
4162  return failure();
4163 
4164  unsigned totalNumDims = 0;
4165  unsigned totalNumSyms = 0;
4166  for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4167  unsigned numDims = flatDimOperands[i].size();
4168  unsigned numSyms = flatSymOperands[i].size();
4169  flatExprs[i] = flatExprs[i]
4170  .shiftDims(numDims, totalNumDims)
4171  .shiftSymbols(numSyms, totalNumSyms);
4172  totalNumDims += numDims;
4173  totalNumSyms += numSyms;
4174  }
4175 
4176  // Deduplicate map operands.
4177  SmallVector<Value> dimOperands, symOperands;
4178  SmallVector<AffineExpr> dimRplacements, symRepacements;
4179  if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
4180  dimRplacements, AffineExprKind::DimId) ||
4181  deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
4182  symRepacements, AffineExprKind::SymbolId))
4183  return failure();
4184 
4185  result.operands.append(dimOperands.begin(), dimOperands.end());
4186  result.operands.append(symOperands.begin(), symOperands.end());
4187 
4188  Builder &builder = parser.getBuilder();
4189  auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4190  parser.getContext());
4191  flatMap = flatMap.replaceDimsAndSymbols(
4192  dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4193 
4194  result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4195  result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
4196  return success();
4197 }
4198 
4199 //
4200 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4201 // `to` parallel-bound steps? region attr-dict?
4202 // steps ::= `steps` `(` integer-literals `)`
4203 //
4205  OperationState &result) {
4206  auto &builder = parser.getBuilder();
4207  auto indexType = builder.getIndexType();
4210  parser.parseEqual() ||
4211  parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
4212  parser.parseKeyword("to") ||
4213  parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
4214  return failure();
4215 
4216  AffineMapAttr stepsMapAttr;
4217  NamedAttrList stepsAttrs;
4219  if (failed(parser.parseOptionalKeyword("step"))) {
4220  SmallVector<int64_t, 4> steps(ivs.size(), 1);
4221  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4222  builder.getI64ArrayAttr(steps));
4223  } else {
4224  if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
4225  AffineParallelOp::getStepsAttrStrName(),
4226  stepsAttrs,
4228  return failure();
4229 
4230  // Convert steps from an AffineMap into an I64ArrayAttr.
4232  auto stepsMap = stepsMapAttr.getValue();
4233  for (const auto &result : stepsMap.getResults()) {
4234  auto constExpr = dyn_cast<AffineConstantExpr>(result);
4235  if (!constExpr)
4236  return parser.emitError(parser.getNameLoc(),
4237  "steps must be constant integers");
4238  steps.push_back(constExpr.getValue());
4239  }
4240  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4241  builder.getI64ArrayAttr(steps));
4242  }
4243 
4244  // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4245  // quoted strings are a member of the enum AtomicRMWKind.
4246  SmallVector<Attribute, 4> reductions;
4247  if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4248  if (parser.parseLParen())
4249  return failure();
4250  auto parseAttributes = [&]() -> ParseResult {
4251  // Parse a single quoted string via the attribute parsing, and then
4252  // verify it is a member of the enum and convert to it's integer
4253  // representation.
4254  StringAttr attrVal;
4255  NamedAttrList attrStorage;
4256  auto loc = parser.getCurrentLocation();
4257  if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
4258  attrStorage))
4259  return failure();
4260  std::optional<arith::AtomicRMWKind> reduction =
4261  arith::symbolizeAtomicRMWKind(attrVal.getValue());
4262  if (!reduction)
4263  return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4264  reductions.push_back(
4265  builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4266  // While we keep getting commas, keep parsing.
4267  return success();
4268  };
4269  if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
4270  return failure();
4271  }
4272  result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4273  builder.getArrayAttr(reductions));
4274 
4275  // Parse return types of reductions (if any)
4276  if (parser.parseOptionalArrowTypeList(result.types))
4277  return failure();
4278 
4279  // Now parse the body.
4280  Region *body = result.addRegion();
4281  for (auto &iv : ivs)
4282  iv.type = indexType;
4283  if (parser.parseRegion(*body, ivs) ||
4284  parser.parseOptionalAttrDict(result.attributes))
4285  return failure();
4286 
4287  // Add a terminator if none was parsed.
4288  AffineParallelOp::ensureTerminator(*body, builder, result.location);
4289  return success();
4290 }
4291 
4292 //===----------------------------------------------------------------------===//
4293 // AffineYieldOp
4294 //===----------------------------------------------------------------------===//
4295 
4297  auto *parentOp = (*this)->getParentOp();
4298  auto results = parentOp->getResults();
4299  auto operands = getOperands();
4300 
4301  if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4302  return emitOpError() << "only terminates affine.if/for/parallel regions";
4303  if (parentOp->getNumResults() != getNumOperands())
4304  return emitOpError() << "parent of yield must have same number of "
4305  "results as the yield operands";
4306  for (auto it : llvm::zip(results, operands)) {
4307  if (std::get<0>(it).getType() != std::get<1>(it).getType())
4308  return emitOpError() << "types mismatch between yield op and its parent";
4309  }
4310 
4311  return success();
4312 }
4313 
4314 //===----------------------------------------------------------------------===//
4315 // AffineVectorLoadOp
4316 //===----------------------------------------------------------------------===//
4317 
4318 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4319  VectorType resultType, AffineMap map,
4320  ValueRange operands) {
4321  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4322  result.addOperands(operands);
4323  if (map)
4324  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4325  result.types.push_back(resultType);
4326 }
4327 
4328 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4329  VectorType resultType, Value memref,
4330  AffineMap map, ValueRange mapOperands) {
4331  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4332  result.addOperands(memref);
4333  result.addOperands(mapOperands);
4334  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4335  result.types.push_back(resultType);
4336 }
4337 
4338 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4339  VectorType resultType, Value memref,
4340  ValueRange indices) {
4341  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4342  int64_t rank = memrefType.getRank();
4343  // Create identity map for memrefs with at least one dimension or () -> ()
4344  // for zero-dimensional memrefs.
4345  auto map =
4346  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4347  build(builder, result, resultType, memref, map, indices);
4348 }
4349 
4350 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4351  MLIRContext *context) {
4352  results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4353 }
4354 
4356  OperationState &result) {
4357  auto &builder = parser.getBuilder();
4358  auto indexTy = builder.getIndexType();
4359 
4360  MemRefType memrefType;
4361  VectorType resultType;
4362  OpAsmParser::UnresolvedOperand memrefInfo;
4363  AffineMapAttr mapAttr;
4365  return failure(
4366  parser.parseOperand(memrefInfo) ||
4367  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4368  AffineVectorLoadOp::getMapAttrStrName(),
4369  result.attributes) ||
4370  parser.parseOptionalAttrDict(result.attributes) ||
4371  parser.parseColonType(memrefType) || parser.parseComma() ||
4372  parser.parseType(resultType) ||
4373  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4374  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
4375  parser.addTypeToList(resultType, result.types));
4376 }
4377 
4379  p << " " << getMemRef() << '[';
4380  if (AffineMapAttr mapAttr =
4381  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4382  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4383  p << ']';
4384  p.printOptionalAttrDict((*this)->getAttrs(),
4385  /*elidedAttrs=*/{getMapAttrStrName()});
4386  p << " : " << getMemRefType() << ", " << getType();
4387 }
4388 
4389 /// Verify common invariants of affine.vector_load and affine.vector_store.
4390 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4391  VectorType vectorType) {
4392  // Check that memref and vector element types match.
4393  if (memrefType.getElementType() != vectorType.getElementType())
4394  return op->emitOpError(
4395  "requires memref and vector types of the same elemental type");
4396  return success();
4397 }
4398 
4400  MemRefType memrefType = getMemRefType();
4402  getOperation(),
4403  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4404  getMapOperands(), memrefType,
4405  /*numIndexOperands=*/getNumOperands() - 1)))
4406  return failure();
4407 
4408  if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
4409  return failure();
4410 
4411  return success();
4412 }
4413 
4414 //===----------------------------------------------------------------------===//
4415 // AffineVectorStoreOp
4416 //===----------------------------------------------------------------------===//
4417 
4418 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4419  Value valueToStore, Value memref, AffineMap map,
4420  ValueRange mapOperands) {
4421  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4422  result.addOperands(valueToStore);
4423  result.addOperands(memref);
4424  result.addOperands(mapOperands);
4425  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4426 }
4427 
4428 // Use identity map.
4429 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4430  Value valueToStore, Value memref,
4431  ValueRange indices) {
4432  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4433  int64_t rank = memrefType.getRank();
4434  // Create identity map for memrefs with at least one dimension or () -> ()
4435  // for zero-dimensional memrefs.
4436  auto map =
4437  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4438  build(builder, result, valueToStore, memref, map, indices);
4439 }
4440 void AffineVectorStoreOp::getCanonicalizationPatterns(
4441  RewritePatternSet &results, MLIRContext *context) {
4442  results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4443 }
4444 
4446  OperationState &result) {
4447  auto indexTy = parser.getBuilder().getIndexType();
4448 
4449  MemRefType memrefType;
4450  VectorType resultType;
4451  OpAsmParser::UnresolvedOperand storeValueInfo;
4452  OpAsmParser::UnresolvedOperand memrefInfo;
4453  AffineMapAttr mapAttr;
4455  return failure(
4456  parser.parseOperand(storeValueInfo) || parser.parseComma() ||
4457  parser.parseOperand(memrefInfo) ||
4458  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4459  AffineVectorStoreOp::getMapAttrStrName(),
4460  result.attributes) ||
4461  parser.parseOptionalAttrDict(result.attributes) ||
4462  parser.parseColonType(memrefType) || parser.parseComma() ||
4463  parser.parseType(resultType) ||
4464  parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
4465  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4466  parser.resolveOperands(mapOperands, indexTy, result.operands));
4467 }
4468 
4470  p << " " << getValueToStore();
4471  p << ", " << getMemRef() << '[';
4472  if (AffineMapAttr mapAttr =
4473  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4474  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4475  p << ']';
4476  p.printOptionalAttrDict((*this)->getAttrs(),
4477  /*elidedAttrs=*/{getMapAttrStrName()});
4478  p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4479 }
4480 
4482  MemRefType memrefType = getMemRefType();
4484  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4485  getMapOperands(), memrefType,
4486  /*numIndexOperands=*/getNumOperands() - 2)))
4487  return failure();
4488 
4489  if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4490  return failure();
4491 
4492  return success();
4493 }
4494 
4495 //===----------------------------------------------------------------------===//
4496 // DelinearizeIndexOp
4497 //===----------------------------------------------------------------------===//
4498 
4499 LogicalResult AffineDelinearizeIndexOp::inferReturnTypes(
4500  MLIRContext *context, std::optional<::mlir::Location> location,
4501  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
4502  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
4503  AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
4504  regions);
4505  inferredReturnTypes.assign(adaptor.getBasis().size(),
4506  IndexType::get(context));
4507  return success();
4508 }
4509 
4510 void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result,
4511  Value linearIndex,
4512  ArrayRef<OpFoldResult> basis) {
4513  result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType()));
4514  result.addOperands(linearIndex);
4515  SmallVector<Value> basisValues =
4516  llvm::map_to_vector(basis, [&](OpFoldResult ofr) -> Value {
4517  std::optional<int64_t> staticDim = getConstantIntValue(ofr);
4518  if (staticDim.has_value())
4519  return builder.create<arith::ConstantIndexOp>(result.location,
4520  *staticDim);
4521  return llvm::dyn_cast_if_present<Value>(ofr);
4522  });
4523  result.addOperands(basisValues);
4524 }
4525 
4527  if (getBasis().empty())
4528  return emitOpError("basis should not be empty");
4529  if (getNumResults() != getBasis().size())
4530  return emitOpError("should return an index for each basis element");
4531  return success();
4532 }
4533 
4534 //===----------------------------------------------------------------------===//
4535 // TableGen'd op method definitions
4536 //===----------------------------------------------------------------------===//
4537 
4538 #define GET_OP_CLASSES
4539 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
Definition: AffineOps.cpp:2644
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
Definition: AffineOps.cpp:2367
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
Definition: AffineOps.cpp:1159
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
Definition: AffineOps.cpp:3214
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
Definition: AffineOps.cpp:3849
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
Definition: AffineOps.cpp:54
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
Definition: AffineOps.cpp:3986
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
Definition: AffineOps.cpp:995
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
Definition: AffineOps.cpp:3250
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
Definition: AffineOps.cpp:2225
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
Definition: AffineOps.cpp:781
static bool isValidAffineIndexOperand(Value value, Region *region)
Definition: AffineOps.cpp:460
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1353
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
Definition: AffineOps.cpp:1066
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:716
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
Definition: AffineOps.cpp:1914
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:708
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
Definition: AffineOps.cpp:2929
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
Definition: AffineOps.cpp:1018
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1310
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
Definition: AffineOps.cpp:4390
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
Definition: AffineOps.cpp:884
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
Definition: AffineOps.cpp:3227
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
Definition: AffineOps.cpp:323
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
Definition: AffineOps.cpp:656
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
Definition: AffineOps.cpp:4101
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
Definition: AffineOps.cpp:2653
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
Definition: AffineOps.cpp:465
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
Definition: AffineOps.cpp:618
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1246
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
Definition: AffineOps.cpp:2603
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
Definition: AffineOps.cpp:2179
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
Definition: AffineOps.cpp:497
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1261
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
Definition: AffineOps.cpp:342
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr &quotientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
Definition: AffineOps.cpp:684
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
Definition: AffineOps.cpp:4054
static LogicalResult verifyAffineMinMaxOp(T op)
Definition: AffineOps.cpp:3204
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
Definition: AffineOps.cpp:2089
static LogicalResult verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
Definition: AffineOps.cpp:3035
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
Definition: AffineOps.cpp:3413
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:76
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:81
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1541
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition: AffineExpr.h:69
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:883
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:27
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
Definition: AffineExpr.cpp:235
MLIRContext * getContext() const
Definition: AffineExpr.cpp:25
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:623
MLIRContext * getContext() const
Definition: AffineMap.cpp:327
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
Definition: AffineMap.h:214
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:260
unsigned getNumSymbols() const
Definition: AffineMap.cpp:382
unsigned getNumDims() const
Definition: AffineMap.cpp:378
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:391
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
Definition: AffineMap.h:221
unsigned getNumResults() const
Definition: AffineMap.cpp:386
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineMap.cpp:484
unsigned getNumInputs() const
Definition: AffineMap.cpp:387
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
Definition: AffineMap.h:273
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:395
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:499
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
Definition: AffineMap.cpp:125
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:615
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
Definition: AffineMap.cpp:418
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:296
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
BlockArgListType getArguments()
Definition: Block.h:84
Operation & front()
Definition: Block.h:150
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
AffineMap getDimIdentityMap()
Definition: Builders.cpp:390
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:394
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
Definition: Builders.cpp:195
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
NoneType getNoneType()
Definition: Builders.cpp:104
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
Definition: Builders.cpp:383
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:385
MLIRContext * getContext() const
Definition: Builders.h:55
AffineMap getSymbolIdentityMap()
Definition: Builders.cpp:403
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
IndexType getIndexType()
Definition: Builders.cpp:71
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition: IntegerSet.h:44
unsigned getNumDims() const
Definition: IntegerSet.cpp:15
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
Definition: IntegerSet.cpp:57
unsigned getNumInputs() const
Definition: IntegerSet.cpp:17
ArrayRef< AffineExpr > getConstraints() const
Definition: IntegerSet.cpp:41
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
Definition: IntegerSet.cpp:51
unsigned getNumSymbols() const
Definition: IntegerSet.cpp:16
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:447
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:322
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:437
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:444
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:529
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
operand_range::iterator operand_iterator
Definition: Operation.h:367
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:614
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
AffineBound represents a lower or upper bound in the for operation.
Definition: AffineOps.h:497
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
Definition: AffineOps.h:94
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
Definition: AffineOps.h:291
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
Definition: AffineOps.cpp:3946
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
unsigned getNumResults() const
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
Definition: AffineOps.cpp:2666
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
Definition: AffineOps.cpp:1128
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
Definition: AffineOps.cpp:2580
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
Definition: AffineOps.cpp:272
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1235
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
Definition: AffineOps.cpp:2552
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: AffineOps.cpp:2556
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1138
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1429
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1301
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
Definition: AffineOps.cpp:2544
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1294
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Definition: AffineOps.cpp:242
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
Definition: AffineOps.cpp:1434
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
Definition: AffineOps.cpp:2587
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
Definition: AffineOps.cpp:386
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1188
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Definition: AffineOps.cpp:2567
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
Definition: AffineOps.cpp:257
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
Definition: AffineOps.cpp:475
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
Definition: AffineOps.cpp:2548
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
Definition: AffineOps.cpp:1255
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt gcd(const MPInt &a, const MPInt &b)
Definition: MPInt.h:399
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:734
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
Definition: AffineMap.cpp:744
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's floordiv operation on constants.
Definition: MathExtras.h:33
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
Definition: MathExtras.h:23
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
AffineExprKind
Definition: AffineExpr.h:41
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:62
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:623
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:599
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
Definition: AffineMap.cpp:706
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:609
Canonicalize the affine map result expression order of an affine min/max operation.
Definition: AffineOps.cpp:3467
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3470
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3484
Remove duplicated expressions in affine min/max ops.
Definition: AffineOps.cpp:3283
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3286
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
Definition: AffineOps.cpp:3326
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3329
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:287
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.