MLIR  20.0.0git
AffineOps.cpp
Go to the documentation of this file.
1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/IntegerSet.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/IR/PatternMatch.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 #include "llvm/ADT/SmallVectorExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/MathExtras.h"
30 #include <numeric>
31 #include <optional>
32 
33 using namespace mlir;
34 using namespace mlir::affine;
35 
36 using llvm::divideCeilSigned;
37 using llvm::divideFloorSigned;
38 using llvm::mod;
39 
40 #define DEBUG_TYPE "affine-ops"
41 
42 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
43 
44 /// A utility function to check if a value is defined at the top level of
45 /// `region` or is an argument of `region`. A value of index type defined at the
46 /// top level of a `AffineScope` region is always a valid symbol for all
47 /// uses in that region.
49  if (auto arg = llvm::dyn_cast<BlockArgument>(value))
50  return arg.getParentRegion() == region;
51  return value.getDefiningOp()->getParentRegion() == region;
52 }
53 
54 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
55 /// region remains legal if the operation that uses it is inlined into `dest`
56 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
57 /// `isValidSymbol`, depending on the value being required to remain a valid
58 /// dimension or symbol.
59 static bool
61  const IRMapping &mapping,
62  function_ref<bool(Value, Region *)> legalityCheck) {
63  // If the value is a valid dimension for any other reason than being
64  // a top-level value, it will remain valid: constants get inlined
65  // with the function, transitive affine applies also get inlined and
66  // will be checked themselves, etc.
67  if (!isTopLevelValue(value, src))
68  return true;
69 
70  // If it's a top-level value because it's a block operand, i.e. a
71  // function argument, check whether the value replacing it after
72  // inlining is a valid dimension in the new region.
73  if (llvm::isa<BlockArgument>(value))
74  return legalityCheck(mapping.lookup(value), dest);
75 
76  // If it's a top-level value because it's defined in the region,
77  // it can only be inlined if the defining op is a constant or a
78  // `dim`, which can appear anywhere and be valid, since the defining
79  // op won't be top-level anymore after inlining.
80  Attribute operandCst;
81  bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp());
82  return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
83  isDimLikeOp;
84 }
85 
86 /// Checks if all values known to be legal affine dimensions or symbols in `src`
87 /// remain so if their respective users are inlined into `dest`.
88 static bool
90  const IRMapping &mapping,
91  function_ref<bool(Value, Region *)> legalityCheck) {
92  return llvm::all_of(values, [&](Value v) {
93  return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
94  });
95 }
96 
97 /// Checks if an affine read or write operation remains legal after inlining
98 /// from `src` to `dest`.
99 template <typename OpTy>
100 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
101  const IRMapping &mapping) {
102  static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
103  AffineWriteOpInterface>::value,
104  "only ops with affine read/write interface are supported");
105 
106  AffineMap map = op.getAffineMap();
107  ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
108  ValueRange symbolOperands =
109  op.getMapOperands().take_back(map.getNumSymbols());
111  dimOperands, src, dest, mapping,
112  static_cast<bool (*)(Value, Region *)>(isValidDim)))
113  return false;
115  symbolOperands, src, dest, mapping,
116  static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
117  return false;
118  return true;
119 }
120 
121 /// Checks if an affine apply operation remains legal after inlining from `src`
122 /// to `dest`.
123 // Use "unused attribute" marker to silence clang-tidy warning stemming from
124 // the inability to see through "llvm::TypeSwitch".
125 template <>
126 bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op,
127  Region *src, Region *dest,
128  const IRMapping &mapping) {
129  // If it's a valid dimension, we need to check that it remains so.
130  if (isValidDim(op.getResult(), src))
132  op.getMapOperands(), src, dest, mapping,
133  static_cast<bool (*)(Value, Region *)>(isValidDim));
134 
135  // Otherwise it must be a valid symbol, check that it remains so.
137  op.getMapOperands(), src, dest, mapping,
138  static_cast<bool (*)(Value, Region *)>(isValidSymbol));
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // AffineDialect Interfaces
143 //===----------------------------------------------------------------------===//
144 
145 namespace {
146 /// This class defines the interface for handling inlining with affine
147 /// operations.
148 struct AffineInlinerInterface : public DialectInlinerInterface {
150 
151  //===--------------------------------------------------------------------===//
152  // Analysis Hooks
153  //===--------------------------------------------------------------------===//
154 
155  /// Returns true if the given region 'src' can be inlined into the region
156  /// 'dest' that is attached to an operation registered to the current dialect.
157  /// 'wouldBeCloned' is set if the region is cloned into its new location
158  /// rather than moved, indicating there may be other users.
159  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
160  IRMapping &valueMapping) const final {
161  // We can inline into affine loops and conditionals if this doesn't break
162  // affine value categorization rules.
163  Operation *destOp = dest->getParentOp();
164  if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
165  return false;
166 
167  // Multi-block regions cannot be inlined into affine constructs, all of
168  // which require single-block regions.
169  if (!llvm::hasSingleElement(*src))
170  return false;
171 
172  // Side-effecting operations that the affine dialect cannot understand
173  // should not be inlined.
174  Block &srcBlock = src->front();
175  for (Operation &op : srcBlock) {
176  // Ops with no side effects are fine,
177  if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
178  if (iface.hasNoEffect())
179  continue;
180  }
181 
182  // Assuming the inlined region is valid, we only need to check if the
183  // inlining would change it.
184  bool remainsValid =
186  .Case<AffineApplyOp, AffineReadOpInterface,
187  AffineWriteOpInterface>([&](auto op) {
188  return remainsLegalAfterInline(op, src, dest, valueMapping);
189  })
190  .Default([](Operation *) {
191  // Conservatively disallow inlining ops we cannot reason about.
192  return false;
193  });
194 
195  if (!remainsValid)
196  return false;
197  }
198 
199  return true;
200  }
201 
202  /// Returns true if the given operation 'op', that is registered to this
203  /// dialect, can be inlined into the given region, false otherwise.
204  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
205  IRMapping &valueMapping) const final {
206  // Always allow inlining affine operations into a region that is marked as
207  // affine scope, or into affine loops and conditionals. There are some edge
208  // cases when inlining *into* affine structures, but that is handled in the
209  // other 'isLegalToInline' hook above.
210  Operation *parentOp = region->getParentOp();
211  return parentOp->hasTrait<OpTrait::AffineScope>() ||
212  isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
213  }
214 
215  /// Affine regions should be analyzed recursively.
216  bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
217 };
218 } // namespace
219 
220 //===----------------------------------------------------------------------===//
221 // AffineDialect
222 //===----------------------------------------------------------------------===//
223 
224 void AffineDialect::initialize() {
225  addOperations<AffineDmaStartOp, AffineDmaWaitOp,
226 #define GET_OP_LIST
227 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
228  >();
229  addInterfaces<AffineInlinerInterface>();
230  declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
231  AffineMinOp>();
232 }
233 
234 /// Materialize a single constant operation from a given attribute value with
235 /// the desired resultant type.
237  Attribute value, Type type,
238  Location loc) {
239  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
240  return builder.create<ub::PoisonOp>(loc, type, poison);
241  return arith::ConstantOp::materialize(builder, value, type, loc);
242 }
243 
244 /// A utility function to check if a value is defined at the top level of an
245 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
246 /// conservatively assume it is not top-level. A value of index type defined at
247 /// the top level is always a valid symbol.
249  if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
250  // The block owning the argument may be unlinked, e.g. when the surrounding
251  // region has not yet been attached to an Op, at which point the parent Op
252  // is null.
253  Operation *parentOp = arg.getOwner()->getParentOp();
254  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
255  }
256  // The defining Op may live in an unlinked block so its parent Op may be null.
257  Operation *parentOp = value.getDefiningOp()->getParentOp();
258  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
259 }
260 
261 /// Returns the closest region enclosing `op` that is held by an operation with
262 /// trait `AffineScope`; `nullptr` if there is no such region.
264  auto *curOp = op;
265  while (auto *parentOp = curOp->getParentOp()) {
266  if (parentOp->hasTrait<OpTrait::AffineScope>())
267  return curOp->getParentRegion();
268  curOp = parentOp;
269  }
270  return nullptr;
271 }
272 
273 // A Value can be used as a dimension id iff it meets one of the following
274 // conditions:
275 // *) It is valid as a symbol.
276 // *) It is an induction variable.
277 // *) It is the result of affine apply operation with dimension id arguments.
279  // The value must be an index type.
280  if (!value.getType().isIndex())
281  return false;
282 
283  if (auto *defOp = value.getDefiningOp())
284  return isValidDim(value, getAffineScope(defOp));
285 
286  // This value has to be a block argument for an op that has the
287  // `AffineScope` trait or for an affine.for or affine.parallel.
288  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
289  return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
290  isa<AffineForOp, AffineParallelOp>(parentOp));
291 }
292 
293 // Value can be used as a dimension id iff it meets one of the following
294 // conditions:
295 // *) It is valid as a symbol.
296 // *) It is an induction variable.
297 // *) It is the result of an affine apply operation with dimension id operands.
298 bool mlir::affine::isValidDim(Value value, Region *region) {
299  // The value must be an index type.
300  if (!value.getType().isIndex())
301  return false;
302 
303  // All valid symbols are okay.
304  if (isValidSymbol(value, region))
305  return true;
306 
307  auto *op = value.getDefiningOp();
308  if (!op) {
309  // This value has to be a block argument for an affine.for or an
310  // affine.parallel.
311  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
312  return isa<AffineForOp, AffineParallelOp>(parentOp);
313  }
314 
315  // Affine apply operation is ok if all of its operands are ok.
316  if (auto applyOp = dyn_cast<AffineApplyOp>(op))
317  return applyOp.isValidDim(region);
318  // The dim op is okay if its operand memref/tensor is defined at the top
319  // level.
320  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
321  return isTopLevelValue(dimOp.getShapedValue());
322  return false;
323 }
324 
325 /// Returns true if the 'index' dimension of the `memref` defined by
326 /// `memrefDefOp` is a statically shaped one or defined using a valid symbol
327 /// for `region`.
328 template <typename AnyMemRefDefOp>
329 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
330  Region *region) {
331  MemRefType memRefType = memrefDefOp.getType();
332 
333  // Dimension index is out of bounds.
334  if (index >= memRefType.getRank()) {
335  return false;
336  }
337 
338  // Statically shaped.
339  if (!memRefType.isDynamicDim(index))
340  return true;
341  // Get the position of the dimension among dynamic dimensions;
342  unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
343  return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
344  region);
345 }
346 
347 /// Returns true if the result of the dim op is a valid symbol for `region`.
348 static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
349  // The dim op is okay if its source is defined at the top level.
350  if (isTopLevelValue(dimOp.getShapedValue()))
351  return true;
352 
353  // Conservatively handle remaining BlockArguments as non-valid symbols.
354  // E.g. scf.for iterArgs.
355  if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
356  return false;
357 
358  // The dim op is also okay if its operand memref is a view/subview whose
359  // corresponding size is a valid symbol.
360  std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
361 
362  // Be conservative if we can't understand the dimension.
363  if (!index.has_value())
364  return false;
365 
366  // Skip over all memref.cast ops (if any).
367  Operation *op = dimOp.getShapedValue().getDefiningOp();
368  while (auto castOp = dyn_cast<memref::CastOp>(op)) {
369  // Bail on unranked memrefs.
370  if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
371  return false;
372  op = castOp.getSource().getDefiningOp();
373  if (!op)
374  return false;
375  }
376 
377  int64_t i = index.value();
379  .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
380  [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
381  .Default([](Operation *) { return false; });
382 }
383 
384 // A value can be used as a symbol (at all its use sites) iff it meets one of
385 // the following conditions:
386 // *) It is a constant.
387 // *) Its defining op or block arg appearance is immediately enclosed by an op
388 // with `AffineScope` trait.
389 // *) It is the result of an affine.apply operation with symbol operands.
390 // *) It is a result of the dim op on a memref whose corresponding size is a
391 // valid symbol.
393  if (!value)
394  return false;
395 
396  // The value must be an index type.
397  if (!value.getType().isIndex())
398  return false;
399 
400  // Check that the value is a top level value.
401  if (isTopLevelValue(value))
402  return true;
403 
404  if (auto *defOp = value.getDefiningOp())
405  return isValidSymbol(value, getAffineScope(defOp));
406 
407  return false;
408 }
409 
410 /// A value can be used as a symbol for `region` iff it meets one of the
411 /// following conditions:
412 /// *) It is a constant.
413 /// *) It is the result of an affine apply operation with symbol arguments.
414 /// *) It is a result of the dim op on a memref whose corresponding size is
415 /// a valid symbol.
416 /// *) It is defined at the top level of 'region' or is its argument.
417 /// *) It dominates `region`'s parent op.
418 /// If `region` is null, conservatively assume the symbol definition scope does
419 /// not exist and only accept the values that would be symbols regardless of
420 /// the surrounding region structure, i.e. the first three cases above.
422  // The value must be an index type.
423  if (!value.getType().isIndex())
424  return false;
425 
426  // A top-level value is a valid symbol.
427  if (region && ::isTopLevelValue(value, region))
428  return true;
429 
430  auto *defOp = value.getDefiningOp();
431  if (!defOp) {
432  // A block argument that is not a top-level value is a valid symbol if it
433  // dominates region's parent op.
434  Operation *regionOp = region ? region->getParentOp() : nullptr;
435  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
436  if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
437  return isValidSymbol(value, parentOpRegion);
438  return false;
439  }
440 
441  // Constant operation is ok.
442  Attribute operandCst;
443  if (matchPattern(defOp, m_Constant(&operandCst)))
444  return true;
445 
446  // Affine apply operation is ok if all of its operands are ok.
447  if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
448  return applyOp.isValidSymbol(region);
449 
450  // Dim op results could be valid symbols at any level.
451  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
452  return isDimOpValidSymbol(dimOp, region);
453 
454  // Check for values dominating `region`'s parent op.
455  Operation *regionOp = region ? region->getParentOp() : nullptr;
456  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
457  if (auto *parentRegion = region->getParentOp()->getParentRegion())
458  return isValidSymbol(value, parentRegion);
459 
460  return false;
461 }
462 
463 // Returns true if 'value' is a valid index to an affine operation (e.g.
464 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
465 // `region` provides the polyhedral symbol scope. Returns false otherwise.
466 static bool isValidAffineIndexOperand(Value value, Region *region) {
467  return isValidDim(value, region) || isValidSymbol(value, region);
468 }
469 
470 /// Prints dimension and symbol list.
473  unsigned numDims, OpAsmPrinter &printer) {
474  OperandRange operands(begin, end);
475  printer << '(' << operands.take_front(numDims) << ')';
476  if (operands.size() > numDims)
477  printer << '[' << operands.drop_front(numDims) << ']';
478 }
479 
480 /// Parses dimension and symbol list and returns true if parsing failed.
482  OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
484  if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
485  return failure();
486  // Store number of dimensions for validation by caller.
487  numDims = opInfos.size();
488 
489  // Parse the optional symbol operands.
490  auto indexTy = parser.getBuilder().getIndexType();
491  return failure(parser.parseOperandList(
493  parser.resolveOperands(opInfos, indexTy, operands));
494 }
495 
496 /// Utility function to verify that a set of operands are valid dimension and
497 /// symbol identifiers. The operands should be laid out such that the dimension
498 /// operands are before the symbol operands. This function returns failure if
499 /// there was an invalid operand. An operation is provided to emit any necessary
500 /// errors.
501 template <typename OpTy>
502 static LogicalResult
504  unsigned numDims) {
505  unsigned opIt = 0;
506  for (auto operand : operands) {
507  if (opIt++ < numDims) {
508  if (!isValidDim(operand, getAffineScope(op)))
509  return op.emitOpError("operand cannot be used as a dimension id");
510  } else if (!isValidSymbol(operand, getAffineScope(op))) {
511  return op.emitOpError("operand cannot be used as a symbol");
512  }
513  }
514  return success();
515 }
516 
517 //===----------------------------------------------------------------------===//
518 // AffineApplyOp
519 //===----------------------------------------------------------------------===//
520 
521 AffineValueMap AffineApplyOp::getAffineValueMap() {
522  return AffineValueMap(getAffineMap(), getOperands(), getResult());
523 }
524 
525 ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
526  auto &builder = parser.getBuilder();
527  auto indexTy = builder.getIndexType();
528 
529  AffineMapAttr mapAttr;
530  unsigned numDims;
531  if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
532  parseDimAndSymbolList(parser, result.operands, numDims) ||
533  parser.parseOptionalAttrDict(result.attributes))
534  return failure();
535  auto map = mapAttr.getValue();
536 
537  if (map.getNumDims() != numDims ||
538  numDims + map.getNumSymbols() != result.operands.size()) {
539  return parser.emitError(parser.getNameLoc(),
540  "dimension or symbol index mismatch");
541  }
542 
543  result.types.append(map.getNumResults(), indexTy);
544  return success();
545 }
546 
548  p << " " << getMapAttr();
549  printDimAndSymbolList(operand_begin(), operand_end(),
550  getAffineMap().getNumDims(), p);
551  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
552 }
553 
554 LogicalResult AffineApplyOp::verify() {
555  // Check input and output dimensions match.
556  AffineMap affineMap = getMap();
557 
558  // Verify that operand count matches affine map dimension and symbol count.
559  if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
560  return emitOpError(
561  "operand count and affine map dimension and symbol count must match");
562 
563  // Verify that the map only produces one result.
564  if (affineMap.getNumResults() != 1)
565  return emitOpError("mapping must produce one value");
566 
567  return success();
568 }
569 
570 // The result of the affine apply operation can be used as a dimension id if all
571 // its operands are valid dimension ids.
573  return llvm::all_of(getOperands(),
574  [](Value op) { return affine::isValidDim(op); });
575 }
576 
577 // The result of the affine apply operation can be used as a dimension id if all
578 // its operands are valid dimension ids with the parent operation of `region`
579 // defining the polyhedral scope for symbols.
580 bool AffineApplyOp::isValidDim(Region *region) {
581  return llvm::all_of(getOperands(),
582  [&](Value op) { return ::isValidDim(op, region); });
583 }
584 
585 // The result of the affine apply operation can be used as a symbol if all its
586 // operands are symbols.
588  return llvm::all_of(getOperands(),
589  [](Value op) { return affine::isValidSymbol(op); });
590 }
591 
592 // The result of the affine apply operation can be used as a symbol in `region`
593 // if all its operands are symbols in `region`.
594 bool AffineApplyOp::isValidSymbol(Region *region) {
595  return llvm::all_of(getOperands(), [&](Value operand) {
596  return affine::isValidSymbol(operand, region);
597  });
598 }
599 
600 OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
601  auto map = getAffineMap();
602 
603  // Fold dims and symbols to existing values.
604  auto expr = map.getResult(0);
605  if (auto dim = dyn_cast<AffineDimExpr>(expr))
606  return getOperand(dim.getPosition());
607  if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
608  return getOperand(map.getNumDims() + sym.getPosition());
609 
610  // Otherwise, default to folding the map.
612  bool hasPoison = false;
613  auto foldResult =
614  map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
615  if (hasPoison)
617  if (failed(foldResult))
618  return {};
619  return result[0];
620 }
621 
622 /// Returns the largest known divisor of `e`. Exploits information from the
623 /// values in `operands`.
624 static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
625  // This method isn't aware of `operands`.
626  int64_t div = e.getLargestKnownDivisor();
627 
628  // We now make use of operands for the case `e` is a dim expression.
629  // TODO: More powerful simplification would have to modify
630  // getLargestKnownDivisor to take `operands` and exploit that information as
631  // well for dim/sym expressions, but in that case, getLargestKnownDivisor
632  // can't be part of the IR library but of the `Analysis` library. The IR
633  // library can only really depend on simple O(1) checks.
634  auto dimExpr = dyn_cast<AffineDimExpr>(e);
635  // If it's not a dim expr, `div` is the best we have.
636  if (!dimExpr)
637  return div;
638 
639  // We simply exploit information from loop IVs.
640  // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
641  // desired simplifications are expected to be part of other
642  // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
643  // LoopAnalysis library.
644  Value operand = operands[dimExpr.getPosition()];
645  int64_t operandDivisor = 1;
646  // TODO: With the right accessors, this can be extended to
647  // LoopLikeOpInterface.
648  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
649  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
650  operandDivisor = forOp.getStepAsInt();
651  } else {
652  uint64_t lbLargestKnownDivisor =
653  forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
654  operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
655  }
656  }
657  return operandDivisor;
658 }
659 
660 /// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
661 /// being an affine dim expression or a constant.
663  int64_t k) {
664  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
665  int64_t constVal = constExpr.getValue();
666  return constVal >= 0 && constVal < k;
667  }
668  auto dimExpr = dyn_cast<AffineDimExpr>(e);
669  if (!dimExpr)
670  return false;
671  Value operand = operands[dimExpr.getPosition()];
672  // TODO: With the right accessors, this can be extended to
673  // LoopLikeOpInterface.
674  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
675  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
676  forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
677  return true;
678  }
679  }
680 
681  // We don't consider other cases like `operand` being defined by a constant or
682  // an affine.apply op since such cases will already be handled by other
683  // patterns and propagation of loop IVs or constant would happen.
684  return false;
685 }
686 
687 /// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
688 /// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
689 /// expression is in that form.
690 static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
691  AffineExpr &quotientTimesDiv, AffineExpr &rem) {
692  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
693  if (!bin || bin.getKind() != AffineExprKind::Add)
694  return false;
695 
696  AffineExpr llhs = bin.getLHS();
697  AffineExpr rlhs = bin.getRHS();
698  div = getLargestKnownDivisor(llhs, operands);
699  if (isNonNegativeBoundedBy(rlhs, operands, div)) {
700  quotientTimesDiv = llhs;
701  rem = rlhs;
702  return true;
703  }
704  div = getLargestKnownDivisor(rlhs, operands);
705  if (isNonNegativeBoundedBy(llhs, operands, div)) {
706  quotientTimesDiv = rlhs;
707  rem = llhs;
708  return true;
709  }
710  return false;
711 }
712 
713 /// Gets the constant lower bound on an `iv`.
714 static std::optional<int64_t> getLowerBound(Value iv) {
715  AffineForOp forOp = getForInductionVarOwner(iv);
716  if (forOp && forOp.hasConstantLowerBound())
717  return forOp.getConstantLowerBound();
718  return std::nullopt;
719 }
720 
721 /// Gets the constant upper bound on an affine.for `iv`.
722 static std::optional<int64_t> getUpperBound(Value iv) {
723  AffineForOp forOp = getForInductionVarOwner(iv);
724  if (!forOp || !forOp.hasConstantUpperBound())
725  return std::nullopt;
726 
727  // If its lower bound is also known, we can get a more precise bound
728  // whenever the step is not one.
729  if (forOp.hasConstantLowerBound()) {
730  return forOp.getConstantUpperBound() - 1 -
731  (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
732  forOp.getStepAsInt();
733  }
734  return forOp.getConstantUpperBound() - 1;
735 }
736 
737 /// Determine a constant upper bound for `expr` if one exists while exploiting
738 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
739 /// is guaranteed to be less than or equal to it.
740 static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
741  unsigned numSymbols,
742  ArrayRef<Value> operands) {
743  // Get the constant lower or upper bounds on the operands.
744  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
745  constLowerBounds.reserve(operands.size());
746  constUpperBounds.reserve(operands.size());
747  for (Value operand : operands) {
748  constLowerBounds.push_back(getLowerBound(operand));
749  constUpperBounds.push_back(getUpperBound(operand));
750  }
751 
752  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
753  return constExpr.getValue();
754 
755  return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
756  constUpperBounds,
757  /*isUpper=*/true);
758 }
759 
760 /// Determine a constant lower bound for `expr` if one exists while exploiting
761 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
762 /// is guaranteed to be less than or equal to it.
763 static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
764  unsigned numSymbols,
765  ArrayRef<Value> operands) {
766  // Get the constant lower or upper bounds on the operands.
767  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
768  constLowerBounds.reserve(operands.size());
769  constUpperBounds.reserve(operands.size());
770  for (Value operand : operands) {
771  constLowerBounds.push_back(getLowerBound(operand));
772  constUpperBounds.push_back(getUpperBound(operand));
773  }
774 
775  std::optional<int64_t> lowerBound;
776  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
777  lowerBound = constExpr.getValue();
778  } else {
779  lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
780  constLowerBounds, constUpperBounds,
781  /*isUpper=*/false);
782  }
783  return lowerBound;
784 }
785 
786 /// Simplify `expr` while exploiting information from the values in `operands`.
787 static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
788  unsigned numSymbols,
789  ArrayRef<Value> operands) {
790  // We do this only for certain floordiv/mod expressions.
791  auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
792  if (!binExpr)
793  return;
794 
795  // Simplify the child expressions first.
796  AffineExpr lhs = binExpr.getLHS();
797  AffineExpr rhs = binExpr.getRHS();
798  simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
799  simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
800  expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
801 
802  binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
803  if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
804  expr.getKind() != AffineExprKind::CeilDiv &&
805  expr.getKind() != AffineExprKind::Mod)) {
806  return;
807  }
808 
809  // The `lhs` and `rhs` may be different post construction of simplified expr.
810  lhs = binExpr.getLHS();
811  rhs = binExpr.getRHS();
812  auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
813  if (!rhsConst)
814  return;
815 
816  int64_t rhsConstVal = rhsConst.getValue();
817  // Undefined exprsessions aren't touched; IR can still be valid with them.
818  if (rhsConstVal <= 0)
819  return;
820 
821  // Exploit constant lower/upper bounds to simplify a floordiv or mod.
822  MLIRContext *context = expr.getContext();
823  std::optional<int64_t> lhsLbConst =
824  getLowerBound(lhs, numDims, numSymbols, operands);
825  std::optional<int64_t> lhsUbConst =
826  getUpperBound(lhs, numDims, numSymbols, operands);
827  if (lhsLbConst && lhsUbConst) {
828  int64_t lhsLbConstVal = *lhsLbConst;
829  int64_t lhsUbConstVal = *lhsUbConst;
830  // lhs floordiv c is a single value lhs is bounded in a range `c` that has
831  // the same quotient.
832  if (binExpr.getKind() == AffineExprKind::FloorDiv &&
833  divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
834  divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
835  expr = getAffineConstantExpr(
836  divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
837  return;
838  }
839  // lhs ceildiv c is a single value if the entire range has the same ceil
840  // quotient.
841  if (binExpr.getKind() == AffineExprKind::CeilDiv &&
842  divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
843  divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
844  expr = getAffineConstantExpr(divideCeilSigned(lhsLbConstVal, rhsConstVal),
845  context);
846  return;
847  }
848  // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
849  if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
850  lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
851  expr = lhs;
852  return;
853  }
854  }
855 
856  // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
857  // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
858  // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
859  // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
860  AffineExpr quotientTimesDiv, rem;
861  int64_t divisor;
862  if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
863  if (rhsConstVal % divisor == 0 &&
864  binExpr.getKind() == AffineExprKind::FloorDiv) {
865  expr = quotientTimesDiv.floorDiv(rhsConst);
866  } else if (divisor % rhsConstVal == 0 &&
867  binExpr.getKind() == AffineExprKind::Mod) {
868  expr = rem % rhsConst;
869  }
870  return;
871  }
872 
873  // Handle the simple case when the LHS expression can be either upper
874  // bounded or is a known multiple of RHS constant.
875  // lhs floordiv c -> 0 if 0 <= lhs < c,
876  // lhs mod c -> 0 if lhs % c = 0.
877  if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) &&
878  binExpr.getKind() == AffineExprKind::FloorDiv) ||
879  (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 &&
880  binExpr.getKind() == AffineExprKind::Mod)) {
881  expr = getAffineConstantExpr(0, expr.getContext());
882  }
883 }
884 
885 /// Simplify the expressions in `map` while making use of lower or upper bounds
886 /// of its operands. If `isMax` is true, the map is to be treated as a max of
887 /// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
888 /// d1) can be simplified to (8) if the operands are respectively lower bounded
889 /// by 2 and 0 (the second expression can't be lower than 8).
891  ArrayRef<Value> operands,
892  bool isMax) {
893  // Can't simplify.
894  if (operands.empty())
895  return;
896 
897  // Get the upper or lower bound on an affine.for op IV using its range.
898  // Get the constant lower or upper bounds on the operands.
899  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
900  constLowerBounds.reserve(operands.size());
901  constUpperBounds.reserve(operands.size());
902  for (Value operand : operands) {
903  constLowerBounds.push_back(getLowerBound(operand));
904  constUpperBounds.push_back(getUpperBound(operand));
905  }
906 
907  // We will compute the lower and upper bounds on each of the expressions
908  // Then, we will check (depending on max or min) as to whether a specific
909  // bound is redundant by checking if its highest (in case of max) and its
910  // lowest (in the case of min) value is already lower than (or higher than)
911  // the lower bound (or upper bound in the case of min) of another bound.
912  SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
913  lowerBounds.reserve(map.getNumResults());
914  upperBounds.reserve(map.getNumResults());
915  for (AffineExpr e : map.getResults()) {
916  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
917  lowerBounds.push_back(constExpr.getValue());
918  upperBounds.push_back(constExpr.getValue());
919  } else {
920  lowerBounds.push_back(
922  constLowerBounds, constUpperBounds,
923  /*isUpper=*/false));
924  upperBounds.push_back(
926  constLowerBounds, constUpperBounds,
927  /*isUpper=*/true));
928  }
929  }
930 
931  // Collect expressions that are not redundant.
932  SmallVector<AffineExpr, 4> irredundantExprs;
933  for (auto exprEn : llvm::enumerate(map.getResults())) {
934  AffineExpr e = exprEn.value();
935  unsigned i = exprEn.index();
936  // Some expressions can be turned into constants.
937  if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
938  e = getAffineConstantExpr(*lowerBounds[i], e.getContext());
939 
940  // Check if the expression is redundant.
941  if (isMax) {
942  if (!upperBounds[i]) {
943  irredundantExprs.push_back(e);
944  continue;
945  }
946  // If there exists another expression such that its lower bound is greater
947  // than this expression's upper bound, it's redundant.
948  if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
949  auto otherLowerBound = en.value();
950  unsigned pos = en.index();
951  if (pos == i || !otherLowerBound)
952  return false;
953  if (*otherLowerBound > *upperBounds[i])
954  return true;
955  if (*otherLowerBound < *upperBounds[i])
956  return false;
957  // Equality case. When both expressions are considered redundant, we
958  // don't want to get both of them. We keep the one that appears
959  // first.
960  if (upperBounds[pos] && lowerBounds[i] &&
961  lowerBounds[i] == upperBounds[i] &&
962  otherLowerBound == *upperBounds[pos] && i < pos)
963  return false;
964  return true;
965  }))
966  irredundantExprs.push_back(e);
967  } else {
968  if (!lowerBounds[i]) {
969  irredundantExprs.push_back(e);
970  continue;
971  }
972  // Likewise for the `min` case. Use the complement of the condition above.
973  if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
974  auto otherUpperBound = en.value();
975  unsigned pos = en.index();
976  if (pos == i || !otherUpperBound)
977  return false;
978  if (*otherUpperBound < *lowerBounds[i])
979  return true;
980  if (*otherUpperBound > *lowerBounds[i])
981  return false;
982  if (lowerBounds[pos] && upperBounds[i] &&
983  lowerBounds[i] == upperBounds[i] &&
984  otherUpperBound == lowerBounds[pos] && i < pos)
985  return false;
986  return true;
987  }))
988  irredundantExprs.push_back(e);
989  }
990  }
991 
992  // Create the map without the redundant expressions.
993  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs,
994  map.getContext());
995 }
996 
997 /// Simplify the map while exploiting information on the values in `operands`.
998 // Use "unused attribute" marker to silence warning stemming from the inability
999 // to see through the template expansion.
1000 static void LLVM_ATTRIBUTE_UNUSED
1002  assert(map.getNumInputs() == operands.size() && "invalid operands for map");
1003  SmallVector<AffineExpr> newResults;
1004  newResults.reserve(map.getNumResults());
1005  for (AffineExpr expr : map.getResults()) {
1007  operands);
1008  newResults.push_back(expr);
1009  }
1010  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
1011  map.getContext());
1012 }
1013 
1014 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
1015 /// defining AffineApplyOp expression and operands.
1016 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
1017 /// When `dimOrSymbolPosition >= dims.size()`,
1018 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
1019 /// Mutate `map`,`dims` and `syms` in place as follows:
1020 /// 1. `dims` and `syms` are only appended to.
1021 /// 2. `map` dim and symbols are gradually shifted to higher positions.
1022 /// 3. Old `dim` and `sym` entries are replaced by nullptr
1023 /// This avoids the need for any bookkeeping.
1024 static LogicalResult replaceDimOrSym(AffineMap *map,
1025  unsigned dimOrSymbolPosition,
1026  SmallVectorImpl<Value> &dims,
1027  SmallVectorImpl<Value> &syms) {
1028  MLIRContext *ctx = map->getContext();
1029  bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1030  unsigned pos = isDimReplacement ? dimOrSymbolPosition
1031  : dimOrSymbolPosition - dims.size();
1032  Value &v = isDimReplacement ? dims[pos] : syms[pos];
1033  if (!v)
1034  return failure();
1035 
1036  auto affineApply = v.getDefiningOp<AffineApplyOp>();
1037  if (!affineApply)
1038  return failure();
1039 
1040  // At this point we will perform a replacement of `v`, set the entry in `dim`
1041  // or `sym` to nullptr immediately.
1042  v = nullptr;
1043 
1044  // Compute the map, dims and symbols coming from the AffineApplyOp.
1045  AffineMap composeMap = affineApply.getAffineMap();
1046  assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1047  SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1048  affineApply.getMapOperands().end());
1049  // Canonicalize the map to promote dims to symbols when possible. This is to
1050  // avoid generating invalid maps.
1051  canonicalizeMapAndOperands(&composeMap, &composeOperands);
1052  AffineExpr replacementExpr =
1053  composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
1054  ValueRange composeDims =
1055  ArrayRef<Value>(composeOperands).take_front(composeMap.getNumDims());
1056  ValueRange composeSyms =
1057  ArrayRef<Value>(composeOperands).take_back(composeMap.getNumSymbols());
1058  AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
1059  : getAffineSymbolExpr(pos, ctx);
1060 
1061  // Append the dims and symbols where relevant and perform the replacement.
1062  dims.append(composeDims.begin(), composeDims.end());
1063  syms.append(composeSyms.begin(), composeSyms.end());
1064  *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());
1065 
1066  return success();
1067 }
1068 
1069 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1070 /// iteratively. Perform canonicalization of map and operands as well as
1071 /// AffineMap simplification. `map` and `operands` are mutated in place.
1073  SmallVectorImpl<Value> *operands) {
1074  if (map->getNumResults() == 0) {
1075  canonicalizeMapAndOperands(map, operands);
1076  *map = simplifyAffineMap(*map);
1077  return;
1078  }
1079 
1080  MLIRContext *ctx = map->getContext();
1081  SmallVector<Value, 4> dims(operands->begin(),
1082  operands->begin() + map->getNumDims());
1083  SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1084  operands->end());
1085 
1086  // Iterate over dims and symbols coming from AffineApplyOp and replace until
1087  // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1088  // and `syms` can only increase by construction.
1089  // The implementation uses a `while` loop to support the case of symbols
1090  // that may be constructed from dims ;this may be overkill.
1091  while (true) {
1092  bool changed = false;
1093  for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1094  if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
1095  break;
1096  if (!changed)
1097  break;
1098  }
1099 
1100  // Clear operands so we can fill them anew.
1101  operands->clear();
1102 
1103  // At this point we may have introduced null operands, prune them out before
1104  // canonicalizing map and operands.
1105  unsigned nDims = 0, nSyms = 0;
1106  SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1107  dimReplacements.reserve(dims.size());
1108  symReplacements.reserve(syms.size());
1109  for (auto *container : {&dims, &syms}) {
1110  bool isDim = (container == &dims);
1111  auto &repls = isDim ? dimReplacements : symReplacements;
1112  for (const auto &en : llvm::enumerate(*container)) {
1113  Value v = en.value();
1114  if (!v) {
1115  assert(isDim ? !map->isFunctionOfDim(en.index())
1116  : !map->isFunctionOfSymbol(en.index()) &&
1117  "map is function of unexpected expr@pos");
1118  repls.push_back(getAffineConstantExpr(0, ctx));
1119  continue;
1120  }
1121  repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
1122  : getAffineSymbolExpr(nSyms++, ctx));
1123  operands->push_back(v);
1124  }
1125  }
1126  *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
1127  nSyms);
1128 
1129  // Canonicalize and simplify before returning.
1130  canonicalizeMapAndOperands(map, operands);
1131  *map = simplifyAffineMap(*map);
1132 }
1133 
1135  AffineMap *map, SmallVectorImpl<Value> *operands) {
1136  while (llvm::any_of(*operands, [](Value v) {
1137  return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
1138  })) {
1139  composeAffineMapAndOperands(map, operands);
1140  }
1141 }
1142 
1143 AffineApplyOp
1145  ArrayRef<OpFoldResult> operands) {
1146  SmallVector<Value> valueOperands;
1147  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1148  composeAffineMapAndOperands(&map, &valueOperands);
1149  assert(map);
1150  return b.create<AffineApplyOp>(loc, map, valueOperands);
1151 }
1152 
1153 AffineApplyOp
1155  ArrayRef<OpFoldResult> operands) {
1156  return makeComposedAffineApply(
1157  b, loc,
1159  .front(),
1160  operands);
1161 }
1162 
1163 /// Composes the given affine map with the given list of operands, pulling in
1164 /// the maps from any affine.apply operations that supply the operands.
1166  SmallVectorImpl<Value> &operands) {
1167  // Compose and canonicalize each expression in the map individually because
1168  // composition only applies to single-result maps, collecting potentially
1169  // duplicate operands in a single list with shifted dimensions and symbols.
1170  SmallVector<Value> dims, symbols;
1172  for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
1173  SmallVector<Value> submapOperands(operands.begin(), operands.end());
1174  AffineMap submap = map.getSubMap({i});
1175  fullyComposeAffineMapAndOperands(&submap, &submapOperands);
1176  canonicalizeMapAndOperands(&submap, &submapOperands);
1177  unsigned numNewDims = submap.getNumDims();
1178  submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
1179  llvm::append_range(dims,
1180  ArrayRef<Value>(submapOperands).take_front(numNewDims));
1181  llvm::append_range(symbols,
1182  ArrayRef<Value>(submapOperands).drop_front(numNewDims));
1183  exprs.push_back(submap.getResult(0));
1184  }
1185 
1186  // Canonicalize the map created from composed expressions to deduplicate the
1187  // dimension and symbol operands.
1188  operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1189  map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
1190  canonicalizeMapAndOperands(&map, &operands);
1191 }
1192 
1195  AffineMap map,
1196  ArrayRef<OpFoldResult> operands) {
1197  assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1198 
1199  // Create new builder without a listener, so that no notification is
1200  // triggered if the op is folded.
1201  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1202  // workaround is no longer needed.
1203  OpBuilder newBuilder(b.getContext());
1205 
1206  // Create op.
1207  AffineApplyOp applyOp =
1208  makeComposedAffineApply(newBuilder, loc, map, operands);
1209 
1210  // Get constant operands.
1211  SmallVector<Attribute> constOperands(applyOp->getNumOperands());
1212  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1213  matchPattern(applyOp->getOperand(i), m_Constant(&constOperands[i]));
1214 
1215  // Try to fold the operation.
1216  SmallVector<OpFoldResult> foldResults;
1217  if (failed(applyOp->fold(constOperands, foldResults)) ||
1218  foldResults.empty()) {
1219  if (OpBuilder::Listener *listener = b.getListener())
1220  listener->notifyOperationInserted(applyOp, /*previous=*/{});
1221  return applyOp.getResult();
1222  }
1223 
1224  applyOp->erase();
1225  assert(foldResults.size() == 1 && "expected 1 folded result");
1226  return foldResults.front();
1227 }
1228 
1231  AffineExpr expr,
1232  ArrayRef<OpFoldResult> operands) {
1234  b, loc,
1236  .front(),
1237  operands);
1238 }
1239 
1242  OpBuilder &b, Location loc, AffineMap map,
1243  ArrayRef<OpFoldResult> operands) {
1244  return llvm::map_to_vector(llvm::seq<unsigned>(0, map.getNumResults()),
1245  [&](unsigned i) {
1246  return makeComposedFoldedAffineApply(
1247  b, loc, map.getSubMap({i}), operands);
1248  });
1249 }
1250 
1251 template <typename OpTy>
1253  ArrayRef<OpFoldResult> operands) {
1254  SmallVector<Value> valueOperands;
1255  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1256  composeMultiResultAffineMap(map, valueOperands);
1257  return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands);
1258 }
1259 
1260 AffineMinOp
1262  ArrayRef<OpFoldResult> operands) {
1263  return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1264 }
1265 
1266 template <typename OpTy>
1268  AffineMap map,
1269  ArrayRef<OpFoldResult> operands) {
1270  // Create new builder without a listener, so that no notification is
1271  // triggered if the op is folded.
1272  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1273  // workaround is no longer needed.
1274  OpBuilder newBuilder(b.getContext());
1276 
1277  // Create op.
1278  auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1279 
1280  // Get constant operands.
1281  SmallVector<Attribute> constOperands(minMaxOp->getNumOperands());
1282  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1283  matchPattern(minMaxOp->getOperand(i), m_Constant(&constOperands[i]));
1284 
1285  // Try to fold the operation.
1286  SmallVector<OpFoldResult> foldResults;
1287  if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1288  foldResults.empty()) {
1289  if (OpBuilder::Listener *listener = b.getListener())
1290  listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
1291  return minMaxOp.getResult();
1292  }
1293 
1294  minMaxOp->erase();
1295  assert(foldResults.size() == 1 && "expected 1 folded result");
1296  return foldResults.front();
1297 }
1298 
1301  AffineMap map,
1302  ArrayRef<OpFoldResult> operands) {
1303  return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1304 }
1305 
1308  AffineMap map,
1309  ArrayRef<OpFoldResult> operands) {
1310  return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1311 }
1312 
1313 // A symbol may appear as a dim in affine.apply operations. This function
1314 // canonicalizes dims that are valid symbols into actual symbols.
1315 template <class MapOrSet>
1316 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1317  SmallVectorImpl<Value> *operands) {
1318  if (!mapOrSet || operands->empty())
1319  return;
1320 
1321  assert(mapOrSet->getNumInputs() == operands->size() &&
1322  "map/set inputs must match number of operands");
1323 
1324  auto *context = mapOrSet->getContext();
1325  SmallVector<Value, 8> resultOperands;
1326  resultOperands.reserve(operands->size());
1327  SmallVector<Value, 8> remappedSymbols;
1328  remappedSymbols.reserve(operands->size());
1329  unsigned nextDim = 0;
1330  unsigned nextSym = 0;
1331  unsigned oldNumSyms = mapOrSet->getNumSymbols();
1332  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1333  for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1334  if (i < mapOrSet->getNumDims()) {
1335  if (isValidSymbol((*operands)[i])) {
1336  // This is a valid symbol that appears as a dim, canonicalize it.
1337  dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1338  remappedSymbols.push_back((*operands)[i]);
1339  } else {
1340  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1341  resultOperands.push_back((*operands)[i]);
1342  }
1343  } else {
1344  resultOperands.push_back((*operands)[i]);
1345  }
1346  }
1347 
1348  resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1349  *operands = resultOperands;
1350  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1351  oldNumSyms + nextSym);
1352 
1353  assert(mapOrSet->getNumInputs() == operands->size() &&
1354  "map/set inputs must match number of operands");
1355 }
1356 
1357 // Works for either an affine map or an integer set.
1358 template <class MapOrSet>
1359 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1360  SmallVectorImpl<Value> *operands) {
1361  static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1362  "Argument must be either of AffineMap or IntegerSet type");
1363 
1364  if (!mapOrSet || operands->empty())
1365  return;
1366 
1367  assert(mapOrSet->getNumInputs() == operands->size() &&
1368  "map/set inputs must match number of operands");
1369 
1370  canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1371 
1372  // Check to see what dims are used.
1373  llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1374  llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1375  mapOrSet->walkExprs([&](AffineExpr expr) {
1376  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1377  usedDims[dimExpr.getPosition()] = true;
1378  else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1379  usedSyms[symExpr.getPosition()] = true;
1380  });
1381 
1382  auto *context = mapOrSet->getContext();
1383 
1384  SmallVector<Value, 8> resultOperands;
1385  resultOperands.reserve(operands->size());
1386 
1387  llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1388  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1389  unsigned nextDim = 0;
1390  for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1391  if (usedDims[i]) {
1392  // Remap dim positions for duplicate operands.
1393  auto it = seenDims.find((*operands)[i]);
1394  if (it == seenDims.end()) {
1395  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1396  resultOperands.push_back((*operands)[i]);
1397  seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1398  } else {
1399  dimRemapping[i] = it->second;
1400  }
1401  }
1402  }
1403  llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1404  SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1405  unsigned nextSym = 0;
1406  for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1407  if (!usedSyms[i])
1408  continue;
1409  // Handle constant operands (only needed for symbolic operands since
1410  // constant operands in dimensional positions would have already been
1411  // promoted to symbolic positions above).
1412  IntegerAttr operandCst;
1413  if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1414  m_Constant(&operandCst))) {
1415  symRemapping[i] =
1416  getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1417  continue;
1418  }
1419  // Remap symbol positions for duplicate operands.
1420  auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1421  if (it == seenSymbols.end()) {
1422  symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1423  resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1424  seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1425  symRemapping[i]));
1426  } else {
1427  symRemapping[i] = it->second;
1428  }
1429  }
1430  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1431  nextDim, nextSym);
1432  *operands = resultOperands;
1433 }
1434 
1436  AffineMap *map, SmallVectorImpl<Value> *operands) {
1437  canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1438 }
1439 
1441  IntegerSet *set, SmallVectorImpl<Value> *operands) {
1442  canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1443 }
1444 
1445 namespace {
1446 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1447 /// maps that supply results into them.
1448 ///
1449 template <typename AffineOpTy>
1450 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1452 
1453  /// Replace the affine op with another instance of it with the supplied
1454  /// map and mapOperands.
1455  void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1456  AffineMap map, ArrayRef<Value> mapOperands) const;
1457 
1458  LogicalResult matchAndRewrite(AffineOpTy affineOp,
1459  PatternRewriter &rewriter) const override {
1460  static_assert(
1461  llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1462  AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1463  AffineVectorStoreOp, AffineVectorLoadOp>::value,
1464  "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1465  "expected");
1466  auto map = affineOp.getAffineMap();
1467  AffineMap oldMap = map;
1468  auto oldOperands = affineOp.getMapOperands();
1469  SmallVector<Value, 8> resultOperands(oldOperands);
1470  composeAffineMapAndOperands(&map, &resultOperands);
1471  canonicalizeMapAndOperands(&map, &resultOperands);
1472  simplifyMapWithOperands(map, resultOperands);
1473  if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1474  resultOperands.begin()))
1475  return failure();
1476 
1477  replaceAffineOp(rewriter, affineOp, map, resultOperands);
1478  return success();
1479  }
1480 };
1481 
1482 // Specialize the template to account for the different build signatures for
1483 // affine load, store, and apply ops.
1484 template <>
1485 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1486  PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1487  ArrayRef<Value> mapOperands) const {
1488  rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1489  mapOperands);
1490 }
1491 template <>
1492 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1493  PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1494  ArrayRef<Value> mapOperands) const {
1495  rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1496  prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1497  prefetch.getLocalityHint(), prefetch.getIsDataCache());
1498 }
1499 template <>
1500 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1501  PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1502  ArrayRef<Value> mapOperands) const {
1503  rewriter.replaceOpWithNewOp<AffineStoreOp>(
1504  store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1505 }
1506 template <>
1507 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1508  PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1509  ArrayRef<Value> mapOperands) const {
1510  rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1511  vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1512  mapOperands);
1513 }
1514 template <>
1515 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1516  PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1517  ArrayRef<Value> mapOperands) const {
1518  rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1519  vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1520  mapOperands);
1521 }
1522 
1523 // Generic version for ops that don't have extra operands.
1524 template <typename AffineOpTy>
1525 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1526  PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1527  ArrayRef<Value> mapOperands) const {
1528  rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1529 }
1530 } // namespace
1531 
1532 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1533  MLIRContext *context) {
1534  results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1535 }
1536 
1537 //===----------------------------------------------------------------------===//
1538 // AffineDmaStartOp
1539 //===----------------------------------------------------------------------===//
1540 
1541 // TODO: Check that map operands are loop IVs or symbols.
1542 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1543  Value srcMemRef, AffineMap srcMap,
1544  ValueRange srcIndices, Value destMemRef,
1545  AffineMap dstMap, ValueRange destIndices,
1546  Value tagMemRef, AffineMap tagMap,
1547  ValueRange tagIndices, Value numElements,
1548  Value stride, Value elementsPerStride) {
1549  result.addOperands(srcMemRef);
1550  result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1551  result.addOperands(srcIndices);
1552  result.addOperands(destMemRef);
1553  result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1554  result.addOperands(destIndices);
1555  result.addOperands(tagMemRef);
1556  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1557  result.addOperands(tagIndices);
1558  result.addOperands(numElements);
1559  if (stride) {
1560  result.addOperands({stride, elementsPerStride});
1561  }
1562 }
1563 
1565  p << " " << getSrcMemRef() << '[';
1566  p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1567  p << "], " << getDstMemRef() << '[';
1568  p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1569  p << "], " << getTagMemRef() << '[';
1570  p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1571  p << "], " << getNumElements();
1572  if (isStrided()) {
1573  p << ", " << getStride();
1574  p << ", " << getNumElementsPerStride();
1575  }
1576  p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1577  << getTagMemRefType();
1578 }
1579 
1580 // Parse AffineDmaStartOp.
1581 // Ex:
1582 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1583 // %stride, %num_elt_per_stride
1584 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1585 //
1587  OperationState &result) {
1588  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1589  AffineMapAttr srcMapAttr;
1591  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1592  AffineMapAttr dstMapAttr;
1594  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1595  AffineMapAttr tagMapAttr;
1597  OpAsmParser::UnresolvedOperand numElementsInfo;
1599 
1600  SmallVector<Type, 3> types;
1601  auto indexType = parser.getBuilder().getIndexType();
1602 
1603  // Parse and resolve the following list of operands:
1604  // *) dst memref followed by its affine maps operands (in square brackets).
1605  // *) src memref followed by its affine map operands (in square brackets).
1606  // *) tag memref followed by its affine map operands (in square brackets).
1607  // *) number of elements transferred by DMA operation.
1608  if (parser.parseOperand(srcMemRefInfo) ||
1609  parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1610  getSrcMapAttrStrName(),
1611  result.attributes) ||
1612  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1613  parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1614  getDstMapAttrStrName(),
1615  result.attributes) ||
1616  parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1617  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1618  getTagMapAttrStrName(),
1619  result.attributes) ||
1620  parser.parseComma() || parser.parseOperand(numElementsInfo))
1621  return failure();
1622 
1623  // Parse optional stride and elements per stride.
1624  if (parser.parseTrailingOperandList(strideInfo))
1625  return failure();
1626 
1627  if (!strideInfo.empty() && strideInfo.size() != 2) {
1628  return parser.emitError(parser.getNameLoc(),
1629  "expected two stride related operands");
1630  }
1631  bool isStrided = strideInfo.size() == 2;
1632 
1633  if (parser.parseColonTypeList(types))
1634  return failure();
1635 
1636  if (types.size() != 3)
1637  return parser.emitError(parser.getNameLoc(), "expected three types");
1638 
1639  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1640  parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1641  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1642  parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1643  parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1644  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1645  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1646  return failure();
1647 
1648  if (isStrided) {
1649  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1650  return failure();
1651  }
1652 
1653  // Check that src/dst/tag operand counts match their map.numInputs.
1654  if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1655  dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1656  tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1657  return parser.emitError(parser.getNameLoc(),
1658  "memref operand count not equal to map.numInputs");
1659  return success();
1660 }
1661 
1662 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1663  if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
1664  return emitOpError("expected DMA source to be of memref type");
1665  if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
1666  return emitOpError("expected DMA destination to be of memref type");
1667  if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
1668  return emitOpError("expected DMA tag to be of memref type");
1669 
1670  unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1671  getDstMap().getNumInputs() +
1672  getTagMap().getNumInputs();
1673  if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1674  getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1675  return emitOpError("incorrect number of operands");
1676  }
1677 
1678  Region *scope = getAffineScope(*this);
1679  for (auto idx : getSrcIndices()) {
1680  if (!idx.getType().isIndex())
1681  return emitOpError("src index to dma_start must have 'index' type");
1682  if (!isValidAffineIndexOperand(idx, scope))
1683  return emitOpError(
1684  "src index must be a valid dimension or symbol identifier");
1685  }
1686  for (auto idx : getDstIndices()) {
1687  if (!idx.getType().isIndex())
1688  return emitOpError("dst index to dma_start must have 'index' type");
1689  if (!isValidAffineIndexOperand(idx, scope))
1690  return emitOpError(
1691  "dst index must be a valid dimension or symbol identifier");
1692  }
1693  for (auto idx : getTagIndices()) {
1694  if (!idx.getType().isIndex())
1695  return emitOpError("tag index to dma_start must have 'index' type");
1696  if (!isValidAffineIndexOperand(idx, scope))
1697  return emitOpError(
1698  "tag index must be a valid dimension or symbol identifier");
1699  }
1700  return success();
1701 }
1702 
1703 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1704  SmallVectorImpl<OpFoldResult> &results) {
1705  /// dma_start(memrefcast) -> dma_start
1706  return memref::foldMemRefCast(*this);
1707 }
1708 
1709 void AffineDmaStartOp::getEffects(
1711  &effects) {
1712  effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
1714  effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
1716  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
1718 }
1719 
1720 //===----------------------------------------------------------------------===//
1721 // AffineDmaWaitOp
1722 //===----------------------------------------------------------------------===//
1723 
1724 // TODO: Check that map operands are loop IVs or symbols.
1725 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1726  Value tagMemRef, AffineMap tagMap,
1727  ValueRange tagIndices, Value numElements) {
1728  result.addOperands(tagMemRef);
1729  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1730  result.addOperands(tagIndices);
1731  result.addOperands(numElements);
1732 }
1733 
1735  p << " " << getTagMemRef() << '[';
1736  SmallVector<Value, 2> operands(getTagIndices());
1737  p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1738  p << "], ";
1740  p << " : " << getTagMemRef().getType();
1741 }
1742 
1743 // Parse AffineDmaWaitOp.
1744 // Eg:
1745 // affine.dma_wait %tag[%index], %num_elements
1746 // : memref<1 x i32, (d0) -> (d0), 4>
1747 //
1749  OperationState &result) {
1750  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1751  AffineMapAttr tagMapAttr;
1753  Type type;
1754  auto indexType = parser.getBuilder().getIndexType();
1755  OpAsmParser::UnresolvedOperand numElementsInfo;
1756 
1757  // Parse tag memref, its map operands, and dma size.
1758  if (parser.parseOperand(tagMemRefInfo) ||
1759  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1760  getTagMapAttrStrName(),
1761  result.attributes) ||
1762  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1763  parser.parseColonType(type) ||
1764  parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1765  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1766  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1767  return failure();
1768 
1769  if (!llvm::isa<MemRefType>(type))
1770  return parser.emitError(parser.getNameLoc(),
1771  "expected tag to be of memref type");
1772 
1773  if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1774  return parser.emitError(parser.getNameLoc(),
1775  "tag memref operand count != to map.numInputs");
1776  return success();
1777 }
1778 
1779 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1780  if (!llvm::isa<MemRefType>(getOperand(0).getType()))
1781  return emitOpError("expected DMA tag to be of memref type");
1782  Region *scope = getAffineScope(*this);
1783  for (auto idx : getTagIndices()) {
1784  if (!idx.getType().isIndex())
1785  return emitOpError("index to dma_wait must have 'index' type");
1786  if (!isValidAffineIndexOperand(idx, scope))
1787  return emitOpError(
1788  "index must be a valid dimension or symbol identifier");
1789  }
1790  return success();
1791 }
1792 
1793 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1794  SmallVectorImpl<OpFoldResult> &results) {
1795  /// dma_wait(memrefcast) -> dma_wait
1796  return memref::foldMemRefCast(*this);
1797 }
1798 
1799 void AffineDmaWaitOp::getEffects(
1801  &effects) {
1802  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
1804 }
1805 
1806 //===----------------------------------------------------------------------===//
1807 // AffineForOp
1808 //===----------------------------------------------------------------------===//
1809 
1810 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1811 /// bodyBuilder are empty/null, we include default terminator op.
1812 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1813  ValueRange lbOperands, AffineMap lbMap,
1814  ValueRange ubOperands, AffineMap ubMap, int64_t step,
1815  ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1816  assert(((!lbMap && lbOperands.empty()) ||
1817  lbOperands.size() == lbMap.getNumInputs()) &&
1818  "lower bound operand count does not match the affine map");
1819  assert(((!ubMap && ubOperands.empty()) ||
1820  ubOperands.size() == ubMap.getNumInputs()) &&
1821  "upper bound operand count does not match the affine map");
1822  assert(step > 0 && "step has to be a positive integer constant");
1823 
1824  OpBuilder::InsertionGuard guard(builder);
1825 
1826  // Set variadic segment sizes.
1827  result.addAttribute(
1828  getOperandSegmentSizeAttr(),
1829  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()),
1830  static_cast<int32_t>(ubOperands.size()),
1831  static_cast<int32_t>(iterArgs.size())}));
1832 
1833  for (Value val : iterArgs)
1834  result.addTypes(val.getType());
1835 
1836  // Add an attribute for the step.
1837  result.addAttribute(getStepAttrName(result.name),
1838  builder.getIntegerAttr(builder.getIndexType(), step));
1839 
1840  // Add the lower bound.
1841  result.addAttribute(getLowerBoundMapAttrName(result.name),
1842  AffineMapAttr::get(lbMap));
1843  result.addOperands(lbOperands);
1844 
1845  // Add the upper bound.
1846  result.addAttribute(getUpperBoundMapAttrName(result.name),
1847  AffineMapAttr::get(ubMap));
1848  result.addOperands(ubOperands);
1849 
1850  result.addOperands(iterArgs);
1851  // Create a region and a block for the body. The argument of the region is
1852  // the loop induction variable.
1853  Region *bodyRegion = result.addRegion();
1854  Block *bodyBlock = builder.createBlock(bodyRegion);
1855  Value inductionVar =
1856  bodyBlock->addArgument(builder.getIndexType(), result.location);
1857  for (Value val : iterArgs)
1858  bodyBlock->addArgument(val.getType(), val.getLoc());
1859 
1860  // Create the default terminator if the builder is not provided and if the
1861  // iteration arguments are not provided. Otherwise, leave this to the caller
1862  // because we don't know which values to return from the loop.
1863  if (iterArgs.empty() && !bodyBuilder) {
1864  ensureTerminator(*bodyRegion, builder, result.location);
1865  } else if (bodyBuilder) {
1866  OpBuilder::InsertionGuard guard(builder);
1867  builder.setInsertionPointToStart(bodyBlock);
1868  bodyBuilder(builder, result.location, inductionVar,
1869  bodyBlock->getArguments().drop_front());
1870  }
1871 }
1872 
1873 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1874  int64_t ub, int64_t step, ValueRange iterArgs,
1875  BodyBuilderFn bodyBuilder) {
1876  auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1877  auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1878  return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1879  bodyBuilder);
1880 }
1881 
1882 LogicalResult AffineForOp::verifyRegions() {
1883  // Check that the body defines as single block argument for the induction
1884  // variable.
1885  auto *body = getBody();
1886  if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1887  return emitOpError("expected body to have a single index argument for the "
1888  "induction variable");
1889 
1890  // Verify that the bound operands are valid dimension/symbols.
1891  /// Lower bound.
1892  if (getLowerBoundMap().getNumInputs() > 0)
1894  getLowerBoundMap().getNumDims())))
1895  return failure();
1896  /// Upper bound.
1897  if (getUpperBoundMap().getNumInputs() > 0)
1899  getUpperBoundMap().getNumDims())))
1900  return failure();
1901 
1902  unsigned opNumResults = getNumResults();
1903  if (opNumResults == 0)
1904  return success();
1905 
1906  // If ForOp defines values, check that the number and types of the defined
1907  // values match ForOp initial iter operands and backedge basic block
1908  // arguments.
1909  if (getNumIterOperands() != opNumResults)
1910  return emitOpError(
1911  "mismatch between the number of loop-carried values and results");
1912  if (getNumRegionIterArgs() != opNumResults)
1913  return emitOpError(
1914  "mismatch between the number of basic block args and results");
1915 
1916  return success();
1917 }
1918 
1919 /// Parse a for operation loop bounds.
1920 static ParseResult parseBound(bool isLower, OperationState &result,
1921  OpAsmParser &p) {
1922  // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1923  // the map has multiple results.
1924  bool failedToParsedMinMax =
1925  failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1926 
1927  auto &builder = p.getBuilder();
1928  auto boundAttrStrName =
1929  isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)
1930  : AffineForOp::getUpperBoundMapAttrName(result.name);
1931 
1932  // Parse ssa-id as identity map.
1934  if (p.parseOperandList(boundOpInfos))
1935  return failure();
1936 
1937  if (!boundOpInfos.empty()) {
1938  // Check that only one operand was parsed.
1939  if (boundOpInfos.size() > 1)
1940  return p.emitError(p.getNameLoc(),
1941  "expected only one loop bound operand");
1942 
1943  // TODO: improve error message when SSA value is not of index type.
1944  // Currently it is 'use of value ... expects different type than prior uses'
1945  if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1946  result.operands))
1947  return failure();
1948 
1949  // Create an identity map using symbol id. This representation is optimized
1950  // for storage. Analysis passes may expand it into a multi-dimensional map
1951  // if desired.
1952  AffineMap map = builder.getSymbolIdentityMap();
1953  result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
1954  return success();
1955  }
1956 
1957  // Get the attribute location.
1958  SMLoc attrLoc = p.getCurrentLocation();
1959 
1960  Attribute boundAttr;
1961  if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
1962  result.attributes))
1963  return failure();
1964 
1965  // Parse full form - affine map followed by dim and symbol list.
1966  if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
1967  unsigned currentNumOperands = result.operands.size();
1968  unsigned numDims;
1969  if (parseDimAndSymbolList(p, result.operands, numDims))
1970  return failure();
1971 
1972  auto map = affineMapAttr.getValue();
1973  if (map.getNumDims() != numDims)
1974  return p.emitError(
1975  p.getNameLoc(),
1976  "dim operand count and affine map dim count must match");
1977 
1978  unsigned numDimAndSymbolOperands =
1979  result.operands.size() - currentNumOperands;
1980  if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1981  return p.emitError(
1982  p.getNameLoc(),
1983  "symbol operand count and affine map symbol count must match");
1984 
1985  // If the map has multiple results, make sure that we parsed the min/max
1986  // prefix.
1987  if (map.getNumResults() > 1 && failedToParsedMinMax) {
1988  if (isLower) {
1989  return p.emitError(attrLoc, "lower loop bound affine map with "
1990  "multiple results requires 'max' prefix");
1991  }
1992  return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1993  "results requires 'min' prefix");
1994  }
1995  return success();
1996  }
1997 
1998  // Parse custom assembly form.
1999  if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
2000  result.attributes.pop_back();
2001  result.addAttribute(
2002  boundAttrStrName,
2003  AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
2004  return success();
2005  }
2006 
2007  return p.emitError(
2008  p.getNameLoc(),
2009  "expected valid affine map representation for loop bounds");
2010 }
2011 
2012 ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
2013  auto &builder = parser.getBuilder();
2014  OpAsmParser::Argument inductionVariable;
2015  inductionVariable.type = builder.getIndexType();
2016  // Parse the induction variable followed by '='.
2017  if (parser.parseArgument(inductionVariable) || parser.parseEqual())
2018  return failure();
2019 
2020  // Parse loop bounds.
2021  int64_t numOperands = result.operands.size();
2022  if (parseBound(/*isLower=*/true, result, parser))
2023  return failure();
2024  int64_t numLbOperands = result.operands.size() - numOperands;
2025  if (parser.parseKeyword("to", " between bounds"))
2026  return failure();
2027  numOperands = result.operands.size();
2028  if (parseBound(/*isLower=*/false, result, parser))
2029  return failure();
2030  int64_t numUbOperands = result.operands.size() - numOperands;
2031 
2032  // Parse the optional loop step, we default to 1 if one is not present.
2033  if (parser.parseOptionalKeyword("step")) {
2034  result.addAttribute(
2035  getStepAttrName(result.name),
2036  builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
2037  } else {
2038  SMLoc stepLoc = parser.getCurrentLocation();
2039  IntegerAttr stepAttr;
2040  if (parser.parseAttribute(stepAttr, builder.getIndexType(),
2041  getStepAttrName(result.name).data(),
2042  result.attributes))
2043  return failure();
2044 
2045  if (stepAttr.getValue().isNegative())
2046  return parser.emitError(
2047  stepLoc,
2048  "expected step to be representable as a positive signed integer");
2049  }
2050 
2051  // Parse the optional initial iteration arguments.
2054 
2055  // Induction variable.
2056  regionArgs.push_back(inductionVariable);
2057 
2058  if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
2059  // Parse assignment list and results type list.
2060  if (parser.parseAssignmentList(regionArgs, operands) ||
2061  parser.parseArrowTypeList(result.types))
2062  return failure();
2063  // Resolve input operands.
2064  for (auto argOperandType :
2065  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2066  Type type = std::get<2>(argOperandType);
2067  std::get<0>(argOperandType).type = type;
2068  if (parser.resolveOperand(std::get<1>(argOperandType), type,
2069  result.operands))
2070  return failure();
2071  }
2072  }
2073 
2074  result.addAttribute(
2075  getOperandSegmentSizeAttr(),
2076  builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands),
2077  static_cast<int32_t>(numUbOperands),
2078  static_cast<int32_t>(operands.size())}));
2079 
2080  // Parse the body region.
2081  Region *body = result.addRegion();
2082  if (regionArgs.size() != result.types.size() + 1)
2083  return parser.emitError(
2084  parser.getNameLoc(),
2085  "mismatch between the number of loop-carried values and results");
2086  if (parser.parseRegion(*body, regionArgs))
2087  return failure();
2088 
2089  AffineForOp::ensureTerminator(*body, builder, result.location);
2090 
2091  // Parse the optional attribute list.
2092  return parser.parseOptionalAttrDict(result.attributes);
2093 }
2094 
2095 static void printBound(AffineMapAttr boundMap,
2096  Operation::operand_range boundOperands,
2097  const char *prefix, OpAsmPrinter &p) {
2098  AffineMap map = boundMap.getValue();
2099 
2100  // Check if this bound should be printed using custom assembly form.
2101  // The decision to restrict printing custom assembly form to trivial cases
2102  // comes from the will to roundtrip MLIR binary -> text -> binary in a
2103  // lossless way.
2104  // Therefore, custom assembly form parsing and printing is only supported for
2105  // zero-operand constant maps and single symbol operand identity maps.
2106  if (map.getNumResults() == 1) {
2107  AffineExpr expr = map.getResult(0);
2108 
2109  // Print constant bound.
2110  if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2111  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2112  p << constExpr.getValue();
2113  return;
2114  }
2115  }
2116 
2117  // Print bound that consists of a single SSA symbol if the map is over a
2118  // single symbol.
2119  if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2120  if (dyn_cast<AffineSymbolExpr>(expr)) {
2121  p.printOperand(*boundOperands.begin());
2122  return;
2123  }
2124  }
2125  } else {
2126  // Map has multiple results. Print 'min' or 'max' prefix.
2127  p << prefix << ' ';
2128  }
2129 
2130  // Print the map and its operands.
2131  p << boundMap;
2132  printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
2133  map.getNumDims(), p);
2134 }
2135 
2136 unsigned AffineForOp::getNumIterOperands() {
2137  AffineMap lbMap = getLowerBoundMapAttr().getValue();
2138  AffineMap ubMap = getUpperBoundMapAttr().getValue();
2139 
2140  return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2141 }
2142 
2143 std::optional<MutableArrayRef<OpOperand>>
2144 AffineForOp::getYieldedValuesMutable() {
2145  return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2146 }
2147 
2149  p << ' ';
2150  p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
2151  /*omitType=*/true);
2152  p << " = ";
2153  printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
2154  p << " to ";
2155  printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
2156 
2157  if (getStepAsInt() != 1)
2158  p << " step " << getStepAsInt();
2159 
2160  bool printBlockTerminators = false;
2161  if (getNumIterOperands() > 0) {
2162  p << " iter_args(";
2163  auto regionArgs = getRegionIterArgs();
2164  auto operands = getInits();
2165 
2166  llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2167  p << std::get<0>(it) << " = " << std::get<1>(it);
2168  });
2169  p << ") -> (" << getResultTypes() << ")";
2170  printBlockTerminators = true;
2171  }
2172 
2173  p << ' ';
2174  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2175  printBlockTerminators);
2177  (*this)->getAttrs(),
2178  /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()),
2179  getUpperBoundMapAttrName(getOperation()->getName()),
2180  getStepAttrName(getOperation()->getName()),
2181  getOperandSegmentSizeAttr()});
2182 }
2183 
2184 /// Fold the constant bounds of a loop.
2185 static LogicalResult foldLoopBounds(AffineForOp forOp) {
2186  auto foldLowerOrUpperBound = [&forOp](bool lower) {
2187  // Check to see if each of the operands is the result of a constant. If
2188  // so, get the value. If not, ignore it.
2189  SmallVector<Attribute, 8> operandConstants;
2190  auto boundOperands =
2191  lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2192  for (auto operand : boundOperands) {
2193  Attribute operandCst;
2194  matchPattern(operand, m_Constant(&operandCst));
2195  operandConstants.push_back(operandCst);
2196  }
2197 
2198  AffineMap boundMap =
2199  lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2200  assert(boundMap.getNumResults() >= 1 &&
2201  "bound maps should have at least one result");
2202  SmallVector<Attribute, 4> foldedResults;
2203  if (failed(boundMap.constantFold(operandConstants, foldedResults)))
2204  return failure();
2205 
2206  // Compute the max or min as applicable over the results.
2207  assert(!foldedResults.empty() && "bounds should have at least one result");
2208  auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2209  for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2210  auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2211  maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2212  : llvm::APIntOps::smin(maxOrMin, foldedResult);
2213  }
2214  lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2215  : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2216  return success();
2217  };
2218 
2219  // Try to fold the lower bound.
2220  bool folded = false;
2221  if (!forOp.hasConstantLowerBound())
2222  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
2223 
2224  // Try to fold the upper bound.
2225  if (!forOp.hasConstantUpperBound())
2226  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
2227  return success(folded);
2228 }
2229 
2230 /// Canonicalize the bounds of the given loop.
2231 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2232  SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2233  SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2234 
2235  auto lbMap = forOp.getLowerBoundMap();
2236  auto ubMap = forOp.getUpperBoundMap();
2237  auto prevLbMap = lbMap;
2238  auto prevUbMap = ubMap;
2239 
2240  composeAffineMapAndOperands(&lbMap, &lbOperands);
2241  canonicalizeMapAndOperands(&lbMap, &lbOperands);
2242  simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
2243  simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
2244  lbMap = removeDuplicateExprs(lbMap);
2245 
2246  composeAffineMapAndOperands(&ubMap, &ubOperands);
2247  canonicalizeMapAndOperands(&ubMap, &ubOperands);
2248  ubMap = removeDuplicateExprs(ubMap);
2249 
2250  // Any canonicalization change always leads to updated map(s).
2251  if (lbMap == prevLbMap && ubMap == prevUbMap)
2252  return failure();
2253 
2254  if (lbMap != prevLbMap)
2255  forOp.setLowerBound(lbOperands, lbMap);
2256  if (ubMap != prevUbMap)
2257  forOp.setUpperBound(ubOperands, ubMap);
2258  return success();
2259 }
2260 
2261 namespace {
2262 /// Returns constant trip count in trivial cases.
2263 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2264  int64_t step = forOp.getStepAsInt();
2265  if (!forOp.hasConstantBounds() || step <= 0)
2266  return std::nullopt;
2267  int64_t lb = forOp.getConstantLowerBound();
2268  int64_t ub = forOp.getConstantUpperBound();
2269  return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2270 }
2271 
2272 /// This is a pattern to fold trivially empty loop bodies.
2273 /// TODO: This should be moved into the folding hook.
2274 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
2276 
2277  LogicalResult matchAndRewrite(AffineForOp forOp,
2278  PatternRewriter &rewriter) const override {
2279  // Check that the body only contains a yield.
2280  if (!llvm::hasSingleElement(*forOp.getBody()))
2281  return failure();
2282  if (forOp.getNumResults() == 0)
2283  return success();
2284  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2285  if (tripCount && *tripCount == 0) {
2286  // The initial values of the iteration arguments would be the op's
2287  // results.
2288  rewriter.replaceOp(forOp, forOp.getInits());
2289  return success();
2290  }
2291  SmallVector<Value, 4> replacements;
2292  auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2293  auto iterArgs = forOp.getRegionIterArgs();
2294  bool hasValDefinedOutsideLoop = false;
2295  bool iterArgsNotInOrder = false;
2296  for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2297  Value val = yieldOp.getOperand(i);
2298  auto *iterArgIt = llvm::find(iterArgs, val);
2299  if (iterArgIt == iterArgs.end()) {
2300  // `val` is defined outside of the loop.
2301  assert(forOp.isDefinedOutsideOfLoop(val) &&
2302  "must be defined outside of the loop");
2303  hasValDefinedOutsideLoop = true;
2304  replacements.push_back(val);
2305  } else {
2306  unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2307  if (pos != i)
2308  iterArgsNotInOrder = true;
2309  replacements.push_back(forOp.getInits()[pos]);
2310  }
2311  }
2312  // Bail out when the trip count is unknown and the loop returns any value
2313  // defined outside of the loop or any iterArg out of order.
2314  if (!tripCount.has_value() &&
2315  (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2316  return failure();
2317  // Bail out when the loop iterates more than once and it returns any iterArg
2318  // out of order.
2319  if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2320  return failure();
2321  rewriter.replaceOp(forOp, replacements);
2322  return success();
2323  }
2324 };
2325 } // namespace
2326 
2327 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
2328  MLIRContext *context) {
2329  results.add<AffineForEmptyLoopFolder>(context);
2330 }
2331 
2332 OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2333  assert((point.isParent() || point == getRegion()) && "invalid region point");
2334 
2335  // The initial operands map to the loop arguments after the induction
2336  // variable or are forwarded to the results when the trip count is zero.
2337  return getInits();
2338 }
2339 
2340 void AffineForOp::getSuccessorRegions(
2342  assert((point.isParent() || point == getRegion()) && "expected loop region");
2343  // The loop may typically branch back to its body or to the parent operation.
2344  // If the predecessor is the parent op and the trip count is known to be at
2345  // least one, branch into the body using the iterator arguments. And in cases
2346  // we know the trip count is zero, it can only branch back to its parent.
2347  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2348  if (point.isParent() && tripCount.has_value()) {
2349  if (tripCount.value() > 0) {
2350  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2351  return;
2352  }
2353  if (tripCount.value() == 0) {
2354  regions.push_back(RegionSuccessor(getResults()));
2355  return;
2356  }
2357  }
2358 
2359  // From the loop body, if the trip count is one, we can only branch back to
2360  // the parent.
2361  if (!point.isParent() && tripCount && *tripCount == 1) {
2362  regions.push_back(RegionSuccessor(getResults()));
2363  return;
2364  }
2365 
2366  // In all other cases, the loop may branch back to itself or the parent
2367  // operation.
2368  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2369  regions.push_back(RegionSuccessor(getResults()));
2370 }
2371 
2372 /// Returns true if the affine.for has zero iterations in trivial cases.
2373 static bool hasTrivialZeroTripCount(AffineForOp op) {
2374  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
2375  return tripCount && *tripCount == 0;
2376 }
2377 
2378 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2379  SmallVectorImpl<OpFoldResult> &results) {
2380  bool folded = succeeded(foldLoopBounds(*this));
2381  folded |= succeeded(canonicalizeLoopBounds(*this));
2382  if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
2383  // The initial values of the loop-carried variables (iter_args) are the
2384  // results of the op. But this must be avoided for an affine.for op that
2385  // does not return any results. Since ops that do not return results cannot
2386  // be folded away, we would enter an infinite loop of folds on the same
2387  // affine.for op.
2388  results.assign(getInits().begin(), getInits().end());
2389  folded = true;
2390  }
2391  return success(folded);
2392 }
2393 
2395  return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
2396 }
2397 
2399  return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap());
2400 }
2401 
2402 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2403  assert(lbOperands.size() == map.getNumInputs());
2404  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2405  getLowerBoundOperandsMutable().assign(lbOperands);
2406  setLowerBoundMap(map);
2407 }
2408 
2409 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2410  assert(ubOperands.size() == map.getNumInputs());
2411  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2412  getUpperBoundOperandsMutable().assign(ubOperands);
2413  setUpperBoundMap(map);
2414 }
2415 
2416 bool AffineForOp::hasConstantLowerBound() {
2417  return getLowerBoundMap().isSingleConstant();
2418 }
2419 
2420 bool AffineForOp::hasConstantUpperBound() {
2421  return getUpperBoundMap().isSingleConstant();
2422 }
2423 
2424 int64_t AffineForOp::getConstantLowerBound() {
2425  return getLowerBoundMap().getSingleConstantResult();
2426 }
2427 
2428 int64_t AffineForOp::getConstantUpperBound() {
2429  return getUpperBoundMap().getSingleConstantResult();
2430 }
2431 
2432 void AffineForOp::setConstantLowerBound(int64_t value) {
2433  setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2434 }
2435 
2436 void AffineForOp::setConstantUpperBound(int64_t value) {
2437  setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2438 }
2439 
2440 AffineForOp::operand_range AffineForOp::getControlOperands() {
2441  return {operand_begin(), operand_begin() + getLowerBoundOperands().size() +
2442  getUpperBoundOperands().size()};
2443 }
2444 
2445 bool AffineForOp::matchingBoundOperandList() {
2446  auto lbMap = getLowerBoundMap();
2447  auto ubMap = getUpperBoundMap();
2448  if (lbMap.getNumDims() != ubMap.getNumDims() ||
2449  lbMap.getNumSymbols() != ubMap.getNumSymbols())
2450  return false;
2451 
2452  unsigned numOperands = lbMap.getNumInputs();
2453  for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2454  // Compare Value 's.
2455  if (getOperand(i) != getOperand(numOperands + i))
2456  return false;
2457  }
2458  return true;
2459 }
2460 
2461 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
2462 
2463 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2464  return SmallVector<Value>{getInductionVar()};
2465 }
2466 
2467 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2468  if (!hasConstantLowerBound())
2469  return std::nullopt;
2470  OpBuilder b(getContext());
2472  OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2473 }
2474 
2475 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2476  OpBuilder b(getContext());
2478  OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
2479 }
2480 
2481 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2482  if (!hasConstantUpperBound())
2483  return {};
2484  OpBuilder b(getContext());
2486  OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2487 }
2488 
2489 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2490  RewriterBase &rewriter, ValueRange newInitOperands,
2491  bool replaceInitOperandUsesInLoop,
2492  const NewYieldValuesFn &newYieldValuesFn) {
2493  // Create a new loop before the existing one, with the extra operands.
2494  OpBuilder::InsertionGuard g(rewriter);
2495  rewriter.setInsertionPoint(getOperation());
2496  auto inits = llvm::to_vector(getInits());
2497  inits.append(newInitOperands.begin(), newInitOperands.end());
2498  AffineForOp newLoop = rewriter.create<AffineForOp>(
2499  getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2500  getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
2501 
2502  // Generate the new yield values and append them to the scf.yield operation.
2503  auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2504  ArrayRef<BlockArgument> newIterArgs =
2505  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2506  {
2507  OpBuilder::InsertionGuard g(rewriter);
2508  rewriter.setInsertionPoint(yieldOp);
2509  SmallVector<Value> newYieldedValues =
2510  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2511  assert(newInitOperands.size() == newYieldedValues.size() &&
2512  "expected as many new yield values as new iter operands");
2513  rewriter.modifyOpInPlace(yieldOp, [&]() {
2514  yieldOp.getOperandsMutable().append(newYieldedValues);
2515  });
2516  }
2517 
2518  // Move the loop body to the new op.
2519  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2520  newLoop.getBody()->getArguments().take_front(
2521  getBody()->getNumArguments()));
2522 
2523  if (replaceInitOperandUsesInLoop) {
2524  // Replace all uses of `newInitOperands` with the corresponding basic block
2525  // arguments.
2526  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2527  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
2528  [&](OpOperand &use) {
2529  Operation *user = use.getOwner();
2530  return newLoop->isProperAncestor(user);
2531  });
2532  }
2533  }
2534 
2535  // Replace the old loop.
2536  rewriter.replaceOp(getOperation(),
2537  newLoop->getResults().take_front(getNumResults()));
2538  return cast<LoopLikeOpInterface>(newLoop.getOperation());
2539 }
2540 
2541 Speculation::Speculatability AffineForOp::getSpeculatability() {
2542  // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2543  // Start and End.
2544  //
2545  // For Step != 1, the loop may not terminate. We can add more smarts here if
2546  // needed.
2547  return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable
2549 }
2550 
2551 /// Returns true if the provided value is the induction variable of a
2552 /// AffineForOp.
2554  return getForInductionVarOwner(val) != AffineForOp();
2555 }
2556 
2558  return getAffineParallelInductionVarOwner(val) != nullptr;
2559 }
2560 
2563 }
2564 
2566  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2567  if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2568  return AffineForOp();
2569  if (auto forOp =
2570  ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2571  // Check to make sure `val` is the induction variable, not an iter_arg.
2572  return forOp.getInductionVar() == val ? forOp : AffineForOp();
2573  return AffineForOp();
2574 }
2575 
2577  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2578  if (!ivArg || !ivArg.getOwner())
2579  return nullptr;
2580  Operation *containingOp = ivArg.getOwner()->getParentOp();
2581  auto parallelOp = dyn_cast<AffineParallelOp>(containingOp);
2582  if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2583  return parallelOp;
2584  return nullptr;
2585 }
2586 
2587 /// Extracts the induction variables from a list of AffineForOps and returns
2588 /// them.
2590  SmallVectorImpl<Value> *ivs) {
2591  ivs->reserve(forInsts.size());
2592  for (auto forInst : forInsts)
2593  ivs->push_back(forInst.getInductionVar());
2594 }
2595 
2598  ivs.reserve(affineOps.size());
2599  for (Operation *op : affineOps) {
2600  // Add constraints from forOp's bounds.
2601  if (auto forOp = dyn_cast<AffineForOp>(op))
2602  ivs.push_back(forOp.getInductionVar());
2603  else if (auto parallelOp = dyn_cast<AffineParallelOp>(op))
2604  for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2605  ivs.push_back(parallelOp.getBody()->getArgument(i));
2606  }
2607 }
2608 
2609 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2610 /// operations.
2611 template <typename BoundListTy, typename LoopCreatorTy>
2613  OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2614  ArrayRef<int64_t> steps,
2615  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2616  LoopCreatorTy &&loopCreatorFn) {
2617  assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2618  assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2619 
2620  // If there are no loops to be constructed, construct the body anyway.
2621  OpBuilder::InsertionGuard guard(builder);
2622  if (lbs.empty()) {
2623  if (bodyBuilderFn)
2624  bodyBuilderFn(builder, loc, ValueRange());
2625  return;
2626  }
2627 
2628  // Create the loops iteratively and store the induction variables.
2630  ivs.reserve(lbs.size());
2631  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2632  // Callback for creating the loop body, always creates the terminator.
2633  auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
2634  ValueRange iterArgs) {
2635  ivs.push_back(iv);
2636  // In the innermost loop, call the body builder.
2637  if (i == e - 1 && bodyBuilderFn) {
2638  OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
2639  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2640  }
2641  nestedBuilder.create<AffineYieldOp>(nestedLoc);
2642  };
2643 
2644  // Delegate actual loop creation to the callback in order to dispatch
2645  // between constant- and variable-bound loops.
2646  auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2647  builder.setInsertionPointToStart(loop.getBody());
2648  }
2649 }
2650 
2651 /// Creates an affine loop from the bounds known to be constants.
2652 static AffineForOp
2654  int64_t ub, int64_t step,
2655  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2656  return builder.create<AffineForOp>(loc, lb, ub, step,
2657  /*iterArgs=*/std::nullopt, bodyBuilderFn);
2658 }
2659 
2660 /// Creates an affine loop from the bounds that may or may not be constants.
2661 static AffineForOp
2663  int64_t step,
2664  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2665  std::optional<int64_t> lbConst = getConstantIntValue(lb);
2666  std::optional<int64_t> ubConst = getConstantIntValue(ub);
2667  if (lbConst && ubConst)
2668  return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
2669  ubConst.value(), step, bodyBuilderFn);
2670  return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
2671  builder.getDimIdentityMap(), step,
2672  /*iterArgs=*/std::nullopt, bodyBuilderFn);
2673 }
2674 
2676  OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
2678  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2679  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2681 }
2682 
2684  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
2685  ArrayRef<int64_t> steps,
2686  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2687  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2689 }
2690 
2691 //===----------------------------------------------------------------------===//
2692 // AffineIfOp
2693 //===----------------------------------------------------------------------===//
2694 
2695 namespace {
2696 /// Remove else blocks that have nothing other than a zero value yield.
2697 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
2699 
2700  LogicalResult matchAndRewrite(AffineIfOp ifOp,
2701  PatternRewriter &rewriter) const override {
2702  if (ifOp.getElseRegion().empty() ||
2703  !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2704  return failure();
2705 
2706  rewriter.startOpModification(ifOp);
2707  rewriter.eraseBlock(ifOp.getElseBlock());
2708  rewriter.finalizeOpModification(ifOp);
2709  return success();
2710  }
2711 };
2712 
2713 /// Removes affine.if cond if the condition is always true or false in certain
2714 /// trivial cases. Promotes the then/else block in the parent operation block.
2715 struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
2717 
2718  LogicalResult matchAndRewrite(AffineIfOp op,
2719  PatternRewriter &rewriter) const override {
2720 
2721  auto isTriviallyFalse = [](IntegerSet iSet) {
2722  return iSet.isEmptyIntegerSet();
2723  };
2724 
2725  auto isTriviallyTrue = [](IntegerSet iSet) {
2726  return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2727  iSet.getConstraint(0) == 0);
2728  };
2729 
2730  IntegerSet affineIfConditions = op.getIntegerSet();
2731  Block *blockToMove;
2732  if (isTriviallyFalse(affineIfConditions)) {
2733  // The absence, or equivalently, the emptiness of the else region need not
2734  // be checked when affine.if is returning results because if an affine.if
2735  // operation is returning results, it always has a non-empty else region.
2736  if (op.getNumResults() == 0 && !op.hasElse()) {
2737  // If the else region is absent, or equivalently, empty, remove the
2738  // affine.if operation (which is not returning any results).
2739  rewriter.eraseOp(op);
2740  return success();
2741  }
2742  blockToMove = op.getElseBlock();
2743  } else if (isTriviallyTrue(affineIfConditions)) {
2744  blockToMove = op.getThenBlock();
2745  } else {
2746  return failure();
2747  }
2748  Operation *blockToMoveTerminator = blockToMove->getTerminator();
2749  // Promote the "blockToMove" block to the parent operation block between the
2750  // prologue and epilogue of "op".
2751  rewriter.inlineBlockBefore(blockToMove, op);
2752  // Replace the "op" operation with the operands of the
2753  // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2754  // the affine.yield operation present in the "blockToMove" block. It has no
2755  // operands when affine.if is not returning results and therefore, in that
2756  // case, replaceOp just erases "op". When affine.if is not returning
2757  // results, the affine.yield operation can be omitted. It gets inserted
2758  // implicitly.
2759  rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
2760  // Erase the "blockToMoveTerminator" operation since it is now in the parent
2761  // operation block, which already has its own terminator.
2762  rewriter.eraseOp(blockToMoveTerminator);
2763  return success();
2764  }
2765 };
2766 } // namespace
2767 
2768 /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
2769 /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
2770 void AffineIfOp::getSuccessorRegions(
2772  // If the predecessor is an AffineIfOp, then branching into both `then` and
2773  // `else` region is valid.
2774  if (point.isParent()) {
2775  regions.reserve(2);
2776  regions.push_back(
2777  RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
2778  // If the "else" region is empty, branch bach into parent.
2779  if (getElseRegion().empty()) {
2780  regions.push_back(getResults());
2781  } else {
2782  regions.push_back(
2783  RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
2784  }
2785  return;
2786  }
2787 
2788  // If the predecessor is the `else`/`then` region, then branching into parent
2789  // op is valid.
2790  regions.push_back(RegionSuccessor(getResults()));
2791 }
2792 
2793 LogicalResult AffineIfOp::verify() {
2794  // Verify that we have a condition attribute.
2795  // FIXME: This should be specified in the arguments list in ODS.
2796  auto conditionAttr =
2797  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2798  if (!conditionAttr)
2799  return emitOpError("requires an integer set attribute named 'condition'");
2800 
2801  // Verify that there are enough operands for the condition.
2802  IntegerSet condition = conditionAttr.getValue();
2803  if (getNumOperands() != condition.getNumInputs())
2804  return emitOpError("operand count and condition integer set dimension and "
2805  "symbol count must match");
2806 
2807  // Verify that the operands are valid dimension/symbols.
2808  if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
2809  condition.getNumDims())))
2810  return failure();
2811 
2812  return success();
2813 }
2814 
2815 ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
2816  // Parse the condition attribute set.
2817  IntegerSetAttr conditionAttr;
2818  unsigned numDims;
2819  if (parser.parseAttribute(conditionAttr,
2820  AffineIfOp::getConditionAttrStrName(),
2821  result.attributes) ||
2822  parseDimAndSymbolList(parser, result.operands, numDims))
2823  return failure();
2824 
2825  // Verify the condition operands.
2826  auto set = conditionAttr.getValue();
2827  if (set.getNumDims() != numDims)
2828  return parser.emitError(
2829  parser.getNameLoc(),
2830  "dim operand count and integer set dim count must match");
2831  if (numDims + set.getNumSymbols() != result.operands.size())
2832  return parser.emitError(
2833  parser.getNameLoc(),
2834  "symbol operand count and integer set symbol count must match");
2835 
2836  if (parser.parseOptionalArrowTypeList(result.types))
2837  return failure();
2838 
2839  // Create the regions for 'then' and 'else'. The latter must be created even
2840  // if it remains empty for the validity of the operation.
2841  result.regions.reserve(2);
2842  Region *thenRegion = result.addRegion();
2843  Region *elseRegion = result.addRegion();
2844 
2845  // Parse the 'then' region.
2846  if (parser.parseRegion(*thenRegion, {}, {}))
2847  return failure();
2848  AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
2849  result.location);
2850 
2851  // If we find an 'else' keyword then parse the 'else' region.
2852  if (!parser.parseOptionalKeyword("else")) {
2853  if (parser.parseRegion(*elseRegion, {}, {}))
2854  return failure();
2855  AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
2856  result.location);
2857  }
2858 
2859  // Parse the optional attribute list.
2860  if (parser.parseOptionalAttrDict(result.attributes))
2861  return failure();
2862 
2863  return success();
2864 }
2865 
2867  auto conditionAttr =
2868  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2869  p << " " << conditionAttr;
2870  printDimAndSymbolList(operand_begin(), operand_end(),
2871  conditionAttr.getValue().getNumDims(), p);
2872  p.printOptionalArrowTypeList(getResultTypes());
2873  p << ' ';
2874  p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
2875  /*printBlockTerminators=*/getNumResults());
2876 
2877  // Print the 'else' regions if it has any blocks.
2878  auto &elseRegion = this->getElseRegion();
2879  if (!elseRegion.empty()) {
2880  p << " else ";
2881  p.printRegion(elseRegion,
2882  /*printEntryBlockArgs=*/false,
2883  /*printBlockTerminators=*/getNumResults());
2884  }
2885 
2886  // Print the attribute list.
2887  p.printOptionalAttrDict((*this)->getAttrs(),
2888  /*elidedAttrs=*/getConditionAttrStrName());
2889 }
2890 
2891 IntegerSet AffineIfOp::getIntegerSet() {
2892  return (*this)
2893  ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2894  .getValue();
2895 }
2896 
2897 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
2898  (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
2899 }
2900 
2901 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
2902  setIntegerSet(set);
2903  (*this)->setOperands(operands);
2904 }
2905 
2906 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2907  TypeRange resultTypes, IntegerSet set, ValueRange args,
2908  bool withElseRegion) {
2909  assert(resultTypes.empty() || withElseRegion);
2910  OpBuilder::InsertionGuard guard(builder);
2911 
2912  result.addTypes(resultTypes);
2913  result.addOperands(args);
2914  result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
2915 
2916  Region *thenRegion = result.addRegion();
2917  builder.createBlock(thenRegion);
2918  if (resultTypes.empty())
2919  AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
2920 
2921  Region *elseRegion = result.addRegion();
2922  if (withElseRegion) {
2923  builder.createBlock(elseRegion);
2924  if (resultTypes.empty())
2925  AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
2926  }
2927 }
2928 
2929 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2930  IntegerSet set, ValueRange args, bool withElseRegion) {
2931  AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
2932  withElseRegion);
2933 }
2934 
2935 /// Compose any affine.apply ops feeding into `operands` of the integer set
2936 /// `set` by composing the maps of such affine.apply ops with the integer
2937 /// set constraints.
2939  SmallVectorImpl<Value> &operands) {
2940  // We will simply reuse the API of the map composition by viewing the LHSs of
2941  // the equalities and inequalities of `set` as the affine exprs of an affine
2942  // map. Convert to equivalent map, compose, and convert back to set.
2943  auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(),
2944  set.getConstraints(), set.getContext());
2945  // Check if any composition is possible.
2946  if (llvm::none_of(operands,
2947  [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
2948  return;
2949 
2950  composeAffineMapAndOperands(&map, &operands);
2951  set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
2952  set.getEqFlags());
2953 }
2954 
2955 /// Canonicalize an affine if op's conditional (integer set + operands).
2956 LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2957  auto set = getIntegerSet();
2958  SmallVector<Value, 4> operands(getOperands());
2959  composeSetAndOperands(set, operands);
2960  canonicalizeSetAndOperands(&set, &operands);
2961 
2962  // Check if the canonicalization or composition led to any change.
2963  if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
2964  return failure();
2965 
2966  setConditional(set, operands);
2967  return success();
2968 }
2969 
2970 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2971  MLIRContext *context) {
2972  results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
2973 }
2974 
2975 //===----------------------------------------------------------------------===//
2976 // AffineLoadOp
2977 //===----------------------------------------------------------------------===//
2978 
2979 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2980  AffineMap map, ValueRange operands) {
2981  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2982  result.addOperands(operands);
2983  if (map)
2984  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2985  auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
2986  result.types.push_back(memrefType.getElementType());
2987 }
2988 
2989 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2990  Value memref, AffineMap map, ValueRange mapOperands) {
2991  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2992  result.addOperands(memref);
2993  result.addOperands(mapOperands);
2994  auto memrefType = llvm::cast<MemRefType>(memref.getType());
2995  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2996  result.types.push_back(memrefType.getElementType());
2997 }
2998 
2999 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3000  Value memref, ValueRange indices) {
3001  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3002  int64_t rank = memrefType.getRank();
3003  // Create identity map for memrefs with at least one dimension or () -> ()
3004  // for zero-dimensional memrefs.
3005  auto map =
3006  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3007  build(builder, result, memref, map, indices);
3008 }
3009 
3010 ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
3011  auto &builder = parser.getBuilder();
3012  auto indexTy = builder.getIndexType();
3013 
3014  MemRefType type;
3015  OpAsmParser::UnresolvedOperand memrefInfo;
3016  AffineMapAttr mapAttr;
3018  return failure(
3019  parser.parseOperand(memrefInfo) ||
3020  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3021  AffineLoadOp::getMapAttrStrName(),
3022  result.attributes) ||
3023  parser.parseOptionalAttrDict(result.attributes) ||
3024  parser.parseColonType(type) ||
3025  parser.resolveOperand(memrefInfo, type, result.operands) ||
3026  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3027  parser.addTypeToList(type.getElementType(), result.types));
3028 }
3029 
3031  p << " " << getMemRef() << '[';
3032  if (AffineMapAttr mapAttr =
3033  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3034  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3035  p << ']';
3036  p.printOptionalAttrDict((*this)->getAttrs(),
3037  /*elidedAttrs=*/{getMapAttrStrName()});
3038  p << " : " << getMemRefType();
3039 }
3040 
3041 /// Verify common indexing invariants of affine.load, affine.store,
3042 /// affine.vector_load and affine.vector_store.
3043 static LogicalResult
3044 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
3045  Operation::operand_range mapOperands,
3046  MemRefType memrefType, unsigned numIndexOperands) {
3047  AffineMap map = mapAttr.getValue();
3048  if (map.getNumResults() != memrefType.getRank())
3049  return op->emitOpError("affine map num results must equal memref rank");
3050  if (map.getNumInputs() != numIndexOperands)
3051  return op->emitOpError("expects as many subscripts as affine map inputs");
3052 
3053  Region *scope = getAffineScope(op);
3054  for (auto idx : mapOperands) {
3055  if (!idx.getType().isIndex())
3056  return op->emitOpError("index to load must have 'index' type");
3057  if (!isValidAffineIndexOperand(idx, scope))
3058  return op->emitOpError(
3059  "index must be a valid dimension or symbol identifier");
3060  }
3061 
3062  return success();
3063 }
3064 
3065 LogicalResult AffineLoadOp::verify() {
3066  auto memrefType = getMemRefType();
3067  if (getType() != memrefType.getElementType())
3068  return emitOpError("result type must match element type of memref");
3069 
3070  if (failed(verifyMemoryOpIndexing(
3071  getOperation(),
3072  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3073  getMapOperands(), memrefType,
3074  /*numIndexOperands=*/getNumOperands() - 1)))
3075  return failure();
3076 
3077  return success();
3078 }
3079 
3080 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3081  MLIRContext *context) {
3082  results.add<SimplifyAffineOp<AffineLoadOp>>(context);
3083 }
3084 
3085 OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3086  /// load(memrefcast) -> load
3087  if (succeeded(memref::foldMemRefCast(*this)))
3088  return getResult();
3089 
3090  // Fold load from a global constant memref.
3091  auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3092  if (!getGlobalOp)
3093  return {};
3094  // Get to the memref.global defining the symbol.
3095  auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3096  if (!symbolTableOp)
3097  return {};
3098  auto global = dyn_cast_or_null<memref::GlobalOp>(
3099  SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
3100  if (!global)
3101  return {};
3102 
3103  // Check if the global memref is a constant.
3104  auto cstAttr =
3105  llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3106  if (!cstAttr)
3107  return {};
3108  // If it's a splat constant, we can fold irrespective of indices.
3109  if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3110  return splatAttr.getSplatValue<Attribute>();
3111  // Otherwise, we can fold only if we know the indices.
3112  if (!getAffineMap().isConstant())
3113  return {};
3114  auto indices = llvm::to_vector<4>(
3115  llvm::map_range(getAffineMap().getConstantResults(),
3116  [](int64_t v) -> uint64_t { return v; }));
3117  return cstAttr.getValues<Attribute>()[indices];
3118 }
3119 
3120 //===----------------------------------------------------------------------===//
3121 // AffineStoreOp
3122 //===----------------------------------------------------------------------===//
3123 
3124 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3125  Value valueToStore, Value memref, AffineMap map,
3126  ValueRange mapOperands) {
3127  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3128  result.addOperands(valueToStore);
3129  result.addOperands(memref);
3130  result.addOperands(mapOperands);
3131  result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3132 }
3133 
3134 // Use identity map.
3135 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3136  Value valueToStore, Value memref,
3137  ValueRange indices) {
3138  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3139  int64_t rank = memrefType.getRank();
3140  // Create identity map for memrefs with at least one dimension or () -> ()
3141  // for zero-dimensional memrefs.
3142  auto map =
3143  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3144  build(builder, result, valueToStore, memref, map, indices);
3145 }
3146 
3147 ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
3148  auto indexTy = parser.getBuilder().getIndexType();
3149 
3150  MemRefType type;
3151  OpAsmParser::UnresolvedOperand storeValueInfo;
3152  OpAsmParser::UnresolvedOperand memrefInfo;
3153  AffineMapAttr mapAttr;
3155  return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3156  parser.parseOperand(memrefInfo) ||
3157  parser.parseAffineMapOfSSAIds(
3158  mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3159  result.attributes) ||
3160  parser.parseOptionalAttrDict(result.attributes) ||
3161  parser.parseColonType(type) ||
3162  parser.resolveOperand(storeValueInfo, type.getElementType(),
3163  result.operands) ||
3164  parser.resolveOperand(memrefInfo, type, result.operands) ||
3165  parser.resolveOperands(mapOperands, indexTy, result.operands));
3166 }
3167 
3169  p << " " << getValueToStore();
3170  p << ", " << getMemRef() << '[';
3171  if (AffineMapAttr mapAttr =
3172  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3173  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3174  p << ']';
3175  p.printOptionalAttrDict((*this)->getAttrs(),
3176  /*elidedAttrs=*/{getMapAttrStrName()});
3177  p << " : " << getMemRefType();
3178 }
3179 
3180 LogicalResult AffineStoreOp::verify() {
3181  // The value to store must have the same type as memref element type.
3182  auto memrefType = getMemRefType();
3183  if (getValueToStore().getType() != memrefType.getElementType())
3184  return emitOpError(
3185  "value to store must have the same type as memref element type");
3186 
3187  if (failed(verifyMemoryOpIndexing(
3188  getOperation(),
3189  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3190  getMapOperands(), memrefType,
3191  /*numIndexOperands=*/getNumOperands() - 2)))
3192  return failure();
3193 
3194  return success();
3195 }
3196 
3197 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3198  MLIRContext *context) {
3199  results.add<SimplifyAffineOp<AffineStoreOp>>(context);
3200 }
3201 
3202 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3203  SmallVectorImpl<OpFoldResult> &results) {
3204  /// store(memrefcast) -> store
3205  return memref::foldMemRefCast(*this, getValueToStore());
3206 }
3207 
3208 //===----------------------------------------------------------------------===//
3209 // AffineMinMaxOpBase
3210 //===----------------------------------------------------------------------===//
3211 
3212 template <typename T>
3213 static LogicalResult verifyAffineMinMaxOp(T op) {
3214  // Verify that operand count matches affine map dimension and symbol count.
3215  if (op.getNumOperands() !=
3216  op.getMap().getNumDims() + op.getMap().getNumSymbols())
3217  return op.emitOpError(
3218  "operand count and affine map dimension and symbol count must match");
3219 
3220  if (op.getMap().getNumResults() == 0)
3221  return op.emitOpError("affine map expect at least one result");
3222  return success();
3223 }
3224 
3225 template <typename T>
3226 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3227  p << ' ' << op->getAttr(T::getMapAttrStrName());
3228  auto operands = op.getOperands();
3229  unsigned numDims = op.getMap().getNumDims();
3230  p << '(' << operands.take_front(numDims) << ')';
3231 
3232  if (operands.size() != numDims)
3233  p << '[' << operands.drop_front(numDims) << ']';
3234  p.printOptionalAttrDict(op->getAttrs(),
3235  /*elidedAttrs=*/{T::getMapAttrStrName()});
3236 }
3237 
3238 template <typename T>
3239 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
3240  OperationState &result) {
3241  auto &builder = parser.getBuilder();
3242  auto indexType = builder.getIndexType();
3245  AffineMapAttr mapAttr;
3246  return failure(
3247  parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3248  result.attributes) ||
3249  parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) ||
3250  parser.parseOperandList(symInfos,
3252  parser.parseOptionalAttrDict(result.attributes) ||
3253  parser.resolveOperands(dimInfos, indexType, result.operands) ||
3254  parser.resolveOperands(symInfos, indexType, result.operands) ||
3255  parser.addTypeToList(indexType, result.types));
3256 }
3257 
3258 /// Fold an affine min or max operation with the given operands. The operand
3259 /// list may contain nulls, which are interpreted as the operand not being a
3260 /// constant.
3261 template <typename T>
3263  static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3264  "expected affine min or max op");
3265 
3266  // Fold the affine map.
3267  // TODO: Fold more cases:
3268  // min(some_affine, some_affine + constant, ...), etc.
3269  SmallVector<int64_t, 2> results;
3270  auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3271 
3272  if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3273  return op.getOperand(0);
3274 
3275  // If some of the map results are not constant, try changing the map in-place.
3276  if (results.empty()) {
3277  // If the map is the same, report that folding did not happen.
3278  if (foldedMap == op.getMap())
3279  return {};
3280  op->setAttr("map", AffineMapAttr::get(foldedMap));
3281  return op.getResult();
3282  }
3283 
3284  // Otherwise, completely fold the op into a constant.
3285  auto resultIt = std::is_same<T, AffineMinOp>::value
3286  ? llvm::min_element(results)
3287  : llvm::max_element(results);
3288  if (resultIt == results.end())
3289  return {};
3290  return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3291 }
3292 
3293 /// Remove duplicated expressions in affine min/max ops.
3294 template <typename T>
3297 
3298  LogicalResult matchAndRewrite(T affineOp,
3299  PatternRewriter &rewriter) const override {
3300  AffineMap oldMap = affineOp.getAffineMap();
3301 
3302  SmallVector<AffineExpr, 4> newExprs;
3303  for (AffineExpr expr : oldMap.getResults()) {
3304  // This is a linear scan over newExprs, but it should be fine given that
3305  // we typically just have a few expressions per op.
3306  if (!llvm::is_contained(newExprs, expr))
3307  newExprs.push_back(expr);
3308  }
3309 
3310  if (newExprs.size() == oldMap.getNumResults())
3311  return failure();
3312 
3313  auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
3314  newExprs, rewriter.getContext());
3315  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3316 
3317  return success();
3318  }
3319 };
3320 
3321 /// Merge an affine min/max op to its consumers if its consumer is also an
3322 /// affine min/max op.
3323 ///
3324 /// This pattern requires the producer affine min/max op is bound to a
3325 /// dimension/symbol that is used as a standalone expression in the consumer
3326 /// affine op's map.
3327 ///
3328 /// For example, a pattern like the following:
3329 ///
3330 /// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3331 /// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3332 ///
3333 /// Can be turned into:
3334 ///
3335 /// %1 = affine.min affine_map<
3336 /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3337 template <typename T>
3340 
3341  LogicalResult matchAndRewrite(T affineOp,
3342  PatternRewriter &rewriter) const override {
3343  AffineMap oldMap = affineOp.getAffineMap();
3344  ValueRange dimOperands =
3345  affineOp.getMapOperands().take_front(oldMap.getNumDims());
3346  ValueRange symOperands =
3347  affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3348 
3349  auto newDimOperands = llvm::to_vector<8>(dimOperands);
3350  auto newSymOperands = llvm::to_vector<8>(symOperands);
3351  SmallVector<AffineExpr, 4> newExprs;
3352  SmallVector<T, 4> producerOps;
3353 
3354  // Go over each expression to see whether it's a single dimension/symbol
3355  // with the corresponding operand which is the result of another affine
3356  // min/max op. If So it can be merged into this affine op.
3357  for (AffineExpr expr : oldMap.getResults()) {
3358  if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3359  Value symValue = symOperands[symExpr.getPosition()];
3360  if (auto producerOp = symValue.getDefiningOp<T>()) {
3361  producerOps.push_back(producerOp);
3362  continue;
3363  }
3364  } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3365  Value dimValue = dimOperands[dimExpr.getPosition()];
3366  if (auto producerOp = dimValue.getDefiningOp<T>()) {
3367  producerOps.push_back(producerOp);
3368  continue;
3369  }
3370  }
3371  // For the above cases we will remove the expression by merging the
3372  // producer affine min/max's affine expressions. Otherwise we need to
3373  // keep the existing expression.
3374  newExprs.push_back(expr);
3375  }
3376 
3377  if (producerOps.empty())
3378  return failure();
3379 
3380  unsigned numUsedDims = oldMap.getNumDims();
3381  unsigned numUsedSyms = oldMap.getNumSymbols();
3382 
3383  // Now go over all producer affine ops and merge their expressions.
3384  for (T producerOp : producerOps) {
3385  AffineMap producerMap = producerOp.getAffineMap();
3386  unsigned numProducerDims = producerMap.getNumDims();
3387  unsigned numProducerSyms = producerMap.getNumSymbols();
3388 
3389  // Collect all dimension/symbol values.
3390  ValueRange dimValues =
3391  producerOp.getMapOperands().take_front(numProducerDims);
3392  ValueRange symValues =
3393  producerOp.getMapOperands().take_back(numProducerSyms);
3394  newDimOperands.append(dimValues.begin(), dimValues.end());
3395  newSymOperands.append(symValues.begin(), symValues.end());
3396 
3397  // For expressions we need to shift to avoid overlap.
3398  for (AffineExpr expr : producerMap.getResults()) {
3399  newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3400  .shiftSymbols(numProducerSyms, numUsedSyms));
3401  }
3402 
3403  numUsedDims += numProducerDims;
3404  numUsedSyms += numProducerSyms;
3405  }
3406 
3407  auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
3408  rewriter.getContext());
3409  auto newOperands =
3410  llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3411  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3412 
3413  return success();
3414  }
3415 };
3416 
3417 /// Canonicalize the result expression order of an affine map and return success
3418 /// if the order changed.
3419 ///
3420 /// The function flattens the map's affine expressions to coefficient arrays and
3421 /// sorts them in lexicographic order. A coefficient array contains a multiplier
3422 /// for every dimension/symbol and a constant term. The canonicalization fails
3423 /// if a result expression is not pure or if the flattening requires local
3424 /// variables that, unlike dimensions and symbols, have no global order.
3425 static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
3426  SmallVector<SmallVector<int64_t>> flattenedExprs;
3427  for (const AffineExpr &resultExpr : map.getResults()) {
3428  // Fail if the expression is not pure.
3429  if (!resultExpr.isPureAffine())
3430  return failure();
3431 
3432  SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3433  auto flattenResult = flattener.walkPostOrder(resultExpr);
3434  if (failed(flattenResult))
3435  return failure();
3436 
3437  // Fail if the flattened expression has local variables.
3438  if (flattener.operandExprStack.back().size() !=
3439  map.getNumDims() + map.getNumSymbols() + 1)
3440  return failure();
3441 
3442  flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3443  flattener.operandExprStack.back().end());
3444  }
3445 
3446  // Fail if sorting is not necessary.
3447  if (llvm::is_sorted(flattenedExprs))
3448  return failure();
3449 
3450  // Reorder the result expressions according to their flattened form.
3451  SmallVector<unsigned> resultPermutation =
3452  llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults()));
3453  llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3454  return flattenedExprs[lhs] < flattenedExprs[rhs];
3455  });
3456  SmallVector<AffineExpr> newExprs;
3457  for (unsigned idx : resultPermutation)
3458  newExprs.push_back(map.getResult(idx));
3459 
3460  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs,
3461  map.getContext());
3462  return success();
3463 }
3464 
3465 /// Canonicalize the affine map result expression order of an affine min/max
3466 /// operation.
3467 ///
3468 /// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3469 /// expressions and replaces the operation if the order changed.
3470 ///
3471 /// For example, the following operation:
3472 ///
3473 /// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3474 ///
3475 /// Turns into:
3476 ///
3477 /// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3478 template <typename T>
3481 
3482  LogicalResult matchAndRewrite(T affineOp,
3483  PatternRewriter &rewriter) const override {
3484  AffineMap map = affineOp.getAffineMap();
3485  if (failed(canonicalizeMapExprAndTermOrder(map)))
3486  return failure();
3487  rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3488  return success();
3489  }
3490 };
3491 
3492 template <typename T>
3495 
3496  LogicalResult matchAndRewrite(T affineOp,
3497  PatternRewriter &rewriter) const override {
3498  if (affineOp.getMap().getNumResults() != 1)
3499  return failure();
3500  rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3501  affineOp.getOperands());
3502  return success();
3503  }
3504 };
3505 
3506 //===----------------------------------------------------------------------===//
3507 // AffineMinOp
3508 //===----------------------------------------------------------------------===//
3509 //
3510 // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3511 //
3512 
3513 OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3514  return foldMinMaxOp(*this, adaptor.getOperands());
3515 }
3516 
3517 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3518  MLIRContext *context) {
3521  MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3523  context);
3524 }
3525 
3526 LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); }
3527 
3528 ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3529  return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3530 }
3531 
3532 void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3533 
3534 //===----------------------------------------------------------------------===//
3535 // AffineMaxOp
3536 //===----------------------------------------------------------------------===//
3537 //
3538 // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3539 //
3540 
3541 OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3542  return foldMinMaxOp(*this, adaptor.getOperands());
3543 }
3544 
3545 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3546  MLIRContext *context) {
3549  MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3551  context);
3552 }
3553 
3554 LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); }
3555 
3556 ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3557  return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3558 }
3559 
3560 void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3561 
3562 //===----------------------------------------------------------------------===//
3563 // AffinePrefetchOp
3564 //===----------------------------------------------------------------------===//
3565 
3566 //
3567 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3568 //
3569 ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3570  OperationState &result) {
3571  auto &builder = parser.getBuilder();
3572  auto indexTy = builder.getIndexType();
3573 
3574  MemRefType type;
3575  OpAsmParser::UnresolvedOperand memrefInfo;
3576  IntegerAttr hintInfo;
3577  auto i32Type = parser.getBuilder().getIntegerType(32);
3578  StringRef readOrWrite, cacheType;
3579 
3580  AffineMapAttr mapAttr;
3582  if (parser.parseOperand(memrefInfo) ||
3583  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3584  AffinePrefetchOp::getMapAttrStrName(),
3585  result.attributes) ||
3586  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3587  parser.parseComma() || parser.parseKeyword("locality") ||
3588  parser.parseLess() ||
3589  parser.parseAttribute(hintInfo, i32Type,
3590  AffinePrefetchOp::getLocalityHintAttrStrName(),
3591  result.attributes) ||
3592  parser.parseGreater() || parser.parseComma() ||
3593  parser.parseKeyword(&cacheType) ||
3594  parser.parseOptionalAttrDict(result.attributes) ||
3595  parser.parseColonType(type) ||
3596  parser.resolveOperand(memrefInfo, type, result.operands) ||
3597  parser.resolveOperands(mapOperands, indexTy, result.operands))
3598  return failure();
3599 
3600  if (readOrWrite != "read" && readOrWrite != "write")
3601  return parser.emitError(parser.getNameLoc(),
3602  "rw specifier has to be 'read' or 'write'");
3603  result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3604  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
3605 
3606  if (cacheType != "data" && cacheType != "instr")
3607  return parser.emitError(parser.getNameLoc(),
3608  "cache type has to be 'data' or 'instr'");
3609 
3610  result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3611  parser.getBuilder().getBoolAttr(cacheType == "data"));
3612 
3613  return success();
3614 }
3615 
3617  p << " " << getMemref() << '[';
3618  AffineMapAttr mapAttr =
3619  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3620  if (mapAttr)
3621  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3622  p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3623  << "locality<" << getLocalityHint() << ">, "
3624  << (getIsDataCache() ? "data" : "instr");
3626  (*this)->getAttrs(),
3627  /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3628  getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3629  p << " : " << getMemRefType();
3630 }
3631 
3632 LogicalResult AffinePrefetchOp::verify() {
3633  auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3634  if (mapAttr) {
3635  AffineMap map = mapAttr.getValue();
3636  if (map.getNumResults() != getMemRefType().getRank())
3637  return emitOpError("affine.prefetch affine map num results must equal"
3638  " memref rank");
3639  if (map.getNumInputs() + 1 != getNumOperands())
3640  return emitOpError("too few operands");
3641  } else {
3642  if (getNumOperands() != 1)
3643  return emitOpError("too few operands");
3644  }
3645 
3646  Region *scope = getAffineScope(*this);
3647  for (auto idx : getMapOperands()) {
3648  if (!isValidAffineIndexOperand(idx, scope))
3649  return emitOpError(
3650  "index must be a valid dimension or symbol identifier");
3651  }
3652  return success();
3653 }
3654 
3655 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3656  MLIRContext *context) {
3657  // prefetch(memrefcast) -> prefetch
3658  results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3659 }
3660 
3661 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3662  SmallVectorImpl<OpFoldResult> &results) {
3663  /// prefetch(memrefcast) -> prefetch
3664  return memref::foldMemRefCast(*this);
3665 }
3666 
3667 //===----------------------------------------------------------------------===//
3668 // AffineParallelOp
3669 //===----------------------------------------------------------------------===//
3670 
3671 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3672  TypeRange resultTypes,
3673  ArrayRef<arith::AtomicRMWKind> reductions,
3674  ArrayRef<int64_t> ranges) {
3675  SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
3676  auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3677  return builder.getConstantAffineMap(value);
3678  }));
3679  SmallVector<int64_t> steps(ranges.size(), 1);
3680  build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
3681  /*ubArgs=*/{}, steps);
3682 }
3683 
3684 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3685  TypeRange resultTypes,
3686  ArrayRef<arith::AtomicRMWKind> reductions,
3687  ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
3688  ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
3689  ArrayRef<int64_t> steps) {
3690  assert(llvm::all_of(lbMaps,
3691  [lbMaps](AffineMap m) {
3692  return m.getNumDims() == lbMaps[0].getNumDims() &&
3693  m.getNumSymbols() == lbMaps[0].getNumSymbols();
3694  }) &&
3695  "expected all lower bounds maps to have the same number of dimensions "
3696  "and symbols");
3697  assert(llvm::all_of(ubMaps,
3698  [ubMaps](AffineMap m) {
3699  return m.getNumDims() == ubMaps[0].getNumDims() &&
3700  m.getNumSymbols() == ubMaps[0].getNumSymbols();
3701  }) &&
3702  "expected all upper bounds maps to have the same number of dimensions "
3703  "and symbols");
3704  assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3705  "expected lower bound maps to have as many inputs as lower bound "
3706  "operands");
3707  assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3708  "expected upper bound maps to have as many inputs as upper bound "
3709  "operands");
3710 
3711  OpBuilder::InsertionGuard guard(builder);
3712  result.addTypes(resultTypes);
3713 
3714  // Convert the reductions to integer attributes.
3715  SmallVector<Attribute, 4> reductionAttrs;
3716  for (arith::AtomicRMWKind reduction : reductions)
3717  reductionAttrs.push_back(
3718  builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
3719  result.addAttribute(getReductionsAttrStrName(),
3720  builder.getArrayAttr(reductionAttrs));
3721 
3722  // Concatenates maps defined in the same input space (same dimensions and
3723  // symbols), assumes there is at least one map.
3724  auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
3725  SmallVectorImpl<int32_t> &groups) {
3726  if (maps.empty())
3727  return AffineMap::get(builder.getContext());
3729  groups.reserve(groups.size() + maps.size());
3730  exprs.reserve(maps.size());
3731  for (AffineMap m : maps) {
3732  llvm::append_range(exprs, m.getResults());
3733  groups.push_back(m.getNumResults());
3734  }
3735  return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3736  maps[0].getContext());
3737  };
3738 
3739  // Set up the bounds.
3740  SmallVector<int32_t> lbGroups, ubGroups;
3741  AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3742  AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3743  result.addAttribute(getLowerBoundsMapAttrStrName(),
3744  AffineMapAttr::get(lbMap));
3745  result.addAttribute(getLowerBoundsGroupsAttrStrName(),
3746  builder.getI32TensorAttr(lbGroups));
3747  result.addAttribute(getUpperBoundsMapAttrStrName(),
3748  AffineMapAttr::get(ubMap));
3749  result.addAttribute(getUpperBoundsGroupsAttrStrName(),
3750  builder.getI32TensorAttr(ubGroups));
3751  result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
3752  result.addOperands(lbArgs);
3753  result.addOperands(ubArgs);
3754 
3755  // Create a region and a block for the body.
3756  auto *bodyRegion = result.addRegion();
3757  Block *body = builder.createBlock(bodyRegion);
3758 
3759  // Add all the block arguments.
3760  for (unsigned i = 0, e = steps.size(); i < e; ++i)
3761  body->addArgument(IndexType::get(builder.getContext()), result.location);
3762  if (resultTypes.empty())
3763  ensureTerminator(*bodyRegion, builder, result.location);
3764 }
3765 
3766 SmallVector<Region *> AffineParallelOp::getLoopRegions() {
3767  return {&getRegion()};
3768 }
3769 
3770 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
3771 
3772 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3773  return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3774 }
3775 
3776 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3777  return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3778 }
3779 
3780 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
3781  auto values = getLowerBoundsGroups().getValues<int32_t>();
3782  unsigned start = 0;
3783  for (unsigned i = 0; i < pos; ++i)
3784  start += values[i];
3785  return getLowerBoundsMap().getSliceMap(start, values[pos]);
3786 }
3787 
3788 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
3789  auto values = getUpperBoundsGroups().getValues<int32_t>();
3790  unsigned start = 0;
3791  for (unsigned i = 0; i < pos; ++i)
3792  start += values[i];
3793  return getUpperBoundsMap().getSliceMap(start, values[pos]);
3794 }
3795 
3796 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
3797  return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3798 }
3799 
3800 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
3801  return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3802 }
3803 
3804 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3805  if (hasMinMaxBounds())
3806  return std::nullopt;
3807 
3808  // Try to convert all the ranges to constant expressions.
3810  AffineValueMap rangesValueMap;
3811  AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3812  &rangesValueMap);
3813  out.reserve(rangesValueMap.getNumResults());
3814  for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
3815  auto expr = rangesValueMap.getResult(i);
3816  auto cst = dyn_cast<AffineConstantExpr>(expr);
3817  if (!cst)
3818  return std::nullopt;
3819  out.push_back(cst.getValue());
3820  }
3821  return out;
3822 }
3823 
3824 Block *AffineParallelOp::getBody() { return &getRegion().front(); }
3825 
3826 OpBuilder AffineParallelOp::getBodyBuilder() {
3827  return OpBuilder(getBody(), std::prev(getBody()->end()));
3828 }
3829 
3830 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
3831  assert(lbOperands.size() == map.getNumInputs() &&
3832  "operands to map must match number of inputs");
3833 
3834  auto ubOperands = getUpperBoundsOperands();
3835 
3836  SmallVector<Value, 4> newOperands(lbOperands);
3837  newOperands.append(ubOperands.begin(), ubOperands.end());
3838  (*this)->setOperands(newOperands);
3839 
3840  setLowerBoundsMapAttr(AffineMapAttr::get(map));
3841 }
3842 
3843 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
3844  assert(ubOperands.size() == map.getNumInputs() &&
3845  "operands to map must match number of inputs");
3846 
3847  SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
3848  newOperands.append(ubOperands.begin(), ubOperands.end());
3849  (*this)->setOperands(newOperands);
3850 
3851  setUpperBoundsMapAttr(AffineMapAttr::get(map));
3852 }
3853 
3854 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
3855  setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3856 }
3857 
3858 // check whether resultType match op or not in affine.parallel
3859 static bool isResultTypeMatchAtomicRMWKind(Type resultType,
3860  arith::AtomicRMWKind op) {
3861  switch (op) {
3862  case arith::AtomicRMWKind::addf:
3863  return isa<FloatType>(resultType);
3864  case arith::AtomicRMWKind::addi:
3865  return isa<IntegerType>(resultType);
3866  case arith::AtomicRMWKind::assign:
3867  return true;
3868  case arith::AtomicRMWKind::mulf:
3869  return isa<FloatType>(resultType);
3870  case arith::AtomicRMWKind::muli:
3871  return isa<IntegerType>(resultType);
3872  case arith::AtomicRMWKind::maximumf:
3873  return isa<FloatType>(resultType);
3874  case arith::AtomicRMWKind::minimumf:
3875  return isa<FloatType>(resultType);
3876  case arith::AtomicRMWKind::maxs: {
3877  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3878  return intType && intType.isSigned();
3879  }
3880  case arith::AtomicRMWKind::mins: {
3881  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3882  return intType && intType.isSigned();
3883  }
3884  case arith::AtomicRMWKind::maxu: {
3885  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3886  return intType && intType.isUnsigned();
3887  }
3888  case arith::AtomicRMWKind::minu: {
3889  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3890  return intType && intType.isUnsigned();
3891  }
3892  case arith::AtomicRMWKind::ori:
3893  return isa<IntegerType>(resultType);
3894  case arith::AtomicRMWKind::andi:
3895  return isa<IntegerType>(resultType);
3896  default:
3897  return false;
3898  }
3899 }
3900 
3901 LogicalResult AffineParallelOp::verify() {
3902  auto numDims = getNumDims();
3903  if (getLowerBoundsGroups().getNumElements() != numDims ||
3904  getUpperBoundsGroups().getNumElements() != numDims ||
3905  getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3906  return emitOpError() << "the number of region arguments ("
3907  << getBody()->getNumArguments()
3908  << ") and the number of map groups for lower ("
3909  << getLowerBoundsGroups().getNumElements()
3910  << ") and upper bound ("
3911  << getUpperBoundsGroups().getNumElements()
3912  << "), and the number of steps (" << getSteps().size()
3913  << ") must all match";
3914  }
3915 
3916  unsigned expectedNumLBResults = 0;
3917  for (APInt v : getLowerBoundsGroups())
3918  expectedNumLBResults += v.getZExtValue();
3919  if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
3920  return emitOpError() << "expected lower bounds map to have "
3921  << expectedNumLBResults << " results";
3922  unsigned expectedNumUBResults = 0;
3923  for (APInt v : getUpperBoundsGroups())
3924  expectedNumUBResults += v.getZExtValue();
3925  if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
3926  return emitOpError() << "expected upper bounds map to have "
3927  << expectedNumUBResults << " results";
3928 
3929  if (getReductions().size() != getNumResults())
3930  return emitOpError("a reduction must be specified for each output");
3931 
3932  // Verify reduction ops are all valid and each result type matches reduction
3933  // ops
3934  for (auto it : llvm::enumerate((getReductions()))) {
3935  Attribute attr = it.value();
3936  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
3937  if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
3938  return emitOpError("invalid reduction attribute");
3939  auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
3940  if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
3941  return emitOpError("result type cannot match reduction attribute");
3942  }
3943 
3944  // Verify that the bound operands are valid dimension/symbols.
3945  /// Lower bounds.
3946  if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
3947  getLowerBoundsMap().getNumDims())))
3948  return failure();
3949  /// Upper bounds.
3950  if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
3951  getUpperBoundsMap().getNumDims())))
3952  return failure();
3953  return success();
3954 }
3955 
3956 LogicalResult AffineValueMap::canonicalize() {
3957  SmallVector<Value, 4> newOperands{operands};
3958  auto newMap = getAffineMap();
3959  composeAffineMapAndOperands(&newMap, &newOperands);
3960  if (newMap == getAffineMap() && newOperands == operands)
3961  return failure();
3962  reset(newMap, newOperands);
3963  return success();
3964 }
3965 
3966 /// Canonicalize the bounds of the given loop.
3967 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
3968  AffineValueMap lb = op.getLowerBoundsValueMap();
3969  bool lbCanonicalized = succeeded(lb.canonicalize());
3970 
3971  AffineValueMap ub = op.getUpperBoundsValueMap();
3972  bool ubCanonicalized = succeeded(ub.canonicalize());
3973 
3974  // Any canonicalization change always leads to updated map(s).
3975  if (!lbCanonicalized && !ubCanonicalized)
3976  return failure();
3977 
3978  if (lbCanonicalized)
3979  op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
3980  if (ubCanonicalized)
3981  op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
3982 
3983  return success();
3984 }
3985 
3986 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
3987  SmallVectorImpl<OpFoldResult> &results) {
3988  return canonicalizeLoopBounds(*this);
3989 }
3990 
3991 /// Prints a lower(upper) bound of an affine parallel loop with max(min)
3992 /// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
3993 /// identifies which of the those expressions form max/min groups. `operands`
3994 /// are the SSA values of dimensions and symbols and `keyword` is either "min"
3995 /// or "max".
3996 static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
3997  DenseIntElementsAttr group, ValueRange operands,
3998  StringRef keyword) {
3999  AffineMap map = mapAttr.getValue();
4000  unsigned numDims = map.getNumDims();
4001  ValueRange dimOperands = operands.take_front(numDims);
4002  ValueRange symOperands = operands.drop_front(numDims);
4003  unsigned start = 0;
4004  for (llvm::APInt groupSize : group) {
4005  if (start != 0)
4006  p << ", ";
4007 
4008  unsigned size = groupSize.getZExtValue();
4009  if (size == 1) {
4010  p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
4011  ++start;
4012  } else {
4013  p << keyword << '(';
4014  AffineMap submap = map.getSliceMap(start, size);
4015  p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
4016  p << ')';
4017  start += size;
4018  }
4019  }
4020 }
4021 
4023  p << " (" << getBody()->getArguments() << ") = (";
4024  printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
4025  getLowerBoundsOperands(), "max");
4026  p << ") to (";
4027  printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
4028  getUpperBoundsOperands(), "min");
4029  p << ')';
4030  SmallVector<int64_t, 8> steps = getSteps();
4031  bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4032  if (!elideSteps) {
4033  p << " step (";
4034  llvm::interleaveComma(steps, p);
4035  p << ')';
4036  }
4037  if (getNumResults()) {
4038  p << " reduce (";
4039  llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4040  arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4041  llvm::cast<IntegerAttr>(attr).getInt());
4042  p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4043  });
4044  p << ") -> (" << getResultTypes() << ")";
4045  }
4046 
4047  p << ' ';
4048  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4049  /*printBlockTerminators=*/getNumResults());
4051  (*this)->getAttrs(),
4052  /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4053  AffineParallelOp::getLowerBoundsMapAttrStrName(),
4054  AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4055  AffineParallelOp::getUpperBoundsMapAttrStrName(),
4056  AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4057  AffineParallelOp::getStepsAttrStrName()});
4058 }
4059 
4060 /// Given a list of lists of parsed operands, populates `uniqueOperands` with
4061 /// unique operands. Also populates `replacements with affine expressions of
4062 /// `kind` that can be used to update affine maps previously accepting a
4063 /// `operands` to accept `uniqueOperands` instead.
4065  OpAsmParser &parser,
4067  SmallVectorImpl<Value> &uniqueOperands,
4068  SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
4069  assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
4070  "expected operands to be dim or symbol expression");
4071 
4072  Type indexType = parser.getBuilder().getIndexType();
4073  for (const auto &list : operands) {
4074  SmallVector<Value> valueOperands;
4075  if (parser.resolveOperands(list, indexType, valueOperands))
4076  return failure();
4077  for (Value operand : valueOperands) {
4078  unsigned pos = std::distance(uniqueOperands.begin(),
4079  llvm::find(uniqueOperands, operand));
4080  if (pos == uniqueOperands.size())
4081  uniqueOperands.push_back(operand);
4082  replacements.push_back(
4083  kind == AffineExprKind::DimId
4084  ? getAffineDimExpr(pos, parser.getContext())
4085  : getAffineSymbolExpr(pos, parser.getContext()));
4086  }
4087  }
4088  return success();
4089 }
4090 
4091 namespace {
4092 enum class MinMaxKind { Min, Max };
4093 } // namespace
4094 
4095 /// Parses an affine map that can contain a min/max for groups of its results,
4096 /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4097 /// `result` attributes with the map (flat list of expressions) and the grouping
4098 /// (list of integers that specify how many expressions to put into each
4099 /// min/max) attributes. Deduplicates repeated operands.
4100 ///
4101 /// parallel-bound ::= `(` parallel-group-list `)`
4102 /// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
4103 /// parallel-group ::= simple-group | min-max-group
4104 /// simple-group ::= expr-of-ssa-ids
4105 /// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4106 /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4107 ///
4108 /// Examples:
4109 /// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4110 /// (%0, max(%1 - 2 * %2))
4111 static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
4112  OperationState &result,
4113  MinMaxKind kind) {
4114  // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4115  // see: https://reviews.llvm.org/D134227#3821753
4116  const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4117 
4118  StringRef mapName = kind == MinMaxKind::Min
4119  ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4120  : AffineParallelOp::getLowerBoundsMapAttrStrName();
4121  StringRef groupsName =
4122  kind == MinMaxKind::Min
4123  ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4124  : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4125 
4126  if (failed(parser.parseLParen()))
4127  return failure();
4128 
4129  if (succeeded(parser.parseOptionalRParen())) {
4130  result.addAttribute(
4131  mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4132  result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
4133  return success();
4134  }
4135 
4136  SmallVector<AffineExpr> flatExprs;
4139  SmallVector<int32_t> numMapsPerGroup;
4141  auto parseOperands = [&]() {
4142  if (succeeded(parser.parseOptionalKeyword(
4143  kind == MinMaxKind::Min ? "min" : "max"))) {
4144  mapOperands.clear();
4145  AffineMapAttr map;
4146  if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName,
4147  result.attributes,
4149  return failure();
4150  result.attributes.erase(tmpAttrStrName);
4151  llvm::append_range(flatExprs, map.getValue().getResults());
4152  auto operandsRef = llvm::ArrayRef(mapOperands);
4153  auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4155  auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4157  flatDimOperands.append(map.getValue().getNumResults(), dims);
4158  flatSymOperands.append(map.getValue().getNumResults(), syms);
4159  numMapsPerGroup.push_back(map.getValue().getNumResults());
4160  } else {
4161  if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
4162  flatSymOperands.emplace_back(),
4163  flatExprs.emplace_back())))
4164  return failure();
4165  numMapsPerGroup.push_back(1);
4166  }
4167  return success();
4168  };
4169  if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
4170  return failure();
4171 
4172  unsigned totalNumDims = 0;
4173  unsigned totalNumSyms = 0;
4174  for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4175  unsigned numDims = flatDimOperands[i].size();
4176  unsigned numSyms = flatSymOperands[i].size();
4177  flatExprs[i] = flatExprs[i]
4178  .shiftDims(numDims, totalNumDims)
4179  .shiftSymbols(numSyms, totalNumSyms);
4180  totalNumDims += numDims;
4181  totalNumSyms += numSyms;
4182  }
4183 
4184  // Deduplicate map operands.
4185  SmallVector<Value> dimOperands, symOperands;
4186  SmallVector<AffineExpr> dimRplacements, symRepacements;
4187  if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
4188  dimRplacements, AffineExprKind::DimId) ||
4189  deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
4190  symRepacements, AffineExprKind::SymbolId))
4191  return failure();
4192 
4193  result.operands.append(dimOperands.begin(), dimOperands.end());
4194  result.operands.append(symOperands.begin(), symOperands.end());
4195 
4196  Builder &builder = parser.getBuilder();
4197  auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4198  parser.getContext());
4199  flatMap = flatMap.replaceDimsAndSymbols(
4200  dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4201 
4202  result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4203  result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
4204  return success();
4205 }
4206 
4207 //
4208 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4209 // `to` parallel-bound steps? region attr-dict?
4210 // steps ::= `steps` `(` integer-literals `)`
4211 //
4212 ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4213  OperationState &result) {
4214  auto &builder = parser.getBuilder();
4215  auto indexType = builder.getIndexType();
4218  parser.parseEqual() ||
4219  parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
4220  parser.parseKeyword("to") ||
4221  parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
4222  return failure();
4223 
4224  AffineMapAttr stepsMapAttr;
4225  NamedAttrList stepsAttrs;
4227  if (failed(parser.parseOptionalKeyword("step"))) {
4228  SmallVector<int64_t, 4> steps(ivs.size(), 1);
4229  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4230  builder.getI64ArrayAttr(steps));
4231  } else {
4232  if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
4233  AffineParallelOp::getStepsAttrStrName(),
4234  stepsAttrs,
4236  return failure();
4237 
4238  // Convert steps from an AffineMap into an I64ArrayAttr.
4240  auto stepsMap = stepsMapAttr.getValue();
4241  for (const auto &result : stepsMap.getResults()) {
4242  auto constExpr = dyn_cast<AffineConstantExpr>(result);
4243  if (!constExpr)
4244  return parser.emitError(parser.getNameLoc(),
4245  "steps must be constant integers");
4246  steps.push_back(constExpr.getValue());
4247  }
4248  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4249  builder.getI64ArrayAttr(steps));
4250  }
4251 
4252  // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4253  // quoted strings are a member of the enum AtomicRMWKind.
4254  SmallVector<Attribute, 4> reductions;
4255  if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4256  if (parser.parseLParen())
4257  return failure();
4258  auto parseAttributes = [&]() -> ParseResult {
4259  // Parse a single quoted string via the attribute parsing, and then
4260  // verify it is a member of the enum and convert to it's integer
4261  // representation.
4262  StringAttr attrVal;
4263  NamedAttrList attrStorage;
4264  auto loc = parser.getCurrentLocation();
4265  if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
4266  attrStorage))
4267  return failure();
4268  std::optional<arith::AtomicRMWKind> reduction =
4269  arith::symbolizeAtomicRMWKind(attrVal.getValue());
4270  if (!reduction)
4271  return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4272  reductions.push_back(
4273  builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4274  // While we keep getting commas, keep parsing.
4275  return success();
4276  };
4277  if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
4278  return failure();
4279  }
4280  result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4281  builder.getArrayAttr(reductions));
4282 
4283  // Parse return types of reductions (if any)
4284  if (parser.parseOptionalArrowTypeList(result.types))
4285  return failure();
4286 
4287  // Now parse the body.
4288  Region *body = result.addRegion();
4289  for (auto &iv : ivs)
4290  iv.type = indexType;
4291  if (parser.parseRegion(*body, ivs) ||
4292  parser.parseOptionalAttrDict(result.attributes))
4293  return failure();
4294 
4295  // Add a terminator if none was parsed.
4296  AffineParallelOp::ensureTerminator(*body, builder, result.location);
4297  return success();
4298 }
4299 
4300 //===----------------------------------------------------------------------===//
4301 // AffineYieldOp
4302 //===----------------------------------------------------------------------===//
4303 
4304 LogicalResult AffineYieldOp::verify() {
4305  auto *parentOp = (*this)->getParentOp();
4306  auto results = parentOp->getResults();
4307  auto operands = getOperands();
4308 
4309  if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4310  return emitOpError() << "only terminates affine.if/for/parallel regions";
4311  if (parentOp->getNumResults() != getNumOperands())
4312  return emitOpError() << "parent of yield must have same number of "
4313  "results as the yield operands";
4314  for (auto it : llvm::zip(results, operands)) {
4315  if (std::get<0>(it).getType() != std::get<1>(it).getType())
4316  return emitOpError() << "types mismatch between yield op and its parent";
4317  }
4318 
4319  return success();
4320 }
4321 
4322 //===----------------------------------------------------------------------===//
4323 // AffineVectorLoadOp
4324 //===----------------------------------------------------------------------===//
4325 
4326 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4327  VectorType resultType, AffineMap map,
4328  ValueRange operands) {
4329  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4330  result.addOperands(operands);
4331  if (map)
4332  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4333  result.types.push_back(resultType);
4334 }
4335 
4336 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4337  VectorType resultType, Value memref,
4338  AffineMap map, ValueRange mapOperands) {
4339  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4340  result.addOperands(memref);
4341  result.addOperands(mapOperands);
4342  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4343  result.types.push_back(resultType);
4344 }
4345 
4346 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4347  VectorType resultType, Value memref,
4348  ValueRange indices) {
4349  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4350  int64_t rank = memrefType.getRank();
4351  // Create identity map for memrefs with at least one dimension or () -> ()
4352  // for zero-dimensional memrefs.
4353  auto map =
4354  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4355  build(builder, result, resultType, memref, map, indices);
4356 }
4357 
4358 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4359  MLIRContext *context) {
4360  results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4361 }
4362 
4363 ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4364  OperationState &result) {
4365  auto &builder = parser.getBuilder();
4366  auto indexTy = builder.getIndexType();
4367 
4368  MemRefType memrefType;
4369  VectorType resultType;
4370  OpAsmParser::UnresolvedOperand memrefInfo;
4371  AffineMapAttr mapAttr;
4373  return failure(
4374  parser.parseOperand(memrefInfo) ||
4375  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4376  AffineVectorLoadOp::getMapAttrStrName(),
4377  result.attributes) ||
4378  parser.parseOptionalAttrDict(result.attributes) ||
4379  parser.parseColonType(memrefType) || parser.parseComma() ||
4380  parser.parseType(resultType) ||
4381  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4382  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
4383  parser.addTypeToList(resultType, result.types));
4384 }
4385 
4387  p << " " << getMemRef() << '[';
4388  if (AffineMapAttr mapAttr =
4389  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4390  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4391  p << ']';
4392  p.printOptionalAttrDict((*this)->getAttrs(),
4393  /*elidedAttrs=*/{getMapAttrStrName()});
4394  p << " : " << getMemRefType() << ", " << getType();
4395 }
4396 
4397 /// Verify common invariants of affine.vector_load and affine.vector_store.
4398 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4399  VectorType vectorType) {
4400  // Check that memref and vector element types match.
4401  if (memrefType.getElementType() != vectorType.getElementType())
4402  return op->emitOpError(
4403  "requires memref and vector types of the same elemental type");
4404  return success();
4405 }
4406 
4407 LogicalResult AffineVectorLoadOp::verify() {
4408  MemRefType memrefType = getMemRefType();
4409  if (failed(verifyMemoryOpIndexing(
4410  getOperation(),
4411  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4412  getMapOperands(), memrefType,
4413  /*numIndexOperands=*/getNumOperands() - 1)))
4414  return failure();
4415 
4416  if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
4417  return failure();
4418 
4419  return success();
4420 }
4421 
4422 //===----------------------------------------------------------------------===//
4423 // AffineVectorStoreOp
4424 //===----------------------------------------------------------------------===//
4425 
4426 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4427  Value valueToStore, Value memref, AffineMap map,
4428  ValueRange mapOperands) {
4429  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4430  result.addOperands(valueToStore);
4431  result.addOperands(memref);
4432  result.addOperands(mapOperands);
4433  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4434 }
4435 
4436 // Use identity map.
4437 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4438  Value valueToStore, Value memref,
4439  ValueRange indices) {
4440  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4441  int64_t rank = memrefType.getRank();
4442  // Create identity map for memrefs with at least one dimension or () -> ()
4443  // for zero-dimensional memrefs.
4444  auto map =
4445  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4446  build(builder, result, valueToStore, memref, map, indices);
4447 }
4448 void AffineVectorStoreOp::getCanonicalizationPatterns(
4449  RewritePatternSet &results, MLIRContext *context) {
4450  results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4451 }
4452 
4453 ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4454  OperationState &result) {
4455  auto indexTy = parser.getBuilder().getIndexType();
4456 
4457  MemRefType memrefType;
4458  VectorType resultType;
4459  OpAsmParser::UnresolvedOperand storeValueInfo;
4460  OpAsmParser::UnresolvedOperand memrefInfo;
4461  AffineMapAttr mapAttr;
4463  return failure(
4464  parser.parseOperand(storeValueInfo) || parser.parseComma() ||
4465  parser.parseOperand(memrefInfo) ||
4466  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4467  AffineVectorStoreOp::getMapAttrStrName(),
4468  result.attributes) ||
4469  parser.parseOptionalAttrDict(result.attributes) ||
4470  parser.parseColonType(memrefType) || parser.parseComma() ||
4471  parser.parseType(resultType) ||
4472  parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
4473  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4474  parser.resolveOperands(mapOperands, indexTy, result.operands));
4475 }
4476 
4478  p << " " << getValueToStore();
4479  p << ", " << getMemRef() << '[';
4480  if (AffineMapAttr mapAttr =
4481  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4482  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4483  p << ']';
4484  p.printOptionalAttrDict((*this)->getAttrs(),
4485  /*elidedAttrs=*/{getMapAttrStrName()});
4486  p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4487 }
4488 
4489 LogicalResult AffineVectorStoreOp::verify() {
4490  MemRefType memrefType = getMemRefType();
4491  if (failed(verifyMemoryOpIndexing(
4492  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4493  getMapOperands(), memrefType,
4494  /*numIndexOperands=*/getNumOperands() - 2)))
4495  return failure();
4496 
4497  if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4498  return failure();
4499 
4500  return success();
4501 }
4502 
4503 //===----------------------------------------------------------------------===//
4504 // DelinearizeIndexOp
4505 //===----------------------------------------------------------------------===//
4506 
4507 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4508  OperationState &odsState,
4509  Value linearIndex, ValueRange dynamicBasis,
4510  ArrayRef<int64_t> staticBasis,
4511  bool hasOuterBound) {
4512  SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
4513  : staticBasis.size() + 1,
4514  linearIndex.getType());
4515  build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4516  staticBasis);
4517 }
4518 
4519 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4520  OperationState &odsState,
4521  Value linearIndex, ValueRange basis,
4522  bool hasOuterBound) {
4523  if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
4524  hasOuterBound = false;
4525  basis = basis.drop_front();
4526  }
4527  SmallVector<Value> dynamicBasis;
4528  SmallVector<int64_t> staticBasis;
4529  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4530  staticBasis);
4531  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4532  hasOuterBound);
4533 }
4534 
4535 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4536  OperationState &odsState,
4537  Value linearIndex,
4538  ArrayRef<OpFoldResult> basis,
4539  bool hasOuterBound) {
4540  if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
4541  hasOuterBound = false;
4542  basis = basis.drop_front();
4543  }
4544  SmallVector<Value> dynamicBasis;
4545  SmallVector<int64_t> staticBasis;
4546  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
4547  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4548  hasOuterBound);
4549 }
4550 
4551 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4552  OperationState &odsState,
4553  Value linearIndex, ArrayRef<int64_t> basis,
4554  bool hasOuterBound) {
4555  build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound);
4556 }
4557 
4558 LogicalResult AffineDelinearizeIndexOp::verify() {
4559  ArrayRef<int64_t> staticBasis = getStaticBasis();
4560  if (getNumResults() != staticBasis.size() &&
4561  getNumResults() != staticBasis.size() + 1)
4562  return emitOpError("should return an index for each basis element and up "
4563  "to one extra index");
4564 
4565  auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4566  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4567  return emitOpError(
4568  "mismatch between dynamic and static basis (kDynamic marker but no "
4569  "corresponding dynamic basis entry) -- this can only happen due to an "
4570  "incorrect fold/rewrite");
4571 
4572  if (!llvm::all_of(staticBasis, [](int64_t v) {
4573  return v > 0 || ShapedType::isDynamic(v);
4574  }))
4575  return emitOpError("no basis element may be statically non-positive");
4576 
4577  return success();
4578 }
4579 
4580 /// Given mixed basis of affine.delinearize_index/linearize_index replace
4581 /// constant SSA values with the constant integer value and return the new
4582 /// static basis. In case no such candidate for replacement exists, this utility
4583 /// returns std::nullopt.
4584 static std::optional<SmallVector<int64_t>>
4586  MutableOperandRange mutableDynamicBasis,
4587  ArrayRef<Attribute> dynamicBasis) {
4588  uint64_t dynamicBasisIndex = 0;
4589  for (OpFoldResult basis : dynamicBasis) {
4590  if (basis) {
4591  mutableDynamicBasis.erase(dynamicBasisIndex);
4592  } else {
4593  ++dynamicBasisIndex;
4594  }
4595  }
4596 
4597  // No constant SSA value exists.
4598  if (dynamicBasisIndex == dynamicBasis.size())
4599  return std::nullopt;
4600 
4601  SmallVector<int64_t> staticBasis;
4602  for (OpFoldResult basis : mixedBasis) {
4603  std::optional<int64_t> basisVal = getConstantIntValue(basis);
4604  if (!basisVal)
4605  staticBasis.push_back(ShapedType::kDynamic);
4606  else
4607  staticBasis.push_back(*basisVal);
4608  }
4609 
4610  return staticBasis;
4611 }
4612 
4613 LogicalResult
4614 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4616  std::optional<SmallVector<int64_t>> maybeStaticBasis =
4617  foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
4618  adaptor.getDynamicBasis());
4619  if (maybeStaticBasis) {
4620  setStaticBasis(*maybeStaticBasis);
4621  return success();
4622  }
4623  // If we won't be doing any division or modulo (no basis or the one basis
4624  // element is purely advisory), simply return the input value.
4625  if (getNumResults() == 1) {
4626  result.push_back(getLinearIndex());
4627  return success();
4628  }
4629 
4630  if (adaptor.getLinearIndex() == nullptr)
4631  return failure();
4632 
4633  if (!adaptor.getDynamicBasis().empty())
4634  return failure();
4635 
4636  int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
4637  Type attrType = getLinearIndex().getType();
4638 
4639  ArrayRef<int64_t> staticBasis = getStaticBasis();
4640  if (hasOuterBound())
4641  staticBasis = staticBasis.drop_front();
4642  for (int64_t modulus : llvm::reverse(staticBasis)) {
4643  result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
4644  highPart = llvm::divideFloorSigned(highPart, modulus);
4645  }
4646  result.push_back(IntegerAttr::get(attrType, highPart));
4647  std::reverse(result.begin(), result.end());
4648  return success();
4649 }
4650 
4651 SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
4652  OpBuilder builder(getContext());
4653  if (hasOuterBound()) {
4654  if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4655  return getMixedValues(getStaticBasis().drop_front(),
4656  getDynamicBasis().drop_front(), builder);
4657 
4658  return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4659  builder);
4660  }
4661 
4662  return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4663 }
4664 
4665 SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
4666  SmallVector<OpFoldResult> ret = getMixedBasis();
4667  if (!hasOuterBound())
4668  ret.insert(ret.begin(), OpFoldResult());
4669  return ret;
4670 }
4671 
4672 namespace {
4673 
4674 // Drops delinearization indices that correspond to unit-extent basis
4675 struct DropUnitExtentBasis
4676  : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4678 
4679  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4680  PatternRewriter &rewriter) const override {
4681  SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
4682  std::optional<Value> zero = std::nullopt;
4683  Location loc = delinearizeOp->getLoc();
4684  auto getZero = [&]() -> Value {
4685  if (!zero)
4686  zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
4687  return zero.value();
4688  };
4689 
4690  // Replace all indices corresponding to unit-extent basis with 0.
4691  // Remaining basis can be used to get a new `affine.delinearize_index` op.
4692  SmallVector<OpFoldResult> newBasis;
4693  for (auto [index, basis] :
4694  llvm::enumerate(delinearizeOp.getPaddedBasis())) {
4695  std::optional<int64_t> basisVal =
4696  basis ? getConstantIntValue(basis) : std::nullopt;
4697  if (basisVal && *basisVal == 1)
4698  replacements[index] = getZero();
4699  else
4700  newBasis.push_back(basis);
4701  }
4702 
4703  if (newBasis.size() == delinearizeOp.getNumResults())
4704  return rewriter.notifyMatchFailure(delinearizeOp,
4705  "no unit basis elements");
4706 
4707  if (!newBasis.empty()) {
4708  // Will drop the leading nullptr from `basis` if there was no outer bound.
4709  auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
4710  loc, delinearizeOp.getLinearIndex(), newBasis);
4711  int newIndex = 0;
4712  // Map back the new delinearized indices to the values they replace.
4713  for (auto &replacement : replacements) {
4714  if (replacement)
4715  continue;
4716  replacement = newDelinearizeOp->getResult(newIndex++);
4717  }
4718  }
4719 
4720  rewriter.replaceOp(delinearizeOp, replacements);
4721  return success();
4722  }
4723 };
4724 
4725 /// If a `affine.delinearize_index`'s input is a `affine.linearize_index
4726 /// disjoint` and the two operations end with the same basis elements,
4727 /// cancel those parts of the operations out because they are inverses
4728 /// of each other.
4729 ///
4730 /// If the operations have the same basis, cancel them entirely.
4731 ///
4732 /// The `disjoint` flag is needed on the `affine.linearize_index` because
4733 /// otherwise, there is no guarantee that the inputs to the linearization are
4734 /// in-bounds the way the outputs of the delinearization would be.
4735 struct CancelDelinearizeOfLinearizeDisjointExactTail
4736  : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4738 
4739  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4740  PatternRewriter &rewriter) const override {
4741  auto linearizeOp = delinearizeOp.getLinearIndex()
4742  .getDefiningOp<affine::AffineLinearizeIndexOp>();
4743  if (!linearizeOp)
4744  return rewriter.notifyMatchFailure(delinearizeOp,
4745  "index doesn't come from linearize");
4746 
4747  if (!linearizeOp.getDisjoint())
4748  return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
4749 
4750  ValueRange linearizeIns = linearizeOp.getMultiIndex();
4751  // Note: we use the full basis so we don't lose outer bounds later.
4752  SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
4753  SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
4754  size_t numMatches = 0;
4755  for (auto [linSize, delinSize] : llvm::zip(
4756  llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
4757  if (linSize != delinSize)
4758  break;
4759  ++numMatches;
4760  }
4761 
4762  if (numMatches == 0)
4763  return rewriter.notifyMatchFailure(
4764  delinearizeOp, "final basis element doesn't match linearize");
4765 
4766  // The easy case: everything lines up and the basis match sup completely.
4767  if (numMatches == linearizeBasis.size() &&
4768  numMatches == delinearizeBasis.size() &&
4769  linearizeIns.size() == delinearizeOp.getNumResults()) {
4770  rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
4771  return success();
4772  }
4773 
4774  Value newLinearize = rewriter.create<affine::AffineLinearizeIndexOp>(
4775  linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
4776  ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
4777  linearizeOp.getDisjoint());
4778  auto newDelinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
4779  delinearizeOp.getLoc(), newLinearize,
4780  ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
4781  delinearizeOp.hasOuterBound());
4782  SmallVector<Value> mergedResults(newDelinearize.getResults());
4783  mergedResults.append(linearizeIns.take_back(numMatches).begin(),
4784  linearizeIns.take_back(numMatches).end());
4785  rewriter.replaceOp(delinearizeOp, mergedResults);
4786  return success();
4787  }
4788 };
4789 
4790 /// If the input to a delinearization is a disjoint linearization, and the
4791 /// last k > 1 components of the delinearization basis multiply to the
4792 /// last component of the linearization basis, break the linearization and
4793 /// delinearization into two parts, peeling off the last input to linearization.
4794 ///
4795 /// For example:
4796 /// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
4797 /// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
4798 /// becomes
4799 /// %0 = affine.linearize_index [%z, %y] by (3, 2) : index
4800 /// %1:2 = affine.delinearize_index %0 by (2, 3) : index
4801 /// %2:2 = affine.delinearize_index %x by (8, 4) : index
4802 /// where the original %1:4 is replaced by %1:2 ++ %2:2
4803 struct SplitDelinearizeSpanningLastLinearizeArg final
4804  : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4806 
4807  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4808  PatternRewriter &rewriter) const override {
4809  auto linearizeOp = delinearizeOp.getLinearIndex()
4810  .getDefiningOp<affine::AffineLinearizeIndexOp>();
4811  if (!linearizeOp)
4812  return rewriter.notifyMatchFailure(delinearizeOp,
4813  "index doesn't come from linearize");
4814 
4815  if (!linearizeOp.getDisjoint())
4816  return rewriter.notifyMatchFailure(linearizeOp,
4817  "linearize isn't disjoint");
4818 
4819  int64_t target = linearizeOp.getStaticBasis().back();
4820  if (ShapedType::isDynamic(target))
4821  return rewriter.notifyMatchFailure(
4822  linearizeOp, "linearize ends with dynamic basis value");
4823 
4824  int64_t sizeToSplit = 1;
4825  size_t elemsToSplit = 0;
4826  ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
4827  for (int64_t basisElem : llvm::reverse(basis)) {
4828  if (ShapedType::isDynamic(basisElem))
4829  return rewriter.notifyMatchFailure(
4830  delinearizeOp, "dynamic basis element while scanning for split");
4831  sizeToSplit *= basisElem;
4832  elemsToSplit += 1;
4833 
4834  if (sizeToSplit > target)
4835  return rewriter.notifyMatchFailure(delinearizeOp,
4836  "overshot last argument size");
4837  if (sizeToSplit == target)
4838  break;
4839  }
4840 
4841  if (sizeToSplit < target)
4842  return rewriter.notifyMatchFailure(
4843  delinearizeOp, "product of known basis elements doesn't exceed last "
4844  "linearize argument");
4845 
4846  if (elemsToSplit < 2)
4847  return rewriter.notifyMatchFailure(
4848  delinearizeOp,
4849  "need at least two elements to form the basis product");
4850 
4851  Value linearizeWithoutBack =
4852  rewriter.create<affine::AffineLinearizeIndexOp>(
4853  linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
4854  linearizeOp.getDynamicBasis(),
4855  linearizeOp.getStaticBasis().drop_back(),
4856  linearizeOp.getDisjoint());
4857  auto delinearizeWithoutSplitPart =
4858  rewriter.create<affine::AffineDelinearizeIndexOp>(
4859  delinearizeOp.getLoc(), linearizeWithoutBack,
4860  delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
4861  delinearizeOp.hasOuterBound());
4862  auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
4863  delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
4864  basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
4865  SmallVector<Value> results = llvm::to_vector(
4866  llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
4867  delinearizeBack.getResults()));
4868  rewriter.replaceOp(delinearizeOp, results);
4869 
4870  return success();
4871  }
4872 };
4873 } // namespace
4874 
4875 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
4876  RewritePatternSet &patterns, MLIRContext *context) {
4877  patterns
4878  .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
4879  DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
4880  context);
4881 }
4882 
4883 //===----------------------------------------------------------------------===//
4884 // LinearizeIndexOp
4885 //===----------------------------------------------------------------------===//
4886 
4887 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4888  OperationState &odsState,
4889  ValueRange multiIndex, ValueRange basis,
4890  bool disjoint) {
4891  if (!basis.empty() && basis.front() == Value())
4892  basis = basis.drop_front();
4893  SmallVector<Value> dynamicBasis;
4894  SmallVector<int64_t> staticBasis;
4895  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4896  staticBasis);
4897  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4898 }
4899 
4900 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4901  OperationState &odsState,
4902  ValueRange multiIndex,
4903  ArrayRef<OpFoldResult> basis,
4904  bool disjoint) {
4905  if (!basis.empty() && basis.front() == OpFoldResult())
4906  basis = basis.drop_front();
4907  SmallVector<Value> dynamicBasis;
4908  SmallVector<int64_t> staticBasis;
4909  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
4910  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4911 }
4912 
4913 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4914  OperationState &odsState,
4915  ValueRange multiIndex,
4916  ArrayRef<int64_t> basis, bool disjoint) {
4917  build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
4918 }
4919 
4920 LogicalResult AffineLinearizeIndexOp::verify() {
4921  size_t numIndexes = getMultiIndex().size();
4922  size_t numBasisElems = getStaticBasis().size();
4923  if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
4924  return emitOpError("should be passed a basis element for each index except "
4925  "possibly the first");
4926 
4927  auto dynamicMarkersCount =
4928  llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
4929  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4930  return emitOpError(
4931  "mismatch between dynamic and static basis (kDynamic marker but no "
4932  "corresponding dynamic basis entry) -- this can only happen due to an "
4933  "incorrect fold/rewrite");
4934 
4935  return success();
4936 }
4937 
4938 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
4939  std::optional<SmallVector<int64_t>> maybeStaticBasis =
4940  foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
4941  adaptor.getDynamicBasis());
4942  if (maybeStaticBasis) {
4943  setStaticBasis(*maybeStaticBasis);
4944  return getResult();
4945  }
4946  // No indices linearizes to zero.
4947  if (getMultiIndex().empty())
4948  return IntegerAttr::get(getResult().getType(), 0);
4949 
4950  // One single index linearizes to itself.
4951  if (getMultiIndex().size() == 1)
4952  return getMultiIndex().front();
4953 
4954  if (llvm::any_of(adaptor.getMultiIndex(),
4955  [](Attribute a) { return a == nullptr; }))
4956  return nullptr;
4957 
4958  if (!adaptor.getDynamicBasis().empty())
4959  return nullptr;
4960 
4961  int64_t result = 0;
4962  int64_t stride = 1;
4963  for (auto [length, indexAttr] :
4964  llvm::zip_first(llvm::reverse(getStaticBasis()),
4965  llvm::reverse(adaptor.getMultiIndex()))) {
4966  result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
4967  stride = stride * length;
4968  }
4969  // Handle the index element with no basis element.
4970  if (!hasOuterBound())
4971  result =
4972  result +
4973  cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
4974 
4975  return IntegerAttr::get(getResult().getType(), result);
4976 }
4977 
4978 SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
4979  OpBuilder builder(getContext());
4980  if (hasOuterBound()) {
4981  if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4982  return getMixedValues(getStaticBasis().drop_front(),
4983  getDynamicBasis().drop_front(), builder);
4984 
4985  return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4986  builder);
4987  }
4988 
4989  return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4990 }
4991 
4992 SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
4993  SmallVector<OpFoldResult> ret = getMixedBasis();
4994  if (!hasOuterBound())
4995  ret.insert(ret.begin(), OpFoldResult());
4996  return ret;
4997 }
4998 
4999 namespace {
5000 /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
5001 /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
5002 /// %...d)`.
5003 
5004 /// Note that `disjoint` is required here, because, without it, we could have
5005 /// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
5006 /// is a valid operation where the `%c64` cannot be trivially dropped.
5007 ///
5008 /// Alternatively, if `%x` in the above is a known constant 0, remove it even if
5009 /// the operation isn't asserted to be `disjoint`.
5010 struct DropLinearizeUnitComponentsIfDisjointOrZero final
5011  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5013 
5014  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5015  PatternRewriter &rewriter) const override {
5016  ValueRange multiIndex = op.getMultiIndex();
5017  size_t numIndices = multiIndex.size();
5018  SmallVector<Value> newIndices;
5019  newIndices.reserve(numIndices);
5020  SmallVector<OpFoldResult> newBasis;
5021  newBasis.reserve(numIndices);
5022 
5023  if (!op.hasOuterBound()) {
5024  newIndices.push_back(multiIndex.front());
5025  multiIndex = multiIndex.drop_front();
5026  }
5027 
5028  SmallVector<OpFoldResult> basis = op.getMixedBasis();
5029  for (auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5030  std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
5031  if (!basisEntry || *basisEntry != 1) {
5032  newIndices.push_back(index);
5033  newBasis.push_back(basisElem);
5034  continue;
5035  }
5036 
5037  std::optional<int64_t> indexValue = getConstantIntValue(index);
5038  if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5039  newIndices.push_back(index);
5040  newBasis.push_back(basisElem);
5041  continue;
5042  }
5043  }
5044  if (newIndices.size() == numIndices)
5045  return rewriter.notifyMatchFailure(op,
5046  "no unit basis entries to replace");
5047 
5048  if (newIndices.size() == 0) {
5049  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
5050  return success();
5051  }
5052  rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5053  op, newIndices, newBasis, op.getDisjoint());
5054  return success();
5055  }
5056 };
5057 
5058 /// Return the product of `terms`, creating an `affine.apply` if any of them are
5059 /// non-constant values. If any of `terms` is `nullptr`, return `nullptr`.
5060 static OpFoldResult computeProduct(Location loc, OpBuilder &builder,
5061  ArrayRef<OpFoldResult> terms) {
5062  int64_t nDynamic = 0;
5063  SmallVector<Value> dynamicPart;
5064  AffineExpr result = builder.getAffineConstantExpr(1);
5065  for (OpFoldResult term : terms) {
5066  if (!term)
5067  return term;
5068  std::optional<int64_t> maybeConst = getConstantIntValue(term);
5069  if (maybeConst) {
5070  result = result * builder.getAffineConstantExpr(*maybeConst);
5071  } else {
5072  dynamicPart.push_back(cast<Value>(term));
5073  result = result * builder.getAffineSymbolExpr(nDynamic++);
5074  }
5075  }
5076  if (auto constant = dyn_cast<AffineConstantExpr>(result))
5077  return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
5078  return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
5079 }
5080 
5081 /// If conseceutive outputs of a delinearize_index are linearized with the same
5082 /// bounds, canonicalize away the redundant arithmetic.
5083 ///
5084 /// That is, if we have
5085 /// ```
5086 /// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
5087 /// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
5088 /// by (...e, B1, B2, ..., BK, ...f)
5089 /// ```
5090 ///
5091 /// We can rewrite this to
5092 /// ```
5093 /// B = B1 * B2 ... BK
5094 /// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
5095 /// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
5096 /// ```
5097 /// where we replace all results of %s unaffected by the change with results
5098 /// from %sMerged.
5099 ///
5100 /// As a special case, if all results of the delinearize are merged in this way
5101 /// we can replace those usages with %x, thus cancelling the delinearization
5102 /// entirely, as in
5103 /// ```
5104 /// %s:3 = affine.delinearize_index %x into (2, 4, 8)
5105 /// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
5106 /// ```
5107 /// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
5108 struct CancelLinearizeOfDelinearizePortion final
5109  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5111 
5112 private:
5113  // Struct representing a case where the cancellation pattern
5114  // applies. A `Match` means that `length` inputs to the linearize operation
5115  // starting at `linStart` can be cancelled with `length` outputs of
5116  // `delinearize`, starting from `delinStart`.
5117  struct Match {
5118  AffineDelinearizeIndexOp delinearize;
5119  unsigned linStart = 0;
5120  unsigned delinStart = 0;
5121  unsigned length = 0;
5122  };
5123 
5124 public:
5125  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5126  PatternRewriter &rewriter) const override {
5127  SmallVector<Match> matches;
5128 
5129  const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
5130  ArrayRef<OpFoldResult> linBasisRef = linBasis;
5131 
5132  ValueRange multiIndex = linearizeOp.getMultiIndex();
5133  unsigned numLinArgs = multiIndex.size();
5134  unsigned linArgIdx = 0;
5135  // We only want to replace one run from the same delinearize op per
5136  // pattern invocation lest we run into invalidation issues.
5137  llvm::SmallPtrSet<Operation *, 2> alreadyMatchedDelinearize;
5138  while (linArgIdx < numLinArgs) {
5139  auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5140  if (!asResult) {
5141  linArgIdx++;
5142  continue;
5143  }
5144 
5145  auto delinearizeOp =
5146  dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5147  if (!delinearizeOp) {
5148  linArgIdx++;
5149  continue;
5150  }
5151 
5152  /// Result 0 of the delinearize and argument 0 of the linearize can
5153  /// leave their maximum value unspecified. However, even if this happens
5154  /// we can still sometimes start the match process. Specifically, if
5155  /// - The argument we're matching is result 0 and argument 0 (so the
5156  /// bounds don't matter). For example,
5157  ///
5158  /// %0:2 = affine.delinearize_index %x into (8) : index, index
5159  /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
5160  /// allows cancellation
5161  /// - The delinearization doesn't specify a bound, but the linearization
5162  /// is `disjoint`, which asserts that the bound on the linearization is
5163  /// correct.
5164  unsigned delinArgIdx = asResult.getResultNumber();
5165  SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
5166  OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5167  OpFoldResult firstLinBound = linBasis[linArgIdx];
5168  bool boundsMatch = firstDelinBound == firstLinBound;
5169  bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5170  bool knownByDisjoint =
5171  linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5172  if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5173  linArgIdx++;
5174  continue;
5175  }
5176 
5177  unsigned j = 1;
5178  unsigned numDelinOuts = delinearizeOp.getNumResults();
5179  for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
5180  ++j) {
5181  if (multiIndex[linArgIdx + j] !=
5182  delinearizeOp.getResult(delinArgIdx + j))
5183  break;
5184  if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
5185  break;
5186  }
5187  // If there're multiple matches against the same delinearize_index,
5188  // only rewrite the first one we find to prevent invalidations. The next
5189  // ones will be taken care of by subsequent pattern invocations.
5190  if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5191  linArgIdx++;
5192  continue;
5193  }
5194  matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx, j});
5195  linArgIdx += j;
5196  }
5197 
5198  if (matches.empty())
5199  return rewriter.notifyMatchFailure(
5200  linearizeOp, "no run of delinearize outputs to deal with");
5201 
5202  // Record all the delinearize replacements so we can do them after creating
5203  // the new linearization operation, since the new operation might use
5204  // outputs of something we're replacing.
5205  SmallVector<SmallVector<Value>> delinearizeReplacements;
5206 
5207  SmallVector<Value> newIndex;
5208  newIndex.reserve(numLinArgs);
5209  SmallVector<OpFoldResult> newBasis;
5210  newBasis.reserve(numLinArgs);
5211  unsigned prevMatchEnd = 0;
5212  for (Match m : matches) {
5213  unsigned gap = m.linStart - prevMatchEnd;
5214  llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5215  llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5216  // Update here so we don't forget this during early continues
5217  prevMatchEnd = m.linStart + m.length;
5218 
5219  PatternRewriter::InsertionGuard g(rewriter);
5220  rewriter.setInsertionPoint(m.delinearize);
5221 
5222  ArrayRef<OpFoldResult> basisToMerge =
5223  linBasisRef.slice(m.linStart, m.length);
5224  // We use the slice from the linearize's basis above because of the
5225  // "bounds inferred from `disjoint`" case above.
5226  OpFoldResult newSize =
5227  computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
5228 
5229  // Trivial case where we can just skip past the delinearize all together
5230  if (m.length == m.delinearize.getNumResults()) {
5231  newIndex.push_back(m.delinearize.getLinearIndex());
5232  newBasis.push_back(newSize);
5233  // Pad out set of replacements so we don't do anything with this one.
5234  delinearizeReplacements.push_back(SmallVector<Value>());
5235  continue;
5236  }
5237 
5238  SmallVector<Value> newDelinResults;
5239  SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
5240  newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5241  newDelinBasis.begin() + m.delinStart + m.length);
5242  newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5243  auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
5244  m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5245  newDelinBasis);
5246 
5247  // Since there may be other uses of the indices we just merged together,
5248  // create a residual affine.delinearize_index that delinearizes the
5249  // merged output into its component parts.
5250  Value combinedElem = newDelinearize.getResult(m.delinStart);
5251  auto residualDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
5252  m.delinearize.getLoc(), combinedElem, basisToMerge);
5253 
5254  // Swap all the uses of the unaffected delinearize outputs to the new
5255  // delinearization so that the old code can be removed if this
5256  // linearize_index is the only user of the merged results.
5257  llvm::append_range(newDelinResults,
5258  newDelinearize.getResults().take_front(m.delinStart));
5259  llvm::append_range(newDelinResults, residualDelinearize.getResults());
5260  llvm::append_range(
5261  newDelinResults,
5262  newDelinearize.getResults().drop_front(m.delinStart + 1));
5263 
5264  delinearizeReplacements.push_back(newDelinResults);
5265  newIndex.push_back(combinedElem);
5266  newBasis.push_back(newSize);
5267  }
5268  llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5269  llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5270  rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
5271  linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5272 
5273  for (auto [m, newResults] :
5274  llvm::zip_equal(matches, delinearizeReplacements)) {
5275  if (newResults.empty())
5276  continue;
5277  rewriter.replaceOp(m.delinearize, newResults);
5278  }
5279 
5280  return success();
5281  }
5282 };
5283 
5284 /// Strip leading zero from affine.linearize_index.
5285 ///
5286 /// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
5287 /// to `affine.linearize_index [...a] by (...b)` in all cases.
5288 struct DropLinearizeLeadingZero final
5289  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5291 
5292  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5293  PatternRewriter &rewriter) const override {
5294  Value leadingIdx = op.getMultiIndex().front();
5295  if (!matchPattern(leadingIdx, m_Zero()))
5296  return failure();
5297 
5298  if (op.getMultiIndex().size() == 1) {
5299  rewriter.replaceOp(op, leadingIdx);
5300  return success();
5301  }
5302 
5303  SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
5304  ArrayRef<OpFoldResult> newMixedBasis = mixedBasis;
5305  if (op.hasOuterBound())
5306  newMixedBasis = newMixedBasis.drop_front();
5307 
5308  rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5309  op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5310  return success();
5311  }
5312 };
5313 } // namespace
5314 
5315 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5316  RewritePatternSet &patterns, MLIRContext *context) {
5317  patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5318  DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5319 }
5320 
5321 //===----------------------------------------------------------------------===//
5322 // TableGen'd op method definitions
5323 //===----------------------------------------------------------------------===//
5324 
5325 #define GET_OP_CLASSES
5326 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
Definition: AffineOps.cpp:2653
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
Definition: AffineOps.cpp:2373
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
Definition: AffineOps.cpp:1165
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
Definition: AffineOps.cpp:3226
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
Definition: AffineOps.cpp:3859
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
Definition: AffineOps.cpp:60
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
Definition: AffineOps.cpp:3996
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
Definition: AffineOps.cpp:1001
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
Definition: AffineOps.cpp:3262
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
Definition: AffineOps.cpp:2231
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
Definition: AffineOps.cpp:787
static bool isValidAffineIndexOperand(Value value, Region *region)
Definition: AffineOps.cpp:466
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1359
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
Definition: AffineOps.cpp:1072
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:722
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
Definition: AffineOps.cpp:1920
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:714
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
Definition: AffineOps.cpp:2938
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
Definition: AffineOps.cpp:1024
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1316
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
Definition: AffineOps.cpp:4398
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
Definition: AffineOps.cpp:890
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
Definition: AffineOps.cpp:3239
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
Definition: AffineOps.cpp:329
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
Definition: AffineOps.cpp:662
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
Definition: AffineOps.cpp:4111
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
Definition: AffineOps.cpp:2662
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
Definition: AffineOps.cpp:471
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
Definition: AffineOps.cpp:624
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1252
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
Definition: AffineOps.cpp:2612
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
Definition: AffineOps.cpp:2185
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
Definition: AffineOps.cpp:503
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1267
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
Definition: AffineOps.cpp:348
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr &quotientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
Definition: AffineOps.cpp:690
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
Definition: AffineOps.cpp:4064
static LogicalResult verifyAffineMinMaxOp(T op)
Definition: AffineOps.cpp:3213
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
Definition: AffineOps.cpp:2095
static LogicalResult verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
Definition: AffineOps.cpp:3044
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
Definition: AffineOps.cpp:4585
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
Definition: AffineOps.cpp:3425
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:76
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:81
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:917
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:35
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
Definition: AffineExpr.cpp:243
MLIRContext * getContext() const
Definition: AffineExpr.cpp:33
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:662
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
Definition: AffineMap.h:221
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:267
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
Definition: AffineMap.h:228
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineMap.cpp:500
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
Definition: AffineMap.h:280
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:515
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
Definition: AffineMap.cpp:128
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:654
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
Definition: AffineMap.cpp:434
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:203
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
AffineMap getDimIdentityMap()
Definition: Builders.cpp:423
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:427
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:408
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:412
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
Definition: Builders.cpp:219
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
NoneType getNoneType()
Definition: Builders.cpp:128
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
Definition: Builders.cpp:416
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:418
MLIRContext * getContext() const
Definition: Builders.h:56
AffineMap getSymbolIdentityMap()
Definition: Builders.cpp:436
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:306
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:321
IndexType getIndexType()
Definition: Builders.cpp:95
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition: IntegerSet.h:44
unsigned getNumDims() const
Definition: IntegerSet.cpp:15
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
Definition: IntegerSet.cpp:57
unsigned getNumInputs() const
Definition: IntegerSet.cpp:17
ArrayRef< AffineExpr > getConstraints() const
Definition: IntegerSet.cpp:41
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
Definition: IntegerSet.cpp:51
unsigned getNumSymbols() const
Definition: IntegerSet.cpp:16
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:454
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:329
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:470
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:451
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
operand_range::iterator operand_iterator
Definition: Operation.h:372
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:620
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:64
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
AffineBound represents a lower or upper bound in the for operation.
Definition: AffineOps.h:511
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
Definition: AffineOps.h:94
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
Definition: AffineOps.h:303
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
Definition: AffineOps.cpp:3956
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
unsigned getNumResults() const
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
Definition: AffineOps.cpp:2675
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
Definition: AffineOps.cpp:1134
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
Definition: AffineOps.cpp:2589
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
Definition: AffineOps.cpp:278
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1241
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
Definition: AffineOps.cpp:2561
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: AffineOps.cpp:2565
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1144
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1435
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1307
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
Definition: AffineOps.cpp:2553
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1300
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Definition: AffineOps.cpp:248
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
Definition: AffineOps.cpp:1440
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
Definition: AffineOps.cpp:2596
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
Definition: AffineOps.cpp:392
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1194
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Definition: AffineOps.cpp:2576
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
Definition: AffineOps.cpp:263
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
Definition: AffineOps.cpp:481
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
Definition: AffineOps.cpp:2557
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
Definition: AffineOps.cpp:1261
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:773
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
Definition: AffineMap.cpp:783
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
AffineExprKind
Definition: AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:70
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:641
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
Definition: AffineMap.cpp:745
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:627
Canonicalize the affine map result expression order of an affine min/max operation.
Definition: AffineOps.cpp:3479
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3482
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3496
Remove duplicated expressions in affine min/max ops.
Definition: AffineOps.cpp:3295
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3298
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
Definition: AffineOps.cpp:3338
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3341
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:294
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.