MLIR  21.0.0git
AffineOps.cpp
Go to the documentation of this file.
1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/IntegerSet.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/IR/PatternMatch.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 #include "llvm/ADT/SmallVectorExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/MathExtras.h"
30 #include <numeric>
31 #include <optional>
32 
33 using namespace mlir;
34 using namespace mlir::affine;
35 
36 using llvm::divideCeilSigned;
37 using llvm::divideFloorSigned;
38 using llvm::mod;
39 
40 #define DEBUG_TYPE "affine-ops"
41 
42 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
43 
44 /// A utility function to check if a value is defined at the top level of
45 /// `region` or is an argument of `region`. A value of index type defined at the
46 /// top level of a `AffineScope` region is always a valid symbol for all
47 /// uses in that region.
49  if (auto arg = llvm::dyn_cast<BlockArgument>(value))
50  return arg.getParentRegion() == region;
51  return value.getDefiningOp()->getParentRegion() == region;
52 }
53 
54 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
55 /// region remains legal if the operation that uses it is inlined into `dest`
56 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
57 /// `isValidSymbol`, depending on the value being required to remain a valid
58 /// dimension or symbol.
59 static bool
61  const IRMapping &mapping,
62  function_ref<bool(Value, Region *)> legalityCheck) {
63  // If the value is a valid dimension for any other reason than being
64  // a top-level value, it will remain valid: constants get inlined
65  // with the function, transitive affine applies also get inlined and
66  // will be checked themselves, etc.
67  if (!isTopLevelValue(value, src))
68  return true;
69 
70  // If it's a top-level value because it's a block operand, i.e. a
71  // function argument, check whether the value replacing it after
72  // inlining is a valid dimension in the new region.
73  if (llvm::isa<BlockArgument>(value))
74  return legalityCheck(mapping.lookup(value), dest);
75 
76  // If it's a top-level value because it's defined in the region,
77  // it can only be inlined if the defining op is a constant or a
78  // `dim`, which can appear anywhere and be valid, since the defining
79  // op won't be top-level anymore after inlining.
80  Attribute operandCst;
81  bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp());
82  return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
83  isDimLikeOp;
84 }
85 
86 /// Checks if all values known to be legal affine dimensions or symbols in `src`
87 /// remain so if their respective users are inlined into `dest`.
88 static bool
90  const IRMapping &mapping,
91  function_ref<bool(Value, Region *)> legalityCheck) {
92  return llvm::all_of(values, [&](Value v) {
93  return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
94  });
95 }
96 
97 /// Checks if an affine read or write operation remains legal after inlining
98 /// from `src` to `dest`.
99 template <typename OpTy>
100 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
101  const IRMapping &mapping) {
102  static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
103  AffineWriteOpInterface>::value,
104  "only ops with affine read/write interface are supported");
105 
106  AffineMap map = op.getAffineMap();
107  ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
108  ValueRange symbolOperands =
109  op.getMapOperands().take_back(map.getNumSymbols());
111  dimOperands, src, dest, mapping,
112  static_cast<bool (*)(Value, Region *)>(isValidDim)))
113  return false;
115  symbolOperands, src, dest, mapping,
116  static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
117  return false;
118  return true;
119 }
120 
121 /// Checks if an affine apply operation remains legal after inlining from `src`
122 /// to `dest`.
123 // Use "unused attribute" marker to silence clang-tidy warning stemming from
124 // the inability to see through "llvm::TypeSwitch".
125 template <>
126 bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op,
127  Region *src, Region *dest,
128  const IRMapping &mapping) {
129  // If it's a valid dimension, we need to check that it remains so.
130  if (isValidDim(op.getResult(), src))
132  op.getMapOperands(), src, dest, mapping,
133  static_cast<bool (*)(Value, Region *)>(isValidDim));
134 
135  // Otherwise it must be a valid symbol, check that it remains so.
137  op.getMapOperands(), src, dest, mapping,
138  static_cast<bool (*)(Value, Region *)>(isValidSymbol));
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // AffineDialect Interfaces
143 //===----------------------------------------------------------------------===//
144 
145 namespace {
146 /// This class defines the interface for handling inlining with affine
147 /// operations.
148 struct AffineInlinerInterface : public DialectInlinerInterface {
150 
151  //===--------------------------------------------------------------------===//
152  // Analysis Hooks
153  //===--------------------------------------------------------------------===//
154 
155  /// Returns true if the given region 'src' can be inlined into the region
156  /// 'dest' that is attached to an operation registered to the current dialect.
157  /// 'wouldBeCloned' is set if the region is cloned into its new location
158  /// rather than moved, indicating there may be other users.
159  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
160  IRMapping &valueMapping) const final {
161  // We can inline into affine loops and conditionals if this doesn't break
162  // affine value categorization rules.
163  Operation *destOp = dest->getParentOp();
164  if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
165  return false;
166 
167  // Multi-block regions cannot be inlined into affine constructs, all of
168  // which require single-block regions.
169  if (!llvm::hasSingleElement(*src))
170  return false;
171 
172  // Side-effecting operations that the affine dialect cannot understand
173  // should not be inlined.
174  Block &srcBlock = src->front();
175  for (Operation &op : srcBlock) {
176  // Ops with no side effects are fine,
177  if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
178  if (iface.hasNoEffect())
179  continue;
180  }
181 
182  // Assuming the inlined region is valid, we only need to check if the
183  // inlining would change it.
184  bool remainsValid =
186  .Case<AffineApplyOp, AffineReadOpInterface,
187  AffineWriteOpInterface>([&](auto op) {
188  return remainsLegalAfterInline(op, src, dest, valueMapping);
189  })
190  .Default([](Operation *) {
191  // Conservatively disallow inlining ops we cannot reason about.
192  return false;
193  });
194 
195  if (!remainsValid)
196  return false;
197  }
198 
199  return true;
200  }
201 
202  /// Returns true if the given operation 'op', that is registered to this
203  /// dialect, can be inlined into the given region, false otherwise.
204  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
205  IRMapping &valueMapping) const final {
206  // Always allow inlining affine operations into a region that is marked as
207  // affine scope, or into affine loops and conditionals. There are some edge
208  // cases when inlining *into* affine structures, but that is handled in the
209  // other 'isLegalToInline' hook above.
210  Operation *parentOp = region->getParentOp();
211  return parentOp->hasTrait<OpTrait::AffineScope>() ||
212  isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
213  }
214 
215  /// Affine regions should be analyzed recursively.
216  bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
217 };
218 } // namespace
219 
220 //===----------------------------------------------------------------------===//
221 // AffineDialect
222 //===----------------------------------------------------------------------===//
223 
224 void AffineDialect::initialize() {
225  addOperations<AffineDmaStartOp, AffineDmaWaitOp,
226 #define GET_OP_LIST
227 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
228  >();
229  addInterfaces<AffineInlinerInterface>();
230  declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
231  AffineMinOp>();
232 }
233 
234 /// Materialize a single constant operation from a given attribute value with
235 /// the desired resultant type.
237  Attribute value, Type type,
238  Location loc) {
239  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
240  return builder.create<ub::PoisonOp>(loc, type, poison);
241  return arith::ConstantOp::materialize(builder, value, type, loc);
242 }
243 
244 /// A utility function to check if a value is defined at the top level of an
245 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
246 /// conservatively assume it is not top-level. A value of index type defined at
247 /// the top level is always a valid symbol.
249  if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
250  // The block owning the argument may be unlinked, e.g. when the surrounding
251  // region has not yet been attached to an Op, at which point the parent Op
252  // is null.
253  Operation *parentOp = arg.getOwner()->getParentOp();
254  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
255  }
256  // The defining Op may live in an unlinked block so its parent Op may be null.
257  Operation *parentOp = value.getDefiningOp()->getParentOp();
258  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
259 }
260 
261 /// Returns the closest region enclosing `op` that is held by an operation with
262 /// trait `AffineScope`; `nullptr` if there is no such region.
264  auto *curOp = op;
265  while (auto *parentOp = curOp->getParentOp()) {
266  if (parentOp->hasTrait<OpTrait::AffineScope>())
267  return curOp->getParentRegion();
268  curOp = parentOp;
269  }
270  return nullptr;
271 }
272 
274  Operation *curOp = op;
275  while (auto *parentOp = curOp->getParentOp()) {
276  if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))
277  return curOp->getParentRegion();
278  curOp = parentOp;
279  }
280  return nullptr;
281 }
282 
283 // A Value can be used as a dimension id iff it meets one of the following
284 // conditions:
285 // *) It is valid as a symbol.
286 // *) It is an induction variable.
287 // *) It is the result of affine apply operation with dimension id arguments.
289  // The value must be an index type.
290  if (!value.getType().isIndex())
291  return false;
292 
293  if (auto *defOp = value.getDefiningOp())
294  return isValidDim(value, getAffineScope(defOp));
295 
296  // This value has to be a block argument for an op that has the
297  // `AffineScope` trait or an induction var of an affine.for or
298  // affine.parallel.
299  if (isAffineInductionVar(value))
300  return true;
301  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
302  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
303 }
304 
305 // Value can be used as a dimension id iff it meets one of the following
306 // conditions:
307 // *) It is valid as a symbol.
308 // *) It is an induction variable.
309 // *) It is the result of an affine apply operation with dimension id operands.
310 // *) It is the result of a more specialized index transformation (ex.
311 // delinearize_index or linearize_index) with dimension id operands.
312 bool mlir::affine::isValidDim(Value value, Region *region) {
313  // The value must be an index type.
314  if (!value.getType().isIndex())
315  return false;
316 
317  // All valid symbols are okay.
318  if (isValidSymbol(value, region))
319  return true;
320 
321  auto *op = value.getDefiningOp();
322  if (!op) {
323  // This value has to be an induction var for an affine.for or an
324  // affine.parallel.
325  return isAffineInductionVar(value);
326  }
327 
328  // Affine apply operation is ok if all of its operands are ok.
329  if (auto applyOp = dyn_cast<AffineApplyOp>(op))
330  return applyOp.isValidDim(region);
331  // delinearize_index and linearize_index are special forms of apply
332  // and so are valid dimensions if all their arguments are valid dimensions.
333  if (isa<AffineDelinearizeIndexOp, AffineLinearizeIndexOp>(op))
334  return llvm::all_of(op->getOperands(),
335  [&](Value arg) { return ::isValidDim(arg, region); });
336  // The dim op is okay if its operand memref/tensor is defined at the top
337  // level.
338  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
339  return isTopLevelValue(dimOp.getShapedValue());
340  return false;
341 }
342 
343 /// Returns true if the 'index' dimension of the `memref` defined by
344 /// `memrefDefOp` is a statically shaped one or defined using a valid symbol
345 /// for `region`.
346 template <typename AnyMemRefDefOp>
347 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
348  Region *region) {
349  MemRefType memRefType = memrefDefOp.getType();
350 
351  // Dimension index is out of bounds.
352  if (index >= memRefType.getRank()) {
353  return false;
354  }
355 
356  // Statically shaped.
357  if (!memRefType.isDynamicDim(index))
358  return true;
359  // Get the position of the dimension among dynamic dimensions;
360  unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
361  return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
362  region);
363 }
364 
365 /// Returns true if the result of the dim op is a valid symbol for `region`.
366 static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
367  // The dim op is okay if its source is defined at the top level.
368  if (isTopLevelValue(dimOp.getShapedValue()))
369  return true;
370 
371  // Conservatively handle remaining BlockArguments as non-valid symbols.
372  // E.g. scf.for iterArgs.
373  if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
374  return false;
375 
376  // The dim op is also okay if its operand memref is a view/subview whose
377  // corresponding size is a valid symbol.
378  std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
379 
380  // Be conservative if we can't understand the dimension.
381  if (!index.has_value())
382  return false;
383 
384  // Skip over all memref.cast ops (if any).
385  Operation *op = dimOp.getShapedValue().getDefiningOp();
386  while (auto castOp = dyn_cast<memref::CastOp>(op)) {
387  // Bail on unranked memrefs.
388  if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
389  return false;
390  op = castOp.getSource().getDefiningOp();
391  if (!op)
392  return false;
393  }
394 
395  int64_t i = index.value();
397  .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
398  [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
399  .Default([](Operation *) { return false; });
400 }
401 
402 // A value can be used as a symbol (at all its use sites) iff it meets one of
403 // the following conditions:
404 // *) It is a constant.
405 // *) Its defining op or block arg appearance is immediately enclosed by an op
406 // with `AffineScope` trait.
407 // *) It is the result of an affine.apply operation with symbol operands.
408 // *) It is a result of the dim op on a memref whose corresponding size is a
409 // valid symbol.
411  if (!value)
412  return false;
413 
414  // The value must be an index type.
415  if (!value.getType().isIndex())
416  return false;
417 
418  // Check that the value is a top level value.
419  if (isTopLevelValue(value))
420  return true;
421 
422  if (auto *defOp = value.getDefiningOp())
423  return isValidSymbol(value, getAffineScope(defOp));
424 
425  return false;
426 }
427 
428 /// A value can be used as a symbol for `region` iff it meets one of the
429 /// following conditions:
430 /// *) It is a constant.
431 /// *) It is a result of a `Pure` operation whose operands are valid symbolic
432 /// *) identifiers.
433 /// *) It is a result of the dim op on a memref whose corresponding size is
434 /// a valid symbol.
435 /// *) It is defined at the top level of 'region' or is its argument.
436 /// *) It dominates `region`'s parent op.
437 /// If `region` is null, conservatively assume the symbol definition scope does
438 /// not exist and only accept the values that would be symbols regardless of
439 /// the surrounding region structure, i.e. the first three cases above.
441  // The value must be an index type.
442  if (!value.getType().isIndex())
443  return false;
444 
445  // A top-level value is a valid symbol.
446  if (region && ::isTopLevelValue(value, region))
447  return true;
448 
449  auto *defOp = value.getDefiningOp();
450  if (!defOp) {
451  // A block argument that is not a top-level value is a valid symbol if it
452  // dominates region's parent op.
453  Operation *regionOp = region ? region->getParentOp() : nullptr;
454  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
455  if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
456  return isValidSymbol(value, parentOpRegion);
457  return false;
458  }
459 
460  // Constant operation is ok.
461  Attribute operandCst;
462  if (matchPattern(defOp, m_Constant(&operandCst)))
463  return true;
464 
465  // `Pure` operation that whose operands are valid symbolic identifiers.
466  if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
467  return affine::isValidSymbol(operand, region);
468  })) {
469  return true;
470  }
471 
472  // Dim op results could be valid symbols at any level.
473  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
474  return isDimOpValidSymbol(dimOp, region);
475 
476  // Check for values dominating `region`'s parent op.
477  Operation *regionOp = region ? region->getParentOp() : nullptr;
478  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
479  if (auto *parentRegion = region->getParentOp()->getParentRegion())
480  return isValidSymbol(value, parentRegion);
481 
482  return false;
483 }
484 
485 // Returns true if 'value' is a valid index to an affine operation (e.g.
486 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
487 // `region` provides the polyhedral symbol scope. Returns false otherwise.
488 static bool isValidAffineIndexOperand(Value value, Region *region) {
489  return isValidDim(value, region) || isValidSymbol(value, region);
490 }
491 
492 /// Prints dimension and symbol list.
495  unsigned numDims, OpAsmPrinter &printer) {
496  OperandRange operands(begin, end);
497  printer << '(' << operands.take_front(numDims) << ')';
498  if (operands.size() > numDims)
499  printer << '[' << operands.drop_front(numDims) << ']';
500 }
501 
502 /// Parses dimension and symbol list and returns true if parsing failed.
504  OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
506  if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
507  return failure();
508  // Store number of dimensions for validation by caller.
509  numDims = opInfos.size();
510 
511  // Parse the optional symbol operands.
512  auto indexTy = parser.getBuilder().getIndexType();
513  return failure(parser.parseOperandList(
515  parser.resolveOperands(opInfos, indexTy, operands));
516 }
517 
518 /// Utility function to verify that a set of operands are valid dimension and
519 /// symbol identifiers. The operands should be laid out such that the dimension
520 /// operands are before the symbol operands. This function returns failure if
521 /// there was an invalid operand. An operation is provided to emit any necessary
522 /// errors.
523 template <typename OpTy>
524 static LogicalResult
526  unsigned numDims) {
527  unsigned opIt = 0;
528  for (auto operand : operands) {
529  if (opIt++ < numDims) {
530  if (!isValidDim(operand, getAffineScope(op)))
531  return op.emitOpError("operand cannot be used as a dimension id");
532  } else if (!isValidSymbol(operand, getAffineScope(op))) {
533  return op.emitOpError("operand cannot be used as a symbol");
534  }
535  }
536  return success();
537 }
538 
539 //===----------------------------------------------------------------------===//
540 // AffineApplyOp
541 //===----------------------------------------------------------------------===//
542 
543 AffineValueMap AffineApplyOp::getAffineValueMap() {
544  return AffineValueMap(getAffineMap(), getOperands(), getResult());
545 }
546 
547 ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
548  auto &builder = parser.getBuilder();
549  auto indexTy = builder.getIndexType();
550 
551  AffineMapAttr mapAttr;
552  unsigned numDims;
553  if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
554  parseDimAndSymbolList(parser, result.operands, numDims) ||
555  parser.parseOptionalAttrDict(result.attributes))
556  return failure();
557  auto map = mapAttr.getValue();
558 
559  if (map.getNumDims() != numDims ||
560  numDims + map.getNumSymbols() != result.operands.size()) {
561  return parser.emitError(parser.getNameLoc(),
562  "dimension or symbol index mismatch");
563  }
564 
565  result.types.append(map.getNumResults(), indexTy);
566  return success();
567 }
568 
570  p << " " << getMapAttr();
571  printDimAndSymbolList(operand_begin(), operand_end(),
572  getAffineMap().getNumDims(), p);
573  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
574 }
575 
576 LogicalResult AffineApplyOp::verify() {
577  // Check input and output dimensions match.
578  AffineMap affineMap = getMap();
579 
580  // Verify that operand count matches affine map dimension and symbol count.
581  if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
582  return emitOpError(
583  "operand count and affine map dimension and symbol count must match");
584 
585  // Verify that the map only produces one result.
586  if (affineMap.getNumResults() != 1)
587  return emitOpError("mapping must produce one value");
588 
589  // Do not allow valid dims to be used in symbol positions. We do allow
590  // affine.apply to use operands for values that may neither qualify as affine
591  // dims or affine symbols due to usage outside of affine ops, analyses, etc.
592  Region *region = getAffineScope(*this);
593  for (Value operand : getMapOperands().drop_front(affineMap.getNumDims())) {
594  if (::isValidDim(operand, region) && !::isValidSymbol(operand, region))
595  return emitError("dimensional operand cannot be used as a symbol");
596  }
597 
598  return success();
599 }
600 
601 // The result of the affine apply operation can be used as a dimension id if all
602 // its operands are valid dimension ids.
604  return llvm::all_of(getOperands(),
605  [](Value op) { return affine::isValidDim(op); });
606 }
607 
608 // The result of the affine apply operation can be used as a dimension id if all
609 // its operands are valid dimension ids with the parent operation of `region`
610 // defining the polyhedral scope for symbols.
611 bool AffineApplyOp::isValidDim(Region *region) {
612  return llvm::all_of(getOperands(),
613  [&](Value op) { return ::isValidDim(op, region); });
614 }
615 
616 // The result of the affine apply operation can be used as a symbol if all its
617 // operands are symbols.
619  return llvm::all_of(getOperands(),
620  [](Value op) { return affine::isValidSymbol(op); });
621 }
622 
623 // The result of the affine apply operation can be used as a symbol in `region`
624 // if all its operands are symbols in `region`.
625 bool AffineApplyOp::isValidSymbol(Region *region) {
626  return llvm::all_of(getOperands(), [&](Value operand) {
627  return affine::isValidSymbol(operand, region);
628  });
629 }
630 
631 OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
632  auto map = getAffineMap();
633 
634  // Fold dims and symbols to existing values.
635  auto expr = map.getResult(0);
636  if (auto dim = dyn_cast<AffineDimExpr>(expr))
637  return getOperand(dim.getPosition());
638  if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
639  return getOperand(map.getNumDims() + sym.getPosition());
640 
641  // Otherwise, default to folding the map.
643  bool hasPoison = false;
644  auto foldResult =
645  map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
646  if (hasPoison)
648  if (failed(foldResult))
649  return {};
650  return result[0];
651 }
652 
653 /// Returns the largest known divisor of `e`. Exploits information from the
654 /// values in `operands`.
655 static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
656  // This method isn't aware of `operands`.
657  int64_t div = e.getLargestKnownDivisor();
658 
659  // We now make use of operands for the case `e` is a dim expression.
660  // TODO: More powerful simplification would have to modify
661  // getLargestKnownDivisor to take `operands` and exploit that information as
662  // well for dim/sym expressions, but in that case, getLargestKnownDivisor
663  // can't be part of the IR library but of the `Analysis` library. The IR
664  // library can only really depend on simple O(1) checks.
665  auto dimExpr = dyn_cast<AffineDimExpr>(e);
666  // If it's not a dim expr, `div` is the best we have.
667  if (!dimExpr)
668  return div;
669 
670  // We simply exploit information from loop IVs.
671  // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
672  // desired simplifications are expected to be part of other
673  // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
674  // LoopAnalysis library.
675  Value operand = operands[dimExpr.getPosition()];
676  int64_t operandDivisor = 1;
677  // TODO: With the right accessors, this can be extended to
678  // LoopLikeOpInterface.
679  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
680  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
681  operandDivisor = forOp.getStepAsInt();
682  } else {
683  uint64_t lbLargestKnownDivisor =
684  forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
685  operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
686  }
687  }
688  return operandDivisor;
689 }
690 
691 /// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
692 /// being an affine dim expression or a constant.
694  int64_t k) {
695  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
696  int64_t constVal = constExpr.getValue();
697  return constVal >= 0 && constVal < k;
698  }
699  auto dimExpr = dyn_cast<AffineDimExpr>(e);
700  if (!dimExpr)
701  return false;
702  Value operand = operands[dimExpr.getPosition()];
703  // TODO: With the right accessors, this can be extended to
704  // LoopLikeOpInterface.
705  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
706  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
707  forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
708  return true;
709  }
710  }
711 
712  // We don't consider other cases like `operand` being defined by a constant or
713  // an affine.apply op since such cases will already be handled by other
714  // patterns and propagation of loop IVs or constant would happen.
715  return false;
716 }
717 
718 /// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
719 /// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
720 /// expression is in that form.
721 static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
722  AffineExpr &quotientTimesDiv, AffineExpr &rem) {
723  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
724  if (!bin || bin.getKind() != AffineExprKind::Add)
725  return false;
726 
727  AffineExpr llhs = bin.getLHS();
728  AffineExpr rlhs = bin.getRHS();
729  div = getLargestKnownDivisor(llhs, operands);
730  if (isNonNegativeBoundedBy(rlhs, operands, div)) {
731  quotientTimesDiv = llhs;
732  rem = rlhs;
733  return true;
734  }
735  div = getLargestKnownDivisor(rlhs, operands);
736  if (isNonNegativeBoundedBy(llhs, operands, div)) {
737  quotientTimesDiv = rlhs;
738  rem = llhs;
739  return true;
740  }
741  return false;
742 }
743 
744 /// Gets the constant lower bound on an `iv`.
745 static std::optional<int64_t> getLowerBound(Value iv) {
746  AffineForOp forOp = getForInductionVarOwner(iv);
747  if (forOp && forOp.hasConstantLowerBound())
748  return forOp.getConstantLowerBound();
749  return std::nullopt;
750 }
751 
752 /// Gets the constant upper bound on an affine.for `iv`.
753 static std::optional<int64_t> getUpperBound(Value iv) {
754  AffineForOp forOp = getForInductionVarOwner(iv);
755  if (!forOp || !forOp.hasConstantUpperBound())
756  return std::nullopt;
757 
758  // If its lower bound is also known, we can get a more precise bound
759  // whenever the step is not one.
760  if (forOp.hasConstantLowerBound()) {
761  return forOp.getConstantUpperBound() - 1 -
762  (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
763  forOp.getStepAsInt();
764  }
765  return forOp.getConstantUpperBound() - 1;
766 }
767 
768 /// Determine a constant upper bound for `expr` if one exists while exploiting
769 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
770 /// is guaranteed to be less than or equal to it.
771 static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
772  unsigned numSymbols,
773  ArrayRef<Value> operands) {
774  // Get the constant lower or upper bounds on the operands.
775  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
776  constLowerBounds.reserve(operands.size());
777  constUpperBounds.reserve(operands.size());
778  for (Value operand : operands) {
779  constLowerBounds.push_back(getLowerBound(operand));
780  constUpperBounds.push_back(getUpperBound(operand));
781  }
782 
783  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
784  return constExpr.getValue();
785 
786  return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
787  constUpperBounds,
788  /*isUpper=*/true);
789 }
790 
791 /// Determine a constant lower bound for `expr` if one exists while exploiting
792 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
793 /// is guaranteed to be less than or equal to it.
794 static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
795  unsigned numSymbols,
796  ArrayRef<Value> operands) {
797  // Get the constant lower or upper bounds on the operands.
798  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
799  constLowerBounds.reserve(operands.size());
800  constUpperBounds.reserve(operands.size());
801  for (Value operand : operands) {
802  constLowerBounds.push_back(getLowerBound(operand));
803  constUpperBounds.push_back(getUpperBound(operand));
804  }
805 
806  std::optional<int64_t> lowerBound;
807  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
808  lowerBound = constExpr.getValue();
809  } else {
810  lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
811  constLowerBounds, constUpperBounds,
812  /*isUpper=*/false);
813  }
814  return lowerBound;
815 }
816 
817 /// Simplify `expr` while exploiting information from the values in `operands`.
818 static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
819  unsigned numSymbols,
820  ArrayRef<Value> operands) {
821  // We do this only for certain floordiv/mod expressions.
822  auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
823  if (!binExpr)
824  return;
825 
826  // Simplify the child expressions first.
827  AffineExpr lhs = binExpr.getLHS();
828  AffineExpr rhs = binExpr.getRHS();
829  simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
830  simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
831  expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
832 
833  binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
834  if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
835  expr.getKind() != AffineExprKind::CeilDiv &&
836  expr.getKind() != AffineExprKind::Mod)) {
837  return;
838  }
839 
840  // The `lhs` and `rhs` may be different post construction of simplified expr.
841  lhs = binExpr.getLHS();
842  rhs = binExpr.getRHS();
843  auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
844  if (!rhsConst)
845  return;
846 
847  int64_t rhsConstVal = rhsConst.getValue();
848  // Undefined exprsessions aren't touched; IR can still be valid with them.
849  if (rhsConstVal <= 0)
850  return;
851 
852  // Exploit constant lower/upper bounds to simplify a floordiv or mod.
853  MLIRContext *context = expr.getContext();
854  std::optional<int64_t> lhsLbConst =
855  getLowerBound(lhs, numDims, numSymbols, operands);
856  std::optional<int64_t> lhsUbConst =
857  getUpperBound(lhs, numDims, numSymbols, operands);
858  if (lhsLbConst && lhsUbConst) {
859  int64_t lhsLbConstVal = *lhsLbConst;
860  int64_t lhsUbConstVal = *lhsUbConst;
861  // lhs floordiv c is a single value lhs is bounded in a range `c` that has
862  // the same quotient.
863  if (binExpr.getKind() == AffineExprKind::FloorDiv &&
864  divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
865  divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
866  expr = getAffineConstantExpr(
867  divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
868  return;
869  }
870  // lhs ceildiv c is a single value if the entire range has the same ceil
871  // quotient.
872  if (binExpr.getKind() == AffineExprKind::CeilDiv &&
873  divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
874  divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
875  expr = getAffineConstantExpr(divideCeilSigned(lhsLbConstVal, rhsConstVal),
876  context);
877  return;
878  }
879  // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
880  if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
881  lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
882  expr = lhs;
883  return;
884  }
885  }
886 
887  // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
888  // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
889  // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
890  // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
891  AffineExpr quotientTimesDiv, rem;
892  int64_t divisor;
893  if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
894  if (rhsConstVal % divisor == 0 &&
895  binExpr.getKind() == AffineExprKind::FloorDiv) {
896  expr = quotientTimesDiv.floorDiv(rhsConst);
897  } else if (divisor % rhsConstVal == 0 &&
898  binExpr.getKind() == AffineExprKind::Mod) {
899  expr = rem % rhsConst;
900  }
901  return;
902  }
903 
904  // Handle the simple case when the LHS expression can be either upper
905  // bounded or is a known multiple of RHS constant.
906  // lhs floordiv c -> 0 if 0 <= lhs < c,
907  // lhs mod c -> 0 if lhs % c = 0.
908  if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) &&
909  binExpr.getKind() == AffineExprKind::FloorDiv) ||
910  (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 &&
911  binExpr.getKind() == AffineExprKind::Mod)) {
912  expr = getAffineConstantExpr(0, expr.getContext());
913  }
914 }
915 
916 /// Simplify the expressions in `map` while making use of lower or upper bounds
917 /// of its operands. If `isMax` is true, the map is to be treated as a max of
918 /// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
919 /// d1) can be simplified to (8) if the operands are respectively lower bounded
920 /// by 2 and 0 (the second expression can't be lower than 8).
922  ArrayRef<Value> operands,
923  bool isMax) {
924  // Can't simplify.
925  if (operands.empty())
926  return;
927 
928  // Get the upper or lower bound on an affine.for op IV using its range.
929  // Get the constant lower or upper bounds on the operands.
930  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
931  constLowerBounds.reserve(operands.size());
932  constUpperBounds.reserve(operands.size());
933  for (Value operand : operands) {
934  constLowerBounds.push_back(getLowerBound(operand));
935  constUpperBounds.push_back(getUpperBound(operand));
936  }
937 
938  // We will compute the lower and upper bounds on each of the expressions
939  // Then, we will check (depending on max or min) as to whether a specific
940  // bound is redundant by checking if its highest (in case of max) and its
941  // lowest (in the case of min) value is already lower than (or higher than)
942  // the lower bound (or upper bound in the case of min) of another bound.
943  SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
944  lowerBounds.reserve(map.getNumResults());
945  upperBounds.reserve(map.getNumResults());
946  for (AffineExpr e : map.getResults()) {
947  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
948  lowerBounds.push_back(constExpr.getValue());
949  upperBounds.push_back(constExpr.getValue());
950  } else {
951  lowerBounds.push_back(
953  constLowerBounds, constUpperBounds,
954  /*isUpper=*/false));
955  upperBounds.push_back(
957  constLowerBounds, constUpperBounds,
958  /*isUpper=*/true));
959  }
960  }
961 
962  // Collect expressions that are not redundant.
963  SmallVector<AffineExpr, 4> irredundantExprs;
964  for (auto exprEn : llvm::enumerate(map.getResults())) {
965  AffineExpr e = exprEn.value();
966  unsigned i = exprEn.index();
967  // Some expressions can be turned into constants.
968  if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
969  e = getAffineConstantExpr(*lowerBounds[i], e.getContext());
970 
971  // Check if the expression is redundant.
972  if (isMax) {
973  if (!upperBounds[i]) {
974  irredundantExprs.push_back(e);
975  continue;
976  }
977  // If there exists another expression such that its lower bound is greater
978  // than this expression's upper bound, it's redundant.
979  if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
980  auto otherLowerBound = en.value();
981  unsigned pos = en.index();
982  if (pos == i || !otherLowerBound)
983  return false;
984  if (*otherLowerBound > *upperBounds[i])
985  return true;
986  if (*otherLowerBound < *upperBounds[i])
987  return false;
988  // Equality case. When both expressions are considered redundant, we
989  // don't want to get both of them. We keep the one that appears
990  // first.
991  if (upperBounds[pos] && lowerBounds[i] &&
992  lowerBounds[i] == upperBounds[i] &&
993  otherLowerBound == *upperBounds[pos] && i < pos)
994  return false;
995  return true;
996  }))
997  irredundantExprs.push_back(e);
998  } else {
999  if (!lowerBounds[i]) {
1000  irredundantExprs.push_back(e);
1001  continue;
1002  }
1003  // Likewise for the `min` case. Use the complement of the condition above.
1004  if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
1005  auto otherUpperBound = en.value();
1006  unsigned pos = en.index();
1007  if (pos == i || !otherUpperBound)
1008  return false;
1009  if (*otherUpperBound < *lowerBounds[i])
1010  return true;
1011  if (*otherUpperBound > *lowerBounds[i])
1012  return false;
1013  if (lowerBounds[pos] && upperBounds[i] &&
1014  lowerBounds[i] == upperBounds[i] &&
1015  otherUpperBound == lowerBounds[pos] && i < pos)
1016  return false;
1017  return true;
1018  }))
1019  irredundantExprs.push_back(e);
1020  }
1021  }
1022 
1023  // Create the map without the redundant expressions.
1024  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs,
1025  map.getContext());
1026 }
1027 
1028 /// Simplify the map while exploiting information on the values in `operands`.
1029 // Use "unused attribute" marker to silence warning stemming from the inability
1030 // to see through the template expansion.
1031 static void LLVM_ATTRIBUTE_UNUSED
1033  assert(map.getNumInputs() == operands.size() && "invalid operands for map");
1034  SmallVector<AffineExpr> newResults;
1035  newResults.reserve(map.getNumResults());
1036  for (AffineExpr expr : map.getResults()) {
1038  operands);
1039  newResults.push_back(expr);
1040  }
1041  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
1042  map.getContext());
1043 }
1044 
1045 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
1046 /// defining AffineApplyOp expression and operands.
1047 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
1048 /// When `dimOrSymbolPosition >= dims.size()`,
1049 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
1050 /// Mutate `map`,`dims` and `syms` in place as follows:
1051 /// 1. `dims` and `syms` are only appended to.
1052 /// 2. `map` dim and symbols are gradually shifted to higher positions.
1053 /// 3. Old `dim` and `sym` entries are replaced by nullptr
1054 /// This avoids the need for any bookkeeping.
1055 static LogicalResult replaceDimOrSym(AffineMap *map,
1056  unsigned dimOrSymbolPosition,
1057  SmallVectorImpl<Value> &dims,
1058  SmallVectorImpl<Value> &syms) {
1059  MLIRContext *ctx = map->getContext();
1060  bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1061  unsigned pos = isDimReplacement ? dimOrSymbolPosition
1062  : dimOrSymbolPosition - dims.size();
1063  Value &v = isDimReplacement ? dims[pos] : syms[pos];
1064  if (!v)
1065  return failure();
1066 
1067  auto affineApply = v.getDefiningOp<AffineApplyOp>();
1068  if (!affineApply)
1069  return failure();
1070 
1071  // At this point we will perform a replacement of `v`, set the entry in `dim`
1072  // or `sym` to nullptr immediately.
1073  v = nullptr;
1074 
1075  // Compute the map, dims and symbols coming from the AffineApplyOp.
1076  AffineMap composeMap = affineApply.getAffineMap();
1077  assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1078  SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1079  affineApply.getMapOperands().end());
1080  // Canonicalize the map to promote dims to symbols when possible. This is to
1081  // avoid generating invalid maps.
1082  canonicalizeMapAndOperands(&composeMap, &composeOperands);
1083  AffineExpr replacementExpr =
1084  composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
1085  ValueRange composeDims =
1086  ArrayRef<Value>(composeOperands).take_front(composeMap.getNumDims());
1087  ValueRange composeSyms =
1088  ArrayRef<Value>(composeOperands).take_back(composeMap.getNumSymbols());
1089  AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
1090  : getAffineSymbolExpr(pos, ctx);
1091 
1092  // Append the dims and symbols where relevant and perform the replacement.
1093  dims.append(composeDims.begin(), composeDims.end());
1094  syms.append(composeSyms.begin(), composeSyms.end());
1095  *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());
1096 
1097  return success();
1098 }
1099 
1100 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1101 /// iteratively. Perform canonicalization of map and operands as well as
1102 /// AffineMap simplification. `map` and `operands` are mutated in place.
1104  SmallVectorImpl<Value> *operands) {
1105  if (map->getNumResults() == 0) {
1106  canonicalizeMapAndOperands(map, operands);
1107  *map = simplifyAffineMap(*map);
1108  return;
1109  }
1110 
1111  MLIRContext *ctx = map->getContext();
1112  SmallVector<Value, 4> dims(operands->begin(),
1113  operands->begin() + map->getNumDims());
1114  SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1115  operands->end());
1116 
1117  // Iterate over dims and symbols coming from AffineApplyOp and replace until
1118  // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1119  // and `syms` can only increase by construction.
1120  // The implementation uses a `while` loop to support the case of symbols
1121  // that may be constructed from dims ;this may be overkill.
1122  while (true) {
1123  bool changed = false;
1124  for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1125  if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
1126  break;
1127  if (!changed)
1128  break;
1129  }
1130 
1131  // Clear operands so we can fill them anew.
1132  operands->clear();
1133 
1134  // At this point we may have introduced null operands, prune them out before
1135  // canonicalizing map and operands.
1136  unsigned nDims = 0, nSyms = 0;
1137  SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1138  dimReplacements.reserve(dims.size());
1139  symReplacements.reserve(syms.size());
1140  for (auto *container : {&dims, &syms}) {
1141  bool isDim = (container == &dims);
1142  auto &repls = isDim ? dimReplacements : symReplacements;
1143  for (const auto &en : llvm::enumerate(*container)) {
1144  Value v = en.value();
1145  if (!v) {
1146  assert(isDim ? !map->isFunctionOfDim(en.index())
1147  : !map->isFunctionOfSymbol(en.index()) &&
1148  "map is function of unexpected expr@pos");
1149  repls.push_back(getAffineConstantExpr(0, ctx));
1150  continue;
1151  }
1152  repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
1153  : getAffineSymbolExpr(nSyms++, ctx));
1154  operands->push_back(v);
1155  }
1156  }
1157  *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
1158  nSyms);
1159 
1160  // Canonicalize and simplify before returning.
1161  canonicalizeMapAndOperands(map, operands);
1162  *map = simplifyAffineMap(*map);
1163 }
1164 
1166  AffineMap *map, SmallVectorImpl<Value> *operands) {
1167  while (llvm::any_of(*operands, [](Value v) {
1168  return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
1169  })) {
1170  composeAffineMapAndOperands(map, operands);
1171  }
1172 }
1173 
1174 AffineApplyOp
1176  ArrayRef<OpFoldResult> operands) {
1177  SmallVector<Value> valueOperands;
1178  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1179  composeAffineMapAndOperands(&map, &valueOperands);
1180  assert(map);
1181  return b.create<AffineApplyOp>(loc, map, valueOperands);
1182 }
1183 
1184 AffineApplyOp
1186  ArrayRef<OpFoldResult> operands) {
1187  return makeComposedAffineApply(
1188  b, loc,
1190  .front(),
1191  operands);
1192 }
1193 
1194 /// Composes the given affine map with the given list of operands, pulling in
1195 /// the maps from any affine.apply operations that supply the operands.
1197  SmallVectorImpl<Value> &operands) {
1198  // Compose and canonicalize each expression in the map individually because
1199  // composition only applies to single-result maps, collecting potentially
1200  // duplicate operands in a single list with shifted dimensions and symbols.
1201  SmallVector<Value> dims, symbols;
1203  for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
1204  SmallVector<Value> submapOperands(operands.begin(), operands.end());
1205  AffineMap submap = map.getSubMap({i});
1206  fullyComposeAffineMapAndOperands(&submap, &submapOperands);
1207  canonicalizeMapAndOperands(&submap, &submapOperands);
1208  unsigned numNewDims = submap.getNumDims();
1209  submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
1210  llvm::append_range(dims,
1211  ArrayRef<Value>(submapOperands).take_front(numNewDims));
1212  llvm::append_range(symbols,
1213  ArrayRef<Value>(submapOperands).drop_front(numNewDims));
1214  exprs.push_back(submap.getResult(0));
1215  }
1216 
1217  // Canonicalize the map created from composed expressions to deduplicate the
1218  // dimension and symbol operands.
1219  operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1220  map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
1221  canonicalizeMapAndOperands(&map, &operands);
1222 }
1223 
1226  AffineMap map,
1227  ArrayRef<OpFoldResult> operands) {
1228  assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1229 
1230  // Create new builder without a listener, so that no notification is
1231  // triggered if the op is folded.
1232  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1233  // workaround is no longer needed.
1234  OpBuilder newBuilder(b.getContext());
1236 
1237  // Create op.
1238  AffineApplyOp applyOp =
1239  makeComposedAffineApply(newBuilder, loc, map, operands);
1240 
1241  // Get constant operands.
1242  SmallVector<Attribute> constOperands(applyOp->getNumOperands());
1243  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1244  matchPattern(applyOp->getOperand(i), m_Constant(&constOperands[i]));
1245 
1246  // Try to fold the operation.
1247  SmallVector<OpFoldResult> foldResults;
1248  if (failed(applyOp->fold(constOperands, foldResults)) ||
1249  foldResults.empty()) {
1250  if (OpBuilder::Listener *listener = b.getListener())
1251  listener->notifyOperationInserted(applyOp, /*previous=*/{});
1252  return applyOp.getResult();
1253  }
1254 
1255  applyOp->erase();
1256  return llvm::getSingleElement(foldResults);
1257 }
1258 
1261  AffineExpr expr,
1262  ArrayRef<OpFoldResult> operands) {
1264  b, loc,
1266  .front(),
1267  operands);
1268 }
1269 
1272  OpBuilder &b, Location loc, AffineMap map,
1273  ArrayRef<OpFoldResult> operands) {
1274  return llvm::map_to_vector(llvm::seq<unsigned>(0, map.getNumResults()),
1275  [&](unsigned i) {
1276  return makeComposedFoldedAffineApply(
1277  b, loc, map.getSubMap({i}), operands);
1278  });
1279 }
1280 
1281 template <typename OpTy>
1283  ArrayRef<OpFoldResult> operands) {
1284  SmallVector<Value> valueOperands;
1285  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1286  composeMultiResultAffineMap(map, valueOperands);
1287  return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands);
1288 }
1289 
1290 AffineMinOp
1292  ArrayRef<OpFoldResult> operands) {
1293  return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1294 }
1295 
1296 template <typename OpTy>
1298  AffineMap map,
1299  ArrayRef<OpFoldResult> operands) {
1300  // Create new builder without a listener, so that no notification is
1301  // triggered if the op is folded.
1302  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1303  // workaround is no longer needed.
1304  OpBuilder newBuilder(b.getContext());
1306 
1307  // Create op.
1308  auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1309 
1310  // Get constant operands.
1311  SmallVector<Attribute> constOperands(minMaxOp->getNumOperands());
1312  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1313  matchPattern(minMaxOp->getOperand(i), m_Constant(&constOperands[i]));
1314 
1315  // Try to fold the operation.
1316  SmallVector<OpFoldResult> foldResults;
1317  if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1318  foldResults.empty()) {
1319  if (OpBuilder::Listener *listener = b.getListener())
1320  listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
1321  return minMaxOp.getResult();
1322  }
1323 
1324  minMaxOp->erase();
1325  return llvm::getSingleElement(foldResults);
1326 }
1327 
1330  AffineMap map,
1331  ArrayRef<OpFoldResult> operands) {
1332  return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1333 }
1334 
1337  AffineMap map,
1338  ArrayRef<OpFoldResult> operands) {
1339  return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1340 }
1341 
1342 // A symbol may appear as a dim in affine.apply operations. This function
1343 // canonicalizes dims that are valid symbols into actual symbols.
1344 template <class MapOrSet>
1345 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1346  SmallVectorImpl<Value> *operands) {
1347  if (!mapOrSet || operands->empty())
1348  return;
1349 
1350  assert(mapOrSet->getNumInputs() == operands->size() &&
1351  "map/set inputs must match number of operands");
1352 
1353  auto *context = mapOrSet->getContext();
1354  SmallVector<Value, 8> resultOperands;
1355  resultOperands.reserve(operands->size());
1356  SmallVector<Value, 8> remappedSymbols;
1357  remappedSymbols.reserve(operands->size());
1358  unsigned nextDim = 0;
1359  unsigned nextSym = 0;
1360  unsigned oldNumSyms = mapOrSet->getNumSymbols();
1361  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1362  for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1363  if (i < mapOrSet->getNumDims()) {
1364  if (isValidSymbol((*operands)[i])) {
1365  // This is a valid symbol that appears as a dim, canonicalize it.
1366  dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1367  remappedSymbols.push_back((*operands)[i]);
1368  } else {
1369  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1370  resultOperands.push_back((*operands)[i]);
1371  }
1372  } else {
1373  resultOperands.push_back((*operands)[i]);
1374  }
1375  }
1376 
1377  resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1378  *operands = resultOperands;
1379  *mapOrSet = mapOrSet->replaceDimsAndSymbols(
1380  dimRemapping, /*symReplacements=*/{}, nextDim, oldNumSyms + nextSym);
1381 
1382  assert(mapOrSet->getNumInputs() == operands->size() &&
1383  "map/set inputs must match number of operands");
1384 }
1385 
1386 /// A valid affine dimension may appear as a symbol in affine.apply operations.
1387 /// Given an application of `operands` to an affine map or integer set
1388 /// `mapOrSet`, this function canonicalizes symbols of `mapOrSet` that are valid
1389 /// dims, but not valid symbols into actual dims. Without such a legalization,
1390 /// the affine.apply will be invalid. This method is the exact inverse of
1391 /// canonicalizePromotedSymbols.
1392 template <class MapOrSet>
1393 static void legalizeDemotedDims(MapOrSet &mapOrSet,
1394  SmallVectorImpl<Value> &operands) {
1395  if (!mapOrSet || operands.empty())
1396  return;
1397 
1398  unsigned numOperands = operands.size();
1399 
1400  assert(mapOrSet.getNumInputs() == numOperands &&
1401  "map/set inputs must match number of operands");
1402 
1403  auto *context = mapOrSet.getContext();
1404  SmallVector<Value, 8> resultOperands;
1405  resultOperands.reserve(numOperands);
1406  SmallVector<Value, 8> remappedDims;
1407  remappedDims.reserve(numOperands);
1408  SmallVector<Value, 8> symOperands;
1409  symOperands.reserve(mapOrSet.getNumSymbols());
1410  unsigned nextSym = 0;
1411  unsigned nextDim = 0;
1412  unsigned oldNumDims = mapOrSet.getNumDims();
1413  SmallVector<AffineExpr, 8> symRemapping(mapOrSet.getNumSymbols());
1414  resultOperands.assign(operands.begin(), operands.begin() + oldNumDims);
1415  for (unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
1416  if (operands[i] && isValidDim(operands[i]) && !isValidSymbol(operands[i])) {
1417  // This is a valid dim that appears as a symbol, legalize it.
1418  symRemapping[i - oldNumDims] =
1419  getAffineDimExpr(oldNumDims + nextDim++, context);
1420  remappedDims.push_back(operands[i]);
1421  } else {
1422  symRemapping[i - oldNumDims] = getAffineSymbolExpr(nextSym++, context);
1423  symOperands.push_back(operands[i]);
1424  }
1425  }
1426 
1427  append_range(resultOperands, remappedDims);
1428  append_range(resultOperands, symOperands);
1429  operands = resultOperands;
1430  mapOrSet = mapOrSet.replaceDimsAndSymbols(
1431  /*dimReplacements=*/{}, symRemapping, oldNumDims + nextDim, nextSym);
1432 
1433  assert(mapOrSet.getNumInputs() == operands.size() &&
1434  "map/set inputs must match number of operands");
1435 }
1436 
1437 // Works for either an affine map or an integer set.
1438 template <class MapOrSet>
1439 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1440  SmallVectorImpl<Value> *operands) {
1441  static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1442  "Argument must be either of AffineMap or IntegerSet type");
1443 
1444  if (!mapOrSet || operands->empty())
1445  return;
1446 
1447  assert(mapOrSet->getNumInputs() == operands->size() &&
1448  "map/set inputs must match number of operands");
1449 
1450  canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1451  legalizeDemotedDims<MapOrSet>(*mapOrSet, *operands);
1452 
1453  // Check to see what dims are used.
1454  llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1455  llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1456  mapOrSet->walkExprs([&](AffineExpr expr) {
1457  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1458  usedDims[dimExpr.getPosition()] = true;
1459  else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1460  usedSyms[symExpr.getPosition()] = true;
1461  });
1462 
1463  auto *context = mapOrSet->getContext();
1464 
1465  SmallVector<Value, 8> resultOperands;
1466  resultOperands.reserve(operands->size());
1467 
1468  llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1469  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1470  unsigned nextDim = 0;
1471  for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1472  if (usedDims[i]) {
1473  // Remap dim positions for duplicate operands.
1474  auto it = seenDims.find((*operands)[i]);
1475  if (it == seenDims.end()) {
1476  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1477  resultOperands.push_back((*operands)[i]);
1478  seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1479  } else {
1480  dimRemapping[i] = it->second;
1481  }
1482  }
1483  }
1484  llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1485  SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1486  unsigned nextSym = 0;
1487  for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1488  if (!usedSyms[i])
1489  continue;
1490  // Handle constant operands (only needed for symbolic operands since
1491  // constant operands in dimensional positions would have already been
1492  // promoted to symbolic positions above).
1493  IntegerAttr operandCst;
1494  if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1495  m_Constant(&operandCst))) {
1496  symRemapping[i] =
1497  getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1498  continue;
1499  }
1500  // Remap symbol positions for duplicate operands.
1501  auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1502  if (it == seenSymbols.end()) {
1503  symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1504  resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1505  seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1506  symRemapping[i]));
1507  } else {
1508  symRemapping[i] = it->second;
1509  }
1510  }
1511  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1512  nextDim, nextSym);
1513  *operands = resultOperands;
1514 }
1515 
1517  AffineMap *map, SmallVectorImpl<Value> *operands) {
1518  canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1519 }
1520 
1522  IntegerSet *set, SmallVectorImpl<Value> *operands) {
1523  canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1524 }
1525 
1526 namespace {
1527 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1528 /// maps that supply results into them.
1529 ///
1530 template <typename AffineOpTy>
1531 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1533 
1534  /// Replace the affine op with another instance of it with the supplied
1535  /// map and mapOperands.
1536  void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1537  AffineMap map, ArrayRef<Value> mapOperands) const;
1538 
1539  LogicalResult matchAndRewrite(AffineOpTy affineOp,
1540  PatternRewriter &rewriter) const override {
1541  static_assert(
1542  llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1543  AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1544  AffineVectorStoreOp, AffineVectorLoadOp>::value,
1545  "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1546  "expected");
1547  auto map = affineOp.getAffineMap();
1548  AffineMap oldMap = map;
1549  auto oldOperands = affineOp.getMapOperands();
1550  SmallVector<Value, 8> resultOperands(oldOperands);
1551  composeAffineMapAndOperands(&map, &resultOperands);
1552  canonicalizeMapAndOperands(&map, &resultOperands);
1553  simplifyMapWithOperands(map, resultOperands);
1554  if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1555  resultOperands.begin()))
1556  return failure();
1557 
1558  replaceAffineOp(rewriter, affineOp, map, resultOperands);
1559  return success();
1560  }
1561 };
1562 
1563 // Specialize the template to account for the different build signatures for
1564 // affine load, store, and apply ops.
1565 template <>
1566 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1567  PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1568  ArrayRef<Value> mapOperands) const {
1569  rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1570  mapOperands);
1571 }
1572 template <>
1573 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1574  PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1575  ArrayRef<Value> mapOperands) const {
1576  rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1577  prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1578  prefetch.getLocalityHint(), prefetch.getIsDataCache());
1579 }
1580 template <>
1581 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1582  PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1583  ArrayRef<Value> mapOperands) const {
1584  rewriter.replaceOpWithNewOp<AffineStoreOp>(
1585  store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1586 }
1587 template <>
1588 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1589  PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1590  ArrayRef<Value> mapOperands) const {
1591  rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1592  vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1593  mapOperands);
1594 }
1595 template <>
1596 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1597  PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1598  ArrayRef<Value> mapOperands) const {
1599  rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1600  vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1601  mapOperands);
1602 }
1603 
1604 // Generic version for ops that don't have extra operands.
1605 template <typename AffineOpTy>
1606 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1607  PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1608  ArrayRef<Value> mapOperands) const {
1609  rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1610 }
1611 } // namespace
1612 
1613 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1614  MLIRContext *context) {
1615  results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1616 }
1617 
1618 //===----------------------------------------------------------------------===//
1619 // AffineDmaStartOp
1620 //===----------------------------------------------------------------------===//
1621 
1622 // TODO: Check that map operands are loop IVs or symbols.
1623 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1624  Value srcMemRef, AffineMap srcMap,
1625  ValueRange srcIndices, Value destMemRef,
1626  AffineMap dstMap, ValueRange destIndices,
1627  Value tagMemRef, AffineMap tagMap,
1628  ValueRange tagIndices, Value numElements,
1629  Value stride, Value elementsPerStride) {
1630  result.addOperands(srcMemRef);
1631  result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1632  result.addOperands(srcIndices);
1633  result.addOperands(destMemRef);
1634  result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1635  result.addOperands(destIndices);
1636  result.addOperands(tagMemRef);
1637  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1638  result.addOperands(tagIndices);
1639  result.addOperands(numElements);
1640  if (stride) {
1641  result.addOperands({stride, elementsPerStride});
1642  }
1643 }
1644 
1646  p << " " << getSrcMemRef() << '[';
1647  p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1648  p << "], " << getDstMemRef() << '[';
1649  p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1650  p << "], " << getTagMemRef() << '[';
1651  p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1652  p << "], " << getNumElements();
1653  if (isStrided()) {
1654  p << ", " << getStride();
1655  p << ", " << getNumElementsPerStride();
1656  }
1657  p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1658  << getTagMemRefType();
1659 }
1660 
1661 // Parse AffineDmaStartOp.
1662 // Ex:
1663 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1664 // %stride, %num_elt_per_stride
1665 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1666 //
1668  OperationState &result) {
1669  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1670  AffineMapAttr srcMapAttr;
1672  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1673  AffineMapAttr dstMapAttr;
1675  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1676  AffineMapAttr tagMapAttr;
1678  OpAsmParser::UnresolvedOperand numElementsInfo;
1680 
1681  SmallVector<Type, 3> types;
1682  auto indexType = parser.getBuilder().getIndexType();
1683 
1684  // Parse and resolve the following list of operands:
1685  // *) dst memref followed by its affine maps operands (in square brackets).
1686  // *) src memref followed by its affine map operands (in square brackets).
1687  // *) tag memref followed by its affine map operands (in square brackets).
1688  // *) number of elements transferred by DMA operation.
1689  if (parser.parseOperand(srcMemRefInfo) ||
1690  parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1691  getSrcMapAttrStrName(),
1692  result.attributes) ||
1693  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1694  parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1695  getDstMapAttrStrName(),
1696  result.attributes) ||
1697  parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1698  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1699  getTagMapAttrStrName(),
1700  result.attributes) ||
1701  parser.parseComma() || parser.parseOperand(numElementsInfo))
1702  return failure();
1703 
1704  // Parse optional stride and elements per stride.
1705  if (parser.parseTrailingOperandList(strideInfo))
1706  return failure();
1707 
1708  if (!strideInfo.empty() && strideInfo.size() != 2) {
1709  return parser.emitError(parser.getNameLoc(),
1710  "expected two stride related operands");
1711  }
1712  bool isStrided = strideInfo.size() == 2;
1713 
1714  if (parser.parseColonTypeList(types))
1715  return failure();
1716 
1717  if (types.size() != 3)
1718  return parser.emitError(parser.getNameLoc(), "expected three types");
1719 
1720  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1721  parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1722  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1723  parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1724  parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1725  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1726  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1727  return failure();
1728 
1729  if (isStrided) {
1730  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1731  return failure();
1732  }
1733 
1734  // Check that src/dst/tag operand counts match their map.numInputs.
1735  if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1736  dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1737  tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1738  return parser.emitError(parser.getNameLoc(),
1739  "memref operand count not equal to map.numInputs");
1740  return success();
1741 }
1742 
1743 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1744  if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
1745  return emitOpError("expected DMA source to be of memref type");
1746  if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
1747  return emitOpError("expected DMA destination to be of memref type");
1748  if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
1749  return emitOpError("expected DMA tag to be of memref type");
1750 
1751  unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1752  getDstMap().getNumInputs() +
1753  getTagMap().getNumInputs();
1754  if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1755  getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1756  return emitOpError("incorrect number of operands");
1757  }
1758 
1759  Region *scope = getAffineScope(*this);
1760  for (auto idx : getSrcIndices()) {
1761  if (!idx.getType().isIndex())
1762  return emitOpError("src index to dma_start must have 'index' type");
1763  if (!isValidAffineIndexOperand(idx, scope))
1764  return emitOpError(
1765  "src index must be a valid dimension or symbol identifier");
1766  }
1767  for (auto idx : getDstIndices()) {
1768  if (!idx.getType().isIndex())
1769  return emitOpError("dst index to dma_start must have 'index' type");
1770  if (!isValidAffineIndexOperand(idx, scope))
1771  return emitOpError(
1772  "dst index must be a valid dimension or symbol identifier");
1773  }
1774  for (auto idx : getTagIndices()) {
1775  if (!idx.getType().isIndex())
1776  return emitOpError("tag index to dma_start must have 'index' type");
1777  if (!isValidAffineIndexOperand(idx, scope))
1778  return emitOpError(
1779  "tag index must be a valid dimension or symbol identifier");
1780  }
1781  return success();
1782 }
1783 
1784 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1785  SmallVectorImpl<OpFoldResult> &results) {
1786  /// dma_start(memrefcast) -> dma_start
1787  return memref::foldMemRefCast(*this);
1788 }
1789 
1790 void AffineDmaStartOp::getEffects(
1792  &effects) {
1793  effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
1795  effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
1797  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
1799 }
1800 
1801 //===----------------------------------------------------------------------===//
1802 // AffineDmaWaitOp
1803 //===----------------------------------------------------------------------===//
1804 
1805 // TODO: Check that map operands are loop IVs or symbols.
1806 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1807  Value tagMemRef, AffineMap tagMap,
1808  ValueRange tagIndices, Value numElements) {
1809  result.addOperands(tagMemRef);
1810  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1811  result.addOperands(tagIndices);
1812  result.addOperands(numElements);
1813 }
1814 
1816  p << " " << getTagMemRef() << '[';
1817  SmallVector<Value, 2> operands(getTagIndices());
1818  p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1819  p << "], ";
1821  p << " : " << getTagMemRef().getType();
1822 }
1823 
1824 // Parse AffineDmaWaitOp.
1825 // Eg:
1826 // affine.dma_wait %tag[%index], %num_elements
1827 // : memref<1 x i32, (d0) -> (d0), 4>
1828 //
1830  OperationState &result) {
1831  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1832  AffineMapAttr tagMapAttr;
1834  Type type;
1835  auto indexType = parser.getBuilder().getIndexType();
1836  OpAsmParser::UnresolvedOperand numElementsInfo;
1837 
1838  // Parse tag memref, its map operands, and dma size.
1839  if (parser.parseOperand(tagMemRefInfo) ||
1840  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1841  getTagMapAttrStrName(),
1842  result.attributes) ||
1843  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1844  parser.parseColonType(type) ||
1845  parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1846  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1847  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1848  return failure();
1849 
1850  if (!llvm::isa<MemRefType>(type))
1851  return parser.emitError(parser.getNameLoc(),
1852  "expected tag to be of memref type");
1853 
1854  if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1855  return parser.emitError(parser.getNameLoc(),
1856  "tag memref operand count != to map.numInputs");
1857  return success();
1858 }
1859 
1860 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1861  if (!llvm::isa<MemRefType>(getOperand(0).getType()))
1862  return emitOpError("expected DMA tag to be of memref type");
1863  Region *scope = getAffineScope(*this);
1864  for (auto idx : getTagIndices()) {
1865  if (!idx.getType().isIndex())
1866  return emitOpError("index to dma_wait must have 'index' type");
1867  if (!isValidAffineIndexOperand(idx, scope))
1868  return emitOpError(
1869  "index must be a valid dimension or symbol identifier");
1870  }
1871  return success();
1872 }
1873 
1874 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1875  SmallVectorImpl<OpFoldResult> &results) {
1876  /// dma_wait(memrefcast) -> dma_wait
1877  return memref::foldMemRefCast(*this);
1878 }
1879 
1880 void AffineDmaWaitOp::getEffects(
1882  &effects) {
1883  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
1885 }
1886 
1887 //===----------------------------------------------------------------------===//
1888 // AffineForOp
1889 //===----------------------------------------------------------------------===//
1890 
1891 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1892 /// bodyBuilder are empty/null, we include default terminator op.
1893 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1894  ValueRange lbOperands, AffineMap lbMap,
1895  ValueRange ubOperands, AffineMap ubMap, int64_t step,
1896  ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1897  assert(((!lbMap && lbOperands.empty()) ||
1898  lbOperands.size() == lbMap.getNumInputs()) &&
1899  "lower bound operand count does not match the affine map");
1900  assert(((!ubMap && ubOperands.empty()) ||
1901  ubOperands.size() == ubMap.getNumInputs()) &&
1902  "upper bound operand count does not match the affine map");
1903  assert(step > 0 && "step has to be a positive integer constant");
1904 
1905  OpBuilder::InsertionGuard guard(builder);
1906 
1907  // Set variadic segment sizes.
1908  result.addAttribute(
1909  getOperandSegmentSizeAttr(),
1910  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()),
1911  static_cast<int32_t>(ubOperands.size()),
1912  static_cast<int32_t>(iterArgs.size())}));
1913 
1914  for (Value val : iterArgs)
1915  result.addTypes(val.getType());
1916 
1917  // Add an attribute for the step.
1918  result.addAttribute(getStepAttrName(result.name),
1919  builder.getIntegerAttr(builder.getIndexType(), step));
1920 
1921  // Add the lower bound.
1922  result.addAttribute(getLowerBoundMapAttrName(result.name),
1923  AffineMapAttr::get(lbMap));
1924  result.addOperands(lbOperands);
1925 
1926  // Add the upper bound.
1927  result.addAttribute(getUpperBoundMapAttrName(result.name),
1928  AffineMapAttr::get(ubMap));
1929  result.addOperands(ubOperands);
1930 
1931  result.addOperands(iterArgs);
1932  // Create a region and a block for the body. The argument of the region is
1933  // the loop induction variable.
1934  Region *bodyRegion = result.addRegion();
1935  Block *bodyBlock = builder.createBlock(bodyRegion);
1936  Value inductionVar =
1937  bodyBlock->addArgument(builder.getIndexType(), result.location);
1938  for (Value val : iterArgs)
1939  bodyBlock->addArgument(val.getType(), val.getLoc());
1940 
1941  // Create the default terminator if the builder is not provided and if the
1942  // iteration arguments are not provided. Otherwise, leave this to the caller
1943  // because we don't know which values to return from the loop.
1944  if (iterArgs.empty() && !bodyBuilder) {
1945  ensureTerminator(*bodyRegion, builder, result.location);
1946  } else if (bodyBuilder) {
1947  OpBuilder::InsertionGuard guard(builder);
1948  builder.setInsertionPointToStart(bodyBlock);
1949  bodyBuilder(builder, result.location, inductionVar,
1950  bodyBlock->getArguments().drop_front());
1951  }
1952 }
1953 
1954 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1955  int64_t ub, int64_t step, ValueRange iterArgs,
1956  BodyBuilderFn bodyBuilder) {
1957  auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1958  auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1959  return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1960  bodyBuilder);
1961 }
1962 
1963 LogicalResult AffineForOp::verifyRegions() {
1964  // Check that the body defines as single block argument for the induction
1965  // variable.
1966  auto *body = getBody();
1967  if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1968  return emitOpError("expected body to have a single index argument for the "
1969  "induction variable");
1970 
1971  // Verify that the bound operands are valid dimension/symbols.
1972  /// Lower bound.
1973  if (getLowerBoundMap().getNumInputs() > 0)
1975  getLowerBoundMap().getNumDims())))
1976  return failure();
1977  /// Upper bound.
1978  if (getUpperBoundMap().getNumInputs() > 0)
1980  getUpperBoundMap().getNumDims())))
1981  return failure();
1982  if (getLowerBoundMap().getNumResults() < 1)
1983  return emitOpError("expected lower bound map to have at least one result");
1984  if (getUpperBoundMap().getNumResults() < 1)
1985  return emitOpError("expected upper bound map to have at least one result");
1986 
1987  unsigned opNumResults = getNumResults();
1988  if (opNumResults == 0)
1989  return success();
1990 
1991  // If ForOp defines values, check that the number and types of the defined
1992  // values match ForOp initial iter operands and backedge basic block
1993  // arguments.
1994  if (getNumIterOperands() != opNumResults)
1995  return emitOpError(
1996  "mismatch between the number of loop-carried values and results");
1997  if (getNumRegionIterArgs() != opNumResults)
1998  return emitOpError(
1999  "mismatch between the number of basic block args and results");
2000 
2001  return success();
2002 }
2003 
2004 /// Parse a for operation loop bounds.
2005 static ParseResult parseBound(bool isLower, OperationState &result,
2006  OpAsmParser &p) {
2007  // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
2008  // the map has multiple results.
2009  bool failedToParsedMinMax =
2010  failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
2011 
2012  auto &builder = p.getBuilder();
2013  auto boundAttrStrName =
2014  isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)
2015  : AffineForOp::getUpperBoundMapAttrName(result.name);
2016 
2017  // Parse ssa-id as identity map.
2019  if (p.parseOperandList(boundOpInfos))
2020  return failure();
2021 
2022  if (!boundOpInfos.empty()) {
2023  // Check that only one operand was parsed.
2024  if (boundOpInfos.size() > 1)
2025  return p.emitError(p.getNameLoc(),
2026  "expected only one loop bound operand");
2027 
2028  // TODO: improve error message when SSA value is not of index type.
2029  // Currently it is 'use of value ... expects different type than prior uses'
2030  if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
2031  result.operands))
2032  return failure();
2033 
2034  // Create an identity map using symbol id. This representation is optimized
2035  // for storage. Analysis passes may expand it into a multi-dimensional map
2036  // if desired.
2037  AffineMap map = builder.getSymbolIdentityMap();
2038  result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
2039  return success();
2040  }
2041 
2042  // Get the attribute location.
2043  SMLoc attrLoc = p.getCurrentLocation();
2044 
2045  Attribute boundAttr;
2046  if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
2047  result.attributes))
2048  return failure();
2049 
2050  // Parse full form - affine map followed by dim and symbol list.
2051  if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
2052  unsigned currentNumOperands = result.operands.size();
2053  unsigned numDims;
2054  if (parseDimAndSymbolList(p, result.operands, numDims))
2055  return failure();
2056 
2057  auto map = affineMapAttr.getValue();
2058  if (map.getNumDims() != numDims)
2059  return p.emitError(
2060  p.getNameLoc(),
2061  "dim operand count and affine map dim count must match");
2062 
2063  unsigned numDimAndSymbolOperands =
2064  result.operands.size() - currentNumOperands;
2065  if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
2066  return p.emitError(
2067  p.getNameLoc(),
2068  "symbol operand count and affine map symbol count must match");
2069 
2070  // If the map has multiple results, make sure that we parsed the min/max
2071  // prefix.
2072  if (map.getNumResults() > 1 && failedToParsedMinMax) {
2073  if (isLower) {
2074  return p.emitError(attrLoc, "lower loop bound affine map with "
2075  "multiple results requires 'max' prefix");
2076  }
2077  return p.emitError(attrLoc, "upper loop bound affine map with multiple "
2078  "results requires 'min' prefix");
2079  }
2080  return success();
2081  }
2082 
2083  // Parse custom assembly form.
2084  if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
2085  result.attributes.pop_back();
2086  result.addAttribute(
2087  boundAttrStrName,
2088  AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
2089  return success();
2090  }
2091 
2092  return p.emitError(
2093  p.getNameLoc(),
2094  "expected valid affine map representation for loop bounds");
2095 }
2096 
2097 ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
2098  auto &builder = parser.getBuilder();
2099  OpAsmParser::Argument inductionVariable;
2100  inductionVariable.type = builder.getIndexType();
2101  // Parse the induction variable followed by '='.
2102  if (parser.parseArgument(inductionVariable) || parser.parseEqual())
2103  return failure();
2104 
2105  // Parse loop bounds.
2106  int64_t numOperands = result.operands.size();
2107  if (parseBound(/*isLower=*/true, result, parser))
2108  return failure();
2109  int64_t numLbOperands = result.operands.size() - numOperands;
2110  if (parser.parseKeyword("to", " between bounds"))
2111  return failure();
2112  numOperands = result.operands.size();
2113  if (parseBound(/*isLower=*/false, result, parser))
2114  return failure();
2115  int64_t numUbOperands = result.operands.size() - numOperands;
2116 
2117  // Parse the optional loop step, we default to 1 if one is not present.
2118  if (parser.parseOptionalKeyword("step")) {
2119  result.addAttribute(
2120  getStepAttrName(result.name),
2121  builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
2122  } else {
2123  SMLoc stepLoc = parser.getCurrentLocation();
2124  IntegerAttr stepAttr;
2125  if (parser.parseAttribute(stepAttr, builder.getIndexType(),
2126  getStepAttrName(result.name).data(),
2127  result.attributes))
2128  return failure();
2129 
2130  if (stepAttr.getValue().isNegative())
2131  return parser.emitError(
2132  stepLoc,
2133  "expected step to be representable as a positive signed integer");
2134  }
2135 
2136  // Parse the optional initial iteration arguments.
2139 
2140  // Induction variable.
2141  regionArgs.push_back(inductionVariable);
2142 
2143  if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
2144  // Parse assignment list and results type list.
2145  if (parser.parseAssignmentList(regionArgs, operands) ||
2146  parser.parseArrowTypeList(result.types))
2147  return failure();
2148  // Resolve input operands.
2149  for (auto argOperandType :
2150  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2151  Type type = std::get<2>(argOperandType);
2152  std::get<0>(argOperandType).type = type;
2153  if (parser.resolveOperand(std::get<1>(argOperandType), type,
2154  result.operands))
2155  return failure();
2156  }
2157  }
2158 
2159  result.addAttribute(
2160  getOperandSegmentSizeAttr(),
2161  builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands),
2162  static_cast<int32_t>(numUbOperands),
2163  static_cast<int32_t>(operands.size())}));
2164 
2165  // Parse the body region.
2166  Region *body = result.addRegion();
2167  if (regionArgs.size() != result.types.size() + 1)
2168  return parser.emitError(
2169  parser.getNameLoc(),
2170  "mismatch between the number of loop-carried values and results");
2171  if (parser.parseRegion(*body, regionArgs))
2172  return failure();
2173 
2174  AffineForOp::ensureTerminator(*body, builder, result.location);
2175 
2176  // Parse the optional attribute list.
2177  return parser.parseOptionalAttrDict(result.attributes);
2178 }
2179 
2180 static void printBound(AffineMapAttr boundMap,
2181  Operation::operand_range boundOperands,
2182  const char *prefix, OpAsmPrinter &p) {
2183  AffineMap map = boundMap.getValue();
2184 
2185  // Check if this bound should be printed using custom assembly form.
2186  // The decision to restrict printing custom assembly form to trivial cases
2187  // comes from the will to roundtrip MLIR binary -> text -> binary in a
2188  // lossless way.
2189  // Therefore, custom assembly form parsing and printing is only supported for
2190  // zero-operand constant maps and single symbol operand identity maps.
2191  if (map.getNumResults() == 1) {
2192  AffineExpr expr = map.getResult(0);
2193 
2194  // Print constant bound.
2195  if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2196  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2197  p << constExpr.getValue();
2198  return;
2199  }
2200  }
2201 
2202  // Print bound that consists of a single SSA symbol if the map is over a
2203  // single symbol.
2204  if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2205  if (isa<AffineSymbolExpr>(expr)) {
2206  p.printOperand(*boundOperands.begin());
2207  return;
2208  }
2209  }
2210  } else {
2211  // Map has multiple results. Print 'min' or 'max' prefix.
2212  p << prefix << ' ';
2213  }
2214 
2215  // Print the map and its operands.
2216  p << boundMap;
2217  printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
2218  map.getNumDims(), p);
2219 }
2220 
2221 unsigned AffineForOp::getNumIterOperands() {
2222  AffineMap lbMap = getLowerBoundMapAttr().getValue();
2223  AffineMap ubMap = getUpperBoundMapAttr().getValue();
2224 
2225  return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2226 }
2227 
2228 std::optional<MutableArrayRef<OpOperand>>
2229 AffineForOp::getYieldedValuesMutable() {
2230  return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2231 }
2232 
2234  p << ' ';
2235  p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
2236  /*omitType=*/true);
2237  p << " = ";
2238  printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
2239  p << " to ";
2240  printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
2241 
2242  if (getStepAsInt() != 1)
2243  p << " step " << getStepAsInt();
2244 
2245  bool printBlockTerminators = false;
2246  if (getNumIterOperands() > 0) {
2247  p << " iter_args(";
2248  auto regionArgs = getRegionIterArgs();
2249  auto operands = getInits();
2250 
2251  llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2252  p << std::get<0>(it) << " = " << std::get<1>(it);
2253  });
2254  p << ") -> (" << getResultTypes() << ")";
2255  printBlockTerminators = true;
2256  }
2257 
2258  p << ' ';
2259  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2260  printBlockTerminators);
2262  (*this)->getAttrs(),
2263  /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()),
2264  getUpperBoundMapAttrName(getOperation()->getName()),
2265  getStepAttrName(getOperation()->getName()),
2266  getOperandSegmentSizeAttr()});
2267 }
2268 
2269 /// Fold the constant bounds of a loop.
2270 static LogicalResult foldLoopBounds(AffineForOp forOp) {
2271  auto foldLowerOrUpperBound = [&forOp](bool lower) {
2272  // Check to see if each of the operands is the result of a constant. If
2273  // so, get the value. If not, ignore it.
2274  SmallVector<Attribute, 8> operandConstants;
2275  auto boundOperands =
2276  lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2277  for (auto operand : boundOperands) {
2278  Attribute operandCst;
2279  matchPattern(operand, m_Constant(&operandCst));
2280  operandConstants.push_back(operandCst);
2281  }
2282 
2283  AffineMap boundMap =
2284  lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2285  assert(boundMap.getNumResults() >= 1 &&
2286  "bound maps should have at least one result");
2287  SmallVector<Attribute, 4> foldedResults;
2288  if (failed(boundMap.constantFold(operandConstants, foldedResults)))
2289  return failure();
2290 
2291  // Compute the max or min as applicable over the results.
2292  assert(!foldedResults.empty() && "bounds should have at least one result");
2293  auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2294  for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2295  auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2296  maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2297  : llvm::APIntOps::smin(maxOrMin, foldedResult);
2298  }
2299  lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2300  : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2301  return success();
2302  };
2303 
2304  // Try to fold the lower bound.
2305  bool folded = false;
2306  if (!forOp.hasConstantLowerBound())
2307  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
2308 
2309  // Try to fold the upper bound.
2310  if (!forOp.hasConstantUpperBound())
2311  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
2312  return success(folded);
2313 }
2314 
2315 /// Canonicalize the bounds of the given loop.
2316 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2317  SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2318  SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2319 
2320  auto lbMap = forOp.getLowerBoundMap();
2321  auto ubMap = forOp.getUpperBoundMap();
2322  auto prevLbMap = lbMap;
2323  auto prevUbMap = ubMap;
2324 
2325  composeAffineMapAndOperands(&lbMap, &lbOperands);
2326  canonicalizeMapAndOperands(&lbMap, &lbOperands);
2327  simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
2328  simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
2329  lbMap = removeDuplicateExprs(lbMap);
2330 
2331  composeAffineMapAndOperands(&ubMap, &ubOperands);
2332  canonicalizeMapAndOperands(&ubMap, &ubOperands);
2333  ubMap = removeDuplicateExprs(ubMap);
2334 
2335  // Any canonicalization change always leads to updated map(s).
2336  if (lbMap == prevLbMap && ubMap == prevUbMap)
2337  return failure();
2338 
2339  if (lbMap != prevLbMap)
2340  forOp.setLowerBound(lbOperands, lbMap);
2341  if (ubMap != prevUbMap)
2342  forOp.setUpperBound(ubOperands, ubMap);
2343  return success();
2344 }
2345 
2346 namespace {
2347 /// Returns constant trip count in trivial cases.
2348 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2349  int64_t step = forOp.getStepAsInt();
2350  if (!forOp.hasConstantBounds() || step <= 0)
2351  return std::nullopt;
2352  int64_t lb = forOp.getConstantLowerBound();
2353  int64_t ub = forOp.getConstantUpperBound();
2354  return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2355 }
2356 
2357 /// This is a pattern to fold trivially empty loop bodies.
2358 /// TODO: This should be moved into the folding hook.
2359 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
2361 
2362  LogicalResult matchAndRewrite(AffineForOp forOp,
2363  PatternRewriter &rewriter) const override {
2364  // Check that the body only contains a yield.
2365  if (!llvm::hasSingleElement(*forOp.getBody()))
2366  return failure();
2367  if (forOp.getNumResults() == 0)
2368  return success();
2369  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2370  if (tripCount == 0) {
2371  // The initial values of the iteration arguments would be the op's
2372  // results.
2373  rewriter.replaceOp(forOp, forOp.getInits());
2374  return success();
2375  }
2376  SmallVector<Value, 4> replacements;
2377  auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2378  auto iterArgs = forOp.getRegionIterArgs();
2379  bool hasValDefinedOutsideLoop = false;
2380  bool iterArgsNotInOrder = false;
2381  for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2382  Value val = yieldOp.getOperand(i);
2383  auto *iterArgIt = llvm::find(iterArgs, val);
2384  // TODO: It should be possible to perform a replacement by computing the
2385  // last value of the IV based on the bounds and the step.
2386  if (val == forOp.getInductionVar())
2387  return failure();
2388  if (iterArgIt == iterArgs.end()) {
2389  // `val` is defined outside of the loop.
2390  assert(forOp.isDefinedOutsideOfLoop(val) &&
2391  "must be defined outside of the loop");
2392  hasValDefinedOutsideLoop = true;
2393  replacements.push_back(val);
2394  } else {
2395  unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2396  if (pos != i)
2397  iterArgsNotInOrder = true;
2398  replacements.push_back(forOp.getInits()[pos]);
2399  }
2400  }
2401  // Bail out when the trip count is unknown and the loop returns any value
2402  // defined outside of the loop or any iterArg out of order.
2403  if (!tripCount.has_value() &&
2404  (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2405  return failure();
2406  // Bail out when the loop iterates more than once and it returns any iterArg
2407  // out of order.
2408  if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2409  return failure();
2410  rewriter.replaceOp(forOp, replacements);
2411  return success();
2412  }
2413 };
2414 } // namespace
2415 
2416 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
2417  MLIRContext *context) {
2418  results.add<AffineForEmptyLoopFolder>(context);
2419 }
2420 
2421 OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2422  assert((point.isParent() || point == getRegion()) && "invalid region point");
2423 
2424  // The initial operands map to the loop arguments after the induction
2425  // variable or are forwarded to the results when the trip count is zero.
2426  return getInits();
2427 }
2428 
2429 void AffineForOp::getSuccessorRegions(
2431  assert((point.isParent() || point == getRegion()) && "expected loop region");
2432  // The loop may typically branch back to its body or to the parent operation.
2433  // If the predecessor is the parent op and the trip count is known to be at
2434  // least one, branch into the body using the iterator arguments. And in cases
2435  // we know the trip count is zero, it can only branch back to its parent.
2436  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2437  if (point.isParent() && tripCount.has_value()) {
2438  if (tripCount.value() > 0) {
2439  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2440  return;
2441  }
2442  if (tripCount.value() == 0) {
2443  regions.push_back(RegionSuccessor(getResults()));
2444  return;
2445  }
2446  }
2447 
2448  // From the loop body, if the trip count is one, we can only branch back to
2449  // the parent.
2450  if (!point.isParent() && tripCount == 1) {
2451  regions.push_back(RegionSuccessor(getResults()));
2452  return;
2453  }
2454 
2455  // In all other cases, the loop may branch back to itself or the parent
2456  // operation.
2457  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2458  regions.push_back(RegionSuccessor(getResults()));
2459 }
2460 
2461 /// Returns true if the affine.for has zero iterations in trivial cases.
2462 static bool hasTrivialZeroTripCount(AffineForOp op) {
2463  return getTrivialConstantTripCount(op) == 0;
2464 }
2465 
2466 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2467  SmallVectorImpl<OpFoldResult> &results) {
2468  bool folded = succeeded(foldLoopBounds(*this));
2469  folded |= succeeded(canonicalizeLoopBounds(*this));
2470  if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
2471  // The initial values of the loop-carried variables (iter_args) are the
2472  // results of the op. But this must be avoided for an affine.for op that
2473  // does not return any results. Since ops that do not return results cannot
2474  // be folded away, we would enter an infinite loop of folds on the same
2475  // affine.for op.
2476  results.assign(getInits().begin(), getInits().end());
2477  folded = true;
2478  }
2479  return success(folded);
2480 }
2481 
2483  return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
2484 }
2485 
2487  return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap());
2488 }
2489 
2490 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2491  assert(lbOperands.size() == map.getNumInputs());
2492  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2493  getLowerBoundOperandsMutable().assign(lbOperands);
2494  setLowerBoundMap(map);
2495 }
2496 
2497 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2498  assert(ubOperands.size() == map.getNumInputs());
2499  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2500  getUpperBoundOperandsMutable().assign(ubOperands);
2501  setUpperBoundMap(map);
2502 }
2503 
2504 bool AffineForOp::hasConstantLowerBound() {
2505  return getLowerBoundMap().isSingleConstant();
2506 }
2507 
2508 bool AffineForOp::hasConstantUpperBound() {
2509  return getUpperBoundMap().isSingleConstant();
2510 }
2511 
2512 int64_t AffineForOp::getConstantLowerBound() {
2513  return getLowerBoundMap().getSingleConstantResult();
2514 }
2515 
2516 int64_t AffineForOp::getConstantUpperBound() {
2517  return getUpperBoundMap().getSingleConstantResult();
2518 }
2519 
2520 void AffineForOp::setConstantLowerBound(int64_t value) {
2521  setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2522 }
2523 
2524 void AffineForOp::setConstantUpperBound(int64_t value) {
2525  setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2526 }
2527 
2528 AffineForOp::operand_range AffineForOp::getControlOperands() {
2529  return {operand_begin(), operand_begin() + getLowerBoundOperands().size() +
2530  getUpperBoundOperands().size()};
2531 }
2532 
2533 bool AffineForOp::matchingBoundOperandList() {
2534  auto lbMap = getLowerBoundMap();
2535  auto ubMap = getUpperBoundMap();
2536  if (lbMap.getNumDims() != ubMap.getNumDims() ||
2537  lbMap.getNumSymbols() != ubMap.getNumSymbols())
2538  return false;
2539 
2540  unsigned numOperands = lbMap.getNumInputs();
2541  for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2542  // Compare Value 's.
2543  if (getOperand(i) != getOperand(numOperands + i))
2544  return false;
2545  }
2546  return true;
2547 }
2548 
2549 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
2550 
2551 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2552  return SmallVector<Value>{getInductionVar()};
2553 }
2554 
2555 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2556  if (!hasConstantLowerBound())
2557  return std::nullopt;
2558  OpBuilder b(getContext());
2560  OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2561 }
2562 
2563 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2564  OpBuilder b(getContext());
2566  OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
2567 }
2568 
2569 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2570  if (!hasConstantUpperBound())
2571  return {};
2572  OpBuilder b(getContext());
2574  OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2575 }
2576 
2577 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2578  RewriterBase &rewriter, ValueRange newInitOperands,
2579  bool replaceInitOperandUsesInLoop,
2580  const NewYieldValuesFn &newYieldValuesFn) {
2581  // Create a new loop before the existing one, with the extra operands.
2582  OpBuilder::InsertionGuard g(rewriter);
2583  rewriter.setInsertionPoint(getOperation());
2584  auto inits = llvm::to_vector(getInits());
2585  inits.append(newInitOperands.begin(), newInitOperands.end());
2586  AffineForOp newLoop = rewriter.create<AffineForOp>(
2587  getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2588  getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
2589 
2590  // Generate the new yield values and append them to the scf.yield operation.
2591  auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2592  ArrayRef<BlockArgument> newIterArgs =
2593  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2594  {
2595  OpBuilder::InsertionGuard g(rewriter);
2596  rewriter.setInsertionPoint(yieldOp);
2597  SmallVector<Value> newYieldedValues =
2598  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2599  assert(newInitOperands.size() == newYieldedValues.size() &&
2600  "expected as many new yield values as new iter operands");
2601  rewriter.modifyOpInPlace(yieldOp, [&]() {
2602  yieldOp.getOperandsMutable().append(newYieldedValues);
2603  });
2604  }
2605 
2606  // Move the loop body to the new op.
2607  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2608  newLoop.getBody()->getArguments().take_front(
2609  getBody()->getNumArguments()));
2610 
2611  if (replaceInitOperandUsesInLoop) {
2612  // Replace all uses of `newInitOperands` with the corresponding basic block
2613  // arguments.
2614  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2615  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
2616  [&](OpOperand &use) {
2617  Operation *user = use.getOwner();
2618  return newLoop->isProperAncestor(user);
2619  });
2620  }
2621  }
2622 
2623  // Replace the old loop.
2624  rewriter.replaceOp(getOperation(),
2625  newLoop->getResults().take_front(getNumResults()));
2626  return cast<LoopLikeOpInterface>(newLoop.getOperation());
2627 }
2628 
2629 Speculation::Speculatability AffineForOp::getSpeculatability() {
2630  // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2631  // Start and End.
2632  //
2633  // For Step != 1, the loop may not terminate. We can add more smarts here if
2634  // needed.
2635  return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable
2637 }
2638 
2639 /// Returns true if the provided value is the induction variable of a
2640 /// AffineForOp.
2642  return getForInductionVarOwner(val) != AffineForOp();
2643 }
2644 
2646  return getAffineParallelInductionVarOwner(val) != nullptr;
2647 }
2648 
2651 }
2652 
2654  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2655  if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2656  return AffineForOp();
2657  if (auto forOp =
2658  ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2659  // Check to make sure `val` is the induction variable, not an iter_arg.
2660  return forOp.getInductionVar() == val ? forOp : AffineForOp();
2661  return AffineForOp();
2662 }
2663 
2665  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2666  if (!ivArg || !ivArg.getOwner())
2667  return nullptr;
2668  Operation *containingOp = ivArg.getOwner()->getParentOp();
2669  auto parallelOp = dyn_cast_if_present<AffineParallelOp>(containingOp);
2670  if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2671  return parallelOp;
2672  return nullptr;
2673 }
2674 
2675 /// Extracts the induction variables from a list of AffineForOps and returns
2676 /// them.
2678  SmallVectorImpl<Value> *ivs) {
2679  ivs->reserve(forInsts.size());
2680  for (auto forInst : forInsts)
2681  ivs->push_back(forInst.getInductionVar());
2682 }
2683 
2686  ivs.reserve(affineOps.size());
2687  for (Operation *op : affineOps) {
2688  // Add constraints from forOp's bounds.
2689  if (auto forOp = dyn_cast<AffineForOp>(op))
2690  ivs.push_back(forOp.getInductionVar());
2691  else if (auto parallelOp = dyn_cast<AffineParallelOp>(op))
2692  for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2693  ivs.push_back(parallelOp.getBody()->getArgument(i));
2694  }
2695 }
2696 
2697 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2698 /// operations.
2699 template <typename BoundListTy, typename LoopCreatorTy>
2701  OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2702  ArrayRef<int64_t> steps,
2703  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2704  LoopCreatorTy &&loopCreatorFn) {
2705  assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2706  assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2707 
2708  // If there are no loops to be constructed, construct the body anyway.
2709  OpBuilder::InsertionGuard guard(builder);
2710  if (lbs.empty()) {
2711  if (bodyBuilderFn)
2712  bodyBuilderFn(builder, loc, ValueRange());
2713  return;
2714  }
2715 
2716  // Create the loops iteratively and store the induction variables.
2718  ivs.reserve(lbs.size());
2719  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2720  // Callback for creating the loop body, always creates the terminator.
2721  auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
2722  ValueRange iterArgs) {
2723  ivs.push_back(iv);
2724  // In the innermost loop, call the body builder.
2725  if (i == e - 1 && bodyBuilderFn) {
2726  OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
2727  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2728  }
2729  nestedBuilder.create<AffineYieldOp>(nestedLoc);
2730  };
2731 
2732  // Delegate actual loop creation to the callback in order to dispatch
2733  // between constant- and variable-bound loops.
2734  auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2735  builder.setInsertionPointToStart(loop.getBody());
2736  }
2737 }
2738 
2739 /// Creates an affine loop from the bounds known to be constants.
2740 static AffineForOp
2742  int64_t ub, int64_t step,
2743  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2744  return builder.create<AffineForOp>(loc, lb, ub, step,
2745  /*iterArgs=*/std::nullopt, bodyBuilderFn);
2746 }
2747 
2748 /// Creates an affine loop from the bounds that may or may not be constants.
2749 static AffineForOp
2751  int64_t step,
2752  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2753  std::optional<int64_t> lbConst = getConstantIntValue(lb);
2754  std::optional<int64_t> ubConst = getConstantIntValue(ub);
2755  if (lbConst && ubConst)
2756  return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
2757  ubConst.value(), step, bodyBuilderFn);
2758  return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
2759  builder.getDimIdentityMap(), step,
2760  /*iterArgs=*/std::nullopt, bodyBuilderFn);
2761 }
2762 
2764  OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
2766  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2767  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2769 }
2770 
2772  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
2773  ArrayRef<int64_t> steps,
2774  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2775  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2777 }
2778 
2779 //===----------------------------------------------------------------------===//
2780 // AffineIfOp
2781 //===----------------------------------------------------------------------===//
2782 
2783 namespace {
2784 /// Remove else blocks that have nothing other than a zero value yield.
2785 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
2787 
2788  LogicalResult matchAndRewrite(AffineIfOp ifOp,
2789  PatternRewriter &rewriter) const override {
2790  if (ifOp.getElseRegion().empty() ||
2791  !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2792  return failure();
2793 
2794  rewriter.startOpModification(ifOp);
2795  rewriter.eraseBlock(ifOp.getElseBlock());
2796  rewriter.finalizeOpModification(ifOp);
2797  return success();
2798  }
2799 };
2800 
2801 /// Removes affine.if cond if the condition is always true or false in certain
2802 /// trivial cases. Promotes the then/else block in the parent operation block.
2803 struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
2805 
2806  LogicalResult matchAndRewrite(AffineIfOp op,
2807  PatternRewriter &rewriter) const override {
2808 
2809  auto isTriviallyFalse = [](IntegerSet iSet) {
2810  return iSet.isEmptyIntegerSet();
2811  };
2812 
2813  auto isTriviallyTrue = [](IntegerSet iSet) {
2814  return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2815  iSet.getConstraint(0) == 0);
2816  };
2817 
2818  IntegerSet affineIfConditions = op.getIntegerSet();
2819  Block *blockToMove;
2820  if (isTriviallyFalse(affineIfConditions)) {
2821  // The absence, or equivalently, the emptiness of the else region need not
2822  // be checked when affine.if is returning results because if an affine.if
2823  // operation is returning results, it always has a non-empty else region.
2824  if (op.getNumResults() == 0 && !op.hasElse()) {
2825  // If the else region is absent, or equivalently, empty, remove the
2826  // affine.if operation (which is not returning any results).
2827  rewriter.eraseOp(op);
2828  return success();
2829  }
2830  blockToMove = op.getElseBlock();
2831  } else if (isTriviallyTrue(affineIfConditions)) {
2832  blockToMove = op.getThenBlock();
2833  } else {
2834  return failure();
2835  }
2836  Operation *blockToMoveTerminator = blockToMove->getTerminator();
2837  // Promote the "blockToMove" block to the parent operation block between the
2838  // prologue and epilogue of "op".
2839  rewriter.inlineBlockBefore(blockToMove, op);
2840  // Replace the "op" operation with the operands of the
2841  // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2842  // the affine.yield operation present in the "blockToMove" block. It has no
2843  // operands when affine.if is not returning results and therefore, in that
2844  // case, replaceOp just erases "op". When affine.if is not returning
2845  // results, the affine.yield operation can be omitted. It gets inserted
2846  // implicitly.
2847  rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
2848  // Erase the "blockToMoveTerminator" operation since it is now in the parent
2849  // operation block, which already has its own terminator.
2850  rewriter.eraseOp(blockToMoveTerminator);
2851  return success();
2852  }
2853 };
2854 } // namespace
2855 
2856 /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
2857 /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
2858 void AffineIfOp::getSuccessorRegions(
2860  // If the predecessor is an AffineIfOp, then branching into both `then` and
2861  // `else` region is valid.
2862  if (point.isParent()) {
2863  regions.reserve(2);
2864  regions.push_back(
2865  RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
2866  // If the "else" region is empty, branch bach into parent.
2867  if (getElseRegion().empty()) {
2868  regions.push_back(getResults());
2869  } else {
2870  regions.push_back(
2871  RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
2872  }
2873  return;
2874  }
2875 
2876  // If the predecessor is the `else`/`then` region, then branching into parent
2877  // op is valid.
2878  regions.push_back(RegionSuccessor(getResults()));
2879 }
2880 
2881 LogicalResult AffineIfOp::verify() {
2882  // Verify that we have a condition attribute.
2883  // FIXME: This should be specified in the arguments list in ODS.
2884  auto conditionAttr =
2885  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2886  if (!conditionAttr)
2887  return emitOpError("requires an integer set attribute named 'condition'");
2888 
2889  // Verify that there are enough operands for the condition.
2890  IntegerSet condition = conditionAttr.getValue();
2891  if (getNumOperands() != condition.getNumInputs())
2892  return emitOpError("operand count and condition integer set dimension and "
2893  "symbol count must match");
2894 
2895  // Verify that the operands are valid dimension/symbols.
2896  if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
2897  condition.getNumDims())))
2898  return failure();
2899 
2900  return success();
2901 }
2902 
2903 ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
2904  // Parse the condition attribute set.
2905  IntegerSetAttr conditionAttr;
2906  unsigned numDims;
2907  if (parser.parseAttribute(conditionAttr,
2908  AffineIfOp::getConditionAttrStrName(),
2909  result.attributes) ||
2910  parseDimAndSymbolList(parser, result.operands, numDims))
2911  return failure();
2912 
2913  // Verify the condition operands.
2914  auto set = conditionAttr.getValue();
2915  if (set.getNumDims() != numDims)
2916  return parser.emitError(
2917  parser.getNameLoc(),
2918  "dim operand count and integer set dim count must match");
2919  if (numDims + set.getNumSymbols() != result.operands.size())
2920  return parser.emitError(
2921  parser.getNameLoc(),
2922  "symbol operand count and integer set symbol count must match");
2923 
2924  if (parser.parseOptionalArrowTypeList(result.types))
2925  return failure();
2926 
2927  // Create the regions for 'then' and 'else'. The latter must be created even
2928  // if it remains empty for the validity of the operation.
2929  result.regions.reserve(2);
2930  Region *thenRegion = result.addRegion();
2931  Region *elseRegion = result.addRegion();
2932 
2933  // Parse the 'then' region.
2934  if (parser.parseRegion(*thenRegion, {}, {}))
2935  return failure();
2936  AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
2937  result.location);
2938 
2939  // If we find an 'else' keyword then parse the 'else' region.
2940  if (!parser.parseOptionalKeyword("else")) {
2941  if (parser.parseRegion(*elseRegion, {}, {}))
2942  return failure();
2943  AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
2944  result.location);
2945  }
2946 
2947  // Parse the optional attribute list.
2948  if (parser.parseOptionalAttrDict(result.attributes))
2949  return failure();
2950 
2951  return success();
2952 }
2953 
2955  auto conditionAttr =
2956  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2957  p << " " << conditionAttr;
2958  printDimAndSymbolList(operand_begin(), operand_end(),
2959  conditionAttr.getValue().getNumDims(), p);
2960  p.printOptionalArrowTypeList(getResultTypes());
2961  p << ' ';
2962  p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
2963  /*printBlockTerminators=*/getNumResults());
2964 
2965  // Print the 'else' regions if it has any blocks.
2966  auto &elseRegion = this->getElseRegion();
2967  if (!elseRegion.empty()) {
2968  p << " else ";
2969  p.printRegion(elseRegion,
2970  /*printEntryBlockArgs=*/false,
2971  /*printBlockTerminators=*/getNumResults());
2972  }
2973 
2974  // Print the attribute list.
2975  p.printOptionalAttrDict((*this)->getAttrs(),
2976  /*elidedAttrs=*/getConditionAttrStrName());
2977 }
2978 
2979 IntegerSet AffineIfOp::getIntegerSet() {
2980  return (*this)
2981  ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2982  .getValue();
2983 }
2984 
2985 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
2986  (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
2987 }
2988 
2989 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
2990  setIntegerSet(set);
2991  (*this)->setOperands(operands);
2992 }
2993 
2994 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2995  TypeRange resultTypes, IntegerSet set, ValueRange args,
2996  bool withElseRegion) {
2997  assert(resultTypes.empty() || withElseRegion);
2998  OpBuilder::InsertionGuard guard(builder);
2999 
3000  result.addTypes(resultTypes);
3001  result.addOperands(args);
3002  result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
3003 
3004  Region *thenRegion = result.addRegion();
3005  builder.createBlock(thenRegion);
3006  if (resultTypes.empty())
3007  AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
3008 
3009  Region *elseRegion = result.addRegion();
3010  if (withElseRegion) {
3011  builder.createBlock(elseRegion);
3012  if (resultTypes.empty())
3013  AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
3014  }
3015 }
3016 
3017 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
3018  IntegerSet set, ValueRange args, bool withElseRegion) {
3019  AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
3020  withElseRegion);
3021 }
3022 
3023 /// Compose any affine.apply ops feeding into `operands` of the integer set
3024 /// `set` by composing the maps of such affine.apply ops with the integer
3025 /// set constraints.
3027  SmallVectorImpl<Value> &operands) {
3028  // We will simply reuse the API of the map composition by viewing the LHSs of
3029  // the equalities and inequalities of `set` as the affine exprs of an affine
3030  // map. Convert to equivalent map, compose, and convert back to set.
3031  auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(),
3032  set.getConstraints(), set.getContext());
3033  // Check if any composition is possible.
3034  if (llvm::none_of(operands,
3035  [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
3036  return;
3037 
3038  composeAffineMapAndOperands(&map, &operands);
3039  set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
3040  set.getEqFlags());
3041 }
3042 
3043 /// Canonicalize an affine if op's conditional (integer set + operands).
3044 LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3045  auto set = getIntegerSet();
3046  SmallVector<Value, 4> operands(getOperands());
3047  composeSetAndOperands(set, operands);
3048  canonicalizeSetAndOperands(&set, &operands);
3049 
3050  // Check if the canonicalization or composition led to any change.
3051  if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
3052  return failure();
3053 
3054  setConditional(set, operands);
3055  return success();
3056 }
3057 
3058 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
3059  MLIRContext *context) {
3060  results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
3061 }
3062 
3063 //===----------------------------------------------------------------------===//
3064 // AffineLoadOp
3065 //===----------------------------------------------------------------------===//
3066 
3067 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3068  AffineMap map, ValueRange operands) {
3069  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
3070  result.addOperands(operands);
3071  if (map)
3072  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3073  auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
3074  result.types.push_back(memrefType.getElementType());
3075 }
3076 
3077 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3078  Value memref, AffineMap map, ValueRange mapOperands) {
3079  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3080  result.addOperands(memref);
3081  result.addOperands(mapOperands);
3082  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3083  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3084  result.types.push_back(memrefType.getElementType());
3085 }
3086 
3087 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3088  Value memref, ValueRange indices) {
3089  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3090  int64_t rank = memrefType.getRank();
3091  // Create identity map for memrefs with at least one dimension or () -> ()
3092  // for zero-dimensional memrefs.
3093  auto map =
3094  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3095  build(builder, result, memref, map, indices);
3096 }
3097 
3098 ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
3099  auto &builder = parser.getBuilder();
3100  auto indexTy = builder.getIndexType();
3101 
3102  MemRefType type;
3103  OpAsmParser::UnresolvedOperand memrefInfo;
3104  AffineMapAttr mapAttr;
3106  return failure(
3107  parser.parseOperand(memrefInfo) ||
3108  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3109  AffineLoadOp::getMapAttrStrName(),
3110  result.attributes) ||
3111  parser.parseOptionalAttrDict(result.attributes) ||
3112  parser.parseColonType(type) ||
3113  parser.resolveOperand(memrefInfo, type, result.operands) ||
3114  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3115  parser.addTypeToList(type.getElementType(), result.types));
3116 }
3117 
3119  p << " " << getMemRef() << '[';
3120  if (AffineMapAttr mapAttr =
3121  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3122  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3123  p << ']';
3124  p.printOptionalAttrDict((*this)->getAttrs(),
3125  /*elidedAttrs=*/{getMapAttrStrName()});
3126  p << " : " << getMemRefType();
3127 }
3128 
3129 /// Verify common indexing invariants of affine.load, affine.store,
3130 /// affine.vector_load and affine.vector_store.
3131 template <typename AffineMemOpTy>
3132 static LogicalResult
3133 verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr,
3134  Operation::operand_range mapOperands,
3135  MemRefType memrefType, unsigned numIndexOperands) {
3136  AffineMap map = mapAttr.getValue();
3137  if (map.getNumResults() != memrefType.getRank())
3138  return op->emitOpError("affine map num results must equal memref rank");
3139  if (map.getNumInputs() != numIndexOperands)
3140  return op->emitOpError("expects as many subscripts as affine map inputs");
3141 
3142  for (auto idx : mapOperands) {
3143  if (!idx.getType().isIndex())
3144  return op->emitOpError("index to load must have 'index' type");
3145  }
3146  if (failed(verifyDimAndSymbolIdentifiers(op, mapOperands, map.getNumDims())))
3147  return failure();
3148 
3149  return success();
3150 }
3151 
3152 LogicalResult AffineLoadOp::verify() {
3153  auto memrefType = getMemRefType();
3154  if (getType() != memrefType.getElementType())
3155  return emitOpError("result type must match element type of memref");
3156 
3157  if (failed(verifyMemoryOpIndexing(
3158  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3159  getMapOperands(), memrefType,
3160  /*numIndexOperands=*/getNumOperands() - 1)))
3161  return failure();
3162 
3163  return success();
3164 }
3165 
3166 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3167  MLIRContext *context) {
3168  results.add<SimplifyAffineOp<AffineLoadOp>>(context);
3169 }
3170 
3171 OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3172  /// load(memrefcast) -> load
3173  if (succeeded(memref::foldMemRefCast(*this)))
3174  return getResult();
3175 
3176  // Fold load from a global constant memref.
3177  auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3178  if (!getGlobalOp)
3179  return {};
3180  // Get to the memref.global defining the symbol.
3181  auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3182  if (!symbolTableOp)
3183  return {};
3184  auto global = dyn_cast_or_null<memref::GlobalOp>(
3185  SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
3186  if (!global)
3187  return {};
3188 
3189  // Check if the global memref is a constant.
3190  auto cstAttr =
3191  llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3192  if (!cstAttr)
3193  return {};
3194  // If it's a splat constant, we can fold irrespective of indices.
3195  if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3196  return splatAttr.getSplatValue<Attribute>();
3197  // Otherwise, we can fold only if we know the indices.
3198  if (!getAffineMap().isConstant())
3199  return {};
3200  auto indices = llvm::to_vector<4>(
3201  llvm::map_range(getAffineMap().getConstantResults(),
3202  [](int64_t v) -> uint64_t { return v; }));
3203  return cstAttr.getValues<Attribute>()[indices];
3204 }
3205 
3206 //===----------------------------------------------------------------------===//
3207 // AffineStoreOp
3208 //===----------------------------------------------------------------------===//
3209 
3210 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3211  Value valueToStore, Value memref, AffineMap map,
3212  ValueRange mapOperands) {
3213  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3214  result.addOperands(valueToStore);
3215  result.addOperands(memref);
3216  result.addOperands(mapOperands);
3217  result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3218 }
3219 
3220 // Use identity map.
3221 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3222  Value valueToStore, Value memref,
3223  ValueRange indices) {
3224  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3225  int64_t rank = memrefType.getRank();
3226  // Create identity map for memrefs with at least one dimension or () -> ()
3227  // for zero-dimensional memrefs.
3228  auto map =
3229  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3230  build(builder, result, valueToStore, memref, map, indices);
3231 }
3232 
3233 ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
3234  auto indexTy = parser.getBuilder().getIndexType();
3235 
3236  MemRefType type;
3237  OpAsmParser::UnresolvedOperand storeValueInfo;
3238  OpAsmParser::UnresolvedOperand memrefInfo;
3239  AffineMapAttr mapAttr;
3241  return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3242  parser.parseOperand(memrefInfo) ||
3243  parser.parseAffineMapOfSSAIds(
3244  mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3245  result.attributes) ||
3246  parser.parseOptionalAttrDict(result.attributes) ||
3247  parser.parseColonType(type) ||
3248  parser.resolveOperand(storeValueInfo, type.getElementType(),
3249  result.operands) ||
3250  parser.resolveOperand(memrefInfo, type, result.operands) ||
3251  parser.resolveOperands(mapOperands, indexTy, result.operands));
3252 }
3253 
3255  p << " " << getValueToStore();
3256  p << ", " << getMemRef() << '[';
3257  if (AffineMapAttr mapAttr =
3258  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3259  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3260  p << ']';
3261  p.printOptionalAttrDict((*this)->getAttrs(),
3262  /*elidedAttrs=*/{getMapAttrStrName()});
3263  p << " : " << getMemRefType();
3264 }
3265 
3266 LogicalResult AffineStoreOp::verify() {
3267  // The value to store must have the same type as memref element type.
3268  auto memrefType = getMemRefType();
3269  if (getValueToStore().getType() != memrefType.getElementType())
3270  return emitOpError(
3271  "value to store must have the same type as memref element type");
3272 
3273  if (failed(verifyMemoryOpIndexing(
3274  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3275  getMapOperands(), memrefType,
3276  /*numIndexOperands=*/getNumOperands() - 2)))
3277  return failure();
3278 
3279  return success();
3280 }
3281 
3282 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3283  MLIRContext *context) {
3284  results.add<SimplifyAffineOp<AffineStoreOp>>(context);
3285 }
3286 
3287 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3288  SmallVectorImpl<OpFoldResult> &results) {
3289  /// store(memrefcast) -> store
3290  return memref::foldMemRefCast(*this, getValueToStore());
3291 }
3292 
3293 //===----------------------------------------------------------------------===//
3294 // AffineMinMaxOpBase
3295 //===----------------------------------------------------------------------===//
3296 
3297 template <typename T>
3298 static LogicalResult verifyAffineMinMaxOp(T op) {
3299  // Verify that operand count matches affine map dimension and symbol count.
3300  if (op.getNumOperands() !=
3301  op.getMap().getNumDims() + op.getMap().getNumSymbols())
3302  return op.emitOpError(
3303  "operand count and affine map dimension and symbol count must match");
3304 
3305  if (op.getMap().getNumResults() == 0)
3306  return op.emitOpError("affine map expect at least one result");
3307  return success();
3308 }
3309 
3310 template <typename T>
3311 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3312  p << ' ' << op->getAttr(T::getMapAttrStrName());
3313  auto operands = op.getOperands();
3314  unsigned numDims = op.getMap().getNumDims();
3315  p << '(' << operands.take_front(numDims) << ')';
3316 
3317  if (operands.size() != numDims)
3318  p << '[' << operands.drop_front(numDims) << ']';
3319  p.printOptionalAttrDict(op->getAttrs(),
3320  /*elidedAttrs=*/{T::getMapAttrStrName()});
3321 }
3322 
3323 template <typename T>
3324 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
3325  OperationState &result) {
3326  auto &builder = parser.getBuilder();
3327  auto indexType = builder.getIndexType();
3330  AffineMapAttr mapAttr;
3331  return failure(
3332  parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3333  result.attributes) ||
3334  parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) ||
3335  parser.parseOperandList(symInfos,
3337  parser.parseOptionalAttrDict(result.attributes) ||
3338  parser.resolveOperands(dimInfos, indexType, result.operands) ||
3339  parser.resolveOperands(symInfos, indexType, result.operands) ||
3340  parser.addTypeToList(indexType, result.types));
3341 }
3342 
3343 /// Fold an affine min or max operation with the given operands. The operand
3344 /// list may contain nulls, which are interpreted as the operand not being a
3345 /// constant.
3346 template <typename T>
3348  static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3349  "expected affine min or max op");
3350 
3351  // Fold the affine map.
3352  // TODO: Fold more cases:
3353  // min(some_affine, some_affine + constant, ...), etc.
3354  SmallVector<int64_t, 2> results;
3355  auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3356 
3357  if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3358  return op.getOperand(0);
3359 
3360  // If some of the map results are not constant, try changing the map in-place.
3361  if (results.empty()) {
3362  // If the map is the same, report that folding did not happen.
3363  if (foldedMap == op.getMap())
3364  return {};
3365  op->setAttr("map", AffineMapAttr::get(foldedMap));
3366  return op.getResult();
3367  }
3368 
3369  // Otherwise, completely fold the op into a constant.
3370  auto resultIt = std::is_same<T, AffineMinOp>::value
3371  ? llvm::min_element(results)
3372  : llvm::max_element(results);
3373  if (resultIt == results.end())
3374  return {};
3375  return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3376 }
3377 
3378 /// Remove duplicated expressions in affine min/max ops.
3379 template <typename T>
3382 
3383  LogicalResult matchAndRewrite(T affineOp,
3384  PatternRewriter &rewriter) const override {
3385  AffineMap oldMap = affineOp.getAffineMap();
3386 
3387  SmallVector<AffineExpr, 4> newExprs;
3388  for (AffineExpr expr : oldMap.getResults()) {
3389  // This is a linear scan over newExprs, but it should be fine given that
3390  // we typically just have a few expressions per op.
3391  if (!llvm::is_contained(newExprs, expr))
3392  newExprs.push_back(expr);
3393  }
3394 
3395  if (newExprs.size() == oldMap.getNumResults())
3396  return failure();
3397 
3398  auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
3399  newExprs, rewriter.getContext());
3400  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3401 
3402  return success();
3403  }
3404 };
3405 
3406 /// Merge an affine min/max op to its consumers if its consumer is also an
3407 /// affine min/max op.
3408 ///
3409 /// This pattern requires the producer affine min/max op is bound to a
3410 /// dimension/symbol that is used as a standalone expression in the consumer
3411 /// affine op's map.
3412 ///
3413 /// For example, a pattern like the following:
3414 ///
3415 /// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3416 /// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3417 ///
3418 /// Can be turned into:
3419 ///
3420 /// %1 = affine.min affine_map<
3421 /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3422 template <typename T>
3425 
3426  LogicalResult matchAndRewrite(T affineOp,
3427  PatternRewriter &rewriter) const override {
3428  AffineMap oldMap = affineOp.getAffineMap();
3429  ValueRange dimOperands =
3430  affineOp.getMapOperands().take_front(oldMap.getNumDims());
3431  ValueRange symOperands =
3432  affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3433 
3434  auto newDimOperands = llvm::to_vector<8>(dimOperands);
3435  auto newSymOperands = llvm::to_vector<8>(symOperands);
3436  SmallVector<AffineExpr, 4> newExprs;
3437  SmallVector<T, 4> producerOps;
3438 
3439  // Go over each expression to see whether it's a single dimension/symbol
3440  // with the corresponding operand which is the result of another affine
3441  // min/max op. If So it can be merged into this affine op.
3442  for (AffineExpr expr : oldMap.getResults()) {
3443  if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3444  Value symValue = symOperands[symExpr.getPosition()];
3445  if (auto producerOp = symValue.getDefiningOp<T>()) {
3446  producerOps.push_back(producerOp);
3447  continue;
3448  }
3449  } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3450  Value dimValue = dimOperands[dimExpr.getPosition()];
3451  if (auto producerOp = dimValue.getDefiningOp<T>()) {
3452  producerOps.push_back(producerOp);
3453  continue;
3454  }
3455  }
3456  // For the above cases we will remove the expression by merging the
3457  // producer affine min/max's affine expressions. Otherwise we need to
3458  // keep the existing expression.
3459  newExprs.push_back(expr);
3460  }
3461 
3462  if (producerOps.empty())
3463  return failure();
3464 
3465  unsigned numUsedDims = oldMap.getNumDims();
3466  unsigned numUsedSyms = oldMap.getNumSymbols();
3467 
3468  // Now go over all producer affine ops and merge their expressions.
3469  for (T producerOp : producerOps) {
3470  AffineMap producerMap = producerOp.getAffineMap();
3471  unsigned numProducerDims = producerMap.getNumDims();
3472  unsigned numProducerSyms = producerMap.getNumSymbols();
3473 
3474  // Collect all dimension/symbol values.
3475  ValueRange dimValues =
3476  producerOp.getMapOperands().take_front(numProducerDims);
3477  ValueRange symValues =
3478  producerOp.getMapOperands().take_back(numProducerSyms);
3479  newDimOperands.append(dimValues.begin(), dimValues.end());
3480  newSymOperands.append(symValues.begin(), symValues.end());
3481 
3482  // For expressions we need to shift to avoid overlap.
3483  for (AffineExpr expr : producerMap.getResults()) {
3484  newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3485  .shiftSymbols(numProducerSyms, numUsedSyms));
3486  }
3487 
3488  numUsedDims += numProducerDims;
3489  numUsedSyms += numProducerSyms;
3490  }
3491 
3492  auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
3493  rewriter.getContext());
3494  auto newOperands =
3495  llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3496  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3497 
3498  return success();
3499  }
3500 };
3501 
3502 /// Canonicalize the result expression order of an affine map and return success
3503 /// if the order changed.
3504 ///
3505 /// The function flattens the map's affine expressions to coefficient arrays and
3506 /// sorts them in lexicographic order. A coefficient array contains a multiplier
3507 /// for every dimension/symbol and a constant term. The canonicalization fails
3508 /// if a result expression is not pure or if the flattening requires local
3509 /// variables that, unlike dimensions and symbols, have no global order.
3510 static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
3511  SmallVector<SmallVector<int64_t>> flattenedExprs;
3512  for (const AffineExpr &resultExpr : map.getResults()) {
3513  // Fail if the expression is not pure.
3514  if (!resultExpr.isPureAffine())
3515  return failure();
3516 
3517  SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3518  auto flattenResult = flattener.walkPostOrder(resultExpr);
3519  if (failed(flattenResult))
3520  return failure();
3521 
3522  // Fail if the flattened expression has local variables.
3523  if (flattener.operandExprStack.back().size() !=
3524  map.getNumDims() + map.getNumSymbols() + 1)
3525  return failure();
3526 
3527  flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3528  flattener.operandExprStack.back().end());
3529  }
3530 
3531  // Fail if sorting is not necessary.
3532  if (llvm::is_sorted(flattenedExprs))
3533  return failure();
3534 
3535  // Reorder the result expressions according to their flattened form.
3536  SmallVector<unsigned> resultPermutation =
3537  llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults()));
3538  llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3539  return flattenedExprs[lhs] < flattenedExprs[rhs];
3540  });
3541  SmallVector<AffineExpr> newExprs;
3542  for (unsigned idx : resultPermutation)
3543  newExprs.push_back(map.getResult(idx));
3544 
3545  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs,
3546  map.getContext());
3547  return success();
3548 }
3549 
3550 /// Canonicalize the affine map result expression order of an affine min/max
3551 /// operation.
3552 ///
3553 /// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3554 /// expressions and replaces the operation if the order changed.
3555 ///
3556 /// For example, the following operation:
3557 ///
3558 /// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3559 ///
3560 /// Turns into:
3561 ///
3562 /// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3563 template <typename T>
3566 
3567  LogicalResult matchAndRewrite(T affineOp,
3568  PatternRewriter &rewriter) const override {
3569  AffineMap map = affineOp.getAffineMap();
3570  if (failed(canonicalizeMapExprAndTermOrder(map)))
3571  return failure();
3572  rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3573  return success();
3574  }
3575 };
3576 
3577 template <typename T>
3580 
3581  LogicalResult matchAndRewrite(T affineOp,
3582  PatternRewriter &rewriter) const override {
3583  if (affineOp.getMap().getNumResults() != 1)
3584  return failure();
3585  rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3586  affineOp.getOperands());
3587  return success();
3588  }
3589 };
3590 
3591 //===----------------------------------------------------------------------===//
3592 // AffineMinOp
3593 //===----------------------------------------------------------------------===//
3594 //
3595 // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3596 //
3597 
3598 OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3599  return foldMinMaxOp(*this, adaptor.getOperands());
3600 }
3601 
3602 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3603  MLIRContext *context) {
3606  MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3608  context);
3609 }
3610 
3611 LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); }
3612 
3613 ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3614  return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3615 }
3616 
3617 void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3618 
3619 //===----------------------------------------------------------------------===//
3620 // AffineMaxOp
3621 //===----------------------------------------------------------------------===//
3622 //
3623 // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3624 //
3625 
3626 OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3627  return foldMinMaxOp(*this, adaptor.getOperands());
3628 }
3629 
3630 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3631  MLIRContext *context) {
3634  MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3636  context);
3637 }
3638 
3639 LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); }
3640 
3641 ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3642  return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3643 }
3644 
3645 void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3646 
3647 //===----------------------------------------------------------------------===//
3648 // AffinePrefetchOp
3649 //===----------------------------------------------------------------------===//
3650 
3651 //
3652 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3653 //
3654 ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3655  OperationState &result) {
3656  auto &builder = parser.getBuilder();
3657  auto indexTy = builder.getIndexType();
3658 
3659  MemRefType type;
3660  OpAsmParser::UnresolvedOperand memrefInfo;
3661  IntegerAttr hintInfo;
3662  auto i32Type = parser.getBuilder().getIntegerType(32);
3663  StringRef readOrWrite, cacheType;
3664 
3665  AffineMapAttr mapAttr;
3667  if (parser.parseOperand(memrefInfo) ||
3668  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3669  AffinePrefetchOp::getMapAttrStrName(),
3670  result.attributes) ||
3671  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3672  parser.parseComma() || parser.parseKeyword("locality") ||
3673  parser.parseLess() ||
3674  parser.parseAttribute(hintInfo, i32Type,
3675  AffinePrefetchOp::getLocalityHintAttrStrName(),
3676  result.attributes) ||
3677  parser.parseGreater() || parser.parseComma() ||
3678  parser.parseKeyword(&cacheType) ||
3679  parser.parseOptionalAttrDict(result.attributes) ||
3680  parser.parseColonType(type) ||
3681  parser.resolveOperand(memrefInfo, type, result.operands) ||
3682  parser.resolveOperands(mapOperands, indexTy, result.operands))
3683  return failure();
3684 
3685  if (readOrWrite != "read" && readOrWrite != "write")
3686  return parser.emitError(parser.getNameLoc(),
3687  "rw specifier has to be 'read' or 'write'");
3688  result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3689  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
3690 
3691  if (cacheType != "data" && cacheType != "instr")
3692  return parser.emitError(parser.getNameLoc(),
3693  "cache type has to be 'data' or 'instr'");
3694 
3695  result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3696  parser.getBuilder().getBoolAttr(cacheType == "data"));
3697 
3698  return success();
3699 }
3700 
3702  p << " " << getMemref() << '[';
3703  AffineMapAttr mapAttr =
3704  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3705  if (mapAttr)
3706  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3707  p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3708  << "locality<" << getLocalityHint() << ">, "
3709  << (getIsDataCache() ? "data" : "instr");
3711  (*this)->getAttrs(),
3712  /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3713  getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3714  p << " : " << getMemRefType();
3715 }
3716 
3717 LogicalResult AffinePrefetchOp::verify() {
3718  auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3719  if (mapAttr) {
3720  AffineMap map = mapAttr.getValue();
3721  if (map.getNumResults() != getMemRefType().getRank())
3722  return emitOpError("affine.prefetch affine map num results must equal"
3723  " memref rank");
3724  if (map.getNumInputs() + 1 != getNumOperands())
3725  return emitOpError("too few operands");
3726  } else {
3727  if (getNumOperands() != 1)
3728  return emitOpError("too few operands");
3729  }
3730 
3731  Region *scope = getAffineScope(*this);
3732  for (auto idx : getMapOperands()) {
3733  if (!isValidAffineIndexOperand(idx, scope))
3734  return emitOpError(
3735  "index must be a valid dimension or symbol identifier");
3736  }
3737  return success();
3738 }
3739 
3740 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3741  MLIRContext *context) {
3742  // prefetch(memrefcast) -> prefetch
3743  results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3744 }
3745 
3746 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3747  SmallVectorImpl<OpFoldResult> &results) {
3748  /// prefetch(memrefcast) -> prefetch
3749  return memref::foldMemRefCast(*this);
3750 }
3751 
3752 //===----------------------------------------------------------------------===//
3753 // AffineParallelOp
3754 //===----------------------------------------------------------------------===//
3755 
3756 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3757  TypeRange resultTypes,
3758  ArrayRef<arith::AtomicRMWKind> reductions,
3759  ArrayRef<int64_t> ranges) {
3760  SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
3761  auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3762  return builder.getConstantAffineMap(value);
3763  }));
3764  SmallVector<int64_t> steps(ranges.size(), 1);
3765  build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
3766  /*ubArgs=*/{}, steps);
3767 }
3768 
3769 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3770  TypeRange resultTypes,
3771  ArrayRef<arith::AtomicRMWKind> reductions,
3772  ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
3773  ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
3774  ArrayRef<int64_t> steps) {
3775  assert(llvm::all_of(lbMaps,
3776  [lbMaps](AffineMap m) {
3777  return m.getNumDims() == lbMaps[0].getNumDims() &&
3778  m.getNumSymbols() == lbMaps[0].getNumSymbols();
3779  }) &&
3780  "expected all lower bounds maps to have the same number of dimensions "
3781  "and symbols");
3782  assert(llvm::all_of(ubMaps,
3783  [ubMaps](AffineMap m) {
3784  return m.getNumDims() == ubMaps[0].getNumDims() &&
3785  m.getNumSymbols() == ubMaps[0].getNumSymbols();
3786  }) &&
3787  "expected all upper bounds maps to have the same number of dimensions "
3788  "and symbols");
3789  assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3790  "expected lower bound maps to have as many inputs as lower bound "
3791  "operands");
3792  assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3793  "expected upper bound maps to have as many inputs as upper bound "
3794  "operands");
3795 
3796  OpBuilder::InsertionGuard guard(builder);
3797  result.addTypes(resultTypes);
3798 
3799  // Convert the reductions to integer attributes.
3800  SmallVector<Attribute, 4> reductionAttrs;
3801  for (arith::AtomicRMWKind reduction : reductions)
3802  reductionAttrs.push_back(
3803  builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
3804  result.addAttribute(getReductionsAttrStrName(),
3805  builder.getArrayAttr(reductionAttrs));
3806 
3807  // Concatenates maps defined in the same input space (same dimensions and
3808  // symbols), assumes there is at least one map.
3809  auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
3810  SmallVectorImpl<int32_t> &groups) {
3811  if (maps.empty())
3812  return AffineMap::get(builder.getContext());
3814  groups.reserve(groups.size() + maps.size());
3815  exprs.reserve(maps.size());
3816  for (AffineMap m : maps) {
3817  llvm::append_range(exprs, m.getResults());
3818  groups.push_back(m.getNumResults());
3819  }
3820  return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3821  maps[0].getContext());
3822  };
3823 
3824  // Set up the bounds.
3825  SmallVector<int32_t> lbGroups, ubGroups;
3826  AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3827  AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3828  result.addAttribute(getLowerBoundsMapAttrStrName(),
3829  AffineMapAttr::get(lbMap));
3830  result.addAttribute(getLowerBoundsGroupsAttrStrName(),
3831  builder.getI32TensorAttr(lbGroups));
3832  result.addAttribute(getUpperBoundsMapAttrStrName(),
3833  AffineMapAttr::get(ubMap));
3834  result.addAttribute(getUpperBoundsGroupsAttrStrName(),
3835  builder.getI32TensorAttr(ubGroups));
3836  result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
3837  result.addOperands(lbArgs);
3838  result.addOperands(ubArgs);
3839 
3840  // Create a region and a block for the body.
3841  auto *bodyRegion = result.addRegion();
3842  Block *body = builder.createBlock(bodyRegion);
3843 
3844  // Add all the block arguments.
3845  for (unsigned i = 0, e = steps.size(); i < e; ++i)
3846  body->addArgument(IndexType::get(builder.getContext()), result.location);
3847  if (resultTypes.empty())
3848  ensureTerminator(*bodyRegion, builder, result.location);
3849 }
3850 
3851 SmallVector<Region *> AffineParallelOp::getLoopRegions() {
3852  return {&getRegion()};
3853 }
3854 
3855 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
3856 
3857 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3858  return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3859 }
3860 
3861 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3862  return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3863 }
3864 
3865 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
3866  auto values = getLowerBoundsGroups().getValues<int32_t>();
3867  unsigned start = 0;
3868  for (unsigned i = 0; i < pos; ++i)
3869  start += values[i];
3870  return getLowerBoundsMap().getSliceMap(start, values[pos]);
3871 }
3872 
3873 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
3874  auto values = getUpperBoundsGroups().getValues<int32_t>();
3875  unsigned start = 0;
3876  for (unsigned i = 0; i < pos; ++i)
3877  start += values[i];
3878  return getUpperBoundsMap().getSliceMap(start, values[pos]);
3879 }
3880 
3881 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
3882  return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3883 }
3884 
3885 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
3886  return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3887 }
3888 
3889 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3890  if (hasMinMaxBounds())
3891  return std::nullopt;
3892 
3893  // Try to convert all the ranges to constant expressions.
3895  AffineValueMap rangesValueMap;
3896  AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3897  &rangesValueMap);
3898  out.reserve(rangesValueMap.getNumResults());
3899  for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
3900  auto expr = rangesValueMap.getResult(i);
3901  auto cst = dyn_cast<AffineConstantExpr>(expr);
3902  if (!cst)
3903  return std::nullopt;
3904  out.push_back(cst.getValue());
3905  }
3906  return out;
3907 }
3908 
3909 Block *AffineParallelOp::getBody() { return &getRegion().front(); }
3910 
3911 OpBuilder AffineParallelOp::getBodyBuilder() {
3912  return OpBuilder(getBody(), std::prev(getBody()->end()));
3913 }
3914 
3915 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
3916  assert(lbOperands.size() == map.getNumInputs() &&
3917  "operands to map must match number of inputs");
3918 
3919  auto ubOperands = getUpperBoundsOperands();
3920 
3921  SmallVector<Value, 4> newOperands(lbOperands);
3922  newOperands.append(ubOperands.begin(), ubOperands.end());
3923  (*this)->setOperands(newOperands);
3924 
3925  setLowerBoundsMapAttr(AffineMapAttr::get(map));
3926 }
3927 
3928 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
3929  assert(ubOperands.size() == map.getNumInputs() &&
3930  "operands to map must match number of inputs");
3931 
3932  SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
3933  newOperands.append(ubOperands.begin(), ubOperands.end());
3934  (*this)->setOperands(newOperands);
3935 
3936  setUpperBoundsMapAttr(AffineMapAttr::get(map));
3937 }
3938 
3939 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
3940  setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3941 }
3942 
3943 // check whether resultType match op or not in affine.parallel
3944 static bool isResultTypeMatchAtomicRMWKind(Type resultType,
3945  arith::AtomicRMWKind op) {
3946  switch (op) {
3947  case arith::AtomicRMWKind::addf:
3948  return isa<FloatType>(resultType);
3949  case arith::AtomicRMWKind::addi:
3950  return isa<IntegerType>(resultType);
3951  case arith::AtomicRMWKind::assign:
3952  return true;
3953  case arith::AtomicRMWKind::mulf:
3954  return isa<FloatType>(resultType);
3955  case arith::AtomicRMWKind::muli:
3956  return isa<IntegerType>(resultType);
3957  case arith::AtomicRMWKind::maximumf:
3958  return isa<FloatType>(resultType);
3959  case arith::AtomicRMWKind::minimumf:
3960  return isa<FloatType>(resultType);
3961  case arith::AtomicRMWKind::maxs: {
3962  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3963  return intType && intType.isSigned();
3964  }
3965  case arith::AtomicRMWKind::mins: {
3966  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3967  return intType && intType.isSigned();
3968  }
3969  case arith::AtomicRMWKind::maxu: {
3970  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3971  return intType && intType.isUnsigned();
3972  }
3973  case arith::AtomicRMWKind::minu: {
3974  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3975  return intType && intType.isUnsigned();
3976  }
3977  case arith::AtomicRMWKind::ori:
3978  return isa<IntegerType>(resultType);
3979  case arith::AtomicRMWKind::andi:
3980  return isa<IntegerType>(resultType);
3981  default:
3982  return false;
3983  }
3984 }
3985 
3986 LogicalResult AffineParallelOp::verify() {
3987  auto numDims = getNumDims();
3988  if (getLowerBoundsGroups().getNumElements() != numDims ||
3989  getUpperBoundsGroups().getNumElements() != numDims ||
3990  getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3991  return emitOpError() << "the number of region arguments ("
3992  << getBody()->getNumArguments()
3993  << ") and the number of map groups for lower ("
3994  << getLowerBoundsGroups().getNumElements()
3995  << ") and upper bound ("
3996  << getUpperBoundsGroups().getNumElements()
3997  << "), and the number of steps (" << getSteps().size()
3998  << ") must all match";
3999  }
4000 
4001  unsigned expectedNumLBResults = 0;
4002  for (APInt v : getLowerBoundsGroups()) {
4003  unsigned results = v.getZExtValue();
4004  if (results == 0)
4005  return emitOpError()
4006  << "expected lower bound map to have at least one result";
4007  expectedNumLBResults += results;
4008  }
4009  if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
4010  return emitOpError() << "expected lower bounds map to have "
4011  << expectedNumLBResults << " results";
4012  unsigned expectedNumUBResults = 0;
4013  for (APInt v : getUpperBoundsGroups()) {
4014  unsigned results = v.getZExtValue();
4015  if (results == 0)
4016  return emitOpError()
4017  << "expected upper bound map to have at least one result";
4018  expectedNumUBResults += results;
4019  }
4020  if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
4021  return emitOpError() << "expected upper bounds map to have "
4022  << expectedNumUBResults << " results";
4023 
4024  if (getReductions().size() != getNumResults())
4025  return emitOpError("a reduction must be specified for each output");
4026 
4027  // Verify reduction ops are all valid and each result type matches reduction
4028  // ops
4029  for (auto it : llvm::enumerate((getReductions()))) {
4030  Attribute attr = it.value();
4031  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
4032  if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
4033  return emitOpError("invalid reduction attribute");
4034  auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
4035  if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
4036  return emitOpError("result type cannot match reduction attribute");
4037  }
4038 
4039  // Verify that the bound operands are valid dimension/symbols.
4040  /// Lower bounds.
4041  if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
4042  getLowerBoundsMap().getNumDims())))
4043  return failure();
4044  /// Upper bounds.
4045  if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
4046  getUpperBoundsMap().getNumDims())))
4047  return failure();
4048  return success();
4049 }
4050 
4051 LogicalResult AffineValueMap::canonicalize() {
4052  SmallVector<Value, 4> newOperands{operands};
4053  auto newMap = getAffineMap();
4054  composeAffineMapAndOperands(&newMap, &newOperands);
4055  if (newMap == getAffineMap() && newOperands == operands)
4056  return failure();
4057  reset(newMap, newOperands);
4058  return success();
4059 }
4060 
4061 /// Canonicalize the bounds of the given loop.
4062 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
4063  AffineValueMap lb = op.getLowerBoundsValueMap();
4064  bool lbCanonicalized = succeeded(lb.canonicalize());
4065 
4066  AffineValueMap ub = op.getUpperBoundsValueMap();
4067  bool ubCanonicalized = succeeded(ub.canonicalize());
4068 
4069  // Any canonicalization change always leads to updated map(s).
4070  if (!lbCanonicalized && !ubCanonicalized)
4071  return failure();
4072 
4073  if (lbCanonicalized)
4074  op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
4075  if (ubCanonicalized)
4076  op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
4077 
4078  return success();
4079 }
4080 
4081 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4082  SmallVectorImpl<OpFoldResult> &results) {
4083  return canonicalizeLoopBounds(*this);
4084 }
4085 
4086 /// Prints a lower(upper) bound of an affine parallel loop with max(min)
4087 /// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
4088 /// identifies which of the those expressions form max/min groups. `operands`
4089 /// are the SSA values of dimensions and symbols and `keyword` is either "min"
4090 /// or "max".
4091 static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
4092  DenseIntElementsAttr group, ValueRange operands,
4093  StringRef keyword) {
4094  AffineMap map = mapAttr.getValue();
4095  unsigned numDims = map.getNumDims();
4096  ValueRange dimOperands = operands.take_front(numDims);
4097  ValueRange symOperands = operands.drop_front(numDims);
4098  unsigned start = 0;
4099  for (llvm::APInt groupSize : group) {
4100  if (start != 0)
4101  p << ", ";
4102 
4103  unsigned size = groupSize.getZExtValue();
4104  if (size == 1) {
4105  p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
4106  ++start;
4107  } else {
4108  p << keyword << '(';
4109  AffineMap submap = map.getSliceMap(start, size);
4110  p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
4111  p << ')';
4112  start += size;
4113  }
4114  }
4115 }
4116 
4118  p << " (" << getBody()->getArguments() << ") = (";
4119  printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
4120  getLowerBoundsOperands(), "max");
4121  p << ") to (";
4122  printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
4123  getUpperBoundsOperands(), "min");
4124  p << ')';
4125  SmallVector<int64_t, 8> steps = getSteps();
4126  bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4127  if (!elideSteps) {
4128  p << " step (";
4129  llvm::interleaveComma(steps, p);
4130  p << ')';
4131  }
4132  if (getNumResults()) {
4133  p << " reduce (";
4134  llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4135  arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4136  llvm::cast<IntegerAttr>(attr).getInt());
4137  p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4138  });
4139  p << ") -> (" << getResultTypes() << ")";
4140  }
4141 
4142  p << ' ';
4143  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4144  /*printBlockTerminators=*/getNumResults());
4146  (*this)->getAttrs(),
4147  /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4148  AffineParallelOp::getLowerBoundsMapAttrStrName(),
4149  AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4150  AffineParallelOp::getUpperBoundsMapAttrStrName(),
4151  AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4152  AffineParallelOp::getStepsAttrStrName()});
4153 }
4154 
4155 /// Given a list of lists of parsed operands, populates `uniqueOperands` with
4156 /// unique operands. Also populates `replacements with affine expressions of
4157 /// `kind` that can be used to update affine maps previously accepting a
4158 /// `operands` to accept `uniqueOperands` instead.
4160  OpAsmParser &parser,
4162  SmallVectorImpl<Value> &uniqueOperands,
4165  "expected operands to be dim or symbol expression");
4166 
4167  Type indexType = parser.getBuilder().getIndexType();
4168  for (const auto &list : operands) {
4169  SmallVector<Value> valueOperands;
4170  if (parser.resolveOperands(list, indexType, valueOperands))
4171  return failure();
4172  for (Value operand : valueOperands) {
4173  unsigned pos = std::distance(uniqueOperands.begin(),
4174  llvm::find(uniqueOperands, operand));
4175  if (pos == uniqueOperands.size())
4176  uniqueOperands.push_back(operand);
4177  replacements.push_back(
4179  ? getAffineDimExpr(pos, parser.getContext())
4180  : getAffineSymbolExpr(pos, parser.getContext()));
4181  }
4182  }
4183  return success();
4184 }
4185 
4186 namespace {
4187 enum class MinMaxKind { Min, Max };
4188 } // namespace
4189 
4190 /// Parses an affine map that can contain a min/max for groups of its results,
4191 /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4192 /// `result` attributes with the map (flat list of expressions) and the grouping
4193 /// (list of integers that specify how many expressions to put into each
4194 /// min/max) attributes. Deduplicates repeated operands.
4195 ///
4196 /// parallel-bound ::= `(` parallel-group-list `)`
4197 /// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
4198 /// parallel-group ::= simple-group | min-max-group
4199 /// simple-group ::= expr-of-ssa-ids
4200 /// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4201 /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4202 ///
4203 /// Examples:
4204 /// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4205 /// (%0, max(%1 - 2 * %2))
4206 static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
4207  OperationState &result,
4208  MinMaxKind kind) {
4209  // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4210  // see: https://reviews.llvm.org/D134227#3821753
4211  const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4212 
4213  StringRef mapName = kind == MinMaxKind::Min
4214  ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4215  : AffineParallelOp::getLowerBoundsMapAttrStrName();
4216  StringRef groupsName =
4217  kind == MinMaxKind::Min
4218  ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4219  : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4220 
4221  if (failed(parser.parseLParen()))
4222  return failure();
4223 
4224  if (succeeded(parser.parseOptionalRParen())) {
4225  result.addAttribute(
4226  mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4227  result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
4228  return success();
4229  }
4230 
4231  SmallVector<AffineExpr> flatExprs;
4234  SmallVector<int32_t> numMapsPerGroup;
4236  auto parseOperands = [&]() {
4237  if (succeeded(parser.parseOptionalKeyword(
4238  kind == MinMaxKind::Min ? "min" : "max"))) {
4239  mapOperands.clear();
4240  AffineMapAttr map;
4241  if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName,
4242  result.attributes,
4244  return failure();
4245  result.attributes.erase(tmpAttrStrName);
4246  llvm::append_range(flatExprs, map.getValue().getResults());
4247  auto operandsRef = llvm::ArrayRef(mapOperands);
4248  auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4250  auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4252  flatDimOperands.append(map.getValue().getNumResults(), dims);
4253  flatSymOperands.append(map.getValue().getNumResults(), syms);
4254  numMapsPerGroup.push_back(map.getValue().getNumResults());
4255  } else {
4256  if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
4257  flatSymOperands.emplace_back(),
4258  flatExprs.emplace_back())))
4259  return failure();
4260  numMapsPerGroup.push_back(1);
4261  }
4262  return success();
4263  };
4264  if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
4265  return failure();
4266 
4267  unsigned totalNumDims = 0;
4268  unsigned totalNumSyms = 0;
4269  for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4270  unsigned numDims = flatDimOperands[i].size();
4271  unsigned numSyms = flatSymOperands[i].size();
4272  flatExprs[i] = flatExprs[i]
4273  .shiftDims(numDims, totalNumDims)
4274  .shiftSymbols(numSyms, totalNumSyms);
4275  totalNumDims += numDims;
4276  totalNumSyms += numSyms;
4277  }
4278 
4279  // Deduplicate map operands.
4280  SmallVector<Value> dimOperands, symOperands;
4281  SmallVector<AffineExpr> dimRplacements, symRepacements;
4282  if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
4283  dimRplacements, AffineExprKind::DimId) ||
4284  deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
4285  symRepacements, AffineExprKind::SymbolId))
4286  return failure();
4287 
4288  result.operands.append(dimOperands.begin(), dimOperands.end());
4289  result.operands.append(symOperands.begin(), symOperands.end());
4290 
4291  Builder &builder = parser.getBuilder();
4292  auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4293  parser.getContext());
4294  flatMap = flatMap.replaceDimsAndSymbols(
4295  dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4296 
4297  result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4298  result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
4299  return success();
4300 }
4301 
4302 //
4303 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4304 // `to` parallel-bound steps? region attr-dict?
4305 // steps ::= `steps` `(` integer-literals `)`
4306 //
4307 ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4308  OperationState &result) {
4309  auto &builder = parser.getBuilder();
4310  auto indexType = builder.getIndexType();
4313  parser.parseEqual() ||
4314  parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
4315  parser.parseKeyword("to") ||
4316  parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
4317  return failure();
4318 
4319  AffineMapAttr stepsMapAttr;
4320  NamedAttrList stepsAttrs;
4322  if (failed(parser.parseOptionalKeyword("step"))) {
4323  SmallVector<int64_t, 4> steps(ivs.size(), 1);
4324  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4325  builder.getI64ArrayAttr(steps));
4326  } else {
4327  if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
4328  AffineParallelOp::getStepsAttrStrName(),
4329  stepsAttrs,
4331  return failure();
4332 
4333  // Convert steps from an AffineMap into an I64ArrayAttr.
4335  auto stepsMap = stepsMapAttr.getValue();
4336  for (const auto &result : stepsMap.getResults()) {
4337  auto constExpr = dyn_cast<AffineConstantExpr>(result);
4338  if (!constExpr)
4339  return parser.emitError(parser.getNameLoc(),
4340  "steps must be constant integers");
4341  steps.push_back(constExpr.getValue());
4342  }
4343  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4344  builder.getI64ArrayAttr(steps));
4345  }
4346 
4347  // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4348  // quoted strings are a member of the enum AtomicRMWKind.
4349  SmallVector<Attribute, 4> reductions;
4350  if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4351  if (parser.parseLParen())
4352  return failure();
4353  auto parseAttributes = [&]() -> ParseResult {
4354  // Parse a single quoted string via the attribute parsing, and then
4355  // verify it is a member of the enum and convert to it's integer
4356  // representation.
4357  StringAttr attrVal;
4358  NamedAttrList attrStorage;
4359  auto loc = parser.getCurrentLocation();
4360  if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
4361  attrStorage))
4362  return failure();
4363  std::optional<arith::AtomicRMWKind> reduction =
4364  arith::symbolizeAtomicRMWKind(attrVal.getValue());
4365  if (!reduction)
4366  return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4367  reductions.push_back(
4368  builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4369  // While we keep getting commas, keep parsing.
4370  return success();
4371  };
4372  if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
4373  return failure();
4374  }
4375  result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4376  builder.getArrayAttr(reductions));
4377 
4378  // Parse return types of reductions (if any)
4379  if (parser.parseOptionalArrowTypeList(result.types))
4380  return failure();
4381 
4382  // Now parse the body.
4383  Region *body = result.addRegion();
4384  for (auto &iv : ivs)
4385  iv.type = indexType;
4386  if (parser.parseRegion(*body, ivs) ||
4387  parser.parseOptionalAttrDict(result.attributes))
4388  return failure();
4389 
4390  // Add a terminator if none was parsed.
4391  AffineParallelOp::ensureTerminator(*body, builder, result.location);
4392  return success();
4393 }
4394 
4395 //===----------------------------------------------------------------------===//
4396 // AffineYieldOp
4397 //===----------------------------------------------------------------------===//
4398 
4399 LogicalResult AffineYieldOp::verify() {
4400  auto *parentOp = (*this)->getParentOp();
4401  auto results = parentOp->getResults();
4402  auto operands = getOperands();
4403 
4404  if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4405  return emitOpError() << "only terminates affine.if/for/parallel regions";
4406  if (parentOp->getNumResults() != getNumOperands())
4407  return emitOpError() << "parent of yield must have same number of "
4408  "results as the yield operands";
4409  for (auto it : llvm::zip(results, operands)) {
4410  if (std::get<0>(it).getType() != std::get<1>(it).getType())
4411  return emitOpError() << "types mismatch between yield op and its parent";
4412  }
4413 
4414  return success();
4415 }
4416 
4417 //===----------------------------------------------------------------------===//
4418 // AffineVectorLoadOp
4419 //===----------------------------------------------------------------------===//
4420 
4421 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4422  VectorType resultType, AffineMap map,
4423  ValueRange operands) {
4424  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4425  result.addOperands(operands);
4426  if (map)
4427  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4428  result.types.push_back(resultType);
4429 }
4430 
4431 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4432  VectorType resultType, Value memref,
4433  AffineMap map, ValueRange mapOperands) {
4434  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4435  result.addOperands(memref);
4436  result.addOperands(mapOperands);
4437  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4438  result.types.push_back(resultType);
4439 }
4440 
4441 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4442  VectorType resultType, Value memref,
4443  ValueRange indices) {
4444  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4445  int64_t rank = memrefType.getRank();
4446  // Create identity map for memrefs with at least one dimension or () -> ()
4447  // for zero-dimensional memrefs.
4448  auto map =
4449  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4450  build(builder, result, resultType, memref, map, indices);
4451 }
4452 
4453 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4454  MLIRContext *context) {
4455  results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4456 }
4457 
4458 ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4459  OperationState &result) {
4460  auto &builder = parser.getBuilder();
4461  auto indexTy = builder.getIndexType();
4462 
4463  MemRefType memrefType;
4464  VectorType resultType;
4465  OpAsmParser::UnresolvedOperand memrefInfo;
4466  AffineMapAttr mapAttr;
4468  return failure(
4469  parser.parseOperand(memrefInfo) ||
4470  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4471  AffineVectorLoadOp::getMapAttrStrName(),
4472  result.attributes) ||
4473  parser.parseOptionalAttrDict(result.attributes) ||
4474  parser.parseColonType(memrefType) || parser.parseComma() ||
4475  parser.parseType(resultType) ||
4476  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4477  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
4478  parser.addTypeToList(resultType, result.types));
4479 }
4480 
4482  p << " " << getMemRef() << '[';
4483  if (AffineMapAttr mapAttr =
4484  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4485  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4486  p << ']';
4487  p.printOptionalAttrDict((*this)->getAttrs(),
4488  /*elidedAttrs=*/{getMapAttrStrName()});
4489  p << " : " << getMemRefType() << ", " << getType();
4490 }
4491 
4492 /// Verify common invariants of affine.vector_load and affine.vector_store.
4493 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4494  VectorType vectorType) {
4495  // Check that memref and vector element types match.
4496  if (memrefType.getElementType() != vectorType.getElementType())
4497  return op->emitOpError(
4498  "requires memref and vector types of the same elemental type");
4499  return success();
4500 }
4501 
4502 LogicalResult AffineVectorLoadOp::verify() {
4503  MemRefType memrefType = getMemRefType();
4504  if (failed(verifyMemoryOpIndexing(
4505  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4506  getMapOperands(), memrefType,
4507  /*numIndexOperands=*/getNumOperands() - 1)))
4508  return failure();
4509 
4510  if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
4511  return failure();
4512 
4513  return success();
4514 }
4515 
4516 //===----------------------------------------------------------------------===//
4517 // AffineVectorStoreOp
4518 //===----------------------------------------------------------------------===//
4519 
4520 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4521  Value valueToStore, Value memref, AffineMap map,
4522  ValueRange mapOperands) {
4523  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4524  result.addOperands(valueToStore);
4525  result.addOperands(memref);
4526  result.addOperands(mapOperands);
4527  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4528 }
4529 
4530 // Use identity map.
4531 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4532  Value valueToStore, Value memref,
4533  ValueRange indices) {
4534  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4535  int64_t rank = memrefType.getRank();
4536  // Create identity map for memrefs with at least one dimension or () -> ()
4537  // for zero-dimensional memrefs.
4538  auto map =
4539  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4540  build(builder, result, valueToStore, memref, map, indices);
4541 }
4542 void AffineVectorStoreOp::getCanonicalizationPatterns(
4543  RewritePatternSet &results, MLIRContext *context) {
4544  results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4545 }
4546 
4547 ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4548  OperationState &result) {
4549  auto indexTy = parser.getBuilder().getIndexType();
4550 
4551  MemRefType memrefType;
4552  VectorType resultType;
4553  OpAsmParser::UnresolvedOperand storeValueInfo;
4554  OpAsmParser::UnresolvedOperand memrefInfo;
4555  AffineMapAttr mapAttr;
4557  return failure(
4558  parser.parseOperand(storeValueInfo) || parser.parseComma() ||
4559  parser.parseOperand(memrefInfo) ||
4560  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4561  AffineVectorStoreOp::getMapAttrStrName(),
4562  result.attributes) ||
4563  parser.parseOptionalAttrDict(result.attributes) ||
4564  parser.parseColonType(memrefType) || parser.parseComma() ||
4565  parser.parseType(resultType) ||
4566  parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
4567  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4568  parser.resolveOperands(mapOperands, indexTy, result.operands));
4569 }
4570 
4572  p << " " << getValueToStore();
4573  p << ", " << getMemRef() << '[';
4574  if (AffineMapAttr mapAttr =
4575  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4576  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4577  p << ']';
4578  p.printOptionalAttrDict((*this)->getAttrs(),
4579  /*elidedAttrs=*/{getMapAttrStrName()});
4580  p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4581 }
4582 
4583 LogicalResult AffineVectorStoreOp::verify() {
4584  MemRefType memrefType = getMemRefType();
4585  if (failed(verifyMemoryOpIndexing(
4586  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4587  getMapOperands(), memrefType,
4588  /*numIndexOperands=*/getNumOperands() - 2)))
4589  return failure();
4590 
4591  if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4592  return failure();
4593 
4594  return success();
4595 }
4596 
4597 //===----------------------------------------------------------------------===//
4598 // DelinearizeIndexOp
4599 //===----------------------------------------------------------------------===//
4600 
4601 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4602  OperationState &odsState,
4603  Value linearIndex, ValueRange dynamicBasis,
4604  ArrayRef<int64_t> staticBasis,
4605  bool hasOuterBound) {
4606  SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
4607  : staticBasis.size() + 1,
4608  linearIndex.getType());
4609  build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4610  staticBasis);
4611 }
4612 
4613 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4614  OperationState &odsState,
4615  Value linearIndex, ValueRange basis,
4616  bool hasOuterBound) {
4617  if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
4618  hasOuterBound = false;
4619  basis = basis.drop_front();
4620  }
4621  SmallVector<Value> dynamicBasis;
4622  SmallVector<int64_t> staticBasis;
4623  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4624  staticBasis);
4625  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4626  hasOuterBound);
4627 }
4628 
4629 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4630  OperationState &odsState,
4631  Value linearIndex,
4632  ArrayRef<OpFoldResult> basis,
4633  bool hasOuterBound) {
4634  if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
4635  hasOuterBound = false;
4636  basis = basis.drop_front();
4637  }
4638  SmallVector<Value> dynamicBasis;
4639  SmallVector<int64_t> staticBasis;
4640  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
4641  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4642  hasOuterBound);
4643 }
4644 
4645 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4646  OperationState &odsState,
4647  Value linearIndex, ArrayRef<int64_t> basis,
4648  bool hasOuterBound) {
4649  build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound);
4650 }
4651 
4652 LogicalResult AffineDelinearizeIndexOp::verify() {
4653  ArrayRef<int64_t> staticBasis = getStaticBasis();
4654  if (getNumResults() != staticBasis.size() &&
4655  getNumResults() != staticBasis.size() + 1)
4656  return emitOpError("should return an index for each basis element and up "
4657  "to one extra index");
4658 
4659  auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4660  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4661  return emitOpError(
4662  "mismatch between dynamic and static basis (kDynamic marker but no "
4663  "corresponding dynamic basis entry) -- this can only happen due to an "
4664  "incorrect fold/rewrite");
4665 
4666  if (!llvm::all_of(staticBasis, [](int64_t v) {
4667  return v > 0 || ShapedType::isDynamic(v);
4668  }))
4669  return emitOpError("no basis element may be statically non-positive");
4670 
4671  return success();
4672 }
4673 
4674 /// Given mixed basis of affine.delinearize_index/linearize_index replace
4675 /// constant SSA values with the constant integer value and return the new
4676 /// static basis. In case no such candidate for replacement exists, this utility
4677 /// returns std::nullopt.
4678 static std::optional<SmallVector<int64_t>>
4680  MutableOperandRange mutableDynamicBasis,
4681  ArrayRef<Attribute> dynamicBasis) {
4682  uint64_t dynamicBasisIndex = 0;
4683  for (OpFoldResult basis : dynamicBasis) {
4684  if (basis) {
4685  mutableDynamicBasis.erase(dynamicBasisIndex);
4686  } else {
4687  ++dynamicBasisIndex;
4688  }
4689  }
4690 
4691  // No constant SSA value exists.
4692  if (dynamicBasisIndex == dynamicBasis.size())
4693  return std::nullopt;
4694 
4695  SmallVector<int64_t> staticBasis;
4696  for (OpFoldResult basis : mixedBasis) {
4697  std::optional<int64_t> basisVal = getConstantIntValue(basis);
4698  if (!basisVal)
4699  staticBasis.push_back(ShapedType::kDynamic);
4700  else
4701  staticBasis.push_back(*basisVal);
4702  }
4703 
4704  return staticBasis;
4705 }
4706 
4707 LogicalResult
4708 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4710  std::optional<SmallVector<int64_t>> maybeStaticBasis =
4711  foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
4712  adaptor.getDynamicBasis());
4713  if (maybeStaticBasis) {
4714  setStaticBasis(*maybeStaticBasis);
4715  return success();
4716  }
4717  // If we won't be doing any division or modulo (no basis or the one basis
4718  // element is purely advisory), simply return the input value.
4719  if (getNumResults() == 1) {
4720  result.push_back(getLinearIndex());
4721  return success();
4722  }
4723 
4724  if (adaptor.getLinearIndex() == nullptr)
4725  return failure();
4726 
4727  if (!adaptor.getDynamicBasis().empty())
4728  return failure();
4729 
4730  int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
4731  Type attrType = getLinearIndex().getType();
4732 
4733  ArrayRef<int64_t> staticBasis = getStaticBasis();
4734  if (hasOuterBound())
4735  staticBasis = staticBasis.drop_front();
4736  for (int64_t modulus : llvm::reverse(staticBasis)) {
4737  result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
4738  highPart = llvm::divideFloorSigned(highPart, modulus);
4739  }
4740  result.push_back(IntegerAttr::get(attrType, highPart));
4741  std::reverse(result.begin(), result.end());
4742  return success();
4743 }
4744 
4745 SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
4746  OpBuilder builder(getContext());
4747  if (hasOuterBound()) {
4748  if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4749  return getMixedValues(getStaticBasis().drop_front(),
4750  getDynamicBasis().drop_front(), builder);
4751 
4752  return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4753  builder);
4754  }
4755 
4756  return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4757 }
4758 
4759 SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
4760  SmallVector<OpFoldResult> ret = getMixedBasis();
4761  if (!hasOuterBound())
4762  ret.insert(ret.begin(), OpFoldResult());
4763  return ret;
4764 }
4765 
4766 namespace {
4767 
4768 // Drops delinearization indices that correspond to unit-extent basis
4769 struct DropUnitExtentBasis
4770  : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4772 
4773  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4774  PatternRewriter &rewriter) const override {
4775  SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
4776  std::optional<Value> zero = std::nullopt;
4777  Location loc = delinearizeOp->getLoc();
4778  auto getZero = [&]() -> Value {
4779  if (!zero)
4780  zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
4781  return zero.value();
4782  };
4783 
4784  // Replace all indices corresponding to unit-extent basis with 0.
4785  // Remaining basis can be used to get a new `affine.delinearize_index` op.
4786  SmallVector<OpFoldResult> newBasis;
4787  for (auto [index, basis] :
4788  llvm::enumerate(delinearizeOp.getPaddedBasis())) {
4789  std::optional<int64_t> basisVal =
4790  basis ? getConstantIntValue(basis) : std::nullopt;
4791  if (basisVal == 1)
4792  replacements[index] = getZero();
4793  else
4794  newBasis.push_back(basis);
4795  }
4796 
4797  if (newBasis.size() == delinearizeOp.getNumResults())
4798  return rewriter.notifyMatchFailure(delinearizeOp,
4799  "no unit basis elements");
4800 
4801  if (!newBasis.empty()) {
4802  // Will drop the leading nullptr from `basis` if there was no outer bound.
4803  auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
4804  loc, delinearizeOp.getLinearIndex(), newBasis);
4805  int newIndex = 0;
4806  // Map back the new delinearized indices to the values they replace.
4807  for (auto &replacement : replacements) {
4808  if (replacement)
4809  continue;
4810  replacement = newDelinearizeOp->getResult(newIndex++);
4811  }
4812  }
4813 
4814  rewriter.replaceOp(delinearizeOp, replacements);
4815  return success();
4816  }
4817 };
4818 
4819 /// If a `affine.delinearize_index`'s input is a `affine.linearize_index
4820 /// disjoint` and the two operations end with the same basis elements,
4821 /// cancel those parts of the operations out because they are inverses
4822 /// of each other.
4823 ///
4824 /// If the operations have the same basis, cancel them entirely.
4825 ///
4826 /// The `disjoint` flag is needed on the `affine.linearize_index` because
4827 /// otherwise, there is no guarantee that the inputs to the linearization are
4828 /// in-bounds the way the outputs of the delinearization would be.
4829 struct CancelDelinearizeOfLinearizeDisjointExactTail
4830  : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4832 
4833  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4834  PatternRewriter &rewriter) const override {
4835  auto linearizeOp = delinearizeOp.getLinearIndex()
4836  .getDefiningOp<affine::AffineLinearizeIndexOp>();
4837  if (!linearizeOp)
4838  return rewriter.notifyMatchFailure(delinearizeOp,
4839  "index doesn't come from linearize");
4840 
4841  if (!linearizeOp.getDisjoint())
4842  return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
4843 
4844  ValueRange linearizeIns = linearizeOp.getMultiIndex();
4845  // Note: we use the full basis so we don't lose outer bounds later.
4846  SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
4847  SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
4848  size_t numMatches = 0;
4849  for (auto [linSize, delinSize] : llvm::zip(
4850  llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
4851  if (linSize != delinSize)
4852  break;
4853  ++numMatches;
4854  }
4855 
4856  if (numMatches == 0)
4857  return rewriter.notifyMatchFailure(
4858  delinearizeOp, "final basis element doesn't match linearize");
4859 
4860  // The easy case: everything lines up and the basis match sup completely.
4861  if (numMatches == linearizeBasis.size() &&
4862  numMatches == delinearizeBasis.size() &&
4863  linearizeIns.size() == delinearizeOp.getNumResults()) {
4864  rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
4865  return success();
4866  }
4867 
4868  Value newLinearize = rewriter.create<affine::AffineLinearizeIndexOp>(
4869  linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
4870  ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
4871  linearizeOp.getDisjoint());
4872  auto newDelinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
4873  delinearizeOp.getLoc(), newLinearize,
4874  ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
4875  delinearizeOp.hasOuterBound());
4876  SmallVector<Value> mergedResults(newDelinearize.getResults());
4877  mergedResults.append(linearizeIns.take_back(numMatches).begin(),
4878  linearizeIns.take_back(numMatches).end());
4879  rewriter.replaceOp(delinearizeOp, mergedResults);
4880  return success();
4881  }
4882 };
4883 
4884 /// If the input to a delinearization is a disjoint linearization, and the
4885 /// last k > 1 components of the delinearization basis multiply to the
4886 /// last component of the linearization basis, break the linearization and
4887 /// delinearization into two parts, peeling off the last input to linearization.
4888 ///
4889 /// For example:
4890 /// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
4891 /// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
4892 /// becomes
4893 /// %0 = affine.linearize_index [%z, %y] by (3, 2) : index
4894 /// %1:2 = affine.delinearize_index %0 by (2, 3) : index
4895 /// %2:2 = affine.delinearize_index %x by (8, 4) : index
4896 /// where the original %1:4 is replaced by %1:2 ++ %2:2
4897 struct SplitDelinearizeSpanningLastLinearizeArg final
4898  : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4900 
4901  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4902  PatternRewriter &rewriter) const override {
4903  auto linearizeOp = delinearizeOp.getLinearIndex()
4904  .getDefiningOp<affine::AffineLinearizeIndexOp>();
4905  if (!linearizeOp)
4906  return rewriter.notifyMatchFailure(delinearizeOp,
4907  "index doesn't come from linearize");
4908 
4909  if (!linearizeOp.getDisjoint())
4910  return rewriter.notifyMatchFailure(linearizeOp,
4911  "linearize isn't disjoint");
4912 
4913  int64_t target = linearizeOp.getStaticBasis().back();
4914  if (ShapedType::isDynamic(target))
4915  return rewriter.notifyMatchFailure(
4916  linearizeOp, "linearize ends with dynamic basis value");
4917 
4918  int64_t sizeToSplit = 1;
4919  size_t elemsToSplit = 0;
4920  ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
4921  for (int64_t basisElem : llvm::reverse(basis)) {
4922  if (ShapedType::isDynamic(basisElem))
4923  return rewriter.notifyMatchFailure(
4924  delinearizeOp, "dynamic basis element while scanning for split");
4925  sizeToSplit *= basisElem;
4926  elemsToSplit += 1;
4927 
4928  if (sizeToSplit > target)
4929  return rewriter.notifyMatchFailure(delinearizeOp,
4930  "overshot last argument size");
4931  if (sizeToSplit == target)
4932  break;
4933  }
4934 
4935  if (sizeToSplit < target)
4936  return rewriter.notifyMatchFailure(
4937  delinearizeOp, "product of known basis elements doesn't exceed last "
4938  "linearize argument");
4939 
4940  if (elemsToSplit < 2)
4941  return rewriter.notifyMatchFailure(
4942  delinearizeOp,
4943  "need at least two elements to form the basis product");
4944 
4945  Value linearizeWithoutBack =
4946  rewriter.create<affine::AffineLinearizeIndexOp>(
4947  linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
4948  linearizeOp.getDynamicBasis(),
4949  linearizeOp.getStaticBasis().drop_back(),
4950  linearizeOp.getDisjoint());
4951  auto delinearizeWithoutSplitPart =
4952  rewriter.create<affine::AffineDelinearizeIndexOp>(
4953  delinearizeOp.getLoc(), linearizeWithoutBack,
4954  delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
4955  delinearizeOp.hasOuterBound());
4956  auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
4957  delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
4958  basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
4959  SmallVector<Value> results = llvm::to_vector(
4960  llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
4961  delinearizeBack.getResults()));
4962  rewriter.replaceOp(delinearizeOp, results);
4963 
4964  return success();
4965  }
4966 };
4967 } // namespace
4968 
4969 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
4970  RewritePatternSet &patterns, MLIRContext *context) {
4971  patterns
4972  .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
4973  DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
4974  context);
4975 }
4976 
4977 //===----------------------------------------------------------------------===//
4978 // LinearizeIndexOp
4979 //===----------------------------------------------------------------------===//
4980 
4981 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4982  OperationState &odsState,
4983  ValueRange multiIndex, ValueRange basis,
4984  bool disjoint) {
4985  if (!basis.empty() && basis.front() == Value())
4986  basis = basis.drop_front();
4987  SmallVector<Value> dynamicBasis;
4988  SmallVector<int64_t> staticBasis;
4989  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4990  staticBasis);
4991  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4992 }
4993 
4994 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4995  OperationState &odsState,
4996  ValueRange multiIndex,
4997  ArrayRef<OpFoldResult> basis,
4998  bool disjoint) {
4999  if (!basis.empty() && basis.front() == OpFoldResult())
5000  basis = basis.drop_front();
5001  SmallVector<Value> dynamicBasis;
5002  SmallVector<int64_t> staticBasis;
5003  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
5004  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5005 }
5006 
5007 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5008  OperationState &odsState,
5009  ValueRange multiIndex,
5010  ArrayRef<int64_t> basis, bool disjoint) {
5011  build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
5012 }
5013 
5014 LogicalResult AffineLinearizeIndexOp::verify() {
5015  size_t numIndexes = getMultiIndex().size();
5016  size_t numBasisElems = getStaticBasis().size();
5017  if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
5018  return emitOpError("should be passed a basis element for each index except "
5019  "possibly the first");
5020 
5021  auto dynamicMarkersCount =
5022  llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
5023  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
5024  return emitOpError(
5025  "mismatch between dynamic and static basis (kDynamic marker but no "
5026  "corresponding dynamic basis entry) -- this can only happen due to an "
5027  "incorrect fold/rewrite");
5028 
5029  return success();
5030 }
5031 
5032 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
5033  std::optional<SmallVector<int64_t>> maybeStaticBasis =
5034  foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
5035  adaptor.getDynamicBasis());
5036  if (maybeStaticBasis) {
5037  setStaticBasis(*maybeStaticBasis);
5038  return getResult();
5039  }
5040  // No indices linearizes to zero.
5041  if (getMultiIndex().empty())
5042  return IntegerAttr::get(getResult().getType(), 0);
5043 
5044  // One single index linearizes to itself.
5045  if (getMultiIndex().size() == 1)
5046  return getMultiIndex().front();
5047 
5048  if (llvm::is_contained(adaptor.getMultiIndex(), nullptr))
5049  return nullptr;
5050 
5051  if (!adaptor.getDynamicBasis().empty())
5052  return nullptr;
5053 
5054  int64_t result = 0;
5055  int64_t stride = 1;
5056  for (auto [length, indexAttr] :
5057  llvm::zip_first(llvm::reverse(getStaticBasis()),
5058  llvm::reverse(adaptor.getMultiIndex()))) {
5059  result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
5060  stride = stride * length;
5061  }
5062  // Handle the index element with no basis element.
5063  if (!hasOuterBound())
5064  result =
5065  result +
5066  cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
5067 
5068  return IntegerAttr::get(getResult().getType(), result);
5069 }
5070 
5071 SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
5072  OpBuilder builder(getContext());
5073  if (hasOuterBound()) {
5074  if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5075  return getMixedValues(getStaticBasis().drop_front(),
5076  getDynamicBasis().drop_front(), builder);
5077 
5078  return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5079  builder);
5080  }
5081 
5082  return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5083 }
5084 
5085 SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
5086  SmallVector<OpFoldResult> ret = getMixedBasis();
5087  if (!hasOuterBound())
5088  ret.insert(ret.begin(), OpFoldResult());
5089  return ret;
5090 }
5091 
5092 namespace {
5093 /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
5094 /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
5095 /// %...d)`.
5096 
5097 /// Note that `disjoint` is required here, because, without it, we could have
5098 /// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
5099 /// is a valid operation where the `%c64` cannot be trivially dropped.
5100 ///
5101 /// Alternatively, if `%x` in the above is a known constant 0, remove it even if
5102 /// the operation isn't asserted to be `disjoint`.
5103 struct DropLinearizeUnitComponentsIfDisjointOrZero final
5104  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5106 
5107  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5108  PatternRewriter &rewriter) const override {
5109  ValueRange multiIndex = op.getMultiIndex();
5110  size_t numIndices = multiIndex.size();
5111  SmallVector<Value> newIndices;
5112  newIndices.reserve(numIndices);
5113  SmallVector<OpFoldResult> newBasis;
5114  newBasis.reserve(numIndices);
5115 
5116  if (!op.hasOuterBound()) {
5117  newIndices.push_back(multiIndex.front());
5118  multiIndex = multiIndex.drop_front();
5119  }
5120 
5121  SmallVector<OpFoldResult> basis = op.getMixedBasis();
5122  for (auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5123  std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
5124  if (!basisEntry || *basisEntry != 1) {
5125  newIndices.push_back(index);
5126  newBasis.push_back(basisElem);
5127  continue;
5128  }
5129 
5130  std::optional<int64_t> indexValue = getConstantIntValue(index);
5131  if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5132  newIndices.push_back(index);
5133  newBasis.push_back(basisElem);
5134  continue;
5135  }
5136  }
5137  if (newIndices.size() == numIndices)
5138  return rewriter.notifyMatchFailure(op,
5139  "no unit basis entries to replace");
5140 
5141  if (newIndices.size() == 0) {
5142  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
5143  return success();
5144  }
5145  rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5146  op, newIndices, newBasis, op.getDisjoint());
5147  return success();
5148  }
5149 };
5150 
5152  ArrayRef<OpFoldResult> terms) {
5153  int64_t nDynamic = 0;
5154  SmallVector<Value> dynamicPart;
5155  AffineExpr result = builder.getAffineConstantExpr(1);
5156  for (OpFoldResult term : terms) {
5157  if (!term)
5158  return term;
5159  std::optional<int64_t> maybeConst = getConstantIntValue(term);
5160  if (maybeConst) {
5161  result = result * builder.getAffineConstantExpr(*maybeConst);
5162  } else {
5163  dynamicPart.push_back(cast<Value>(term));
5164  result = result * builder.getAffineSymbolExpr(nDynamic++);
5165  }
5166  }
5167  if (auto constant = dyn_cast<AffineConstantExpr>(result))
5168  return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
5169  return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
5170 }
5171 
5172 /// If conseceutive outputs of a delinearize_index are linearized with the same
5173 /// bounds, canonicalize away the redundant arithmetic.
5174 ///
5175 /// That is, if we have
5176 /// ```
5177 /// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
5178 /// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
5179 /// by (...e, B1, B2, ..., BK, ...f)
5180 /// ```
5181 ///
5182 /// We can rewrite this to
5183 /// ```
5184 /// B = B1 * B2 ... BK
5185 /// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
5186 /// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
5187 /// ```
5188 /// where we replace all results of %s unaffected by the change with results
5189 /// from %sMerged.
5190 ///
5191 /// As a special case, if all results of the delinearize are merged in this way
5192 /// we can replace those usages with %x, thus cancelling the delinearization
5193 /// entirely, as in
5194 /// ```
5195 /// %s:3 = affine.delinearize_index %x into (2, 4, 8)
5196 /// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
5197 /// ```
5198 /// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
5199 struct CancelLinearizeOfDelinearizePortion final
5200  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5202 
5203 private:
5204  // Struct representing a case where the cancellation pattern
5205  // applies. A `Match` means that `length` inputs to the linearize operation
5206  // starting at `linStart` can be cancelled with `length` outputs of
5207  // `delinearize`, starting from `delinStart`.
5208  struct Match {
5209  AffineDelinearizeIndexOp delinearize;
5210  unsigned linStart = 0;
5211  unsigned delinStart = 0;
5212  unsigned length = 0;
5213  };
5214 
5215 public:
5216  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5217  PatternRewriter &rewriter) const override {
5218  SmallVector<Match> matches;
5219 
5220  const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
5221  ArrayRef<OpFoldResult> linBasisRef = linBasis;
5222 
5223  ValueRange multiIndex = linearizeOp.getMultiIndex();
5224  unsigned numLinArgs = multiIndex.size();
5225  unsigned linArgIdx = 0;
5226  // We only want to replace one run from the same delinearize op per
5227  // pattern invocation lest we run into invalidation issues.
5228  llvm::SmallPtrSet<Operation *, 2> alreadyMatchedDelinearize;
5229  while (linArgIdx < numLinArgs) {
5230  auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5231  if (!asResult) {
5232  linArgIdx++;
5233  continue;
5234  }
5235 
5236  auto delinearizeOp =
5237  dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5238  if (!delinearizeOp) {
5239  linArgIdx++;
5240  continue;
5241  }
5242 
5243  /// Result 0 of the delinearize and argument 0 of the linearize can
5244  /// leave their maximum value unspecified. However, even if this happens
5245  /// we can still sometimes start the match process. Specifically, if
5246  /// - The argument we're matching is result 0 and argument 0 (so the
5247  /// bounds don't matter). For example,
5248  ///
5249  /// %0:2 = affine.delinearize_index %x into (8) : index, index
5250  /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
5251  /// allows cancellation
5252  /// - The delinearization doesn't specify a bound, but the linearization
5253  /// is `disjoint`, which asserts that the bound on the linearization is
5254  /// correct.
5255  unsigned delinArgIdx = asResult.getResultNumber();
5256  SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
5257  OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5258  OpFoldResult firstLinBound = linBasis[linArgIdx];
5259  bool boundsMatch = firstDelinBound == firstLinBound;
5260  bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5261  bool knownByDisjoint =
5262  linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5263  if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5264  linArgIdx++;
5265  continue;
5266  }
5267 
5268  unsigned j = 1;
5269  unsigned numDelinOuts = delinearizeOp.getNumResults();
5270  for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
5271  ++j) {
5272  if (multiIndex[linArgIdx + j] !=
5273  delinearizeOp.getResult(delinArgIdx + j))
5274  break;
5275  if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
5276  break;
5277  }
5278  // If there're multiple matches against the same delinearize_index,
5279  // only rewrite the first one we find to prevent invalidations. The next
5280  // ones will be taken care of by subsequent pattern invocations.
5281  if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5282  linArgIdx++;
5283  continue;
5284  }
5285  matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx, j});
5286  linArgIdx += j;
5287  }
5288 
5289  if (matches.empty())
5290  return rewriter.notifyMatchFailure(
5291  linearizeOp, "no run of delinearize outputs to deal with");
5292 
5293  // Record all the delinearize replacements so we can do them after creating
5294  // the new linearization operation, since the new operation might use
5295  // outputs of something we're replacing.
5296  SmallVector<SmallVector<Value>> delinearizeReplacements;
5297 
5298  SmallVector<Value> newIndex;
5299  newIndex.reserve(numLinArgs);
5300  SmallVector<OpFoldResult> newBasis;
5301  newBasis.reserve(numLinArgs);
5302  unsigned prevMatchEnd = 0;
5303  for (Match m : matches) {
5304  unsigned gap = m.linStart - prevMatchEnd;
5305  llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5306  llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5307  // Update here so we don't forget this during early continues
5308  prevMatchEnd = m.linStart + m.length;
5309 
5310  PatternRewriter::InsertionGuard g(rewriter);
5311  rewriter.setInsertionPoint(m.delinearize);
5312 
5313  ArrayRef<OpFoldResult> basisToMerge =
5314  linBasisRef.slice(m.linStart, m.length);
5315  // We use the slice from the linearize's basis above because of the
5316  // "bounds inferred from `disjoint`" case above.
5317  OpFoldResult newSize =
5318  computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
5319 
5320  // Trivial case where we can just skip past the delinearize all together
5321  if (m.length == m.delinearize.getNumResults()) {
5322  newIndex.push_back(m.delinearize.getLinearIndex());
5323  newBasis.push_back(newSize);
5324  // Pad out set of replacements so we don't do anything with this one.
5325  delinearizeReplacements.push_back(SmallVector<Value>());
5326  continue;
5327  }
5328 
5329  SmallVector<Value> newDelinResults;
5330  SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
5331  newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5332  newDelinBasis.begin() + m.delinStart + m.length);
5333  newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5334  auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
5335  m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5336  newDelinBasis);
5337 
5338  // Since there may be other uses of the indices we just merged together,
5339  // create a residual affine.delinearize_index that delinearizes the
5340  // merged output into its component parts.
5341  Value combinedElem = newDelinearize.getResult(m.delinStart);
5342  auto residualDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
5343  m.delinearize.getLoc(), combinedElem, basisToMerge);
5344 
5345  // Swap all the uses of the unaffected delinearize outputs to the new
5346  // delinearization so that the old code can be removed if this
5347  // linearize_index is the only user of the merged results.
5348  llvm::append_range(newDelinResults,
5349  newDelinearize.getResults().take_front(m.delinStart));
5350  llvm::append_range(newDelinResults, residualDelinearize.getResults());
5351  llvm::append_range(
5352  newDelinResults,
5353  newDelinearize.getResults().drop_front(m.delinStart + 1));
5354 
5355  delinearizeReplacements.push_back(newDelinResults);
5356  newIndex.push_back(combinedElem);
5357  newBasis.push_back(newSize);
5358  }
5359  llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5360  llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5361  rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
5362  linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5363 
5364  for (auto [m, newResults] :
5365  llvm::zip_equal(matches, delinearizeReplacements)) {
5366  if (newResults.empty())
5367  continue;
5368  rewriter.replaceOp(m.delinearize, newResults);
5369  }
5370 
5371  return success();
5372  }
5373 };
5374 
5375 /// Strip leading zero from affine.linearize_index.
5376 ///
5377 /// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
5378 /// to `affine.linearize_index [...a] by (...b)` in all cases.
5379 struct DropLinearizeLeadingZero final
5380  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5382 
5383  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5384  PatternRewriter &rewriter) const override {
5385  Value leadingIdx = op.getMultiIndex().front();
5386  if (!matchPattern(leadingIdx, m_Zero()))
5387  return failure();
5388 
5389  if (op.getMultiIndex().size() == 1) {
5390  rewriter.replaceOp(op, leadingIdx);
5391  return success();
5392  }
5393 
5394  SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
5395  ArrayRef<OpFoldResult> newMixedBasis = mixedBasis;
5396  if (op.hasOuterBound())
5397  newMixedBasis = newMixedBasis.drop_front();
5398 
5399  rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5400  op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5401  return success();
5402  }
5403 };
5404 } // namespace
5405 
5406 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5407  RewritePatternSet &patterns, MLIRContext *context) {
5408  patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5409  DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5410 }
5411 
5412 //===----------------------------------------------------------------------===//
5413 // TableGen'd op method definitions
5414 //===----------------------------------------------------------------------===//
5415 
5416 #define GET_OP_CLASSES
5417 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
Definition: AffineOps.cpp:2741
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
Definition: AffineOps.cpp:2462
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
Definition: AffineOps.cpp:1196
static LogicalResult verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
Definition: AffineOps.cpp:3133
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
Definition: AffineOps.cpp:3311
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
Definition: AffineOps.cpp:3944
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
Definition: AffineOps.cpp:60
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
Definition: AffineOps.cpp:4091
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
Definition: AffineOps.cpp:1032
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
Definition: AffineOps.cpp:3347
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
Definition: AffineOps.cpp:2316
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
Definition: AffineOps.cpp:818
static bool isValidAffineIndexOperand(Value value, Region *region)
Definition: AffineOps.cpp:488
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1439
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
Definition: AffineOps.cpp:1103
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:753
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
Definition: AffineOps.cpp:2005
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:745
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
Definition: AffineOps.cpp:3026
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
Definition: AffineOps.cpp:1055
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1345
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
Definition: AffineOps.cpp:4493
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
Definition: AffineOps.cpp:921
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
Definition: AffineOps.cpp:3324
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
Definition: AffineOps.cpp:347
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
Definition: AffineOps.cpp:693
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
Definition: AffineOps.cpp:4206
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
Definition: AffineOps.cpp:2750
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
Definition: AffineOps.cpp:493
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
Definition: AffineOps.cpp:655
static void legalizeDemotedDims(MapOrSet &mapOrSet, SmallVectorImpl< Value > &operands)
A valid affine dimension may appear as a symbol in affine.apply operations.
Definition: AffineOps.cpp:1393
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1282
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
Definition: AffineOps.cpp:2700
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
Definition: AffineOps.cpp:2270
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
Definition: AffineOps.cpp:525
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1297
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
Definition: AffineOps.cpp:366
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr &quotientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
Definition: AffineOps.cpp:721
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
Definition: AffineOps.cpp:4159
static LogicalResult verifyAffineMinMaxOp(T op)
Definition: AffineOps.cpp:3298
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
Definition: AffineOps.cpp:2180
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
Definition: AffineOps.cpp:4679
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
Definition: AffineOps.cpp:3510
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
union mlir::linalg::@1204::ArityGroupAndKind::Kind kind
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:76
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:81
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:921
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:35
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
Definition: AffineExpr.cpp:243
MLIRContext * getContext() const
Definition: AffineExpr.cpp:33
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:659
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
Definition: AffineMap.h:221
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:267
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
Definition: AffineMap.h:228
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineMap.cpp:500
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
Definition: AffineMap.h:280
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:515
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
Definition: AffineMap.cpp:128
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:651
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
Definition: AffineMap.cpp:434
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:161
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:226
AffineMap getDimIdentityMap()
Definition: Builders.cpp:381
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:385
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:366
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:370
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
Definition: Builders.cpp:177
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:110
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:69
NoneType getNoneType()
Definition: Builders.cpp:86
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:98
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
Definition: Builders.cpp:374
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:376
MLIRContext * getContext() const
Definition: Builders.h:55
AffineMap getSymbolIdentityMap()
Definition: Builders.cpp:394
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:264
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:279
IndexType getIndexType()
Definition: Builders.cpp:53
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition: IntegerSet.h:44
unsigned getNumDims() const
Definition: IntegerSet.cpp:15
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
Definition: IntegerSet.cpp:57
unsigned getNumInputs() const
Definition: IntegerSet.cpp:17
ArrayRef< AffineExpr > getConstraints() const
Definition: IntegerSet.cpp:41
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
Definition: IntegerSet.cpp:51
unsigned getNumSymbols() const
Definition: IntegerSet.cpp:16
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:428
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:451
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
operand_range::iterator operand_iterator
Definition: Operation.h:372
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:811
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:578
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
AffineBound represents a lower or upper bound in the for operation.
Definition: AffineOps.h:523
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
Definition: AffineOps.h:106
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
Definition: AffineOps.h:315
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
Definition: AffineOps.cpp:4051
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
unsigned getNumResults() const
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
Definition: AffineOps.cpp:2763
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
Definition: AffineOps.cpp:1165
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
Definition: AffineOps.cpp:2677
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
Definition: AffineOps.cpp:288
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1271
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
Definition: AffineOps.cpp:2649
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: AffineOps.cpp:2653
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1175
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1516
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1336
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
Definition: AffineOps.cpp:2641
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1329
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Definition: AffineOps.cpp:248
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
Definition: AffineOps.cpp:273
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
Definition: AffineOps.cpp:1521
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
Definition: AffineOps.cpp:2684
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
Definition: AffineOps.cpp:410
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1225
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Definition: AffineOps.cpp:2664
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
Definition: AffineOps.cpp:263
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
Definition: AffineOps.cpp:503
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
Definition: AffineOps.cpp:2645
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
Definition: AffineOps.cpp:1291
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:770
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
Definition: AffineMap.cpp:780
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
AffineExprKind
Definition: AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:70
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
Definition: AffineMap.cpp:742
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:631
Canonicalize the affine map result expression order of an affine min/max operation.
Definition: AffineOps.cpp:3564
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3567
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3581
Remove duplicated expressions in affine min/max ops.
Definition: AffineOps.cpp:3380
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3383
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
Definition: AffineOps.cpp:3423
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3426
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:283
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.