MLIR  20.0.0git
AffineOps.cpp
Go to the documentation of this file.
1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/IntegerSet.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/IR/PatternMatch.h"
23 #include "llvm/ADT/ScopeExit.h"
24 #include "llvm/ADT/SmallBitVector.h"
25 #include "llvm/ADT/SmallVectorExtras.h"
26 #include "llvm/ADT/TypeSwitch.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/MathExtras.h"
29 #include <numeric>
30 #include <optional>
31 
32 using namespace mlir;
33 using namespace mlir::affine;
34 
35 using llvm::divideCeilSigned;
36 using llvm::divideFloorSigned;
37 using llvm::mod;
38 
39 #define DEBUG_TYPE "affine-ops"
40 
41 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
42 
43 /// A utility function to check if a value is defined at the top level of
44 /// `region` or is an argument of `region`. A value of index type defined at the
45 /// top level of a `AffineScope` region is always a valid symbol for all
46 /// uses in that region.
48  if (auto arg = llvm::dyn_cast<BlockArgument>(value))
49  return arg.getParentRegion() == region;
50  return value.getDefiningOp()->getParentRegion() == region;
51 }
52 
53 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
54 /// region remains legal if the operation that uses it is inlined into `dest`
55 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
56 /// `isValidSymbol`, depending on the value being required to remain a valid
57 /// dimension or symbol.
58 static bool
60  const IRMapping &mapping,
61  function_ref<bool(Value, Region *)> legalityCheck) {
62  // If the value is a valid dimension for any other reason than being
63  // a top-level value, it will remain valid: constants get inlined
64  // with the function, transitive affine applies also get inlined and
65  // will be checked themselves, etc.
66  if (!isTopLevelValue(value, src))
67  return true;
68 
69  // If it's a top-level value because it's a block operand, i.e. a
70  // function argument, check whether the value replacing it after
71  // inlining is a valid dimension in the new region.
72  if (llvm::isa<BlockArgument>(value))
73  return legalityCheck(mapping.lookup(value), dest);
74 
75  // If it's a top-level value because it's defined in the region,
76  // it can only be inlined if the defining op is a constant or a
77  // `dim`, which can appear anywhere and be valid, since the defining
78  // op won't be top-level anymore after inlining.
79  Attribute operandCst;
80  bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp());
81  return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
82  isDimLikeOp;
83 }
84 
85 /// Checks if all values known to be legal affine dimensions or symbols in `src`
86 /// remain so if their respective users are inlined into `dest`.
87 static bool
89  const IRMapping &mapping,
90  function_ref<bool(Value, Region *)> legalityCheck) {
91  return llvm::all_of(values, [&](Value v) {
92  return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
93  });
94 }
95 
96 /// Checks if an affine read or write operation remains legal after inlining
97 /// from `src` to `dest`.
98 template <typename OpTy>
99 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
100  const IRMapping &mapping) {
101  static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
102  AffineWriteOpInterface>::value,
103  "only ops with affine read/write interface are supported");
104 
105  AffineMap map = op.getAffineMap();
106  ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
107  ValueRange symbolOperands =
108  op.getMapOperands().take_back(map.getNumSymbols());
110  dimOperands, src, dest, mapping,
111  static_cast<bool (*)(Value, Region *)>(isValidDim)))
112  return false;
114  symbolOperands, src, dest, mapping,
115  static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
116  return false;
117  return true;
118 }
119 
120 /// Checks if an affine apply operation remains legal after inlining from `src`
121 /// to `dest`.
122 // Use "unused attribute" marker to silence clang-tidy warning stemming from
123 // the inability to see through "llvm::TypeSwitch".
124 template <>
125 bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op,
126  Region *src, Region *dest,
127  const IRMapping &mapping) {
128  // If it's a valid dimension, we need to check that it remains so.
129  if (isValidDim(op.getResult(), src))
131  op.getMapOperands(), src, dest, mapping,
132  static_cast<bool (*)(Value, Region *)>(isValidDim));
133 
134  // Otherwise it must be a valid symbol, check that it remains so.
136  op.getMapOperands(), src, dest, mapping,
137  static_cast<bool (*)(Value, Region *)>(isValidSymbol));
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // AffineDialect Interfaces
142 //===----------------------------------------------------------------------===//
143 
144 namespace {
145 /// This class defines the interface for handling inlining with affine
146 /// operations.
147 struct AffineInlinerInterface : public DialectInlinerInterface {
149 
150  //===--------------------------------------------------------------------===//
151  // Analysis Hooks
152  //===--------------------------------------------------------------------===//
153 
154  /// Returns true if the given region 'src' can be inlined into the region
155  /// 'dest' that is attached to an operation registered to the current dialect.
156  /// 'wouldBeCloned' is set if the region is cloned into its new location
157  /// rather than moved, indicating there may be other users.
158  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
159  IRMapping &valueMapping) const final {
160  // We can inline into affine loops and conditionals if this doesn't break
161  // affine value categorization rules.
162  Operation *destOp = dest->getParentOp();
163  if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
164  return false;
165 
166  // Multi-block regions cannot be inlined into affine constructs, all of
167  // which require single-block regions.
168  if (!llvm::hasSingleElement(*src))
169  return false;
170 
171  // Side-effecting operations that the affine dialect cannot understand
172  // should not be inlined.
173  Block &srcBlock = src->front();
174  for (Operation &op : srcBlock) {
175  // Ops with no side effects are fine,
176  if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
177  if (iface.hasNoEffect())
178  continue;
179  }
180 
181  // Assuming the inlined region is valid, we only need to check if the
182  // inlining would change it.
183  bool remainsValid =
185  .Case<AffineApplyOp, AffineReadOpInterface,
186  AffineWriteOpInterface>([&](auto op) {
187  return remainsLegalAfterInline(op, src, dest, valueMapping);
188  })
189  .Default([](Operation *) {
190  // Conservatively disallow inlining ops we cannot reason about.
191  return false;
192  });
193 
194  if (!remainsValid)
195  return false;
196  }
197 
198  return true;
199  }
200 
201  /// Returns true if the given operation 'op', that is registered to this
202  /// dialect, can be inlined into the given region, false otherwise.
203  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
204  IRMapping &valueMapping) const final {
205  // Always allow inlining affine operations into a region that is marked as
206  // affine scope, or into affine loops and conditionals. There are some edge
207  // cases when inlining *into* affine structures, but that is handled in the
208  // other 'isLegalToInline' hook above.
209  Operation *parentOp = region->getParentOp();
210  return parentOp->hasTrait<OpTrait::AffineScope>() ||
211  isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
212  }
213 
214  /// Affine regions should be analyzed recursively.
215  bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
216 };
217 } // namespace
218 
219 //===----------------------------------------------------------------------===//
220 // AffineDialect
221 //===----------------------------------------------------------------------===//
222 
223 void AffineDialect::initialize() {
224  addOperations<AffineDmaStartOp, AffineDmaWaitOp,
225 #define GET_OP_LIST
226 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
227  >();
228  addInterfaces<AffineInlinerInterface>();
229  declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
230  AffineMinOp>();
231 }
232 
233 /// Materialize a single constant operation from a given attribute value with
234 /// the desired resultant type.
236  Attribute value, Type type,
237  Location loc) {
238  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
239  return builder.create<ub::PoisonOp>(loc, type, poison);
240  return arith::ConstantOp::materialize(builder, value, type, loc);
241 }
242 
243 /// A utility function to check if a value is defined at the top level of an
244 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
245 /// conservatively assume it is not top-level. A value of index type defined at
246 /// the top level is always a valid symbol.
248  if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
249  // The block owning the argument may be unlinked, e.g. when the surrounding
250  // region has not yet been attached to an Op, at which point the parent Op
251  // is null.
252  Operation *parentOp = arg.getOwner()->getParentOp();
253  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
254  }
255  // The defining Op may live in an unlinked block so its parent Op may be null.
256  Operation *parentOp = value.getDefiningOp()->getParentOp();
257  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
258 }
259 
260 /// Returns the closest region enclosing `op` that is held by an operation with
261 /// trait `AffineScope`; `nullptr` if there is no such region.
263  auto *curOp = op;
264  while (auto *parentOp = curOp->getParentOp()) {
265  if (parentOp->hasTrait<OpTrait::AffineScope>())
266  return curOp->getParentRegion();
267  curOp = parentOp;
268  }
269  return nullptr;
270 }
271 
272 // A Value can be used as a dimension id iff it meets one of the following
273 // conditions:
274 // *) It is valid as a symbol.
275 // *) It is an induction variable.
276 // *) It is the result of affine apply operation with dimension id arguments.
278  // The value must be an index type.
279  if (!value.getType().isIndex())
280  return false;
281 
282  if (auto *defOp = value.getDefiningOp())
283  return isValidDim(value, getAffineScope(defOp));
284 
285  // This value has to be a block argument for an op that has the
286  // `AffineScope` trait or for an affine.for or affine.parallel.
287  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
288  return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
289  isa<AffineForOp, AffineParallelOp>(parentOp));
290 }
291 
292 // Value can be used as a dimension id iff it meets one of the following
293 // conditions:
294 // *) It is valid as a symbol.
295 // *) It is an induction variable.
296 // *) It is the result of an affine apply operation with dimension id operands.
297 bool mlir::affine::isValidDim(Value value, Region *region) {
298  // The value must be an index type.
299  if (!value.getType().isIndex())
300  return false;
301 
302  // All valid symbols are okay.
303  if (isValidSymbol(value, region))
304  return true;
305 
306  auto *op = value.getDefiningOp();
307  if (!op) {
308  // This value has to be a block argument for an affine.for or an
309  // affine.parallel.
310  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
311  return isa<AffineForOp, AffineParallelOp>(parentOp);
312  }
313 
314  // Affine apply operation is ok if all of its operands are ok.
315  if (auto applyOp = dyn_cast<AffineApplyOp>(op))
316  return applyOp.isValidDim(region);
317  // The dim op is okay if its operand memref/tensor is defined at the top
318  // level.
319  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
320  return isTopLevelValue(dimOp.getShapedValue());
321  return false;
322 }
323 
324 /// Returns true if the 'index' dimension of the `memref` defined by
325 /// `memrefDefOp` is a statically shaped one or defined using a valid symbol
326 /// for `region`.
327 template <typename AnyMemRefDefOp>
328 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
329  Region *region) {
330  MemRefType memRefType = memrefDefOp.getType();
331 
332  // Dimension index is out of bounds.
333  if (index >= memRefType.getRank()) {
334  return false;
335  }
336 
337  // Statically shaped.
338  if (!memRefType.isDynamicDim(index))
339  return true;
340  // Get the position of the dimension among dynamic dimensions;
341  unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
342  return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
343  region);
344 }
345 
346 /// Returns true if the result of the dim op is a valid symbol for `region`.
347 static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
348  // The dim op is okay if its source is defined at the top level.
349  if (isTopLevelValue(dimOp.getShapedValue()))
350  return true;
351 
352  // Conservatively handle remaining BlockArguments as non-valid symbols.
353  // E.g. scf.for iterArgs.
354  if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
355  return false;
356 
357  // The dim op is also okay if its operand memref is a view/subview whose
358  // corresponding size is a valid symbol.
359  std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
360 
361  // Be conservative if we can't understand the dimension.
362  if (!index.has_value())
363  return false;
364 
365  // Skip over all memref.cast ops (if any).
366  Operation *op = dimOp.getShapedValue().getDefiningOp();
367  while (auto castOp = dyn_cast<memref::CastOp>(op)) {
368  // Bail on unranked memrefs.
369  if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
370  return false;
371  op = castOp.getSource().getDefiningOp();
372  if (!op)
373  return false;
374  }
375 
376  int64_t i = index.value();
378  .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
379  [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
380  .Default([](Operation *) { return false; });
381 }
382 
383 // A value can be used as a symbol (at all its use sites) iff it meets one of
384 // the following conditions:
385 // *) It is a constant.
386 // *) Its defining op or block arg appearance is immediately enclosed by an op
387 // with `AffineScope` trait.
388 // *) It is the result of an affine.apply operation with symbol operands.
389 // *) It is a result of the dim op on a memref whose corresponding size is a
390 // valid symbol.
392  if (!value)
393  return false;
394 
395  // The value must be an index type.
396  if (!value.getType().isIndex())
397  return false;
398 
399  // Check that the value is a top level value.
400  if (isTopLevelValue(value))
401  return true;
402 
403  if (auto *defOp = value.getDefiningOp())
404  return isValidSymbol(value, getAffineScope(defOp));
405 
406  return false;
407 }
408 
409 /// A value can be used as a symbol for `region` iff it meets one of the
410 /// following conditions:
411 /// *) It is a constant.
412 /// *) It is the result of an affine apply operation with symbol arguments.
413 /// *) It is a result of the dim op on a memref whose corresponding size is
414 /// a valid symbol.
415 /// *) It is defined at the top level of 'region' or is its argument.
416 /// *) It dominates `region`'s parent op.
417 /// If `region` is null, conservatively assume the symbol definition scope does
418 /// not exist and only accept the values that would be symbols regardless of
419 /// the surrounding region structure, i.e. the first three cases above.
421  // The value must be an index type.
422  if (!value.getType().isIndex())
423  return false;
424 
425  // A top-level value is a valid symbol.
426  if (region && ::isTopLevelValue(value, region))
427  return true;
428 
429  auto *defOp = value.getDefiningOp();
430  if (!defOp) {
431  // A block argument that is not a top-level value is a valid symbol if it
432  // dominates region's parent op.
433  Operation *regionOp = region ? region->getParentOp() : nullptr;
434  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
435  if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
436  return isValidSymbol(value, parentOpRegion);
437  return false;
438  }
439 
440  // Constant operation is ok.
441  Attribute operandCst;
442  if (matchPattern(defOp, m_Constant(&operandCst)))
443  return true;
444 
445  // Affine apply operation is ok if all of its operands are ok.
446  if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
447  return applyOp.isValidSymbol(region);
448 
449  // Dim op results could be valid symbols at any level.
450  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
451  return isDimOpValidSymbol(dimOp, region);
452 
453  // Check for values dominating `region`'s parent op.
454  Operation *regionOp = region ? region->getParentOp() : nullptr;
455  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
456  if (auto *parentRegion = region->getParentOp()->getParentRegion())
457  return isValidSymbol(value, parentRegion);
458 
459  return false;
460 }
461 
462 // Returns true if 'value' is a valid index to an affine operation (e.g.
463 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
464 // `region` provides the polyhedral symbol scope. Returns false otherwise.
465 static bool isValidAffineIndexOperand(Value value, Region *region) {
466  return isValidDim(value, region) || isValidSymbol(value, region);
467 }
468 
469 /// Prints dimension and symbol list.
472  unsigned numDims, OpAsmPrinter &printer) {
473  OperandRange operands(begin, end);
474  printer << '(' << operands.take_front(numDims) << ')';
475  if (operands.size() > numDims)
476  printer << '[' << operands.drop_front(numDims) << ']';
477 }
478 
479 /// Parses dimension and symbol list and returns true if parsing failed.
481  OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
483  if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
484  return failure();
485  // Store number of dimensions for validation by caller.
486  numDims = opInfos.size();
487 
488  // Parse the optional symbol operands.
489  auto indexTy = parser.getBuilder().getIndexType();
490  return failure(parser.parseOperandList(
492  parser.resolveOperands(opInfos, indexTy, operands));
493 }
494 
495 /// Utility function to verify that a set of operands are valid dimension and
496 /// symbol identifiers. The operands should be laid out such that the dimension
497 /// operands are before the symbol operands. This function returns failure if
498 /// there was an invalid operand. An operation is provided to emit any necessary
499 /// errors.
500 template <typename OpTy>
501 static LogicalResult
503  unsigned numDims) {
504  unsigned opIt = 0;
505  for (auto operand : operands) {
506  if (opIt++ < numDims) {
507  if (!isValidDim(operand, getAffineScope(op)))
508  return op.emitOpError("operand cannot be used as a dimension id");
509  } else if (!isValidSymbol(operand, getAffineScope(op))) {
510  return op.emitOpError("operand cannot be used as a symbol");
511  }
512  }
513  return success();
514 }
515 
516 //===----------------------------------------------------------------------===//
517 // AffineApplyOp
518 //===----------------------------------------------------------------------===//
519 
520 AffineValueMap AffineApplyOp::getAffineValueMap() {
521  return AffineValueMap(getAffineMap(), getOperands(), getResult());
522 }
523 
524 ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
525  auto &builder = parser.getBuilder();
526  auto indexTy = builder.getIndexType();
527 
528  AffineMapAttr mapAttr;
529  unsigned numDims;
530  if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
531  parseDimAndSymbolList(parser, result.operands, numDims) ||
532  parser.parseOptionalAttrDict(result.attributes))
533  return failure();
534  auto map = mapAttr.getValue();
535 
536  if (map.getNumDims() != numDims ||
537  numDims + map.getNumSymbols() != result.operands.size()) {
538  return parser.emitError(parser.getNameLoc(),
539  "dimension or symbol index mismatch");
540  }
541 
542  result.types.append(map.getNumResults(), indexTy);
543  return success();
544 }
545 
547  p << " " << getMapAttr();
548  printDimAndSymbolList(operand_begin(), operand_end(),
549  getAffineMap().getNumDims(), p);
550  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
551 }
552 
553 LogicalResult AffineApplyOp::verify() {
554  // Check input and output dimensions match.
555  AffineMap affineMap = getMap();
556 
557  // Verify that operand count matches affine map dimension and symbol count.
558  if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
559  return emitOpError(
560  "operand count and affine map dimension and symbol count must match");
561 
562  // Verify that the map only produces one result.
563  if (affineMap.getNumResults() != 1)
564  return emitOpError("mapping must produce one value");
565 
566  return success();
567 }
568 
569 // The result of the affine apply operation can be used as a dimension id if all
570 // its operands are valid dimension ids.
572  return llvm::all_of(getOperands(),
573  [](Value op) { return affine::isValidDim(op); });
574 }
575 
576 // The result of the affine apply operation can be used as a dimension id if all
577 // its operands are valid dimension ids with the parent operation of `region`
578 // defining the polyhedral scope for symbols.
579 bool AffineApplyOp::isValidDim(Region *region) {
580  return llvm::all_of(getOperands(),
581  [&](Value op) { return ::isValidDim(op, region); });
582 }
583 
584 // The result of the affine apply operation can be used as a symbol if all its
585 // operands are symbols.
587  return llvm::all_of(getOperands(),
588  [](Value op) { return affine::isValidSymbol(op); });
589 }
590 
591 // The result of the affine apply operation can be used as a symbol in `region`
592 // if all its operands are symbols in `region`.
593 bool AffineApplyOp::isValidSymbol(Region *region) {
594  return llvm::all_of(getOperands(), [&](Value operand) {
595  return affine::isValidSymbol(operand, region);
596  });
597 }
598 
599 OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
600  auto map = getAffineMap();
601 
602  // Fold dims and symbols to existing values.
603  auto expr = map.getResult(0);
604  if (auto dim = dyn_cast<AffineDimExpr>(expr))
605  return getOperand(dim.getPosition());
606  if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
607  return getOperand(map.getNumDims() + sym.getPosition());
608 
609  // Otherwise, default to folding the map.
611  bool hasPoison = false;
612  auto foldResult =
613  map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
614  if (hasPoison)
616  if (failed(foldResult))
617  return {};
618  return result[0];
619 }
620 
621 /// Returns the largest known divisor of `e`. Exploits information from the
622 /// values in `operands`.
623 static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
624  // This method isn't aware of `operands`.
625  int64_t div = e.getLargestKnownDivisor();
626 
627  // We now make use of operands for the case `e` is a dim expression.
628  // TODO: More powerful simplification would have to modify
629  // getLargestKnownDivisor to take `operands` and exploit that information as
630  // well for dim/sym expressions, but in that case, getLargestKnownDivisor
631  // can't be part of the IR library but of the `Analysis` library. The IR
632  // library can only really depend on simple O(1) checks.
633  auto dimExpr = dyn_cast<AffineDimExpr>(e);
634  // If it's not a dim expr, `div` is the best we have.
635  if (!dimExpr)
636  return div;
637 
638  // We simply exploit information from loop IVs.
639  // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
640  // desired simplifications are expected to be part of other
641  // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
642  // LoopAnalysis library.
643  Value operand = operands[dimExpr.getPosition()];
644  int64_t operandDivisor = 1;
645  // TODO: With the right accessors, this can be extended to
646  // LoopLikeOpInterface.
647  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
648  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
649  operandDivisor = forOp.getStepAsInt();
650  } else {
651  uint64_t lbLargestKnownDivisor =
652  forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
653  operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
654  }
655  }
656  return operandDivisor;
657 }
658 
659 /// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
660 /// being an affine dim expression or a constant.
662  int64_t k) {
663  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
664  int64_t constVal = constExpr.getValue();
665  return constVal >= 0 && constVal < k;
666  }
667  auto dimExpr = dyn_cast<AffineDimExpr>(e);
668  if (!dimExpr)
669  return false;
670  Value operand = operands[dimExpr.getPosition()];
671  // TODO: With the right accessors, this can be extended to
672  // LoopLikeOpInterface.
673  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
674  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
675  forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
676  return true;
677  }
678  }
679 
680  // We don't consider other cases like `operand` being defined by a constant or
681  // an affine.apply op since such cases will already be handled by other
682  // patterns and propagation of loop IVs or constant would happen.
683  return false;
684 }
685 
686 /// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
687 /// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
688 /// expression is in that form.
689 static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
690  AffineExpr &quotientTimesDiv, AffineExpr &rem) {
691  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
692  if (!bin || bin.getKind() != AffineExprKind::Add)
693  return false;
694 
695  AffineExpr llhs = bin.getLHS();
696  AffineExpr rlhs = bin.getRHS();
697  div = getLargestKnownDivisor(llhs, operands);
698  if (isNonNegativeBoundedBy(rlhs, operands, div)) {
699  quotientTimesDiv = llhs;
700  rem = rlhs;
701  return true;
702  }
703  div = getLargestKnownDivisor(rlhs, operands);
704  if (isNonNegativeBoundedBy(llhs, operands, div)) {
705  quotientTimesDiv = rlhs;
706  rem = llhs;
707  return true;
708  }
709  return false;
710 }
711 
712 /// Gets the constant lower bound on an `iv`.
713 static std::optional<int64_t> getLowerBound(Value iv) {
714  AffineForOp forOp = getForInductionVarOwner(iv);
715  if (forOp && forOp.hasConstantLowerBound())
716  return forOp.getConstantLowerBound();
717  return std::nullopt;
718 }
719 
720 /// Gets the constant upper bound on an affine.for `iv`.
721 static std::optional<int64_t> getUpperBound(Value iv) {
722  AffineForOp forOp = getForInductionVarOwner(iv);
723  if (!forOp || !forOp.hasConstantUpperBound())
724  return std::nullopt;
725 
726  // If its lower bound is also known, we can get a more precise bound
727  // whenever the step is not one.
728  if (forOp.hasConstantLowerBound()) {
729  return forOp.getConstantUpperBound() - 1 -
730  (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
731  forOp.getStepAsInt();
732  }
733  return forOp.getConstantUpperBound() - 1;
734 }
735 
736 /// Determine a constant upper bound for `expr` if one exists while exploiting
737 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
738 /// is guaranteed to be less than or equal to it.
739 static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
740  unsigned numSymbols,
741  ArrayRef<Value> operands) {
742  // Get the constant lower or upper bounds on the operands.
743  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
744  constLowerBounds.reserve(operands.size());
745  constUpperBounds.reserve(operands.size());
746  for (Value operand : operands) {
747  constLowerBounds.push_back(getLowerBound(operand));
748  constUpperBounds.push_back(getUpperBound(operand));
749  }
750 
751  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
752  return constExpr.getValue();
753 
754  return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
755  constUpperBounds,
756  /*isUpper=*/true);
757 }
758 
759 /// Determine a constant lower bound for `expr` if one exists while exploiting
760 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
761 /// is guaranteed to be less than or equal to it.
762 static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
763  unsigned numSymbols,
764  ArrayRef<Value> operands) {
765  // Get the constant lower or upper bounds on the operands.
766  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
767  constLowerBounds.reserve(operands.size());
768  constUpperBounds.reserve(operands.size());
769  for (Value operand : operands) {
770  constLowerBounds.push_back(getLowerBound(operand));
771  constUpperBounds.push_back(getUpperBound(operand));
772  }
773 
774  std::optional<int64_t> lowerBound;
775  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
776  lowerBound = constExpr.getValue();
777  } else {
778  lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
779  constLowerBounds, constUpperBounds,
780  /*isUpper=*/false);
781  }
782  return lowerBound;
783 }
784 
785 /// Simplify `expr` while exploiting information from the values in `operands`.
786 static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
787  unsigned numSymbols,
788  ArrayRef<Value> operands) {
789  // We do this only for certain floordiv/mod expressions.
790  auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
791  if (!binExpr)
792  return;
793 
794  // Simplify the child expressions first.
795  AffineExpr lhs = binExpr.getLHS();
796  AffineExpr rhs = binExpr.getRHS();
797  simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
798  simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
799  expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
800 
801  binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
802  if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
803  expr.getKind() != AffineExprKind::CeilDiv &&
804  expr.getKind() != AffineExprKind::Mod)) {
805  return;
806  }
807 
808  // The `lhs` and `rhs` may be different post construction of simplified expr.
809  lhs = binExpr.getLHS();
810  rhs = binExpr.getRHS();
811  auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
812  if (!rhsConst)
813  return;
814 
815  int64_t rhsConstVal = rhsConst.getValue();
816  // Undefined exprsessions aren't touched; IR can still be valid with them.
817  if (rhsConstVal <= 0)
818  return;
819 
820  // Exploit constant lower/upper bounds to simplify a floordiv or mod.
821  MLIRContext *context = expr.getContext();
822  std::optional<int64_t> lhsLbConst =
823  getLowerBound(lhs, numDims, numSymbols, operands);
824  std::optional<int64_t> lhsUbConst =
825  getUpperBound(lhs, numDims, numSymbols, operands);
826  if (lhsLbConst && lhsUbConst) {
827  int64_t lhsLbConstVal = *lhsLbConst;
828  int64_t lhsUbConstVal = *lhsUbConst;
829  // lhs floordiv c is a single value lhs is bounded in a range `c` that has
830  // the same quotient.
831  if (binExpr.getKind() == AffineExprKind::FloorDiv &&
832  divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
833  divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
834  expr = getAffineConstantExpr(
835  divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
836  return;
837  }
838  // lhs ceildiv c is a single value if the entire range has the same ceil
839  // quotient.
840  if (binExpr.getKind() == AffineExprKind::CeilDiv &&
841  divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
842  divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
843  expr = getAffineConstantExpr(divideCeilSigned(lhsLbConstVal, rhsConstVal),
844  context);
845  return;
846  }
847  // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
848  if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
849  lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
850  expr = lhs;
851  return;
852  }
853  }
854 
855  // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
856  // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
857  // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
858  // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
859  AffineExpr quotientTimesDiv, rem;
860  int64_t divisor;
861  if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
862  if (rhsConstVal % divisor == 0 &&
863  binExpr.getKind() == AffineExprKind::FloorDiv) {
864  expr = quotientTimesDiv.floorDiv(rhsConst);
865  } else if (divisor % rhsConstVal == 0 &&
866  binExpr.getKind() == AffineExprKind::Mod) {
867  expr = rem % rhsConst;
868  }
869  return;
870  }
871 
872  // Handle the simple case when the LHS expression can be either upper
873  // bounded or is a known multiple of RHS constant.
874  // lhs floordiv c -> 0 if 0 <= lhs < c,
875  // lhs mod c -> 0 if lhs % c = 0.
876  if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) &&
877  binExpr.getKind() == AffineExprKind::FloorDiv) ||
878  (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 &&
879  binExpr.getKind() == AffineExprKind::Mod)) {
880  expr = getAffineConstantExpr(0, expr.getContext());
881  }
882 }
883 
884 /// Simplify the expressions in `map` while making use of lower or upper bounds
885 /// of its operands. If `isMax` is true, the map is to be treated as a max of
886 /// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
887 /// d1) can be simplified to (8) if the operands are respectively lower bounded
888 /// by 2 and 0 (the second expression can't be lower than 8).
890  ArrayRef<Value> operands,
891  bool isMax) {
892  // Can't simplify.
893  if (operands.empty())
894  return;
895 
896  // Get the upper or lower bound on an affine.for op IV using its range.
897  // Get the constant lower or upper bounds on the operands.
898  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
899  constLowerBounds.reserve(operands.size());
900  constUpperBounds.reserve(operands.size());
901  for (Value operand : operands) {
902  constLowerBounds.push_back(getLowerBound(operand));
903  constUpperBounds.push_back(getUpperBound(operand));
904  }
905 
906  // We will compute the lower and upper bounds on each of the expressions
907  // Then, we will check (depending on max or min) as to whether a specific
908  // bound is redundant by checking if its highest (in case of max) and its
909  // lowest (in the case of min) value is already lower than (or higher than)
910  // the lower bound (or upper bound in the case of min) of another bound.
911  SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
912  lowerBounds.reserve(map.getNumResults());
913  upperBounds.reserve(map.getNumResults());
914  for (AffineExpr e : map.getResults()) {
915  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
916  lowerBounds.push_back(constExpr.getValue());
917  upperBounds.push_back(constExpr.getValue());
918  } else {
919  lowerBounds.push_back(
921  constLowerBounds, constUpperBounds,
922  /*isUpper=*/false));
923  upperBounds.push_back(
925  constLowerBounds, constUpperBounds,
926  /*isUpper=*/true));
927  }
928  }
929 
930  // Collect expressions that are not redundant.
931  SmallVector<AffineExpr, 4> irredundantExprs;
932  for (auto exprEn : llvm::enumerate(map.getResults())) {
933  AffineExpr e = exprEn.value();
934  unsigned i = exprEn.index();
935  // Some expressions can be turned into constants.
936  if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
937  e = getAffineConstantExpr(*lowerBounds[i], e.getContext());
938 
939  // Check if the expression is redundant.
940  if (isMax) {
941  if (!upperBounds[i]) {
942  irredundantExprs.push_back(e);
943  continue;
944  }
945  // If there exists another expression such that its lower bound is greater
946  // than this expression's upper bound, it's redundant.
947  if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
948  auto otherLowerBound = en.value();
949  unsigned pos = en.index();
950  if (pos == i || !otherLowerBound)
951  return false;
952  if (*otherLowerBound > *upperBounds[i])
953  return true;
954  if (*otherLowerBound < *upperBounds[i])
955  return false;
956  // Equality case. When both expressions are considered redundant, we
957  // don't want to get both of them. We keep the one that appears
958  // first.
959  if (upperBounds[pos] && lowerBounds[i] &&
960  lowerBounds[i] == upperBounds[i] &&
961  otherLowerBound == *upperBounds[pos] && i < pos)
962  return false;
963  return true;
964  }))
965  irredundantExprs.push_back(e);
966  } else {
967  if (!lowerBounds[i]) {
968  irredundantExprs.push_back(e);
969  continue;
970  }
971  // Likewise for the `min` case. Use the complement of the condition above.
972  if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
973  auto otherUpperBound = en.value();
974  unsigned pos = en.index();
975  if (pos == i || !otherUpperBound)
976  return false;
977  if (*otherUpperBound < *lowerBounds[i])
978  return true;
979  if (*otherUpperBound > *lowerBounds[i])
980  return false;
981  if (lowerBounds[pos] && upperBounds[i] &&
982  lowerBounds[i] == upperBounds[i] &&
983  otherUpperBound == lowerBounds[pos] && i < pos)
984  return false;
985  return true;
986  }))
987  irredundantExprs.push_back(e);
988  }
989  }
990 
991  // Create the map without the redundant expressions.
992  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs,
993  map.getContext());
994 }
995 
996 /// Simplify the map while exploiting information on the values in `operands`.
997 // Use "unused attribute" marker to silence warning stemming from the inability
998 // to see through the template expansion.
999 static void LLVM_ATTRIBUTE_UNUSED
1001  assert(map.getNumInputs() == operands.size() && "invalid operands for map");
1002  SmallVector<AffineExpr> newResults;
1003  newResults.reserve(map.getNumResults());
1004  for (AffineExpr expr : map.getResults()) {
1006  operands);
1007  newResults.push_back(expr);
1008  }
1009  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
1010  map.getContext());
1011 }
1012 
1013 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
1014 /// defining AffineApplyOp expression and operands.
1015 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
1016 /// When `dimOrSymbolPosition >= dims.size()`,
1017 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
1018 /// Mutate `map`,`dims` and `syms` in place as follows:
1019 /// 1. `dims` and `syms` are only appended to.
1020 /// 2. `map` dim and symbols are gradually shifted to higher positions.
1021 /// 3. Old `dim` and `sym` entries are replaced by nullptr
1022 /// This avoids the need for any bookkeeping.
1023 static LogicalResult replaceDimOrSym(AffineMap *map,
1024  unsigned dimOrSymbolPosition,
1025  SmallVectorImpl<Value> &dims,
1026  SmallVectorImpl<Value> &syms) {
1027  MLIRContext *ctx = map->getContext();
1028  bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1029  unsigned pos = isDimReplacement ? dimOrSymbolPosition
1030  : dimOrSymbolPosition - dims.size();
1031  Value &v = isDimReplacement ? dims[pos] : syms[pos];
1032  if (!v)
1033  return failure();
1034 
1035  auto affineApply = v.getDefiningOp<AffineApplyOp>();
1036  if (!affineApply)
1037  return failure();
1038 
1039  // At this point we will perform a replacement of `v`, set the entry in `dim`
1040  // or `sym` to nullptr immediately.
1041  v = nullptr;
1042 
1043  // Compute the map, dims and symbols coming from the AffineApplyOp.
1044  AffineMap composeMap = affineApply.getAffineMap();
1045  assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1046  SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1047  affineApply.getMapOperands().end());
1048  // Canonicalize the map to promote dims to symbols when possible. This is to
1049  // avoid generating invalid maps.
1050  canonicalizeMapAndOperands(&composeMap, &composeOperands);
1051  AffineExpr replacementExpr =
1052  composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
1053  ValueRange composeDims =
1054  ArrayRef<Value>(composeOperands).take_front(composeMap.getNumDims());
1055  ValueRange composeSyms =
1056  ArrayRef<Value>(composeOperands).take_back(composeMap.getNumSymbols());
1057  AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
1058  : getAffineSymbolExpr(pos, ctx);
1059 
1060  // Append the dims and symbols where relevant and perform the replacement.
1061  dims.append(composeDims.begin(), composeDims.end());
1062  syms.append(composeSyms.begin(), composeSyms.end());
1063  *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());
1064 
1065  return success();
1066 }
1067 
1068 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1069 /// iteratively. Perform canonicalization of map and operands as well as
1070 /// AffineMap simplification. `map` and `operands` are mutated in place.
1072  SmallVectorImpl<Value> *operands) {
1073  if (map->getNumResults() == 0) {
1074  canonicalizeMapAndOperands(map, operands);
1075  *map = simplifyAffineMap(*map);
1076  return;
1077  }
1078 
1079  MLIRContext *ctx = map->getContext();
1080  SmallVector<Value, 4> dims(operands->begin(),
1081  operands->begin() + map->getNumDims());
1082  SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1083  operands->end());
1084 
1085  // Iterate over dims and symbols coming from AffineApplyOp and replace until
1086  // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1087  // and `syms` can only increase by construction.
1088  // The implementation uses a `while` loop to support the case of symbols
1089  // that may be constructed from dims ;this may be overkill.
1090  while (true) {
1091  bool changed = false;
1092  for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1093  if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
1094  break;
1095  if (!changed)
1096  break;
1097  }
1098 
1099  // Clear operands so we can fill them anew.
1100  operands->clear();
1101 
1102  // At this point we may have introduced null operands, prune them out before
1103  // canonicalizing map and operands.
1104  unsigned nDims = 0, nSyms = 0;
1105  SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1106  dimReplacements.reserve(dims.size());
1107  symReplacements.reserve(syms.size());
1108  for (auto *container : {&dims, &syms}) {
1109  bool isDim = (container == &dims);
1110  auto &repls = isDim ? dimReplacements : symReplacements;
1111  for (const auto &en : llvm::enumerate(*container)) {
1112  Value v = en.value();
1113  if (!v) {
1114  assert(isDim ? !map->isFunctionOfDim(en.index())
1115  : !map->isFunctionOfSymbol(en.index()) &&
1116  "map is function of unexpected expr@pos");
1117  repls.push_back(getAffineConstantExpr(0, ctx));
1118  continue;
1119  }
1120  repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
1121  : getAffineSymbolExpr(nSyms++, ctx));
1122  operands->push_back(v);
1123  }
1124  }
1125  *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
1126  nSyms);
1127 
1128  // Canonicalize and simplify before returning.
1129  canonicalizeMapAndOperands(map, operands);
1130  *map = simplifyAffineMap(*map);
1131 }
1132 
1134  AffineMap *map, SmallVectorImpl<Value> *operands) {
1135  while (llvm::any_of(*operands, [](Value v) {
1136  return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
1137  })) {
1138  composeAffineMapAndOperands(map, operands);
1139  }
1140 }
1141 
1142 AffineApplyOp
1144  ArrayRef<OpFoldResult> operands) {
1145  SmallVector<Value> valueOperands;
1146  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1147  composeAffineMapAndOperands(&map, &valueOperands);
1148  assert(map);
1149  return b.create<AffineApplyOp>(loc, map, valueOperands);
1150 }
1151 
1152 AffineApplyOp
1154  ArrayRef<OpFoldResult> operands) {
1155  return makeComposedAffineApply(
1156  b, loc,
1158  .front(),
1159  operands);
1160 }
1161 
1162 /// Composes the given affine map with the given list of operands, pulling in
1163 /// the maps from any affine.apply operations that supply the operands.
1165  SmallVectorImpl<Value> &operands) {
1166  // Compose and canonicalize each expression in the map individually because
1167  // composition only applies to single-result maps, collecting potentially
1168  // duplicate operands in a single list with shifted dimensions and symbols.
1169  SmallVector<Value> dims, symbols;
1171  for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
1172  SmallVector<Value> submapOperands(operands.begin(), operands.end());
1173  AffineMap submap = map.getSubMap({i});
1174  fullyComposeAffineMapAndOperands(&submap, &submapOperands);
1175  canonicalizeMapAndOperands(&submap, &submapOperands);
1176  unsigned numNewDims = submap.getNumDims();
1177  submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
1178  llvm::append_range(dims,
1179  ArrayRef<Value>(submapOperands).take_front(numNewDims));
1180  llvm::append_range(symbols,
1181  ArrayRef<Value>(submapOperands).drop_front(numNewDims));
1182  exprs.push_back(submap.getResult(0));
1183  }
1184 
1185  // Canonicalize the map created from composed expressions to deduplicate the
1186  // dimension and symbol operands.
1187  operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1188  map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
1189  canonicalizeMapAndOperands(&map, &operands);
1190 }
1191 
1194  AffineMap map,
1195  ArrayRef<OpFoldResult> operands) {
1196  assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1197 
1198  // Create new builder without a listener, so that no notification is
1199  // triggered if the op is folded.
1200  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1201  // workaround is no longer needed.
1202  OpBuilder newBuilder(b.getContext());
1204 
1205  // Create op.
1206  AffineApplyOp applyOp =
1207  makeComposedAffineApply(newBuilder, loc, map, operands);
1208 
1209  // Get constant operands.
1210  SmallVector<Attribute> constOperands(applyOp->getNumOperands());
1211  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1212  matchPattern(applyOp->getOperand(i), m_Constant(&constOperands[i]));
1213 
1214  // Try to fold the operation.
1215  SmallVector<OpFoldResult> foldResults;
1216  if (failed(applyOp->fold(constOperands, foldResults)) ||
1217  foldResults.empty()) {
1218  if (OpBuilder::Listener *listener = b.getListener())
1219  listener->notifyOperationInserted(applyOp, /*previous=*/{});
1220  return applyOp.getResult();
1221  }
1222 
1223  applyOp->erase();
1224  assert(foldResults.size() == 1 && "expected 1 folded result");
1225  return foldResults.front();
1226 }
1227 
1230  AffineExpr expr,
1231  ArrayRef<OpFoldResult> operands) {
1233  b, loc,
1235  .front(),
1236  operands);
1237 }
1238 
1241  OpBuilder &b, Location loc, AffineMap map,
1242  ArrayRef<OpFoldResult> operands) {
1243  return llvm::map_to_vector(llvm::seq<unsigned>(0, map.getNumResults()),
1244  [&](unsigned i) {
1245  return makeComposedFoldedAffineApply(
1246  b, loc, map.getSubMap({i}), operands);
1247  });
1248 }
1249 
1250 template <typename OpTy>
1252  ArrayRef<OpFoldResult> operands) {
1253  SmallVector<Value> valueOperands;
1254  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1255  composeMultiResultAffineMap(map, valueOperands);
1256  return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands);
1257 }
1258 
1259 AffineMinOp
1261  ArrayRef<OpFoldResult> operands) {
1262  return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1263 }
1264 
1265 template <typename OpTy>
1267  AffineMap map,
1268  ArrayRef<OpFoldResult> operands) {
1269  // Create new builder without a listener, so that no notification is
1270  // triggered if the op is folded.
1271  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1272  // workaround is no longer needed.
1273  OpBuilder newBuilder(b.getContext());
1275 
1276  // Create op.
1277  auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1278 
1279  // Get constant operands.
1280  SmallVector<Attribute> constOperands(minMaxOp->getNumOperands());
1281  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1282  matchPattern(minMaxOp->getOperand(i), m_Constant(&constOperands[i]));
1283 
1284  // Try to fold the operation.
1285  SmallVector<OpFoldResult> foldResults;
1286  if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1287  foldResults.empty()) {
1288  if (OpBuilder::Listener *listener = b.getListener())
1289  listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
1290  return minMaxOp.getResult();
1291  }
1292 
1293  minMaxOp->erase();
1294  assert(foldResults.size() == 1 && "expected 1 folded result");
1295  return foldResults.front();
1296 }
1297 
1300  AffineMap map,
1301  ArrayRef<OpFoldResult> operands) {
1302  return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1303 }
1304 
1307  AffineMap map,
1308  ArrayRef<OpFoldResult> operands) {
1309  return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1310 }
1311 
1312 // A symbol may appear as a dim in affine.apply operations. This function
1313 // canonicalizes dims that are valid symbols into actual symbols.
1314 template <class MapOrSet>
1315 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1316  SmallVectorImpl<Value> *operands) {
1317  if (!mapOrSet || operands->empty())
1318  return;
1319 
1320  assert(mapOrSet->getNumInputs() == operands->size() &&
1321  "map/set inputs must match number of operands");
1322 
1323  auto *context = mapOrSet->getContext();
1324  SmallVector<Value, 8> resultOperands;
1325  resultOperands.reserve(operands->size());
1326  SmallVector<Value, 8> remappedSymbols;
1327  remappedSymbols.reserve(operands->size());
1328  unsigned nextDim = 0;
1329  unsigned nextSym = 0;
1330  unsigned oldNumSyms = mapOrSet->getNumSymbols();
1331  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1332  for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1333  if (i < mapOrSet->getNumDims()) {
1334  if (isValidSymbol((*operands)[i])) {
1335  // This is a valid symbol that appears as a dim, canonicalize it.
1336  dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1337  remappedSymbols.push_back((*operands)[i]);
1338  } else {
1339  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1340  resultOperands.push_back((*operands)[i]);
1341  }
1342  } else {
1343  resultOperands.push_back((*operands)[i]);
1344  }
1345  }
1346 
1347  resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1348  *operands = resultOperands;
1349  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1350  oldNumSyms + nextSym);
1351 
1352  assert(mapOrSet->getNumInputs() == operands->size() &&
1353  "map/set inputs must match number of operands");
1354 }
1355 
1356 // Works for either an affine map or an integer set.
1357 template <class MapOrSet>
1358 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1359  SmallVectorImpl<Value> *operands) {
1360  static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1361  "Argument must be either of AffineMap or IntegerSet type");
1362 
1363  if (!mapOrSet || operands->empty())
1364  return;
1365 
1366  assert(mapOrSet->getNumInputs() == operands->size() &&
1367  "map/set inputs must match number of operands");
1368 
1369  canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1370 
1371  // Check to see what dims are used.
1372  llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1373  llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1374  mapOrSet->walkExprs([&](AffineExpr expr) {
1375  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1376  usedDims[dimExpr.getPosition()] = true;
1377  else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1378  usedSyms[symExpr.getPosition()] = true;
1379  });
1380 
1381  auto *context = mapOrSet->getContext();
1382 
1383  SmallVector<Value, 8> resultOperands;
1384  resultOperands.reserve(operands->size());
1385 
1386  llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1387  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1388  unsigned nextDim = 0;
1389  for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1390  if (usedDims[i]) {
1391  // Remap dim positions for duplicate operands.
1392  auto it = seenDims.find((*operands)[i]);
1393  if (it == seenDims.end()) {
1394  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1395  resultOperands.push_back((*operands)[i]);
1396  seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1397  } else {
1398  dimRemapping[i] = it->second;
1399  }
1400  }
1401  }
1402  llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1403  SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1404  unsigned nextSym = 0;
1405  for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1406  if (!usedSyms[i])
1407  continue;
1408  // Handle constant operands (only needed for symbolic operands since
1409  // constant operands in dimensional positions would have already been
1410  // promoted to symbolic positions above).
1411  IntegerAttr operandCst;
1412  if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1413  m_Constant(&operandCst))) {
1414  symRemapping[i] =
1415  getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1416  continue;
1417  }
1418  // Remap symbol positions for duplicate operands.
1419  auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1420  if (it == seenSymbols.end()) {
1421  symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1422  resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1423  seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1424  symRemapping[i]));
1425  } else {
1426  symRemapping[i] = it->second;
1427  }
1428  }
1429  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1430  nextDim, nextSym);
1431  *operands = resultOperands;
1432 }
1433 
1435  AffineMap *map, SmallVectorImpl<Value> *operands) {
1436  canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1437 }
1438 
1440  IntegerSet *set, SmallVectorImpl<Value> *operands) {
1441  canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1442 }
1443 
1444 namespace {
1445 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1446 /// maps that supply results into them.
1447 ///
1448 template <typename AffineOpTy>
1449 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1451 
1452  /// Replace the affine op with another instance of it with the supplied
1453  /// map and mapOperands.
1454  void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1455  AffineMap map, ArrayRef<Value> mapOperands) const;
1456 
1457  LogicalResult matchAndRewrite(AffineOpTy affineOp,
1458  PatternRewriter &rewriter) const override {
1459  static_assert(
1460  llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1461  AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1462  AffineVectorStoreOp, AffineVectorLoadOp>::value,
1463  "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1464  "expected");
1465  auto map = affineOp.getAffineMap();
1466  AffineMap oldMap = map;
1467  auto oldOperands = affineOp.getMapOperands();
1468  SmallVector<Value, 8> resultOperands(oldOperands);
1469  composeAffineMapAndOperands(&map, &resultOperands);
1470  canonicalizeMapAndOperands(&map, &resultOperands);
1471  simplifyMapWithOperands(map, resultOperands);
1472  if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1473  resultOperands.begin()))
1474  return failure();
1475 
1476  replaceAffineOp(rewriter, affineOp, map, resultOperands);
1477  return success();
1478  }
1479 };
1480 
1481 // Specialize the template to account for the different build signatures for
1482 // affine load, store, and apply ops.
1483 template <>
1484 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1485  PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1486  ArrayRef<Value> mapOperands) const {
1487  rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1488  mapOperands);
1489 }
1490 template <>
1491 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1492  PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1493  ArrayRef<Value> mapOperands) const {
1494  rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1495  prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1496  prefetch.getLocalityHint(), prefetch.getIsDataCache());
1497 }
1498 template <>
1499 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1500  PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1501  ArrayRef<Value> mapOperands) const {
1502  rewriter.replaceOpWithNewOp<AffineStoreOp>(
1503  store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1504 }
1505 template <>
1506 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1507  PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1508  ArrayRef<Value> mapOperands) const {
1509  rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1510  vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1511  mapOperands);
1512 }
1513 template <>
1514 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1515  PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1516  ArrayRef<Value> mapOperands) const {
1517  rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1518  vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1519  mapOperands);
1520 }
1521 
1522 // Generic version for ops that don't have extra operands.
1523 template <typename AffineOpTy>
1524 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1525  PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1526  ArrayRef<Value> mapOperands) const {
1527  rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1528 }
1529 } // namespace
1530 
1531 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1532  MLIRContext *context) {
1533  results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1534 }
1535 
1536 //===----------------------------------------------------------------------===//
1537 // AffineDmaStartOp
1538 //===----------------------------------------------------------------------===//
1539 
1540 // TODO: Check that map operands are loop IVs or symbols.
1541 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1542  Value srcMemRef, AffineMap srcMap,
1543  ValueRange srcIndices, Value destMemRef,
1544  AffineMap dstMap, ValueRange destIndices,
1545  Value tagMemRef, AffineMap tagMap,
1546  ValueRange tagIndices, Value numElements,
1547  Value stride, Value elementsPerStride) {
1548  result.addOperands(srcMemRef);
1549  result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1550  result.addOperands(srcIndices);
1551  result.addOperands(destMemRef);
1552  result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1553  result.addOperands(destIndices);
1554  result.addOperands(tagMemRef);
1555  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1556  result.addOperands(tagIndices);
1557  result.addOperands(numElements);
1558  if (stride) {
1559  result.addOperands({stride, elementsPerStride});
1560  }
1561 }
1562 
1564  p << " " << getSrcMemRef() << '[';
1565  p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1566  p << "], " << getDstMemRef() << '[';
1567  p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1568  p << "], " << getTagMemRef() << '[';
1569  p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1570  p << "], " << getNumElements();
1571  if (isStrided()) {
1572  p << ", " << getStride();
1573  p << ", " << getNumElementsPerStride();
1574  }
1575  p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1576  << getTagMemRefType();
1577 }
1578 
1579 // Parse AffineDmaStartOp.
1580 // Ex:
1581 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1582 // %stride, %num_elt_per_stride
1583 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1584 //
1586  OperationState &result) {
1587  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1588  AffineMapAttr srcMapAttr;
1590  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1591  AffineMapAttr dstMapAttr;
1593  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1594  AffineMapAttr tagMapAttr;
1596  OpAsmParser::UnresolvedOperand numElementsInfo;
1598 
1599  SmallVector<Type, 3> types;
1600  auto indexType = parser.getBuilder().getIndexType();
1601 
1602  // Parse and resolve the following list of operands:
1603  // *) dst memref followed by its affine maps operands (in square brackets).
1604  // *) src memref followed by its affine map operands (in square brackets).
1605  // *) tag memref followed by its affine map operands (in square brackets).
1606  // *) number of elements transferred by DMA operation.
1607  if (parser.parseOperand(srcMemRefInfo) ||
1608  parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1609  getSrcMapAttrStrName(),
1610  result.attributes) ||
1611  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1612  parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1613  getDstMapAttrStrName(),
1614  result.attributes) ||
1615  parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1616  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1617  getTagMapAttrStrName(),
1618  result.attributes) ||
1619  parser.parseComma() || parser.parseOperand(numElementsInfo))
1620  return failure();
1621 
1622  // Parse optional stride and elements per stride.
1623  if (parser.parseTrailingOperandList(strideInfo))
1624  return failure();
1625 
1626  if (!strideInfo.empty() && strideInfo.size() != 2) {
1627  return parser.emitError(parser.getNameLoc(),
1628  "expected two stride related operands");
1629  }
1630  bool isStrided = strideInfo.size() == 2;
1631 
1632  if (parser.parseColonTypeList(types))
1633  return failure();
1634 
1635  if (types.size() != 3)
1636  return parser.emitError(parser.getNameLoc(), "expected three types");
1637 
1638  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1639  parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1640  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1641  parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1642  parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1643  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1644  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1645  return failure();
1646 
1647  if (isStrided) {
1648  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1649  return failure();
1650  }
1651 
1652  // Check that src/dst/tag operand counts match their map.numInputs.
1653  if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1654  dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1655  tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1656  return parser.emitError(parser.getNameLoc(),
1657  "memref operand count not equal to map.numInputs");
1658  return success();
1659 }
1660 
1661 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1662  if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
1663  return emitOpError("expected DMA source to be of memref type");
1664  if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
1665  return emitOpError("expected DMA destination to be of memref type");
1666  if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
1667  return emitOpError("expected DMA tag to be of memref type");
1668 
1669  unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1670  getDstMap().getNumInputs() +
1671  getTagMap().getNumInputs();
1672  if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1673  getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1674  return emitOpError("incorrect number of operands");
1675  }
1676 
1677  Region *scope = getAffineScope(*this);
1678  for (auto idx : getSrcIndices()) {
1679  if (!idx.getType().isIndex())
1680  return emitOpError("src index to dma_start must have 'index' type");
1681  if (!isValidAffineIndexOperand(idx, scope))
1682  return emitOpError(
1683  "src index must be a valid dimension or symbol identifier");
1684  }
1685  for (auto idx : getDstIndices()) {
1686  if (!idx.getType().isIndex())
1687  return emitOpError("dst index to dma_start must have 'index' type");
1688  if (!isValidAffineIndexOperand(idx, scope))
1689  return emitOpError(
1690  "dst index must be a valid dimension or symbol identifier");
1691  }
1692  for (auto idx : getTagIndices()) {
1693  if (!idx.getType().isIndex())
1694  return emitOpError("tag index to dma_start must have 'index' type");
1695  if (!isValidAffineIndexOperand(idx, scope))
1696  return emitOpError(
1697  "tag index must be a valid dimension or symbol identifier");
1698  }
1699  return success();
1700 }
1701 
1702 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1703  SmallVectorImpl<OpFoldResult> &results) {
1704  /// dma_start(memrefcast) -> dma_start
1705  return memref::foldMemRefCast(*this);
1706 }
1707 
1708 void AffineDmaStartOp::getEffects(
1710  &effects) {
1711  effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
1713  effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
1715  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
1717 }
1718 
1719 //===----------------------------------------------------------------------===//
1720 // AffineDmaWaitOp
1721 //===----------------------------------------------------------------------===//
1722 
1723 // TODO: Check that map operands are loop IVs or symbols.
1724 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1725  Value tagMemRef, AffineMap tagMap,
1726  ValueRange tagIndices, Value numElements) {
1727  result.addOperands(tagMemRef);
1728  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1729  result.addOperands(tagIndices);
1730  result.addOperands(numElements);
1731 }
1732 
1734  p << " " << getTagMemRef() << '[';
1735  SmallVector<Value, 2> operands(getTagIndices());
1736  p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1737  p << "], ";
1739  p << " : " << getTagMemRef().getType();
1740 }
1741 
1742 // Parse AffineDmaWaitOp.
1743 // Eg:
1744 // affine.dma_wait %tag[%index], %num_elements
1745 // : memref<1 x i32, (d0) -> (d0), 4>
1746 //
1748  OperationState &result) {
1749  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1750  AffineMapAttr tagMapAttr;
1752  Type type;
1753  auto indexType = parser.getBuilder().getIndexType();
1754  OpAsmParser::UnresolvedOperand numElementsInfo;
1755 
1756  // Parse tag memref, its map operands, and dma size.
1757  if (parser.parseOperand(tagMemRefInfo) ||
1758  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1759  getTagMapAttrStrName(),
1760  result.attributes) ||
1761  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1762  parser.parseColonType(type) ||
1763  parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1764  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1765  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1766  return failure();
1767 
1768  if (!llvm::isa<MemRefType>(type))
1769  return parser.emitError(parser.getNameLoc(),
1770  "expected tag to be of memref type");
1771 
1772  if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1773  return parser.emitError(parser.getNameLoc(),
1774  "tag memref operand count != to map.numInputs");
1775  return success();
1776 }
1777 
1778 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1779  if (!llvm::isa<MemRefType>(getOperand(0).getType()))
1780  return emitOpError("expected DMA tag to be of memref type");
1781  Region *scope = getAffineScope(*this);
1782  for (auto idx : getTagIndices()) {
1783  if (!idx.getType().isIndex())
1784  return emitOpError("index to dma_wait must have 'index' type");
1785  if (!isValidAffineIndexOperand(idx, scope))
1786  return emitOpError(
1787  "index must be a valid dimension or symbol identifier");
1788  }
1789  return success();
1790 }
1791 
1792 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1793  SmallVectorImpl<OpFoldResult> &results) {
1794  /// dma_wait(memrefcast) -> dma_wait
1795  return memref::foldMemRefCast(*this);
1796 }
1797 
1798 void AffineDmaWaitOp::getEffects(
1800  &effects) {
1801  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
1803 }
1804 
1805 //===----------------------------------------------------------------------===//
1806 // AffineForOp
1807 //===----------------------------------------------------------------------===//
1808 
1809 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1810 /// bodyBuilder are empty/null, we include default terminator op.
1811 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1812  ValueRange lbOperands, AffineMap lbMap,
1813  ValueRange ubOperands, AffineMap ubMap, int64_t step,
1814  ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1815  assert(((!lbMap && lbOperands.empty()) ||
1816  lbOperands.size() == lbMap.getNumInputs()) &&
1817  "lower bound operand count does not match the affine map");
1818  assert(((!ubMap && ubOperands.empty()) ||
1819  ubOperands.size() == ubMap.getNumInputs()) &&
1820  "upper bound operand count does not match the affine map");
1821  assert(step > 0 && "step has to be a positive integer constant");
1822 
1823  OpBuilder::InsertionGuard guard(builder);
1824 
1825  // Set variadic segment sizes.
1826  result.addAttribute(
1827  getOperandSegmentSizeAttr(),
1828  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()),
1829  static_cast<int32_t>(ubOperands.size()),
1830  static_cast<int32_t>(iterArgs.size())}));
1831 
1832  for (Value val : iterArgs)
1833  result.addTypes(val.getType());
1834 
1835  // Add an attribute for the step.
1836  result.addAttribute(getStepAttrName(result.name),
1837  builder.getIntegerAttr(builder.getIndexType(), step));
1838 
1839  // Add the lower bound.
1840  result.addAttribute(getLowerBoundMapAttrName(result.name),
1841  AffineMapAttr::get(lbMap));
1842  result.addOperands(lbOperands);
1843 
1844  // Add the upper bound.
1845  result.addAttribute(getUpperBoundMapAttrName(result.name),
1846  AffineMapAttr::get(ubMap));
1847  result.addOperands(ubOperands);
1848 
1849  result.addOperands(iterArgs);
1850  // Create a region and a block for the body. The argument of the region is
1851  // the loop induction variable.
1852  Region *bodyRegion = result.addRegion();
1853  Block *bodyBlock = builder.createBlock(bodyRegion);
1854  Value inductionVar =
1855  bodyBlock->addArgument(builder.getIndexType(), result.location);
1856  for (Value val : iterArgs)
1857  bodyBlock->addArgument(val.getType(), val.getLoc());
1858 
1859  // Create the default terminator if the builder is not provided and if the
1860  // iteration arguments are not provided. Otherwise, leave this to the caller
1861  // because we don't know which values to return from the loop.
1862  if (iterArgs.empty() && !bodyBuilder) {
1863  ensureTerminator(*bodyRegion, builder, result.location);
1864  } else if (bodyBuilder) {
1865  OpBuilder::InsertionGuard guard(builder);
1866  builder.setInsertionPointToStart(bodyBlock);
1867  bodyBuilder(builder, result.location, inductionVar,
1868  bodyBlock->getArguments().drop_front());
1869  }
1870 }
1871 
1872 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1873  int64_t ub, int64_t step, ValueRange iterArgs,
1874  BodyBuilderFn bodyBuilder) {
1875  auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1876  auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1877  return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1878  bodyBuilder);
1879 }
1880 
1881 LogicalResult AffineForOp::verifyRegions() {
1882  // Check that the body defines as single block argument for the induction
1883  // variable.
1884  auto *body = getBody();
1885  if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1886  return emitOpError("expected body to have a single index argument for the "
1887  "induction variable");
1888 
1889  // Verify that the bound operands are valid dimension/symbols.
1890  /// Lower bound.
1891  if (getLowerBoundMap().getNumInputs() > 0)
1893  getLowerBoundMap().getNumDims())))
1894  return failure();
1895  /// Upper bound.
1896  if (getUpperBoundMap().getNumInputs() > 0)
1898  getUpperBoundMap().getNumDims())))
1899  return failure();
1900 
1901  unsigned opNumResults = getNumResults();
1902  if (opNumResults == 0)
1903  return success();
1904 
1905  // If ForOp defines values, check that the number and types of the defined
1906  // values match ForOp initial iter operands and backedge basic block
1907  // arguments.
1908  if (getNumIterOperands() != opNumResults)
1909  return emitOpError(
1910  "mismatch between the number of loop-carried values and results");
1911  if (getNumRegionIterArgs() != opNumResults)
1912  return emitOpError(
1913  "mismatch between the number of basic block args and results");
1914 
1915  return success();
1916 }
1917 
1918 /// Parse a for operation loop bounds.
1919 static ParseResult parseBound(bool isLower, OperationState &result,
1920  OpAsmParser &p) {
1921  // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1922  // the map has multiple results.
1923  bool failedToParsedMinMax =
1924  failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1925 
1926  auto &builder = p.getBuilder();
1927  auto boundAttrStrName =
1928  isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)
1929  : AffineForOp::getUpperBoundMapAttrName(result.name);
1930 
1931  // Parse ssa-id as identity map.
1933  if (p.parseOperandList(boundOpInfos))
1934  return failure();
1935 
1936  if (!boundOpInfos.empty()) {
1937  // Check that only one operand was parsed.
1938  if (boundOpInfos.size() > 1)
1939  return p.emitError(p.getNameLoc(),
1940  "expected only one loop bound operand");
1941 
1942  // TODO: improve error message when SSA value is not of index type.
1943  // Currently it is 'use of value ... expects different type than prior uses'
1944  if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1945  result.operands))
1946  return failure();
1947 
1948  // Create an identity map using symbol id. This representation is optimized
1949  // for storage. Analysis passes may expand it into a multi-dimensional map
1950  // if desired.
1951  AffineMap map = builder.getSymbolIdentityMap();
1952  result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
1953  return success();
1954  }
1955 
1956  // Get the attribute location.
1957  SMLoc attrLoc = p.getCurrentLocation();
1958 
1959  Attribute boundAttr;
1960  if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
1961  result.attributes))
1962  return failure();
1963 
1964  // Parse full form - affine map followed by dim and symbol list.
1965  if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
1966  unsigned currentNumOperands = result.operands.size();
1967  unsigned numDims;
1968  if (parseDimAndSymbolList(p, result.operands, numDims))
1969  return failure();
1970 
1971  auto map = affineMapAttr.getValue();
1972  if (map.getNumDims() != numDims)
1973  return p.emitError(
1974  p.getNameLoc(),
1975  "dim operand count and affine map dim count must match");
1976 
1977  unsigned numDimAndSymbolOperands =
1978  result.operands.size() - currentNumOperands;
1979  if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1980  return p.emitError(
1981  p.getNameLoc(),
1982  "symbol operand count and affine map symbol count must match");
1983 
1984  // If the map has multiple results, make sure that we parsed the min/max
1985  // prefix.
1986  if (map.getNumResults() > 1 && failedToParsedMinMax) {
1987  if (isLower) {
1988  return p.emitError(attrLoc, "lower loop bound affine map with "
1989  "multiple results requires 'max' prefix");
1990  }
1991  return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1992  "results requires 'min' prefix");
1993  }
1994  return success();
1995  }
1996 
1997  // Parse custom assembly form.
1998  if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
1999  result.attributes.pop_back();
2000  result.addAttribute(
2001  boundAttrStrName,
2002  AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
2003  return success();
2004  }
2005 
2006  return p.emitError(
2007  p.getNameLoc(),
2008  "expected valid affine map representation for loop bounds");
2009 }
2010 
2011 ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
2012  auto &builder = parser.getBuilder();
2013  OpAsmParser::Argument inductionVariable;
2014  inductionVariable.type = builder.getIndexType();
2015  // Parse the induction variable followed by '='.
2016  if (parser.parseArgument(inductionVariable) || parser.parseEqual())
2017  return failure();
2018 
2019  // Parse loop bounds.
2020  int64_t numOperands = result.operands.size();
2021  if (parseBound(/*isLower=*/true, result, parser))
2022  return failure();
2023  int64_t numLbOperands = result.operands.size() - numOperands;
2024  if (parser.parseKeyword("to", " between bounds"))
2025  return failure();
2026  numOperands = result.operands.size();
2027  if (parseBound(/*isLower=*/false, result, parser))
2028  return failure();
2029  int64_t numUbOperands = result.operands.size() - numOperands;
2030 
2031  // Parse the optional loop step, we default to 1 if one is not present.
2032  if (parser.parseOptionalKeyword("step")) {
2033  result.addAttribute(
2034  getStepAttrName(result.name),
2035  builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
2036  } else {
2037  SMLoc stepLoc = parser.getCurrentLocation();
2038  IntegerAttr stepAttr;
2039  if (parser.parseAttribute(stepAttr, builder.getIndexType(),
2040  getStepAttrName(result.name).data(),
2041  result.attributes))
2042  return failure();
2043 
2044  if (stepAttr.getValue().isNegative())
2045  return parser.emitError(
2046  stepLoc,
2047  "expected step to be representable as a positive signed integer");
2048  }
2049 
2050  // Parse the optional initial iteration arguments.
2053 
2054  // Induction variable.
2055  regionArgs.push_back(inductionVariable);
2056 
2057  if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
2058  // Parse assignment list and results type list.
2059  if (parser.parseAssignmentList(regionArgs, operands) ||
2060  parser.parseArrowTypeList(result.types))
2061  return failure();
2062  // Resolve input operands.
2063  for (auto argOperandType :
2064  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2065  Type type = std::get<2>(argOperandType);
2066  std::get<0>(argOperandType).type = type;
2067  if (parser.resolveOperand(std::get<1>(argOperandType), type,
2068  result.operands))
2069  return failure();
2070  }
2071  }
2072 
2073  result.addAttribute(
2074  getOperandSegmentSizeAttr(),
2075  builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands),
2076  static_cast<int32_t>(numUbOperands),
2077  static_cast<int32_t>(operands.size())}));
2078 
2079  // Parse the body region.
2080  Region *body = result.addRegion();
2081  if (regionArgs.size() != result.types.size() + 1)
2082  return parser.emitError(
2083  parser.getNameLoc(),
2084  "mismatch between the number of loop-carried values and results");
2085  if (parser.parseRegion(*body, regionArgs))
2086  return failure();
2087 
2088  AffineForOp::ensureTerminator(*body, builder, result.location);
2089 
2090  // Parse the optional attribute list.
2091  return parser.parseOptionalAttrDict(result.attributes);
2092 }
2093 
2094 static void printBound(AffineMapAttr boundMap,
2095  Operation::operand_range boundOperands,
2096  const char *prefix, OpAsmPrinter &p) {
2097  AffineMap map = boundMap.getValue();
2098 
2099  // Check if this bound should be printed using custom assembly form.
2100  // The decision to restrict printing custom assembly form to trivial cases
2101  // comes from the will to roundtrip MLIR binary -> text -> binary in a
2102  // lossless way.
2103  // Therefore, custom assembly form parsing and printing is only supported for
2104  // zero-operand constant maps and single symbol operand identity maps.
2105  if (map.getNumResults() == 1) {
2106  AffineExpr expr = map.getResult(0);
2107 
2108  // Print constant bound.
2109  if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2110  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2111  p << constExpr.getValue();
2112  return;
2113  }
2114  }
2115 
2116  // Print bound that consists of a single SSA symbol if the map is over a
2117  // single symbol.
2118  if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2119  if (dyn_cast<AffineSymbolExpr>(expr)) {
2120  p.printOperand(*boundOperands.begin());
2121  return;
2122  }
2123  }
2124  } else {
2125  // Map has multiple results. Print 'min' or 'max' prefix.
2126  p << prefix << ' ';
2127  }
2128 
2129  // Print the map and its operands.
2130  p << boundMap;
2131  printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
2132  map.getNumDims(), p);
2133 }
2134 
2135 unsigned AffineForOp::getNumIterOperands() {
2136  AffineMap lbMap = getLowerBoundMapAttr().getValue();
2137  AffineMap ubMap = getUpperBoundMapAttr().getValue();
2138 
2139  return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2140 }
2141 
2142 std::optional<MutableArrayRef<OpOperand>>
2143 AffineForOp::getYieldedValuesMutable() {
2144  return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2145 }
2146 
2148  p << ' ';
2149  p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
2150  /*omitType=*/true);
2151  p << " = ";
2152  printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
2153  p << " to ";
2154  printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
2155 
2156  if (getStepAsInt() != 1)
2157  p << " step " << getStepAsInt();
2158 
2159  bool printBlockTerminators = false;
2160  if (getNumIterOperands() > 0) {
2161  p << " iter_args(";
2162  auto regionArgs = getRegionIterArgs();
2163  auto operands = getInits();
2164 
2165  llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2166  p << std::get<0>(it) << " = " << std::get<1>(it);
2167  });
2168  p << ") -> (" << getResultTypes() << ")";
2169  printBlockTerminators = true;
2170  }
2171 
2172  p << ' ';
2173  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2174  printBlockTerminators);
2176  (*this)->getAttrs(),
2177  /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()),
2178  getUpperBoundMapAttrName(getOperation()->getName()),
2179  getStepAttrName(getOperation()->getName()),
2180  getOperandSegmentSizeAttr()});
2181 }
2182 
2183 /// Fold the constant bounds of a loop.
2184 static LogicalResult foldLoopBounds(AffineForOp forOp) {
2185  auto foldLowerOrUpperBound = [&forOp](bool lower) {
2186  // Check to see if each of the operands is the result of a constant. If
2187  // so, get the value. If not, ignore it.
2188  SmallVector<Attribute, 8> operandConstants;
2189  auto boundOperands =
2190  lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2191  for (auto operand : boundOperands) {
2192  Attribute operandCst;
2193  matchPattern(operand, m_Constant(&operandCst));
2194  operandConstants.push_back(operandCst);
2195  }
2196 
2197  AffineMap boundMap =
2198  lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2199  assert(boundMap.getNumResults() >= 1 &&
2200  "bound maps should have at least one result");
2201  SmallVector<Attribute, 4> foldedResults;
2202  if (failed(boundMap.constantFold(operandConstants, foldedResults)))
2203  return failure();
2204 
2205  // Compute the max or min as applicable over the results.
2206  assert(!foldedResults.empty() && "bounds should have at least one result");
2207  auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2208  for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2209  auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2210  maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2211  : llvm::APIntOps::smin(maxOrMin, foldedResult);
2212  }
2213  lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2214  : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2215  return success();
2216  };
2217 
2218  // Try to fold the lower bound.
2219  bool folded = false;
2220  if (!forOp.hasConstantLowerBound())
2221  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
2222 
2223  // Try to fold the upper bound.
2224  if (!forOp.hasConstantUpperBound())
2225  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
2226  return success(folded);
2227 }
2228 
2229 /// Canonicalize the bounds of the given loop.
2230 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2231  SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2232  SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2233 
2234  auto lbMap = forOp.getLowerBoundMap();
2235  auto ubMap = forOp.getUpperBoundMap();
2236  auto prevLbMap = lbMap;
2237  auto prevUbMap = ubMap;
2238 
2239  composeAffineMapAndOperands(&lbMap, &lbOperands);
2240  canonicalizeMapAndOperands(&lbMap, &lbOperands);
2241  simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
2242  simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
2243  lbMap = removeDuplicateExprs(lbMap);
2244 
2245  composeAffineMapAndOperands(&ubMap, &ubOperands);
2246  canonicalizeMapAndOperands(&ubMap, &ubOperands);
2247  ubMap = removeDuplicateExprs(ubMap);
2248 
2249  // Any canonicalization change always leads to updated map(s).
2250  if (lbMap == prevLbMap && ubMap == prevUbMap)
2251  return failure();
2252 
2253  if (lbMap != prevLbMap)
2254  forOp.setLowerBound(lbOperands, lbMap);
2255  if (ubMap != prevUbMap)
2256  forOp.setUpperBound(ubOperands, ubMap);
2257  return success();
2258 }
2259 
2260 namespace {
2261 /// Returns constant trip count in trivial cases.
2262 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2263  int64_t step = forOp.getStepAsInt();
2264  if (!forOp.hasConstantBounds() || step <= 0)
2265  return std::nullopt;
2266  int64_t lb = forOp.getConstantLowerBound();
2267  int64_t ub = forOp.getConstantUpperBound();
2268  return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2269 }
2270 
2271 /// This is a pattern to fold trivially empty loop bodies.
2272 /// TODO: This should be moved into the folding hook.
2273 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
2275 
2276  LogicalResult matchAndRewrite(AffineForOp forOp,
2277  PatternRewriter &rewriter) const override {
2278  // Check that the body only contains a yield.
2279  if (!llvm::hasSingleElement(*forOp.getBody()))
2280  return failure();
2281  if (forOp.getNumResults() == 0)
2282  return success();
2283  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2284  if (tripCount && *tripCount == 0) {
2285  // The initial values of the iteration arguments would be the op's
2286  // results.
2287  rewriter.replaceOp(forOp, forOp.getInits());
2288  return success();
2289  }
2290  SmallVector<Value, 4> replacements;
2291  auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2292  auto iterArgs = forOp.getRegionIterArgs();
2293  bool hasValDefinedOutsideLoop = false;
2294  bool iterArgsNotInOrder = false;
2295  for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2296  Value val = yieldOp.getOperand(i);
2297  auto *iterArgIt = llvm::find(iterArgs, val);
2298  if (iterArgIt == iterArgs.end()) {
2299  // `val` is defined outside of the loop.
2300  assert(forOp.isDefinedOutsideOfLoop(val) &&
2301  "must be defined outside of the loop");
2302  hasValDefinedOutsideLoop = true;
2303  replacements.push_back(val);
2304  } else {
2305  unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2306  if (pos != i)
2307  iterArgsNotInOrder = true;
2308  replacements.push_back(forOp.getInits()[pos]);
2309  }
2310  }
2311  // Bail out when the trip count is unknown and the loop returns any value
2312  // defined outside of the loop or any iterArg out of order.
2313  if (!tripCount.has_value() &&
2314  (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2315  return failure();
2316  // Bail out when the loop iterates more than once and it returns any iterArg
2317  // out of order.
2318  if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2319  return failure();
2320  rewriter.replaceOp(forOp, replacements);
2321  return success();
2322  }
2323 };
2324 } // namespace
2325 
2326 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
2327  MLIRContext *context) {
2328  results.add<AffineForEmptyLoopFolder>(context);
2329 }
2330 
2331 OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2332  assert((point.isParent() || point == getRegion()) && "invalid region point");
2333 
2334  // The initial operands map to the loop arguments after the induction
2335  // variable or are forwarded to the results when the trip count is zero.
2336  return getInits();
2337 }
2338 
2339 void AffineForOp::getSuccessorRegions(
2341  assert((point.isParent() || point == getRegion()) && "expected loop region");
2342  // The loop may typically branch back to its body or to the parent operation.
2343  // If the predecessor is the parent op and the trip count is known to be at
2344  // least one, branch into the body using the iterator arguments. And in cases
2345  // we know the trip count is zero, it can only branch back to its parent.
2346  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2347  if (point.isParent() && tripCount.has_value()) {
2348  if (tripCount.value() > 0) {
2349  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2350  return;
2351  }
2352  if (tripCount.value() == 0) {
2353  regions.push_back(RegionSuccessor(getResults()));
2354  return;
2355  }
2356  }
2357 
2358  // From the loop body, if the trip count is one, we can only branch back to
2359  // the parent.
2360  if (!point.isParent() && tripCount && *tripCount == 1) {
2361  regions.push_back(RegionSuccessor(getResults()));
2362  return;
2363  }
2364 
2365  // In all other cases, the loop may branch back to itself or the parent
2366  // operation.
2367  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2368  regions.push_back(RegionSuccessor(getResults()));
2369 }
2370 
2371 /// Returns true if the affine.for has zero iterations in trivial cases.
2372 static bool hasTrivialZeroTripCount(AffineForOp op) {
2373  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
2374  return tripCount && *tripCount == 0;
2375 }
2376 
2377 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2378  SmallVectorImpl<OpFoldResult> &results) {
2379  bool folded = succeeded(foldLoopBounds(*this));
2380  folded |= succeeded(canonicalizeLoopBounds(*this));
2381  if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
2382  // The initial values of the loop-carried variables (iter_args) are the
2383  // results of the op. But this must be avoided for an affine.for op that
2384  // does not return any results. Since ops that do not return results cannot
2385  // be folded away, we would enter an infinite loop of folds on the same
2386  // affine.for op.
2387  results.assign(getInits().begin(), getInits().end());
2388  folded = true;
2389  }
2390  return success(folded);
2391 }
2392 
2394  return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
2395 }
2396 
2398  return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap());
2399 }
2400 
2401 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2402  assert(lbOperands.size() == map.getNumInputs());
2403  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2404  getLowerBoundOperandsMutable().assign(lbOperands);
2405  setLowerBoundMap(map);
2406 }
2407 
2408 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2409  assert(ubOperands.size() == map.getNumInputs());
2410  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2411  getUpperBoundOperandsMutable().assign(ubOperands);
2412  setUpperBoundMap(map);
2413 }
2414 
2415 bool AffineForOp::hasConstantLowerBound() {
2416  return getLowerBoundMap().isSingleConstant();
2417 }
2418 
2419 bool AffineForOp::hasConstantUpperBound() {
2420  return getUpperBoundMap().isSingleConstant();
2421 }
2422 
2423 int64_t AffineForOp::getConstantLowerBound() {
2424  return getLowerBoundMap().getSingleConstantResult();
2425 }
2426 
2427 int64_t AffineForOp::getConstantUpperBound() {
2428  return getUpperBoundMap().getSingleConstantResult();
2429 }
2430 
2431 void AffineForOp::setConstantLowerBound(int64_t value) {
2432  setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2433 }
2434 
2435 void AffineForOp::setConstantUpperBound(int64_t value) {
2436  setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2437 }
2438 
2439 AffineForOp::operand_range AffineForOp::getControlOperands() {
2440  return {operand_begin(), operand_begin() + getLowerBoundOperands().size() +
2441  getUpperBoundOperands().size()};
2442 }
2443 
2444 bool AffineForOp::matchingBoundOperandList() {
2445  auto lbMap = getLowerBoundMap();
2446  auto ubMap = getUpperBoundMap();
2447  if (lbMap.getNumDims() != ubMap.getNumDims() ||
2448  lbMap.getNumSymbols() != ubMap.getNumSymbols())
2449  return false;
2450 
2451  unsigned numOperands = lbMap.getNumInputs();
2452  for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2453  // Compare Value 's.
2454  if (getOperand(i) != getOperand(numOperands + i))
2455  return false;
2456  }
2457  return true;
2458 }
2459 
2460 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
2461 
2462 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2463  return SmallVector<Value>{getInductionVar()};
2464 }
2465 
2466 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2467  if (!hasConstantLowerBound())
2468  return std::nullopt;
2469  OpBuilder b(getContext());
2471  OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2472 }
2473 
2474 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2475  OpBuilder b(getContext());
2477  OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
2478 }
2479 
2480 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2481  if (!hasConstantUpperBound())
2482  return {};
2483  OpBuilder b(getContext());
2485  OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2486 }
2487 
2488 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2489  RewriterBase &rewriter, ValueRange newInitOperands,
2490  bool replaceInitOperandUsesInLoop,
2491  const NewYieldValuesFn &newYieldValuesFn) {
2492  // Create a new loop before the existing one, with the extra operands.
2493  OpBuilder::InsertionGuard g(rewriter);
2494  rewriter.setInsertionPoint(getOperation());
2495  auto inits = llvm::to_vector(getInits());
2496  inits.append(newInitOperands.begin(), newInitOperands.end());
2497  AffineForOp newLoop = rewriter.create<AffineForOp>(
2498  getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2499  getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
2500 
2501  // Generate the new yield values and append them to the scf.yield operation.
2502  auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2503  ArrayRef<BlockArgument> newIterArgs =
2504  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2505  {
2506  OpBuilder::InsertionGuard g(rewriter);
2507  rewriter.setInsertionPoint(yieldOp);
2508  SmallVector<Value> newYieldedValues =
2509  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2510  assert(newInitOperands.size() == newYieldedValues.size() &&
2511  "expected as many new yield values as new iter operands");
2512  rewriter.modifyOpInPlace(yieldOp, [&]() {
2513  yieldOp.getOperandsMutable().append(newYieldedValues);
2514  });
2515  }
2516 
2517  // Move the loop body to the new op.
2518  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2519  newLoop.getBody()->getArguments().take_front(
2520  getBody()->getNumArguments()));
2521 
2522  if (replaceInitOperandUsesInLoop) {
2523  // Replace all uses of `newInitOperands` with the corresponding basic block
2524  // arguments.
2525  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2526  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
2527  [&](OpOperand &use) {
2528  Operation *user = use.getOwner();
2529  return newLoop->isProperAncestor(user);
2530  });
2531  }
2532  }
2533 
2534  // Replace the old loop.
2535  rewriter.replaceOp(getOperation(),
2536  newLoop->getResults().take_front(getNumResults()));
2537  return cast<LoopLikeOpInterface>(newLoop.getOperation());
2538 }
2539 
2540 Speculation::Speculatability AffineForOp::getSpeculatability() {
2541  // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2542  // Start and End.
2543  //
2544  // For Step != 1, the loop may not terminate. We can add more smarts here if
2545  // needed.
2546  return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable
2548 }
2549 
2550 /// Returns true if the provided value is the induction variable of a
2551 /// AffineForOp.
2553  return getForInductionVarOwner(val) != AffineForOp();
2554 }
2555 
2557  return getAffineParallelInductionVarOwner(val) != nullptr;
2558 }
2559 
2562 }
2563 
2565  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2566  if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2567  return AffineForOp();
2568  if (auto forOp =
2569  ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2570  // Check to make sure `val` is the induction variable, not an iter_arg.
2571  return forOp.getInductionVar() == val ? forOp : AffineForOp();
2572  return AffineForOp();
2573 }
2574 
2576  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2577  if (!ivArg || !ivArg.getOwner())
2578  return nullptr;
2579  Operation *containingOp = ivArg.getOwner()->getParentOp();
2580  auto parallelOp = dyn_cast<AffineParallelOp>(containingOp);
2581  if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2582  return parallelOp;
2583  return nullptr;
2584 }
2585 
2586 /// Extracts the induction variables from a list of AffineForOps and returns
2587 /// them.
2589  SmallVectorImpl<Value> *ivs) {
2590  ivs->reserve(forInsts.size());
2591  for (auto forInst : forInsts)
2592  ivs->push_back(forInst.getInductionVar());
2593 }
2594 
2597  ivs.reserve(affineOps.size());
2598  for (Operation *op : affineOps) {
2599  // Add constraints from forOp's bounds.
2600  if (auto forOp = dyn_cast<AffineForOp>(op))
2601  ivs.push_back(forOp.getInductionVar());
2602  else if (auto parallelOp = dyn_cast<AffineParallelOp>(op))
2603  for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2604  ivs.push_back(parallelOp.getBody()->getArgument(i));
2605  }
2606 }
2607 
2608 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2609 /// operations.
2610 template <typename BoundListTy, typename LoopCreatorTy>
2612  OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2613  ArrayRef<int64_t> steps,
2614  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2615  LoopCreatorTy &&loopCreatorFn) {
2616  assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2617  assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2618 
2619  // If there are no loops to be constructed, construct the body anyway.
2620  OpBuilder::InsertionGuard guard(builder);
2621  if (lbs.empty()) {
2622  if (bodyBuilderFn)
2623  bodyBuilderFn(builder, loc, ValueRange());
2624  return;
2625  }
2626 
2627  // Create the loops iteratively and store the induction variables.
2629  ivs.reserve(lbs.size());
2630  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2631  // Callback for creating the loop body, always creates the terminator.
2632  auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
2633  ValueRange iterArgs) {
2634  ivs.push_back(iv);
2635  // In the innermost loop, call the body builder.
2636  if (i == e - 1 && bodyBuilderFn) {
2637  OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
2638  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2639  }
2640  nestedBuilder.create<AffineYieldOp>(nestedLoc);
2641  };
2642 
2643  // Delegate actual loop creation to the callback in order to dispatch
2644  // between constant- and variable-bound loops.
2645  auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2646  builder.setInsertionPointToStart(loop.getBody());
2647  }
2648 }
2649 
2650 /// Creates an affine loop from the bounds known to be constants.
2651 static AffineForOp
2653  int64_t ub, int64_t step,
2654  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2655  return builder.create<AffineForOp>(loc, lb, ub, step,
2656  /*iterArgs=*/std::nullopt, bodyBuilderFn);
2657 }
2658 
2659 /// Creates an affine loop from the bounds that may or may not be constants.
2660 static AffineForOp
2662  int64_t step,
2663  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2664  std::optional<int64_t> lbConst = getConstantIntValue(lb);
2665  std::optional<int64_t> ubConst = getConstantIntValue(ub);
2666  if (lbConst && ubConst)
2667  return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
2668  ubConst.value(), step, bodyBuilderFn);
2669  return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
2670  builder.getDimIdentityMap(), step,
2671  /*iterArgs=*/std::nullopt, bodyBuilderFn);
2672 }
2673 
2675  OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
2677  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2678  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2680 }
2681 
2683  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
2684  ArrayRef<int64_t> steps,
2685  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2686  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2688 }
2689 
2690 //===----------------------------------------------------------------------===//
2691 // AffineIfOp
2692 //===----------------------------------------------------------------------===//
2693 
2694 namespace {
2695 /// Remove else blocks that have nothing other than a zero value yield.
2696 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
2698 
2699  LogicalResult matchAndRewrite(AffineIfOp ifOp,
2700  PatternRewriter &rewriter) const override {
2701  if (ifOp.getElseRegion().empty() ||
2702  !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2703  return failure();
2704 
2705  rewriter.startOpModification(ifOp);
2706  rewriter.eraseBlock(ifOp.getElseBlock());
2707  rewriter.finalizeOpModification(ifOp);
2708  return success();
2709  }
2710 };
2711 
2712 /// Removes affine.if cond if the condition is always true or false in certain
2713 /// trivial cases. Promotes the then/else block in the parent operation block.
2714 struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
2716 
2717  LogicalResult matchAndRewrite(AffineIfOp op,
2718  PatternRewriter &rewriter) const override {
2719 
2720  auto isTriviallyFalse = [](IntegerSet iSet) {
2721  return iSet.isEmptyIntegerSet();
2722  };
2723 
2724  auto isTriviallyTrue = [](IntegerSet iSet) {
2725  return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2726  iSet.getConstraint(0) == 0);
2727  };
2728 
2729  IntegerSet affineIfConditions = op.getIntegerSet();
2730  Block *blockToMove;
2731  if (isTriviallyFalse(affineIfConditions)) {
2732  // The absence, or equivalently, the emptiness of the else region need not
2733  // be checked when affine.if is returning results because if an affine.if
2734  // operation is returning results, it always has a non-empty else region.
2735  if (op.getNumResults() == 0 && !op.hasElse()) {
2736  // If the else region is absent, or equivalently, empty, remove the
2737  // affine.if operation (which is not returning any results).
2738  rewriter.eraseOp(op);
2739  return success();
2740  }
2741  blockToMove = op.getElseBlock();
2742  } else if (isTriviallyTrue(affineIfConditions)) {
2743  blockToMove = op.getThenBlock();
2744  } else {
2745  return failure();
2746  }
2747  Operation *blockToMoveTerminator = blockToMove->getTerminator();
2748  // Promote the "blockToMove" block to the parent operation block between the
2749  // prologue and epilogue of "op".
2750  rewriter.inlineBlockBefore(blockToMove, op);
2751  // Replace the "op" operation with the operands of the
2752  // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2753  // the affine.yield operation present in the "blockToMove" block. It has no
2754  // operands when affine.if is not returning results and therefore, in that
2755  // case, replaceOp just erases "op". When affine.if is not returning
2756  // results, the affine.yield operation can be omitted. It gets inserted
2757  // implicitly.
2758  rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
2759  // Erase the "blockToMoveTerminator" operation since it is now in the parent
2760  // operation block, which already has its own terminator.
2761  rewriter.eraseOp(blockToMoveTerminator);
2762  return success();
2763  }
2764 };
2765 } // namespace
2766 
2767 /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
2768 /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
2769 void AffineIfOp::getSuccessorRegions(
2771  // If the predecessor is an AffineIfOp, then branching into both `then` and
2772  // `else` region is valid.
2773  if (point.isParent()) {
2774  regions.reserve(2);
2775  regions.push_back(
2776  RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
2777  // If the "else" region is empty, branch bach into parent.
2778  if (getElseRegion().empty()) {
2779  regions.push_back(getResults());
2780  } else {
2781  regions.push_back(
2782  RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
2783  }
2784  return;
2785  }
2786 
2787  // If the predecessor is the `else`/`then` region, then branching into parent
2788  // op is valid.
2789  regions.push_back(RegionSuccessor(getResults()));
2790 }
2791 
2792 LogicalResult AffineIfOp::verify() {
2793  // Verify that we have a condition attribute.
2794  // FIXME: This should be specified in the arguments list in ODS.
2795  auto conditionAttr =
2796  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2797  if (!conditionAttr)
2798  return emitOpError("requires an integer set attribute named 'condition'");
2799 
2800  // Verify that there are enough operands for the condition.
2801  IntegerSet condition = conditionAttr.getValue();
2802  if (getNumOperands() != condition.getNumInputs())
2803  return emitOpError("operand count and condition integer set dimension and "
2804  "symbol count must match");
2805 
2806  // Verify that the operands are valid dimension/symbols.
2807  if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
2808  condition.getNumDims())))
2809  return failure();
2810 
2811  return success();
2812 }
2813 
2814 ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
2815  // Parse the condition attribute set.
2816  IntegerSetAttr conditionAttr;
2817  unsigned numDims;
2818  if (parser.parseAttribute(conditionAttr,
2819  AffineIfOp::getConditionAttrStrName(),
2820  result.attributes) ||
2821  parseDimAndSymbolList(parser, result.operands, numDims))
2822  return failure();
2823 
2824  // Verify the condition operands.
2825  auto set = conditionAttr.getValue();
2826  if (set.getNumDims() != numDims)
2827  return parser.emitError(
2828  parser.getNameLoc(),
2829  "dim operand count and integer set dim count must match");
2830  if (numDims + set.getNumSymbols() != result.operands.size())
2831  return parser.emitError(
2832  parser.getNameLoc(),
2833  "symbol operand count and integer set symbol count must match");
2834 
2835  if (parser.parseOptionalArrowTypeList(result.types))
2836  return failure();
2837 
2838  // Create the regions for 'then' and 'else'. The latter must be created even
2839  // if it remains empty for the validity of the operation.
2840  result.regions.reserve(2);
2841  Region *thenRegion = result.addRegion();
2842  Region *elseRegion = result.addRegion();
2843 
2844  // Parse the 'then' region.
2845  if (parser.parseRegion(*thenRegion, {}, {}))
2846  return failure();
2847  AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
2848  result.location);
2849 
2850  // If we find an 'else' keyword then parse the 'else' region.
2851  if (!parser.parseOptionalKeyword("else")) {
2852  if (parser.parseRegion(*elseRegion, {}, {}))
2853  return failure();
2854  AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
2855  result.location);
2856  }
2857 
2858  // Parse the optional attribute list.
2859  if (parser.parseOptionalAttrDict(result.attributes))
2860  return failure();
2861 
2862  return success();
2863 }
2864 
2866  auto conditionAttr =
2867  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2868  p << " " << conditionAttr;
2869  printDimAndSymbolList(operand_begin(), operand_end(),
2870  conditionAttr.getValue().getNumDims(), p);
2871  p.printOptionalArrowTypeList(getResultTypes());
2872  p << ' ';
2873  p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
2874  /*printBlockTerminators=*/getNumResults());
2875 
2876  // Print the 'else' regions if it has any blocks.
2877  auto &elseRegion = this->getElseRegion();
2878  if (!elseRegion.empty()) {
2879  p << " else ";
2880  p.printRegion(elseRegion,
2881  /*printEntryBlockArgs=*/false,
2882  /*printBlockTerminators=*/getNumResults());
2883  }
2884 
2885  // Print the attribute list.
2886  p.printOptionalAttrDict((*this)->getAttrs(),
2887  /*elidedAttrs=*/getConditionAttrStrName());
2888 }
2889 
2890 IntegerSet AffineIfOp::getIntegerSet() {
2891  return (*this)
2892  ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2893  .getValue();
2894 }
2895 
2896 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
2897  (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
2898 }
2899 
2900 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
2901  setIntegerSet(set);
2902  (*this)->setOperands(operands);
2903 }
2904 
2905 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2906  TypeRange resultTypes, IntegerSet set, ValueRange args,
2907  bool withElseRegion) {
2908  assert(resultTypes.empty() || withElseRegion);
2909  OpBuilder::InsertionGuard guard(builder);
2910 
2911  result.addTypes(resultTypes);
2912  result.addOperands(args);
2913  result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
2914 
2915  Region *thenRegion = result.addRegion();
2916  builder.createBlock(thenRegion);
2917  if (resultTypes.empty())
2918  AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
2919 
2920  Region *elseRegion = result.addRegion();
2921  if (withElseRegion) {
2922  builder.createBlock(elseRegion);
2923  if (resultTypes.empty())
2924  AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
2925  }
2926 }
2927 
2928 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2929  IntegerSet set, ValueRange args, bool withElseRegion) {
2930  AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
2931  withElseRegion);
2932 }
2933 
2934 /// Compose any affine.apply ops feeding into `operands` of the integer set
2935 /// `set` by composing the maps of such affine.apply ops with the integer
2936 /// set constraints.
2938  SmallVectorImpl<Value> &operands) {
2939  // We will simply reuse the API of the map composition by viewing the LHSs of
2940  // the equalities and inequalities of `set` as the affine exprs of an affine
2941  // map. Convert to equivalent map, compose, and convert back to set.
2942  auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(),
2943  set.getConstraints(), set.getContext());
2944  // Check if any composition is possible.
2945  if (llvm::none_of(operands,
2946  [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
2947  return;
2948 
2949  composeAffineMapAndOperands(&map, &operands);
2950  set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
2951  set.getEqFlags());
2952 }
2953 
2954 /// Canonicalize an affine if op's conditional (integer set + operands).
2955 LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2956  auto set = getIntegerSet();
2957  SmallVector<Value, 4> operands(getOperands());
2958  composeSetAndOperands(set, operands);
2959  canonicalizeSetAndOperands(&set, &operands);
2960 
2961  // Check if the canonicalization or composition led to any change.
2962  if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
2963  return failure();
2964 
2965  setConditional(set, operands);
2966  return success();
2967 }
2968 
2969 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2970  MLIRContext *context) {
2971  results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
2972 }
2973 
2974 //===----------------------------------------------------------------------===//
2975 // AffineLoadOp
2976 //===----------------------------------------------------------------------===//
2977 
2978 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2979  AffineMap map, ValueRange operands) {
2980  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2981  result.addOperands(operands);
2982  if (map)
2983  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2984  auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
2985  result.types.push_back(memrefType.getElementType());
2986 }
2987 
2988 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2989  Value memref, AffineMap map, ValueRange mapOperands) {
2990  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2991  result.addOperands(memref);
2992  result.addOperands(mapOperands);
2993  auto memrefType = llvm::cast<MemRefType>(memref.getType());
2994  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2995  result.types.push_back(memrefType.getElementType());
2996 }
2997 
2998 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2999  Value memref, ValueRange indices) {
3000  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3001  int64_t rank = memrefType.getRank();
3002  // Create identity map for memrefs with at least one dimension or () -> ()
3003  // for zero-dimensional memrefs.
3004  auto map =
3005  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3006  build(builder, result, memref, map, indices);
3007 }
3008 
3009 ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
3010  auto &builder = parser.getBuilder();
3011  auto indexTy = builder.getIndexType();
3012 
3013  MemRefType type;
3014  OpAsmParser::UnresolvedOperand memrefInfo;
3015  AffineMapAttr mapAttr;
3017  return failure(
3018  parser.parseOperand(memrefInfo) ||
3019  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3020  AffineLoadOp::getMapAttrStrName(),
3021  result.attributes) ||
3022  parser.parseOptionalAttrDict(result.attributes) ||
3023  parser.parseColonType(type) ||
3024  parser.resolveOperand(memrefInfo, type, result.operands) ||
3025  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3026  parser.addTypeToList(type.getElementType(), result.types));
3027 }
3028 
3030  p << " " << getMemRef() << '[';
3031  if (AffineMapAttr mapAttr =
3032  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3033  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3034  p << ']';
3035  p.printOptionalAttrDict((*this)->getAttrs(),
3036  /*elidedAttrs=*/{getMapAttrStrName()});
3037  p << " : " << getMemRefType();
3038 }
3039 
3040 /// Verify common indexing invariants of affine.load, affine.store,
3041 /// affine.vector_load and affine.vector_store.
3042 static LogicalResult
3043 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
3044  Operation::operand_range mapOperands,
3045  MemRefType memrefType, unsigned numIndexOperands) {
3046  AffineMap map = mapAttr.getValue();
3047  if (map.getNumResults() != memrefType.getRank())
3048  return op->emitOpError("affine map num results must equal memref rank");
3049  if (map.getNumInputs() != numIndexOperands)
3050  return op->emitOpError("expects as many subscripts as affine map inputs");
3051 
3052  Region *scope = getAffineScope(op);
3053  for (auto idx : mapOperands) {
3054  if (!idx.getType().isIndex())
3055  return op->emitOpError("index to load must have 'index' type");
3056  if (!isValidAffineIndexOperand(idx, scope))
3057  return op->emitOpError(
3058  "index must be a valid dimension or symbol identifier");
3059  }
3060 
3061  return success();
3062 }
3063 
3064 LogicalResult AffineLoadOp::verify() {
3065  auto memrefType = getMemRefType();
3066  if (getType() != memrefType.getElementType())
3067  return emitOpError("result type must match element type of memref");
3068 
3069  if (failed(verifyMemoryOpIndexing(
3070  getOperation(),
3071  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3072  getMapOperands(), memrefType,
3073  /*numIndexOperands=*/getNumOperands() - 1)))
3074  return failure();
3075 
3076  return success();
3077 }
3078 
3079 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3080  MLIRContext *context) {
3081  results.add<SimplifyAffineOp<AffineLoadOp>>(context);
3082 }
3083 
3084 OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3085  /// load(memrefcast) -> load
3086  if (succeeded(memref::foldMemRefCast(*this)))
3087  return getResult();
3088 
3089  // Fold load from a global constant memref.
3090  auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3091  if (!getGlobalOp)
3092  return {};
3093  // Get to the memref.global defining the symbol.
3094  auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3095  if (!symbolTableOp)
3096  return {};
3097  auto global = dyn_cast_or_null<memref::GlobalOp>(
3098  SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
3099  if (!global)
3100  return {};
3101 
3102  // Check if the global memref is a constant.
3103  auto cstAttr =
3104  llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3105  if (!cstAttr)
3106  return {};
3107  // If it's a splat constant, we can fold irrespective of indices.
3108  if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3109  return splatAttr.getSplatValue<Attribute>();
3110  // Otherwise, we can fold only if we know the indices.
3111  if (!getAffineMap().isConstant())
3112  return {};
3113  auto indices = llvm::to_vector<4>(
3114  llvm::map_range(getAffineMap().getConstantResults(),
3115  [](int64_t v) -> uint64_t { return v; }));
3116  return cstAttr.getValues<Attribute>()[indices];
3117 }
3118 
3119 //===----------------------------------------------------------------------===//
3120 // AffineStoreOp
3121 //===----------------------------------------------------------------------===//
3122 
3123 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3124  Value valueToStore, Value memref, AffineMap map,
3125  ValueRange mapOperands) {
3126  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3127  result.addOperands(valueToStore);
3128  result.addOperands(memref);
3129  result.addOperands(mapOperands);
3130  result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3131 }
3132 
3133 // Use identity map.
3134 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3135  Value valueToStore, Value memref,
3136  ValueRange indices) {
3137  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3138  int64_t rank = memrefType.getRank();
3139  // Create identity map for memrefs with at least one dimension or () -> ()
3140  // for zero-dimensional memrefs.
3141  auto map =
3142  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3143  build(builder, result, valueToStore, memref, map, indices);
3144 }
3145 
3146 ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
3147  auto indexTy = parser.getBuilder().getIndexType();
3148 
3149  MemRefType type;
3150  OpAsmParser::UnresolvedOperand storeValueInfo;
3151  OpAsmParser::UnresolvedOperand memrefInfo;
3152  AffineMapAttr mapAttr;
3154  return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3155  parser.parseOperand(memrefInfo) ||
3156  parser.parseAffineMapOfSSAIds(
3157  mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3158  result.attributes) ||
3159  parser.parseOptionalAttrDict(result.attributes) ||
3160  parser.parseColonType(type) ||
3161  parser.resolveOperand(storeValueInfo, type.getElementType(),
3162  result.operands) ||
3163  parser.resolveOperand(memrefInfo, type, result.operands) ||
3164  parser.resolveOperands(mapOperands, indexTy, result.operands));
3165 }
3166 
3168  p << " " << getValueToStore();
3169  p << ", " << getMemRef() << '[';
3170  if (AffineMapAttr mapAttr =
3171  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3172  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3173  p << ']';
3174  p.printOptionalAttrDict((*this)->getAttrs(),
3175  /*elidedAttrs=*/{getMapAttrStrName()});
3176  p << " : " << getMemRefType();
3177 }
3178 
3179 LogicalResult AffineStoreOp::verify() {
3180  // The value to store must have the same type as memref element type.
3181  auto memrefType = getMemRefType();
3182  if (getValueToStore().getType() != memrefType.getElementType())
3183  return emitOpError(
3184  "value to store must have the same type as memref element type");
3185 
3186  if (failed(verifyMemoryOpIndexing(
3187  getOperation(),
3188  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3189  getMapOperands(), memrefType,
3190  /*numIndexOperands=*/getNumOperands() - 2)))
3191  return failure();
3192 
3193  return success();
3194 }
3195 
3196 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3197  MLIRContext *context) {
3198  results.add<SimplifyAffineOp<AffineStoreOp>>(context);
3199 }
3200 
3201 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3202  SmallVectorImpl<OpFoldResult> &results) {
3203  /// store(memrefcast) -> store
3204  return memref::foldMemRefCast(*this, getValueToStore());
3205 }
3206 
3207 //===----------------------------------------------------------------------===//
3208 // AffineMinMaxOpBase
3209 //===----------------------------------------------------------------------===//
3210 
3211 template <typename T>
3212 static LogicalResult verifyAffineMinMaxOp(T op) {
3213  // Verify that operand count matches affine map dimension and symbol count.
3214  if (op.getNumOperands() !=
3215  op.getMap().getNumDims() + op.getMap().getNumSymbols())
3216  return op.emitOpError(
3217  "operand count and affine map dimension and symbol count must match");
3218 
3219  if (op.getMap().getNumResults() == 0)
3220  return op.emitOpError("affine map expect at least one result");
3221  return success();
3222 }
3223 
3224 template <typename T>
3225 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3226  p << ' ' << op->getAttr(T::getMapAttrStrName());
3227  auto operands = op.getOperands();
3228  unsigned numDims = op.getMap().getNumDims();
3229  p << '(' << operands.take_front(numDims) << ')';
3230 
3231  if (operands.size() != numDims)
3232  p << '[' << operands.drop_front(numDims) << ']';
3233  p.printOptionalAttrDict(op->getAttrs(),
3234  /*elidedAttrs=*/{T::getMapAttrStrName()});
3235 }
3236 
3237 template <typename T>
3238 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
3239  OperationState &result) {
3240  auto &builder = parser.getBuilder();
3241  auto indexType = builder.getIndexType();
3244  AffineMapAttr mapAttr;
3245  return failure(
3246  parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3247  result.attributes) ||
3248  parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) ||
3249  parser.parseOperandList(symInfos,
3251  parser.parseOptionalAttrDict(result.attributes) ||
3252  parser.resolveOperands(dimInfos, indexType, result.operands) ||
3253  parser.resolveOperands(symInfos, indexType, result.operands) ||
3254  parser.addTypeToList(indexType, result.types));
3255 }
3256 
3257 /// Fold an affine min or max operation with the given operands. The operand
3258 /// list may contain nulls, which are interpreted as the operand not being a
3259 /// constant.
3260 template <typename T>
3262  static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3263  "expected affine min or max op");
3264 
3265  // Fold the affine map.
3266  // TODO: Fold more cases:
3267  // min(some_affine, some_affine + constant, ...), etc.
3268  SmallVector<int64_t, 2> results;
3269  auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3270 
3271  if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3272  return op.getOperand(0);
3273 
3274  // If some of the map results are not constant, try changing the map in-place.
3275  if (results.empty()) {
3276  // If the map is the same, report that folding did not happen.
3277  if (foldedMap == op.getMap())
3278  return {};
3279  op->setAttr("map", AffineMapAttr::get(foldedMap));
3280  return op.getResult();
3281  }
3282 
3283  // Otherwise, completely fold the op into a constant.
3284  auto resultIt = std::is_same<T, AffineMinOp>::value
3285  ? llvm::min_element(results)
3286  : llvm::max_element(results);
3287  if (resultIt == results.end())
3288  return {};
3289  return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3290 }
3291 
3292 /// Remove duplicated expressions in affine min/max ops.
3293 template <typename T>
3296 
3297  LogicalResult matchAndRewrite(T affineOp,
3298  PatternRewriter &rewriter) const override {
3299  AffineMap oldMap = affineOp.getAffineMap();
3300 
3301  SmallVector<AffineExpr, 4> newExprs;
3302  for (AffineExpr expr : oldMap.getResults()) {
3303  // This is a linear scan over newExprs, but it should be fine given that
3304  // we typically just have a few expressions per op.
3305  if (!llvm::is_contained(newExprs, expr))
3306  newExprs.push_back(expr);
3307  }
3308 
3309  if (newExprs.size() == oldMap.getNumResults())
3310  return failure();
3311 
3312  auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
3313  newExprs, rewriter.getContext());
3314  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3315 
3316  return success();
3317  }
3318 };
3319 
3320 /// Merge an affine min/max op to its consumers if its consumer is also an
3321 /// affine min/max op.
3322 ///
3323 /// This pattern requires the producer affine min/max op is bound to a
3324 /// dimension/symbol that is used as a standalone expression in the consumer
3325 /// affine op's map.
3326 ///
3327 /// For example, a pattern like the following:
3328 ///
3329 /// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3330 /// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3331 ///
3332 /// Can be turned into:
3333 ///
3334 /// %1 = affine.min affine_map<
3335 /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3336 template <typename T>
3339 
3340  LogicalResult matchAndRewrite(T affineOp,
3341  PatternRewriter &rewriter) const override {
3342  AffineMap oldMap = affineOp.getAffineMap();
3343  ValueRange dimOperands =
3344  affineOp.getMapOperands().take_front(oldMap.getNumDims());
3345  ValueRange symOperands =
3346  affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3347 
3348  auto newDimOperands = llvm::to_vector<8>(dimOperands);
3349  auto newSymOperands = llvm::to_vector<8>(symOperands);
3350  SmallVector<AffineExpr, 4> newExprs;
3351  SmallVector<T, 4> producerOps;
3352 
3353  // Go over each expression to see whether it's a single dimension/symbol
3354  // with the corresponding operand which is the result of another affine
3355  // min/max op. If So it can be merged into this affine op.
3356  for (AffineExpr expr : oldMap.getResults()) {
3357  if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3358  Value symValue = symOperands[symExpr.getPosition()];
3359  if (auto producerOp = symValue.getDefiningOp<T>()) {
3360  producerOps.push_back(producerOp);
3361  continue;
3362  }
3363  } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3364  Value dimValue = dimOperands[dimExpr.getPosition()];
3365  if (auto producerOp = dimValue.getDefiningOp<T>()) {
3366  producerOps.push_back(producerOp);
3367  continue;
3368  }
3369  }
3370  // For the above cases we will remove the expression by merging the
3371  // producer affine min/max's affine expressions. Otherwise we need to
3372  // keep the existing expression.
3373  newExprs.push_back(expr);
3374  }
3375 
3376  if (producerOps.empty())
3377  return failure();
3378 
3379  unsigned numUsedDims = oldMap.getNumDims();
3380  unsigned numUsedSyms = oldMap.getNumSymbols();
3381 
3382  // Now go over all producer affine ops and merge their expressions.
3383  for (T producerOp : producerOps) {
3384  AffineMap producerMap = producerOp.getAffineMap();
3385  unsigned numProducerDims = producerMap.getNumDims();
3386  unsigned numProducerSyms = producerMap.getNumSymbols();
3387 
3388  // Collect all dimension/symbol values.
3389  ValueRange dimValues =
3390  producerOp.getMapOperands().take_front(numProducerDims);
3391  ValueRange symValues =
3392  producerOp.getMapOperands().take_back(numProducerSyms);
3393  newDimOperands.append(dimValues.begin(), dimValues.end());
3394  newSymOperands.append(symValues.begin(), symValues.end());
3395 
3396  // For expressions we need to shift to avoid overlap.
3397  for (AffineExpr expr : producerMap.getResults()) {
3398  newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3399  .shiftSymbols(numProducerSyms, numUsedSyms));
3400  }
3401 
3402  numUsedDims += numProducerDims;
3403  numUsedSyms += numProducerSyms;
3404  }
3405 
3406  auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
3407  rewriter.getContext());
3408  auto newOperands =
3409  llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3410  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3411 
3412  return success();
3413  }
3414 };
3415 
3416 /// Canonicalize the result expression order of an affine map and return success
3417 /// if the order changed.
3418 ///
3419 /// The function flattens the map's affine expressions to coefficient arrays and
3420 /// sorts them in lexicographic order. A coefficient array contains a multiplier
3421 /// for every dimension/symbol and a constant term. The canonicalization fails
3422 /// if a result expression is not pure or if the flattening requires local
3423 /// variables that, unlike dimensions and symbols, have no global order.
3424 static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
3425  SmallVector<SmallVector<int64_t>> flattenedExprs;
3426  for (const AffineExpr &resultExpr : map.getResults()) {
3427  // Fail if the expression is not pure.
3428  if (!resultExpr.isPureAffine())
3429  return failure();
3430 
3431  SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3432  auto flattenResult = flattener.walkPostOrder(resultExpr);
3433  if (failed(flattenResult))
3434  return failure();
3435 
3436  // Fail if the flattened expression has local variables.
3437  if (flattener.operandExprStack.back().size() !=
3438  map.getNumDims() + map.getNumSymbols() + 1)
3439  return failure();
3440 
3441  flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3442  flattener.operandExprStack.back().end());
3443  }
3444 
3445  // Fail if sorting is not necessary.
3446  if (llvm::is_sorted(flattenedExprs))
3447  return failure();
3448 
3449  // Reorder the result expressions according to their flattened form.
3450  SmallVector<unsigned> resultPermutation =
3451  llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults()));
3452  llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3453  return flattenedExprs[lhs] < flattenedExprs[rhs];
3454  });
3455  SmallVector<AffineExpr> newExprs;
3456  for (unsigned idx : resultPermutation)
3457  newExprs.push_back(map.getResult(idx));
3458 
3459  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs,
3460  map.getContext());
3461  return success();
3462 }
3463 
3464 /// Canonicalize the affine map result expression order of an affine min/max
3465 /// operation.
3466 ///
3467 /// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3468 /// expressions and replaces the operation if the order changed.
3469 ///
3470 /// For example, the following operation:
3471 ///
3472 /// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3473 ///
3474 /// Turns into:
3475 ///
3476 /// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3477 template <typename T>
3480 
3481  LogicalResult matchAndRewrite(T affineOp,
3482  PatternRewriter &rewriter) const override {
3483  AffineMap map = affineOp.getAffineMap();
3484  if (failed(canonicalizeMapExprAndTermOrder(map)))
3485  return failure();
3486  rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3487  return success();
3488  }
3489 };
3490 
3491 template <typename T>
3494 
3495  LogicalResult matchAndRewrite(T affineOp,
3496  PatternRewriter &rewriter) const override {
3497  if (affineOp.getMap().getNumResults() != 1)
3498  return failure();
3499  rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3500  affineOp.getOperands());
3501  return success();
3502  }
3503 };
3504 
3505 //===----------------------------------------------------------------------===//
3506 // AffineMinOp
3507 //===----------------------------------------------------------------------===//
3508 //
3509 // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3510 //
3511 
3512 OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3513  return foldMinMaxOp(*this, adaptor.getOperands());
3514 }
3515 
3516 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3517  MLIRContext *context) {
3520  MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3522  context);
3523 }
3524 
3525 LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); }
3526 
3527 ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3528  return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3529 }
3530 
3531 void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3532 
3533 //===----------------------------------------------------------------------===//
3534 // AffineMaxOp
3535 //===----------------------------------------------------------------------===//
3536 //
3537 // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3538 //
3539 
3540 OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3541  return foldMinMaxOp(*this, adaptor.getOperands());
3542 }
3543 
3544 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3545  MLIRContext *context) {
3548  MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3550  context);
3551 }
3552 
3553 LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); }
3554 
3555 ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3556  return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3557 }
3558 
3559 void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3560 
3561 //===----------------------------------------------------------------------===//
3562 // AffinePrefetchOp
3563 //===----------------------------------------------------------------------===//
3564 
3565 //
3566 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3567 //
3568 ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3569  OperationState &result) {
3570  auto &builder = parser.getBuilder();
3571  auto indexTy = builder.getIndexType();
3572 
3573  MemRefType type;
3574  OpAsmParser::UnresolvedOperand memrefInfo;
3575  IntegerAttr hintInfo;
3576  auto i32Type = parser.getBuilder().getIntegerType(32);
3577  StringRef readOrWrite, cacheType;
3578 
3579  AffineMapAttr mapAttr;
3581  if (parser.parseOperand(memrefInfo) ||
3582  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3583  AffinePrefetchOp::getMapAttrStrName(),
3584  result.attributes) ||
3585  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3586  parser.parseComma() || parser.parseKeyword("locality") ||
3587  parser.parseLess() ||
3588  parser.parseAttribute(hintInfo, i32Type,
3589  AffinePrefetchOp::getLocalityHintAttrStrName(),
3590  result.attributes) ||
3591  parser.parseGreater() || parser.parseComma() ||
3592  parser.parseKeyword(&cacheType) ||
3593  parser.parseOptionalAttrDict(result.attributes) ||
3594  parser.parseColonType(type) ||
3595  parser.resolveOperand(memrefInfo, type, result.operands) ||
3596  parser.resolveOperands(mapOperands, indexTy, result.operands))
3597  return failure();
3598 
3599  if (readOrWrite != "read" && readOrWrite != "write")
3600  return parser.emitError(parser.getNameLoc(),
3601  "rw specifier has to be 'read' or 'write'");
3602  result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3603  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
3604 
3605  if (cacheType != "data" && cacheType != "instr")
3606  return parser.emitError(parser.getNameLoc(),
3607  "cache type has to be 'data' or 'instr'");
3608 
3609  result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3610  parser.getBuilder().getBoolAttr(cacheType == "data"));
3611 
3612  return success();
3613 }
3614 
3616  p << " " << getMemref() << '[';
3617  AffineMapAttr mapAttr =
3618  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3619  if (mapAttr)
3620  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3621  p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3622  << "locality<" << getLocalityHint() << ">, "
3623  << (getIsDataCache() ? "data" : "instr");
3625  (*this)->getAttrs(),
3626  /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3627  getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3628  p << " : " << getMemRefType();
3629 }
3630 
3631 LogicalResult AffinePrefetchOp::verify() {
3632  auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3633  if (mapAttr) {
3634  AffineMap map = mapAttr.getValue();
3635  if (map.getNumResults() != getMemRefType().getRank())
3636  return emitOpError("affine.prefetch affine map num results must equal"
3637  " memref rank");
3638  if (map.getNumInputs() + 1 != getNumOperands())
3639  return emitOpError("too few operands");
3640  } else {
3641  if (getNumOperands() != 1)
3642  return emitOpError("too few operands");
3643  }
3644 
3645  Region *scope = getAffineScope(*this);
3646  for (auto idx : getMapOperands()) {
3647  if (!isValidAffineIndexOperand(idx, scope))
3648  return emitOpError(
3649  "index must be a valid dimension or symbol identifier");
3650  }
3651  return success();
3652 }
3653 
3654 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3655  MLIRContext *context) {
3656  // prefetch(memrefcast) -> prefetch
3657  results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3658 }
3659 
3660 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3661  SmallVectorImpl<OpFoldResult> &results) {
3662  /// prefetch(memrefcast) -> prefetch
3663  return memref::foldMemRefCast(*this);
3664 }
3665 
3666 //===----------------------------------------------------------------------===//
3667 // AffineParallelOp
3668 //===----------------------------------------------------------------------===//
3669 
3670 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3671  TypeRange resultTypes,
3672  ArrayRef<arith::AtomicRMWKind> reductions,
3673  ArrayRef<int64_t> ranges) {
3674  SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
3675  auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3676  return builder.getConstantAffineMap(value);
3677  }));
3678  SmallVector<int64_t> steps(ranges.size(), 1);
3679  build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
3680  /*ubArgs=*/{}, steps);
3681 }
3682 
3683 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3684  TypeRange resultTypes,
3685  ArrayRef<arith::AtomicRMWKind> reductions,
3686  ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
3687  ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
3688  ArrayRef<int64_t> steps) {
3689  assert(llvm::all_of(lbMaps,
3690  [lbMaps](AffineMap m) {
3691  return m.getNumDims() == lbMaps[0].getNumDims() &&
3692  m.getNumSymbols() == lbMaps[0].getNumSymbols();
3693  }) &&
3694  "expected all lower bounds maps to have the same number of dimensions "
3695  "and symbols");
3696  assert(llvm::all_of(ubMaps,
3697  [ubMaps](AffineMap m) {
3698  return m.getNumDims() == ubMaps[0].getNumDims() &&
3699  m.getNumSymbols() == ubMaps[0].getNumSymbols();
3700  }) &&
3701  "expected all upper bounds maps to have the same number of dimensions "
3702  "and symbols");
3703  assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3704  "expected lower bound maps to have as many inputs as lower bound "
3705  "operands");
3706  assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3707  "expected upper bound maps to have as many inputs as upper bound "
3708  "operands");
3709 
3710  OpBuilder::InsertionGuard guard(builder);
3711  result.addTypes(resultTypes);
3712 
3713  // Convert the reductions to integer attributes.
3714  SmallVector<Attribute, 4> reductionAttrs;
3715  for (arith::AtomicRMWKind reduction : reductions)
3716  reductionAttrs.push_back(
3717  builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
3718  result.addAttribute(getReductionsAttrStrName(),
3719  builder.getArrayAttr(reductionAttrs));
3720 
3721  // Concatenates maps defined in the same input space (same dimensions and
3722  // symbols), assumes there is at least one map.
3723  auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
3724  SmallVectorImpl<int32_t> &groups) {
3725  if (maps.empty())
3726  return AffineMap::get(builder.getContext());
3728  groups.reserve(groups.size() + maps.size());
3729  exprs.reserve(maps.size());
3730  for (AffineMap m : maps) {
3731  llvm::append_range(exprs, m.getResults());
3732  groups.push_back(m.getNumResults());
3733  }
3734  return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3735  maps[0].getContext());
3736  };
3737 
3738  // Set up the bounds.
3739  SmallVector<int32_t> lbGroups, ubGroups;
3740  AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3741  AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3742  result.addAttribute(getLowerBoundsMapAttrStrName(),
3743  AffineMapAttr::get(lbMap));
3744  result.addAttribute(getLowerBoundsGroupsAttrStrName(),
3745  builder.getI32TensorAttr(lbGroups));
3746  result.addAttribute(getUpperBoundsMapAttrStrName(),
3747  AffineMapAttr::get(ubMap));
3748  result.addAttribute(getUpperBoundsGroupsAttrStrName(),
3749  builder.getI32TensorAttr(ubGroups));
3750  result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
3751  result.addOperands(lbArgs);
3752  result.addOperands(ubArgs);
3753 
3754  // Create a region and a block for the body.
3755  auto *bodyRegion = result.addRegion();
3756  Block *body = builder.createBlock(bodyRegion);
3757 
3758  // Add all the block arguments.
3759  for (unsigned i = 0, e = steps.size(); i < e; ++i)
3760  body->addArgument(IndexType::get(builder.getContext()), result.location);
3761  if (resultTypes.empty())
3762  ensureTerminator(*bodyRegion, builder, result.location);
3763 }
3764 
3765 SmallVector<Region *> AffineParallelOp::getLoopRegions() {
3766  return {&getRegion()};
3767 }
3768 
3769 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
3770 
3771 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3772  return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3773 }
3774 
3775 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3776  return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3777 }
3778 
3779 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
3780  auto values = getLowerBoundsGroups().getValues<int32_t>();
3781  unsigned start = 0;
3782  for (unsigned i = 0; i < pos; ++i)
3783  start += values[i];
3784  return getLowerBoundsMap().getSliceMap(start, values[pos]);
3785 }
3786 
3787 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
3788  auto values = getUpperBoundsGroups().getValues<int32_t>();
3789  unsigned start = 0;
3790  for (unsigned i = 0; i < pos; ++i)
3791  start += values[i];
3792  return getUpperBoundsMap().getSliceMap(start, values[pos]);
3793 }
3794 
3795 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
3796  return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3797 }
3798 
3799 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
3800  return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3801 }
3802 
3803 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3804  if (hasMinMaxBounds())
3805  return std::nullopt;
3806 
3807  // Try to convert all the ranges to constant expressions.
3809  AffineValueMap rangesValueMap;
3810  AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3811  &rangesValueMap);
3812  out.reserve(rangesValueMap.getNumResults());
3813  for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
3814  auto expr = rangesValueMap.getResult(i);
3815  auto cst = dyn_cast<AffineConstantExpr>(expr);
3816  if (!cst)
3817  return std::nullopt;
3818  out.push_back(cst.getValue());
3819  }
3820  return out;
3821 }
3822 
3823 Block *AffineParallelOp::getBody() { return &getRegion().front(); }
3824 
3825 OpBuilder AffineParallelOp::getBodyBuilder() {
3826  return OpBuilder(getBody(), std::prev(getBody()->end()));
3827 }
3828 
3829 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
3830  assert(lbOperands.size() == map.getNumInputs() &&
3831  "operands to map must match number of inputs");
3832 
3833  auto ubOperands = getUpperBoundsOperands();
3834 
3835  SmallVector<Value, 4> newOperands(lbOperands);
3836  newOperands.append(ubOperands.begin(), ubOperands.end());
3837  (*this)->setOperands(newOperands);
3838 
3839  setLowerBoundsMapAttr(AffineMapAttr::get(map));
3840 }
3841 
3842 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
3843  assert(ubOperands.size() == map.getNumInputs() &&
3844  "operands to map must match number of inputs");
3845 
3846  SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
3847  newOperands.append(ubOperands.begin(), ubOperands.end());
3848  (*this)->setOperands(newOperands);
3849 
3850  setUpperBoundsMapAttr(AffineMapAttr::get(map));
3851 }
3852 
3853 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
3854  setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3855 }
3856 
3857 // check whether resultType match op or not in affine.parallel
3858 static bool isResultTypeMatchAtomicRMWKind(Type resultType,
3859  arith::AtomicRMWKind op) {
3860  switch (op) {
3861  case arith::AtomicRMWKind::addf:
3862  return isa<FloatType>(resultType);
3863  case arith::AtomicRMWKind::addi:
3864  return isa<IntegerType>(resultType);
3865  case arith::AtomicRMWKind::assign:
3866  return true;
3867  case arith::AtomicRMWKind::mulf:
3868  return isa<FloatType>(resultType);
3869  case arith::AtomicRMWKind::muli:
3870  return isa<IntegerType>(resultType);
3871  case arith::AtomicRMWKind::maximumf:
3872  return isa<FloatType>(resultType);
3873  case arith::AtomicRMWKind::minimumf:
3874  return isa<FloatType>(resultType);
3875  case arith::AtomicRMWKind::maxs: {
3876  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3877  return intType && intType.isSigned();
3878  }
3879  case arith::AtomicRMWKind::mins: {
3880  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3881  return intType && intType.isSigned();
3882  }
3883  case arith::AtomicRMWKind::maxu: {
3884  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3885  return intType && intType.isUnsigned();
3886  }
3887  case arith::AtomicRMWKind::minu: {
3888  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3889  return intType && intType.isUnsigned();
3890  }
3891  case arith::AtomicRMWKind::ori:
3892  return isa<IntegerType>(resultType);
3893  case arith::AtomicRMWKind::andi:
3894  return isa<IntegerType>(resultType);
3895  default:
3896  return false;
3897  }
3898 }
3899 
3900 LogicalResult AffineParallelOp::verify() {
3901  auto numDims = getNumDims();
3902  if (getLowerBoundsGroups().getNumElements() != numDims ||
3903  getUpperBoundsGroups().getNumElements() != numDims ||
3904  getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3905  return emitOpError() << "the number of region arguments ("
3906  << getBody()->getNumArguments()
3907  << ") and the number of map groups for lower ("
3908  << getLowerBoundsGroups().getNumElements()
3909  << ") and upper bound ("
3910  << getUpperBoundsGroups().getNumElements()
3911  << "), and the number of steps (" << getSteps().size()
3912  << ") must all match";
3913  }
3914 
3915  unsigned expectedNumLBResults = 0;
3916  for (APInt v : getLowerBoundsGroups())
3917  expectedNumLBResults += v.getZExtValue();
3918  if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
3919  return emitOpError() << "expected lower bounds map to have "
3920  << expectedNumLBResults << " results";
3921  unsigned expectedNumUBResults = 0;
3922  for (APInt v : getUpperBoundsGroups())
3923  expectedNumUBResults += v.getZExtValue();
3924  if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
3925  return emitOpError() << "expected upper bounds map to have "
3926  << expectedNumUBResults << " results";
3927 
3928  if (getReductions().size() != getNumResults())
3929  return emitOpError("a reduction must be specified for each output");
3930 
3931  // Verify reduction ops are all valid and each result type matches reduction
3932  // ops
3933  for (auto it : llvm::enumerate((getReductions()))) {
3934  Attribute attr = it.value();
3935  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
3936  if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
3937  return emitOpError("invalid reduction attribute");
3938  auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
3939  if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
3940  return emitOpError("result type cannot match reduction attribute");
3941  }
3942 
3943  // Verify that the bound operands are valid dimension/symbols.
3944  /// Lower bounds.
3945  if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
3946  getLowerBoundsMap().getNumDims())))
3947  return failure();
3948  /// Upper bounds.
3949  if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
3950  getUpperBoundsMap().getNumDims())))
3951  return failure();
3952  return success();
3953 }
3954 
3955 LogicalResult AffineValueMap::canonicalize() {
3956  SmallVector<Value, 4> newOperands{operands};
3957  auto newMap = getAffineMap();
3958  composeAffineMapAndOperands(&newMap, &newOperands);
3959  if (newMap == getAffineMap() && newOperands == operands)
3960  return failure();
3961  reset(newMap, newOperands);
3962  return success();
3963 }
3964 
3965 /// Canonicalize the bounds of the given loop.
3966 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
3967  AffineValueMap lb = op.getLowerBoundsValueMap();
3968  bool lbCanonicalized = succeeded(lb.canonicalize());
3969 
3970  AffineValueMap ub = op.getUpperBoundsValueMap();
3971  bool ubCanonicalized = succeeded(ub.canonicalize());
3972 
3973  // Any canonicalization change always leads to updated map(s).
3974  if (!lbCanonicalized && !ubCanonicalized)
3975  return failure();
3976 
3977  if (lbCanonicalized)
3978  op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
3979  if (ubCanonicalized)
3980  op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
3981 
3982  return success();
3983 }
3984 
3985 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
3986  SmallVectorImpl<OpFoldResult> &results) {
3987  return canonicalizeLoopBounds(*this);
3988 }
3989 
3990 /// Prints a lower(upper) bound of an affine parallel loop with max(min)
3991 /// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
3992 /// identifies which of the those expressions form max/min groups. `operands`
3993 /// are the SSA values of dimensions and symbols and `keyword` is either "min"
3994 /// or "max".
3995 static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
3996  DenseIntElementsAttr group, ValueRange operands,
3997  StringRef keyword) {
3998  AffineMap map = mapAttr.getValue();
3999  unsigned numDims = map.getNumDims();
4000  ValueRange dimOperands = operands.take_front(numDims);
4001  ValueRange symOperands = operands.drop_front(numDims);
4002  unsigned start = 0;
4003  for (llvm::APInt groupSize : group) {
4004  if (start != 0)
4005  p << ", ";
4006 
4007  unsigned size = groupSize.getZExtValue();
4008  if (size == 1) {
4009  p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
4010  ++start;
4011  } else {
4012  p << keyword << '(';
4013  AffineMap submap = map.getSliceMap(start, size);
4014  p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
4015  p << ')';
4016  start += size;
4017  }
4018  }
4019 }
4020 
4022  p << " (" << getBody()->getArguments() << ") = (";
4023  printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
4024  getLowerBoundsOperands(), "max");
4025  p << ") to (";
4026  printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
4027  getUpperBoundsOperands(), "min");
4028  p << ')';
4029  SmallVector<int64_t, 8> steps = getSteps();
4030  bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4031  if (!elideSteps) {
4032  p << " step (";
4033  llvm::interleaveComma(steps, p);
4034  p << ')';
4035  }
4036  if (getNumResults()) {
4037  p << " reduce (";
4038  llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4039  arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4040  llvm::cast<IntegerAttr>(attr).getInt());
4041  p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4042  });
4043  p << ") -> (" << getResultTypes() << ")";
4044  }
4045 
4046  p << ' ';
4047  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4048  /*printBlockTerminators=*/getNumResults());
4050  (*this)->getAttrs(),
4051  /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4052  AffineParallelOp::getLowerBoundsMapAttrStrName(),
4053  AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4054  AffineParallelOp::getUpperBoundsMapAttrStrName(),
4055  AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4056  AffineParallelOp::getStepsAttrStrName()});
4057 }
4058 
4059 /// Given a list of lists of parsed operands, populates `uniqueOperands` with
4060 /// unique operands. Also populates `replacements with affine expressions of
4061 /// `kind` that can be used to update affine maps previously accepting a
4062 /// `operands` to accept `uniqueOperands` instead.
4064  OpAsmParser &parser,
4066  SmallVectorImpl<Value> &uniqueOperands,
4067  SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
4068  assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
4069  "expected operands to be dim or symbol expression");
4070 
4071  Type indexType = parser.getBuilder().getIndexType();
4072  for (const auto &list : operands) {
4073  SmallVector<Value> valueOperands;
4074  if (parser.resolveOperands(list, indexType, valueOperands))
4075  return failure();
4076  for (Value operand : valueOperands) {
4077  unsigned pos = std::distance(uniqueOperands.begin(),
4078  llvm::find(uniqueOperands, operand));
4079  if (pos == uniqueOperands.size())
4080  uniqueOperands.push_back(operand);
4081  replacements.push_back(
4082  kind == AffineExprKind::DimId
4083  ? getAffineDimExpr(pos, parser.getContext())
4084  : getAffineSymbolExpr(pos, parser.getContext()));
4085  }
4086  }
4087  return success();
4088 }
4089 
4090 namespace {
4091 enum class MinMaxKind { Min, Max };
4092 } // namespace
4093 
4094 /// Parses an affine map that can contain a min/max for groups of its results,
4095 /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4096 /// `result` attributes with the map (flat list of expressions) and the grouping
4097 /// (list of integers that specify how many expressions to put into each
4098 /// min/max) attributes. Deduplicates repeated operands.
4099 ///
4100 /// parallel-bound ::= `(` parallel-group-list `)`
4101 /// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
4102 /// parallel-group ::= simple-group | min-max-group
4103 /// simple-group ::= expr-of-ssa-ids
4104 /// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4105 /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4106 ///
4107 /// Examples:
4108 /// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4109 /// (%0, max(%1 - 2 * %2))
4110 static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
4111  OperationState &result,
4112  MinMaxKind kind) {
4113  // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4114  // see: https://reviews.llvm.org/D134227#3821753
4115  const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4116 
4117  StringRef mapName = kind == MinMaxKind::Min
4118  ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4119  : AffineParallelOp::getLowerBoundsMapAttrStrName();
4120  StringRef groupsName =
4121  kind == MinMaxKind::Min
4122  ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4123  : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4124 
4125  if (failed(parser.parseLParen()))
4126  return failure();
4127 
4128  if (succeeded(parser.parseOptionalRParen())) {
4129  result.addAttribute(
4130  mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4131  result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
4132  return success();
4133  }
4134 
4135  SmallVector<AffineExpr> flatExprs;
4138  SmallVector<int32_t> numMapsPerGroup;
4140  auto parseOperands = [&]() {
4141  if (succeeded(parser.parseOptionalKeyword(
4142  kind == MinMaxKind::Min ? "min" : "max"))) {
4143  mapOperands.clear();
4144  AffineMapAttr map;
4145  if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName,
4146  result.attributes,
4148  return failure();
4149  result.attributes.erase(tmpAttrStrName);
4150  llvm::append_range(flatExprs, map.getValue().getResults());
4151  auto operandsRef = llvm::ArrayRef(mapOperands);
4152  auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4154  auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4156  flatDimOperands.append(map.getValue().getNumResults(), dims);
4157  flatSymOperands.append(map.getValue().getNumResults(), syms);
4158  numMapsPerGroup.push_back(map.getValue().getNumResults());
4159  } else {
4160  if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
4161  flatSymOperands.emplace_back(),
4162  flatExprs.emplace_back())))
4163  return failure();
4164  numMapsPerGroup.push_back(1);
4165  }
4166  return success();
4167  };
4168  if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
4169  return failure();
4170 
4171  unsigned totalNumDims = 0;
4172  unsigned totalNumSyms = 0;
4173  for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4174  unsigned numDims = flatDimOperands[i].size();
4175  unsigned numSyms = flatSymOperands[i].size();
4176  flatExprs[i] = flatExprs[i]
4177  .shiftDims(numDims, totalNumDims)
4178  .shiftSymbols(numSyms, totalNumSyms);
4179  totalNumDims += numDims;
4180  totalNumSyms += numSyms;
4181  }
4182 
4183  // Deduplicate map operands.
4184  SmallVector<Value> dimOperands, symOperands;
4185  SmallVector<AffineExpr> dimRplacements, symRepacements;
4186  if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
4187  dimRplacements, AffineExprKind::DimId) ||
4188  deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
4189  symRepacements, AffineExprKind::SymbolId))
4190  return failure();
4191 
4192  result.operands.append(dimOperands.begin(), dimOperands.end());
4193  result.operands.append(symOperands.begin(), symOperands.end());
4194 
4195  Builder &builder = parser.getBuilder();
4196  auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4197  parser.getContext());
4198  flatMap = flatMap.replaceDimsAndSymbols(
4199  dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4200 
4201  result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4202  result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
4203  return success();
4204 }
4205 
4206 //
4207 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4208 // `to` parallel-bound steps? region attr-dict?
4209 // steps ::= `steps` `(` integer-literals `)`
4210 //
4211 ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4212  OperationState &result) {
4213  auto &builder = parser.getBuilder();
4214  auto indexType = builder.getIndexType();
4217  parser.parseEqual() ||
4218  parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
4219  parser.parseKeyword("to") ||
4220  parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
4221  return failure();
4222 
4223  AffineMapAttr stepsMapAttr;
4224  NamedAttrList stepsAttrs;
4226  if (failed(parser.parseOptionalKeyword("step"))) {
4227  SmallVector<int64_t, 4> steps(ivs.size(), 1);
4228  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4229  builder.getI64ArrayAttr(steps));
4230  } else {
4231  if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
4232  AffineParallelOp::getStepsAttrStrName(),
4233  stepsAttrs,
4235  return failure();
4236 
4237  // Convert steps from an AffineMap into an I64ArrayAttr.
4239  auto stepsMap = stepsMapAttr.getValue();
4240  for (const auto &result : stepsMap.getResults()) {
4241  auto constExpr = dyn_cast<AffineConstantExpr>(result);
4242  if (!constExpr)
4243  return parser.emitError(parser.getNameLoc(),
4244  "steps must be constant integers");
4245  steps.push_back(constExpr.getValue());
4246  }
4247  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4248  builder.getI64ArrayAttr(steps));
4249  }
4250 
4251  // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4252  // quoted strings are a member of the enum AtomicRMWKind.
4253  SmallVector<Attribute, 4> reductions;
4254  if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4255  if (parser.parseLParen())
4256  return failure();
4257  auto parseAttributes = [&]() -> ParseResult {
4258  // Parse a single quoted string via the attribute parsing, and then
4259  // verify it is a member of the enum and convert to it's integer
4260  // representation.
4261  StringAttr attrVal;
4262  NamedAttrList attrStorage;
4263  auto loc = parser.getCurrentLocation();
4264  if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
4265  attrStorage))
4266  return failure();
4267  std::optional<arith::AtomicRMWKind> reduction =
4268  arith::symbolizeAtomicRMWKind(attrVal.getValue());
4269  if (!reduction)
4270  return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4271  reductions.push_back(
4272  builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4273  // While we keep getting commas, keep parsing.
4274  return success();
4275  };
4276  if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
4277  return failure();
4278  }
4279  result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4280  builder.getArrayAttr(reductions));
4281 
4282  // Parse return types of reductions (if any)
4283  if (parser.parseOptionalArrowTypeList(result.types))
4284  return failure();
4285 
4286  // Now parse the body.
4287  Region *body = result.addRegion();
4288  for (auto &iv : ivs)
4289  iv.type = indexType;
4290  if (parser.parseRegion(*body, ivs) ||
4291  parser.parseOptionalAttrDict(result.attributes))
4292  return failure();
4293 
4294  // Add a terminator if none was parsed.
4295  AffineParallelOp::ensureTerminator(*body, builder, result.location);
4296  return success();
4297 }
4298 
4299 //===----------------------------------------------------------------------===//
4300 // AffineYieldOp
4301 //===----------------------------------------------------------------------===//
4302 
4303 LogicalResult AffineYieldOp::verify() {
4304  auto *parentOp = (*this)->getParentOp();
4305  auto results = parentOp->getResults();
4306  auto operands = getOperands();
4307 
4308  if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4309  return emitOpError() << "only terminates affine.if/for/parallel regions";
4310  if (parentOp->getNumResults() != getNumOperands())
4311  return emitOpError() << "parent of yield must have same number of "
4312  "results as the yield operands";
4313  for (auto it : llvm::zip(results, operands)) {
4314  if (std::get<0>(it).getType() != std::get<1>(it).getType())
4315  return emitOpError() << "types mismatch between yield op and its parent";
4316  }
4317 
4318  return success();
4319 }
4320 
4321 //===----------------------------------------------------------------------===//
4322 // AffineVectorLoadOp
4323 //===----------------------------------------------------------------------===//
4324 
4325 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4326  VectorType resultType, AffineMap map,
4327  ValueRange operands) {
4328  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4329  result.addOperands(operands);
4330  if (map)
4331  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4332  result.types.push_back(resultType);
4333 }
4334 
4335 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4336  VectorType resultType, Value memref,
4337  AffineMap map, ValueRange mapOperands) {
4338  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4339  result.addOperands(memref);
4340  result.addOperands(mapOperands);
4341  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4342  result.types.push_back(resultType);
4343 }
4344 
4345 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4346  VectorType resultType, Value memref,
4347  ValueRange indices) {
4348  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4349  int64_t rank = memrefType.getRank();
4350  // Create identity map for memrefs with at least one dimension or () -> ()
4351  // for zero-dimensional memrefs.
4352  auto map =
4353  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4354  build(builder, result, resultType, memref, map, indices);
4355 }
4356 
4357 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4358  MLIRContext *context) {
4359  results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4360 }
4361 
4362 ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4363  OperationState &result) {
4364  auto &builder = parser.getBuilder();
4365  auto indexTy = builder.getIndexType();
4366 
4367  MemRefType memrefType;
4368  VectorType resultType;
4369  OpAsmParser::UnresolvedOperand memrefInfo;
4370  AffineMapAttr mapAttr;
4372  return failure(
4373  parser.parseOperand(memrefInfo) ||
4374  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4375  AffineVectorLoadOp::getMapAttrStrName(),
4376  result.attributes) ||
4377  parser.parseOptionalAttrDict(result.attributes) ||
4378  parser.parseColonType(memrefType) || parser.parseComma() ||
4379  parser.parseType(resultType) ||
4380  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4381  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
4382  parser.addTypeToList(resultType, result.types));
4383 }
4384 
4386  p << " " << getMemRef() << '[';
4387  if (AffineMapAttr mapAttr =
4388  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4389  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4390  p << ']';
4391  p.printOptionalAttrDict((*this)->getAttrs(),
4392  /*elidedAttrs=*/{getMapAttrStrName()});
4393  p << " : " << getMemRefType() << ", " << getType();
4394 }
4395 
4396 /// Verify common invariants of affine.vector_load and affine.vector_store.
4397 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4398  VectorType vectorType) {
4399  // Check that memref and vector element types match.
4400  if (memrefType.getElementType() != vectorType.getElementType())
4401  return op->emitOpError(
4402  "requires memref and vector types of the same elemental type");
4403  return success();
4404 }
4405 
4406 LogicalResult AffineVectorLoadOp::verify() {
4407  MemRefType memrefType = getMemRefType();
4408  if (failed(verifyMemoryOpIndexing(
4409  getOperation(),
4410  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4411  getMapOperands(), memrefType,
4412  /*numIndexOperands=*/getNumOperands() - 1)))
4413  return failure();
4414 
4415  if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
4416  return failure();
4417 
4418  return success();
4419 }
4420 
4421 //===----------------------------------------------------------------------===//
4422 // AffineVectorStoreOp
4423 //===----------------------------------------------------------------------===//
4424 
4425 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4426  Value valueToStore, Value memref, AffineMap map,
4427  ValueRange mapOperands) {
4428  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4429  result.addOperands(valueToStore);
4430  result.addOperands(memref);
4431  result.addOperands(mapOperands);
4432  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4433 }
4434 
4435 // Use identity map.
4436 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4437  Value valueToStore, Value memref,
4438  ValueRange indices) {
4439  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4440  int64_t rank = memrefType.getRank();
4441  // Create identity map for memrefs with at least one dimension or () -> ()
4442  // for zero-dimensional memrefs.
4443  auto map =
4444  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4445  build(builder, result, valueToStore, memref, map, indices);
4446 }
4447 void AffineVectorStoreOp::getCanonicalizationPatterns(
4448  RewritePatternSet &results, MLIRContext *context) {
4449  results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4450 }
4451 
4452 ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4453  OperationState &result) {
4454  auto indexTy = parser.getBuilder().getIndexType();
4455 
4456  MemRefType memrefType;
4457  VectorType resultType;
4458  OpAsmParser::UnresolvedOperand storeValueInfo;
4459  OpAsmParser::UnresolvedOperand memrefInfo;
4460  AffineMapAttr mapAttr;
4462  return failure(
4463  parser.parseOperand(storeValueInfo) || parser.parseComma() ||
4464  parser.parseOperand(memrefInfo) ||
4465  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4466  AffineVectorStoreOp::getMapAttrStrName(),
4467  result.attributes) ||
4468  parser.parseOptionalAttrDict(result.attributes) ||
4469  parser.parseColonType(memrefType) || parser.parseComma() ||
4470  parser.parseType(resultType) ||
4471  parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
4472  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4473  parser.resolveOperands(mapOperands, indexTy, result.operands));
4474 }
4475 
4477  p << " " << getValueToStore();
4478  p << ", " << getMemRef() << '[';
4479  if (AffineMapAttr mapAttr =
4480  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4481  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4482  p << ']';
4483  p.printOptionalAttrDict((*this)->getAttrs(),
4484  /*elidedAttrs=*/{getMapAttrStrName()});
4485  p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4486 }
4487 
4488 LogicalResult AffineVectorStoreOp::verify() {
4489  MemRefType memrefType = getMemRefType();
4490  if (failed(verifyMemoryOpIndexing(
4491  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4492  getMapOperands(), memrefType,
4493  /*numIndexOperands=*/getNumOperands() - 2)))
4494  return failure();
4495 
4496  if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4497  return failure();
4498 
4499  return success();
4500 }
4501 
4502 //===----------------------------------------------------------------------===//
4503 // DelinearizeIndexOp
4504 //===----------------------------------------------------------------------===//
4505 
4506 LogicalResult AffineDelinearizeIndexOp::inferReturnTypes(
4507  MLIRContext *context, std::optional<::mlir::Location> location,
4508  ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
4509  RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
4510  AffineDelinearizeIndexOpAdaptor adaptor(operands, attributes, properties,
4511  regions);
4512  inferredReturnTypes.assign(adaptor.getStaticBasis().size(),
4513  IndexType::get(context));
4514  return success();
4515 }
4516 
4517 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4518  OperationState &odsState,
4519  Value linearIndex, ValueRange basis) {
4520  SmallVector<Value> dynamicBasis;
4521  SmallVector<int64_t> staticBasis;
4522  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4523  staticBasis);
4524  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
4525 }
4526 
4527 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4528  OperationState &odsState,
4529  Value linearIndex,
4530  ArrayRef<OpFoldResult> basis) {
4531  SmallVector<Value> dynamicBasis;
4532  SmallVector<int64_t> staticBasis;
4533  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
4534  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis);
4535 }
4536 
4537 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4538  OperationState &odsState,
4539  Value linearIndex,
4540  ArrayRef<int64_t> basis) {
4541  build(odsBuilder, odsState, linearIndex, ValueRange{}, basis);
4542 }
4543 
4544 LogicalResult AffineDelinearizeIndexOp::verify() {
4545  if (getStaticBasis().empty())
4546  return emitOpError("basis should not be empty");
4547  if (getNumResults() != getStaticBasis().size())
4548  return emitOpError("should return an index for each basis element");
4549  auto dynamicMarkersCount =
4550  llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
4551  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4552  return emitOpError(
4553  "mismatch between dynamic and static basis (kDynamic marker but no "
4554  "corresponding dynamic basis entry) -- this can only happen due to an "
4555  "incorrect fold/rewrite");
4556  return success();
4557 }
4558 
4559 namespace {
4560 
4561 // Drops delinearization indices that correspond to unit-extent basis
4562 struct DropUnitExtentBasis
4563  : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4565 
4566  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4567  PatternRewriter &rewriter) const override {
4568  SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
4569  std::optional<Value> zero = std::nullopt;
4570  Location loc = delinearizeOp->getLoc();
4571  auto getZero = [&]() -> Value {
4572  if (!zero)
4573  zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
4574  return zero.value();
4575  };
4576 
4577  // Replace all indices corresponding to unit-extent basis with 0.
4578  // Remaining basis can be used to get a new `affine.delinearize_index` op.
4579  SmallVector<OpFoldResult> newOperands;
4580  for (auto [index, basis] : llvm::enumerate(delinearizeOp.getMixedBasis())) {
4581  std::optional<int64_t> basisVal = getConstantIntValue(basis);
4582  if (basisVal && *basisVal == 1)
4583  replacements[index] = getZero();
4584  else
4585  newOperands.push_back(basis);
4586  }
4587 
4588  if (newOperands.size() == delinearizeOp.getStaticBasis().size())
4589  return failure();
4590 
4591  if (!newOperands.empty()) {
4592  auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
4593  loc, delinearizeOp.getLinearIndex(), newOperands);
4594  int newIndex = 0;
4595  // Map back the new delinearized indices to the values they replace.
4596  for (auto &replacement : replacements) {
4597  if (replacement)
4598  continue;
4599  replacement = newDelinearizeOp->getResult(newIndex++);
4600  }
4601  }
4602 
4603  rewriter.replaceOp(delinearizeOp, replacements);
4604  return success();
4605  }
4606 };
4607 
4608 /// Drop delinearization with a single basis element
4609 ///
4610 /// By definition, `delinearize_index %linear into (%basis)` is
4611 /// `%linear floorDiv 1` (since `1` is the product of the basis elememts,
4612 /// ignoring the 0th one, and since there is no previous division we need
4613 /// to use the remainder of). Therefore, a single-element `delinearize`
4614 /// can be replaced by the underlying linear index.
4615 struct DropDelinearizeOneBasisElement
4616  : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4618 
4619  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4620  PatternRewriter &rewriter) const override {
4621  if (delinearizeOp.getStaticBasis().size() != 1)
4622  return failure();
4623  rewriter.replaceOp(delinearizeOp, delinearizeOp.getLinearIndex());
4624  return success();
4625  }
4626 };
4627 
4628 } // namespace
4629 
4630 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
4631  RewritePatternSet &patterns, MLIRContext *context) {
4632  patterns.insert<DropDelinearizeOneBasisElement, DropUnitExtentBasis>(context);
4633 }
4634 
4635 //===----------------------------------------------------------------------===//
4636 // LinearizeIndexOp
4637 //===----------------------------------------------------------------------===//
4638 
4639 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4640  OperationState &odsState,
4641  ValueRange multiIndex, ValueRange basis,
4642  bool disjoint) {
4643  SmallVector<Value> dynamicBasis;
4644  SmallVector<int64_t> staticBasis;
4645  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4646  staticBasis);
4647  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4648 }
4649 
4650 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4651  OperationState &odsState,
4652  ValueRange multiIndex,
4653  ArrayRef<OpFoldResult> basis,
4654  bool disjoint) {
4655  SmallVector<Value> dynamicBasis;
4656  SmallVector<int64_t> staticBasis;
4657  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
4658  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4659 }
4660 
4661 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4662  OperationState &odsState,
4663  ValueRange multiIndex,
4664  ArrayRef<int64_t> basis, bool disjoint) {
4665  build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
4666 }
4667 
4668 LogicalResult AffineLinearizeIndexOp::verify() {
4669  if (getStaticBasis().empty())
4670  return emitOpError("basis should not be empty");
4671 
4672  if (getMultiIndex().size() != getStaticBasis().size())
4673  return emitOpError("should be passed an index for each basis element");
4674 
4675  auto dynamicMarkersCount =
4676  llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
4677  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4678  return emitOpError(
4679  "mismatch between dynamic and static basis (kDynamic marker but no "
4680  "corresponding dynamic basis entry) -- this can only happen due to an "
4681  "incorrect fold/rewrite");
4682 
4683  return success();
4684 }
4685 
4686 namespace {
4687 /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
4688 /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
4689 /// %...d)`.
4690 
4691 /// Note that `disjoint` is required here, because, without it, we could have
4692 /// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
4693 /// is a valid operation where the `%c64` cannot be trivially dropped.
4694 ///
4695 /// Alternatively, if `%x` in the above is a known constant 0, remove it even if
4696 /// the operation isn't asserted to be `disjoint`.
4697 struct DropLinearizeUnitComponentsIfDisjointOrZero final
4698  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
4700 
4701  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
4702  PatternRewriter &rewriter) const override {
4703  size_t numIndices = op.getMultiIndex().size();
4704  SmallVector<Value> newIndices;
4705  newIndices.reserve(numIndices);
4706  SmallVector<OpFoldResult> newBasis;
4707  newBasis.reserve(numIndices);
4708 
4709  SmallVector<OpFoldResult> basis = op.getMixedBasis();
4710  for (auto [index, basisElem] : llvm::zip_equal(op.getMultiIndex(), basis)) {
4711  std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
4712  if (!basisEntry || *basisEntry != 1) {
4713  newIndices.push_back(index);
4714  newBasis.push_back(basisElem);
4715  continue;
4716  }
4717 
4718  std::optional<int64_t> indexValue = getConstantIntValue(index);
4719  if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
4720  newIndices.push_back(index);
4721  newBasis.push_back(basisElem);
4722  continue;
4723  }
4724  }
4725  if (newIndices.size() == numIndices)
4726  return failure();
4727 
4728  if (newIndices.size() == 0) {
4729  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
4730  return success();
4731  }
4732  rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
4733  op, newIndices, newBasis, op.getDisjoint());
4734  return success();
4735  }
4736 };
4737 } // namespace
4738 
4739 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
4740  RewritePatternSet &patterns, MLIRContext *context) {
4741  patterns.add<DropLinearizeUnitComponentsIfDisjointOrZero>(context);
4742 }
4743 
4744 //===----------------------------------------------------------------------===//
4745 // TableGen'd op method definitions
4746 //===----------------------------------------------------------------------===//
4747 
4748 #define GET_OP_CLASSES
4749 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
Definition: AffineOps.cpp:2652
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
Definition: AffineOps.cpp:2372
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
Definition: AffineOps.cpp:1164
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
Definition: AffineOps.cpp:3225
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
Definition: AffineOps.cpp:3858
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
Definition: AffineOps.cpp:59
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
Definition: AffineOps.cpp:3995
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
Definition: AffineOps.cpp:1000
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
Definition: AffineOps.cpp:3261
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
Definition: AffineOps.cpp:2230
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
Definition: AffineOps.cpp:786
static bool isValidAffineIndexOperand(Value value, Region *region)
Definition: AffineOps.cpp:465
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1358
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
Definition: AffineOps.cpp:1071
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:721
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
Definition: AffineOps.cpp:1919
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:713
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
Definition: AffineOps.cpp:2937
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
Definition: AffineOps.cpp:1023
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1315
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
Definition: AffineOps.cpp:4397
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
Definition: AffineOps.cpp:889
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
Definition: AffineOps.cpp:3238
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
Definition: AffineOps.cpp:328
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
Definition: AffineOps.cpp:661
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
Definition: AffineOps.cpp:4110
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
Definition: AffineOps.cpp:2661
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
Definition: AffineOps.cpp:470
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
Definition: AffineOps.cpp:623
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1251
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
Definition: AffineOps.cpp:2611
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
Definition: AffineOps.cpp:2184
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
Definition: AffineOps.cpp:502
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1266
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
Definition: AffineOps.cpp:347
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr &quotientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
Definition: AffineOps.cpp:689
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
Definition: AffineOps.cpp:4063
static LogicalResult verifyAffineMinMaxOp(T op)
Definition: AffineOps.cpp:3212
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
Definition: AffineOps.cpp:2094
static LogicalResult verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
Definition: AffineOps.cpp:3043
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
Definition: AffineOps.cpp:3424
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:76
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:81
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:917
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:35
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
Definition: AffineExpr.cpp:243
MLIRContext * getContext() const
Definition: AffineExpr.cpp:33
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:662
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
Definition: AffineMap.h:221
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:267
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
Definition: AffineMap.h:228
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineMap.cpp:500
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
Definition: AffineMap.h:280
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:515
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
Definition: AffineMap.cpp:128
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:654
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
Definition: AffineMap.cpp:434
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:203
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
AffineMap getDimIdentityMap()
Definition: Builders.cpp:423
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:427
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
Definition: Builders.cpp:219
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
NoneType getNoneType()
Definition: Builders.cpp:128
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
Definition: Builders.cpp:416
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:418
MLIRContext * getContext() const
Definition: Builders.h:55
AffineMap getSymbolIdentityMap()
Definition: Builders.cpp:436
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:321
IndexType getIndexType()
Definition: Builders.cpp:95
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition: IntegerSet.h:44
unsigned getNumDims() const
Definition: IntegerSet.cpp:15
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
Definition: IntegerSet.cpp:57
unsigned getNumInputs() const
Definition: IntegerSet.cpp:17
ArrayRef< AffineExpr > getConstraints() const
Definition: IntegerSet.cpp:41
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
Definition: IntegerSet.cpp:51
unsigned getNumSymbols() const
Definition: IntegerSet.cpp:16
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:453
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:328
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:450
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
operand_range::iterator operand_iterator
Definition: Operation.h:367
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:937
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:620
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:64
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
AffineBound represents a lower or upper bound in the for operation.
Definition: AffineOps.h:511
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
Definition: AffineOps.h:94
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
Definition: AffineOps.h:303
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
Definition: AffineOps.cpp:3955
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
unsigned getNumResults() const
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
Definition: AffineOps.cpp:2674
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
Definition: AffineOps.cpp:1133
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
Definition: AffineOps.cpp:2588
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
Definition: AffineOps.cpp:277
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1240
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
Definition: AffineOps.cpp:2560
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: AffineOps.cpp:2564
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1143
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1434
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1306
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
Definition: AffineOps.cpp:2552
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1299
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Definition: AffineOps.cpp:247
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
Definition: AffineOps.cpp:1439
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
Definition: AffineOps.cpp:2595
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
Definition: AffineOps.cpp:391
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1193
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Definition: AffineOps.cpp:2575
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
Definition: AffineOps.cpp:262
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
Definition: AffineOps.cpp:480
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
Definition: AffineOps.cpp:2556
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
Definition: AffineOps.cpp:1260
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:773
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:485
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
Definition: AffineMap.cpp:783
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
AffineExprKind
Definition: AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:70
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:641
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
Definition: AffineMap.cpp:745
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:627
Canonicalize the affine map result expression order of an affine min/max operation.
Definition: AffineOps.cpp:3478
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3481
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3495
Remove duplicated expressions in affine min/max ops.
Definition: AffineOps.cpp:3294
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3297
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
Definition: AffineOps.cpp:3337
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3340
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:293
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.