MLIR  18.0.0git
AffineOps.cpp
Go to the documentation of this file.
1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/IRMapping.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/ScopeExit.h"
23 #include "llvm/ADT/SmallBitVector.h"
24 #include "llvm/ADT/SmallVectorExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/Support/Debug.h"
27 #include <numeric>
28 #include <optional>
29 
30 using namespace mlir;
31 using namespace mlir::affine;
32 
33 #define DEBUG_TYPE "affine-ops"
34 
35 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
36 
37 /// A utility function to check if a value is defined at the top level of
38 /// `region` or is an argument of `region`. A value of index type defined at the
39 /// top level of a `AffineScope` region is always a valid symbol for all
40 /// uses in that region.
42  if (auto arg = llvm::dyn_cast<BlockArgument>(value))
43  return arg.getParentRegion() == region;
44  return value.getDefiningOp()->getParentRegion() == region;
45 }
46 
47 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
48 /// region remains legal if the operation that uses it is inlined into `dest`
49 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
50 /// `isValidSymbol`, depending on the value being required to remain a valid
51 /// dimension or symbol.
52 static bool
54  const IRMapping &mapping,
55  function_ref<bool(Value, Region *)> legalityCheck) {
56  // If the value is a valid dimension for any other reason than being
57  // a top-level value, it will remain valid: constants get inlined
58  // with the function, transitive affine applies also get inlined and
59  // will be checked themselves, etc.
60  if (!isTopLevelValue(value, src))
61  return true;
62 
63  // If it's a top-level value because it's a block operand, i.e. a
64  // function argument, check whether the value replacing it after
65  // inlining is a valid dimension in the new region.
66  if (llvm::isa<BlockArgument>(value))
67  return legalityCheck(mapping.lookup(value), dest);
68 
69  // If it's a top-level value because it's defined in the region,
70  // it can only be inlined if the defining op is a constant or a
71  // `dim`, which can appear anywhere and be valid, since the defining
72  // op won't be top-level anymore after inlining.
73  Attribute operandCst;
74  bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp());
75  return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
76  isDimLikeOp;
77 }
78 
79 /// Checks if all values known to be legal affine dimensions or symbols in `src`
80 /// remain so if their respective users are inlined into `dest`.
81 static bool
83  const IRMapping &mapping,
84  function_ref<bool(Value, Region *)> legalityCheck) {
85  return llvm::all_of(values, [&](Value v) {
86  return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
87  });
88 }
89 
90 /// Checks if an affine read or write operation remains legal after inlining
91 /// from `src` to `dest`.
92 template <typename OpTy>
93 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
94  const IRMapping &mapping) {
95  static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
96  AffineWriteOpInterface>::value,
97  "only ops with affine read/write interface are supported");
98 
99  AffineMap map = op.getAffineMap();
100  ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
101  ValueRange symbolOperands =
102  op.getMapOperands().take_back(map.getNumSymbols());
104  dimOperands, src, dest, mapping,
105  static_cast<bool (*)(Value, Region *)>(isValidDim)))
106  return false;
108  symbolOperands, src, dest, mapping,
109  static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
110  return false;
111  return true;
112 }
113 
114 /// Checks if an affine apply operation remains legal after inlining from `src`
115 /// to `dest`.
116 // Use "unused attribute" marker to silence clang-tidy warning stemming from
117 // the inability to see through "llvm::TypeSwitch".
118 template <>
119 bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op,
120  Region *src, Region *dest,
121  const IRMapping &mapping) {
122  // If it's a valid dimension, we need to check that it remains so.
123  if (isValidDim(op.getResult(), src))
125  op.getMapOperands(), src, dest, mapping,
126  static_cast<bool (*)(Value, Region *)>(isValidDim));
127 
128  // Otherwise it must be a valid symbol, check that it remains so.
130  op.getMapOperands(), src, dest, mapping,
131  static_cast<bool (*)(Value, Region *)>(isValidSymbol));
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // AffineDialect Interfaces
136 //===----------------------------------------------------------------------===//
137 
138 namespace {
139 /// This class defines the interface for handling inlining with affine
140 /// operations.
141 struct AffineInlinerInterface : public DialectInlinerInterface {
143 
144  //===--------------------------------------------------------------------===//
145  // Analysis Hooks
146  //===--------------------------------------------------------------------===//
147 
148  /// Returns true if the given region 'src' can be inlined into the region
149  /// 'dest' that is attached to an operation registered to the current dialect.
150  /// 'wouldBeCloned' is set if the region is cloned into its new location
151  /// rather than moved, indicating there may be other users.
152  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
153  IRMapping &valueMapping) const final {
154  // We can inline into affine loops and conditionals if this doesn't break
155  // affine value categorization rules.
156  Operation *destOp = dest->getParentOp();
157  if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
158  return false;
159 
160  // Multi-block regions cannot be inlined into affine constructs, all of
161  // which require single-block regions.
162  if (!llvm::hasSingleElement(*src))
163  return false;
164 
165  // Side-effecting operations that the affine dialect cannot understand
166  // should not be inlined.
167  Block &srcBlock = src->front();
168  for (Operation &op : srcBlock) {
169  // Ops with no side effects are fine,
170  if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
171  if (iface.hasNoEffect())
172  continue;
173  }
174 
175  // Assuming the inlined region is valid, we only need to check if the
176  // inlining would change it.
177  bool remainsValid =
179  .Case<AffineApplyOp, AffineReadOpInterface,
180  AffineWriteOpInterface>([&](auto op) {
181  return remainsLegalAfterInline(op, src, dest, valueMapping);
182  })
183  .Default([](Operation *) {
184  // Conservatively disallow inlining ops we cannot reason about.
185  return false;
186  });
187 
188  if (!remainsValid)
189  return false;
190  }
191 
192  return true;
193  }
194 
195  /// Returns true if the given operation 'op', that is registered to this
196  /// dialect, can be inlined into the given region, false otherwise.
197  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
198  IRMapping &valueMapping) const final {
199  // Always allow inlining affine operations into a region that is marked as
200  // affine scope, or into affine loops and conditionals. There are some edge
201  // cases when inlining *into* affine structures, but that is handled in the
202  // other 'isLegalToInline' hook above.
203  Operation *parentOp = region->getParentOp();
204  return parentOp->hasTrait<OpTrait::AffineScope>() ||
205  isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
206  }
207 
208  /// Affine regions should be analyzed recursively.
209  bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
210 };
211 } // namespace
212 
213 //===----------------------------------------------------------------------===//
214 // AffineDialect
215 //===----------------------------------------------------------------------===//
216 
217 void AffineDialect::initialize() {
218  addOperations<AffineDmaStartOp, AffineDmaWaitOp,
219 #define GET_OP_LIST
220 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
221  >();
222  addInterfaces<AffineInlinerInterface>();
223 }
224 
225 /// Materialize a single constant operation from a given attribute value with
226 /// the desired resultant type.
228  Attribute value, Type type,
229  Location loc) {
230  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
231  return builder.create<ub::PoisonOp>(loc, type, poison);
232  return arith::ConstantOp::materialize(builder, value, type, loc);
233 }
234 
235 /// A utility function to check if a value is defined at the top level of an
236 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
237 /// conservatively assume it is not top-level. A value of index type defined at
238 /// the top level is always a valid symbol.
240  if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
241  // The block owning the argument may be unlinked, e.g. when the surrounding
242  // region has not yet been attached to an Op, at which point the parent Op
243  // is null.
244  Operation *parentOp = arg.getOwner()->getParentOp();
245  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
246  }
247  // The defining Op may live in an unlinked block so its parent Op may be null.
248  Operation *parentOp = value.getDefiningOp()->getParentOp();
249  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
250 }
251 
252 /// Returns the closest region enclosing `op` that is held by an operation with
253 /// trait `AffineScope`; `nullptr` if there is no such region.
255  auto *curOp = op;
256  while (auto *parentOp = curOp->getParentOp()) {
257  if (parentOp->hasTrait<OpTrait::AffineScope>())
258  return curOp->getParentRegion();
259  curOp = parentOp;
260  }
261  return nullptr;
262 }
263 
264 // A Value can be used as a dimension id iff it meets one of the following
265 // conditions:
266 // *) It is valid as a symbol.
267 // *) It is an induction variable.
268 // *) It is the result of affine apply operation with dimension id arguments.
270  // The value must be an index type.
271  if (!value.getType().isIndex())
272  return false;
273 
274  if (auto *defOp = value.getDefiningOp())
275  return isValidDim(value, getAffineScope(defOp));
276 
277  // This value has to be a block argument for an op that has the
278  // `AffineScope` trait or for an affine.for or affine.parallel.
279  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
280  return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
281  isa<AffineForOp, AffineParallelOp>(parentOp));
282 }
283 
284 // Value can be used as a dimension id iff it meets one of the following
285 // conditions:
286 // *) It is valid as a symbol.
287 // *) It is an induction variable.
288 // *) It is the result of an affine apply operation with dimension id operands.
289 bool mlir::affine::isValidDim(Value value, Region *region) {
290  // The value must be an index type.
291  if (!value.getType().isIndex())
292  return false;
293 
294  // All valid symbols are okay.
295  if (isValidSymbol(value, region))
296  return true;
297 
298  auto *op = value.getDefiningOp();
299  if (!op) {
300  // This value has to be a block argument for an affine.for or an
301  // affine.parallel.
302  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
303  return isa<AffineForOp, AffineParallelOp>(parentOp);
304  }
305 
306  // Affine apply operation is ok if all of its operands are ok.
307  if (auto applyOp = dyn_cast<AffineApplyOp>(op))
308  return applyOp.isValidDim(region);
309  // The dim op is okay if its operand memref/tensor is defined at the top
310  // level.
311  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
312  return isTopLevelValue(dimOp.getShapedValue());
313  return false;
314 }
315 
316 /// Returns true if the 'index' dimension of the `memref` defined by
317 /// `memrefDefOp` is a statically shaped one or defined using a valid symbol
318 /// for `region`.
319 template <typename AnyMemRefDefOp>
320 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
321  Region *region) {
322  MemRefType memRefType = memrefDefOp.getType();
323 
324  // Dimension index is out of bounds.
325  if (index >= memRefType.getRank()) {
326  return false;
327  }
328 
329  // Statically shaped.
330  if (!memRefType.isDynamicDim(index))
331  return true;
332  // Get the position of the dimension among dynamic dimensions;
333  unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
334  return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
335  region);
336 }
337 
338 /// Returns true if the result of the dim op is a valid symbol for `region`.
339 static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
340  // The dim op is okay if its source is defined at the top level.
341  if (isTopLevelValue(dimOp.getShapedValue()))
342  return true;
343 
344  // Conservatively handle remaining BlockArguments as non-valid symbols.
345  // E.g. scf.for iterArgs.
346  if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
347  return false;
348 
349  // The dim op is also okay if its operand memref is a view/subview whose
350  // corresponding size is a valid symbol.
351  std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
352 
353  // Be conservative if we can't understand the dimension.
354  if (!index.has_value())
355  return false;
356 
357  int64_t i = index.value();
358  return TypeSwitch<Operation *, bool>(dimOp.getShapedValue().getDefiningOp())
359  .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
360  [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
361  .Default([](Operation *) { return false; });
362 }
363 
364 // A value can be used as a symbol (at all its use sites) iff it meets one of
365 // the following conditions:
366 // *) It is a constant.
367 // *) Its defining op or block arg appearance is immediately enclosed by an op
368 // with `AffineScope` trait.
369 // *) It is the result of an affine.apply operation with symbol operands.
370 // *) It is a result of the dim op on a memref whose corresponding size is a
371 // valid symbol.
373  if (!value)
374  return false;
375 
376  // The value must be an index type.
377  if (!value.getType().isIndex())
378  return false;
379 
380  // Check that the value is a top level value.
381  if (isTopLevelValue(value))
382  return true;
383 
384  if (auto *defOp = value.getDefiningOp())
385  return isValidSymbol(value, getAffineScope(defOp));
386 
387  return false;
388 }
389 
390 /// A value can be used as a symbol for `region` iff it meets one of the
391 /// following conditions:
392 /// *) It is a constant.
393 /// *) It is the result of an affine apply operation with symbol arguments.
394 /// *) It is a result of the dim op on a memref whose corresponding size is
395 /// a valid symbol.
396 /// *) It is defined at the top level of 'region' or is its argument.
397 /// *) It dominates `region`'s parent op.
398 /// If `region` is null, conservatively assume the symbol definition scope does
399 /// not exist and only accept the values that would be symbols regardless of
400 /// the surrounding region structure, i.e. the first three cases above.
402  // The value must be an index type.
403  if (!value.getType().isIndex())
404  return false;
405 
406  // A top-level value is a valid symbol.
407  if (region && ::isTopLevelValue(value, region))
408  return true;
409 
410  auto *defOp = value.getDefiningOp();
411  if (!defOp) {
412  // A block argument that is not a top-level value is a valid symbol if it
413  // dominates region's parent op.
414  Operation *regionOp = region ? region->getParentOp() : nullptr;
415  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
416  if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
417  return isValidSymbol(value, parentOpRegion);
418  return false;
419  }
420 
421  // Constant operation is ok.
422  Attribute operandCst;
423  if (matchPattern(defOp, m_Constant(&operandCst)))
424  return true;
425 
426  // Affine apply operation is ok if all of its operands are ok.
427  if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
428  return applyOp.isValidSymbol(region);
429 
430  // Dim op results could be valid symbols at any level.
431  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
432  return isDimOpValidSymbol(dimOp, region);
433 
434  // Check for values dominating `region`'s parent op.
435  Operation *regionOp = region ? region->getParentOp() : nullptr;
436  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
437  if (auto *parentRegion = region->getParentOp()->getParentRegion())
438  return isValidSymbol(value, parentRegion);
439 
440  return false;
441 }
442 
443 // Returns true if 'value' is a valid index to an affine operation (e.g.
444 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
445 // `region` provides the polyhedral symbol scope. Returns false otherwise.
446 static bool isValidAffineIndexOperand(Value value, Region *region) {
447  return isValidDim(value, region) || isValidSymbol(value, region);
448 }
449 
450 /// Prints dimension and symbol list.
453  unsigned numDims, OpAsmPrinter &printer) {
454  OperandRange operands(begin, end);
455  printer << '(' << operands.take_front(numDims) << ')';
456  if (operands.size() > numDims)
457  printer << '[' << operands.drop_front(numDims) << ']';
458 }
459 
460 /// Parses dimension and symbol list and returns true if parsing failed.
462  OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
464  if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
465  return failure();
466  // Store number of dimensions for validation by caller.
467  numDims = opInfos.size();
468 
469  // Parse the optional symbol operands.
470  auto indexTy = parser.getBuilder().getIndexType();
471  return failure(parser.parseOperandList(
473  parser.resolveOperands(opInfos, indexTy, operands));
474 }
475 
476 /// Utility function to verify that a set of operands are valid dimension and
477 /// symbol identifiers. The operands should be laid out such that the dimension
478 /// operands are before the symbol operands. This function returns failure if
479 /// there was an invalid operand. An operation is provided to emit any necessary
480 /// errors.
481 template <typename OpTy>
482 static LogicalResult
484  unsigned numDims) {
485  unsigned opIt = 0;
486  for (auto operand : operands) {
487  if (opIt++ < numDims) {
488  if (!isValidDim(operand, getAffineScope(op)))
489  return op.emitOpError("operand cannot be used as a dimension id");
490  } else if (!isValidSymbol(operand, getAffineScope(op))) {
491  return op.emitOpError("operand cannot be used as a symbol");
492  }
493  }
494  return success();
495 }
496 
497 //===----------------------------------------------------------------------===//
498 // AffineApplyOp
499 //===----------------------------------------------------------------------===//
500 
501 AffineValueMap AffineApplyOp::getAffineValueMap() {
502  return AffineValueMap(getAffineMap(), getOperands(), getResult());
503 }
504 
506  auto &builder = parser.getBuilder();
507  auto indexTy = builder.getIndexType();
508 
509  AffineMapAttr mapAttr;
510  unsigned numDims;
511  if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
512  parseDimAndSymbolList(parser, result.operands, numDims) ||
513  parser.parseOptionalAttrDict(result.attributes))
514  return failure();
515  auto map = mapAttr.getValue();
516 
517  if (map.getNumDims() != numDims ||
518  numDims + map.getNumSymbols() != result.operands.size()) {
519  return parser.emitError(parser.getNameLoc(),
520  "dimension or symbol index mismatch");
521  }
522 
523  result.types.append(map.getNumResults(), indexTy);
524  return success();
525 }
526 
528  p << " " << getMapAttr();
529  printDimAndSymbolList(operand_begin(), operand_end(),
530  getAffineMap().getNumDims(), p);
531  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
532 }
533 
535  // Check input and output dimensions match.
536  AffineMap affineMap = getMap();
537 
538  // Verify that operand count matches affine map dimension and symbol count.
539  if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
540  return emitOpError(
541  "operand count and affine map dimension and symbol count must match");
542 
543  // Verify that the map only produces one result.
544  if (affineMap.getNumResults() != 1)
545  return emitOpError("mapping must produce one value");
546 
547  return success();
548 }
549 
550 // The result of the affine apply operation can be used as a dimension id if all
551 // its operands are valid dimension ids.
553  return llvm::all_of(getOperands(),
554  [](Value op) { return affine::isValidDim(op); });
555 }
556 
557 // The result of the affine apply operation can be used as a dimension id if all
558 // its operands are valid dimension ids with the parent operation of `region`
559 // defining the polyhedral scope for symbols.
560 bool AffineApplyOp::isValidDim(Region *region) {
561  return llvm::all_of(getOperands(),
562  [&](Value op) { return ::isValidDim(op, region); });
563 }
564 
565 // The result of the affine apply operation can be used as a symbol if all its
566 // operands are symbols.
568  return llvm::all_of(getOperands(),
569  [](Value op) { return affine::isValidSymbol(op); });
570 }
571 
572 // The result of the affine apply operation can be used as a symbol in `region`
573 // if all its operands are symbols in `region`.
574 bool AffineApplyOp::isValidSymbol(Region *region) {
575  return llvm::all_of(getOperands(), [&](Value operand) {
576  return affine::isValidSymbol(operand, region);
577  });
578 }
579 
580 OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
581  auto map = getAffineMap();
582 
583  // Fold dims and symbols to existing values.
584  auto expr = map.getResult(0);
585  if (auto dim = dyn_cast<AffineDimExpr>(expr))
586  return getOperand(dim.getPosition());
587  if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
588  return getOperand(map.getNumDims() + sym.getPosition());
589 
590  // Otherwise, default to folding the map.
592  bool hasPoison = false;
593  auto foldResult =
594  map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
595  if (hasPoison)
597  if (failed(foldResult))
598  return {};
599  return result[0];
600 }
601 
602 /// Returns the largest known divisor of `e`. Exploits information from the
603 /// values in `operands`.
604 static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
605  // This method isn't aware of `operands`.
606  int64_t div = e.getLargestKnownDivisor();
607 
608  // We now make use of operands for the case `e` is a dim expression.
609  // TODO: More powerful simplification would have to modify
610  // getLargestKnownDivisor to take `operands` and exploit that information as
611  // well for dim/sym expressions, but in that case, getLargestKnownDivisor
612  // can't be part of the IR library but of the `Analysis` library. The IR
613  // library can only really depend on simple O(1) checks.
614  auto dimExpr = dyn_cast<AffineDimExpr>(e);
615  // If it's not a dim expr, `div` is the best we have.
616  if (!dimExpr)
617  return div;
618 
619  // We simply exploit information from loop IVs.
620  // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
621  // desired simplifications are expected to be part of other
622  // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
623  // LoopAnalysis library.
624  Value operand = operands[dimExpr.getPosition()];
625  int64_t operandDivisor = 1;
626  // TODO: With the right accessors, this can be extended to
627  // LoopLikeOpInterface.
628  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
629  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
630  operandDivisor = forOp.getStepAsInt();
631  } else {
632  uint64_t lbLargestKnownDivisor =
633  forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
634  operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
635  }
636  }
637  return operandDivisor;
638 }
639 
640 /// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
641 /// being an affine dim expression or a constant.
643  int64_t k) {
644  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
645  int64_t constVal = constExpr.getValue();
646  return constVal >= 0 && constVal < k;
647  }
648  auto dimExpr = dyn_cast<AffineDimExpr>(e);
649  if (!dimExpr)
650  return false;
651  Value operand = operands[dimExpr.getPosition()];
652  // TODO: With the right accessors, this can be extended to
653  // LoopLikeOpInterface.
654  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
655  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
656  forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
657  return true;
658  }
659  }
660 
661  // We don't consider other cases like `operand` being defined by a constant or
662  // an affine.apply op since such cases will already be handled by other
663  // patterns and propagation of loop IVs or constant would happen.
664  return false;
665 }
666 
667 /// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
668 /// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
669 /// expression is in that form.
670 static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
671  AffineExpr &quotientTimesDiv, AffineExpr &rem) {
672  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
673  if (!bin || bin.getKind() != AffineExprKind::Add)
674  return false;
675 
676  AffineExpr llhs = bin.getLHS();
677  AffineExpr rlhs = bin.getRHS();
678  div = getLargestKnownDivisor(llhs, operands);
679  if (isNonNegativeBoundedBy(rlhs, operands, div)) {
680  quotientTimesDiv = llhs;
681  rem = rlhs;
682  return true;
683  }
684  div = getLargestKnownDivisor(rlhs, operands);
685  if (isNonNegativeBoundedBy(llhs, operands, div)) {
686  quotientTimesDiv = rlhs;
687  rem = llhs;
688  return true;
689  }
690  return false;
691 }
692 
693 /// Gets the constant lower bound on an `iv`.
694 static std::optional<int64_t> getLowerBound(Value iv) {
695  AffineForOp forOp = getForInductionVarOwner(iv);
696  if (forOp && forOp.hasConstantLowerBound())
697  return forOp.getConstantLowerBound();
698  return std::nullopt;
699 }
700 
701 /// Gets the constant upper bound on an affine.for `iv`.
702 static std::optional<int64_t> getUpperBound(Value iv) {
703  AffineForOp forOp = getForInductionVarOwner(iv);
704  if (!forOp || !forOp.hasConstantUpperBound())
705  return std::nullopt;
706 
707  // If its lower bound is also known, we can get a more precise bound
708  // whenever the step is not one.
709  if (forOp.hasConstantLowerBound()) {
710  return forOp.getConstantUpperBound() - 1 -
711  (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
712  forOp.getStepAsInt();
713  }
714  return forOp.getConstantUpperBound() - 1;
715 }
716 
717 /// Determine a constant upper bound for `expr` if one exists while exploiting
718 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
719 /// is guaranteed to be less than or equal to it.
720 static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
721  unsigned numSymbols,
722  ArrayRef<Value> operands) {
723  // Get the constant lower or upper bounds on the operands.
724  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
725  constLowerBounds.reserve(operands.size());
726  constUpperBounds.reserve(operands.size());
727  for (Value operand : operands) {
728  constLowerBounds.push_back(getLowerBound(operand));
729  constUpperBounds.push_back(getUpperBound(operand));
730  }
731 
732  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
733  return constExpr.getValue();
734 
735  return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
736  constUpperBounds,
737  /*isUpper=*/true);
738 }
739 
740 /// Determine a constant lower bound for `expr` if one exists while exploiting
741 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
742 /// is guaranteed to be less than or equal to it.
743 static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
744  unsigned numSymbols,
745  ArrayRef<Value> operands) {
746  // Get the constant lower or upper bounds on the operands.
747  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
748  constLowerBounds.reserve(operands.size());
749  constUpperBounds.reserve(operands.size());
750  for (Value operand : operands) {
751  constLowerBounds.push_back(getLowerBound(operand));
752  constUpperBounds.push_back(getUpperBound(operand));
753  }
754 
755  std::optional<int64_t> lowerBound;
756  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
757  lowerBound = constExpr.getValue();
758  } else {
759  lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
760  constLowerBounds, constUpperBounds,
761  /*isUpper=*/false);
762  }
763  return lowerBound;
764 }
765 
766 /// Simplify `expr` while exploiting information from the values in `operands`.
767 static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
768  unsigned numSymbols,
769  ArrayRef<Value> operands) {
770  // We do this only for certain floordiv/mod expressions.
771  auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
772  if (!binExpr)
773  return;
774 
775  // Simplify the child expressions first.
776  AffineExpr lhs = binExpr.getLHS();
777  AffineExpr rhs = binExpr.getRHS();
778  simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
779  simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
780  expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
781 
782  binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
783  if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
784  expr.getKind() != AffineExprKind::CeilDiv &&
785  expr.getKind() != AffineExprKind::Mod)) {
786  return;
787  }
788 
789  // The `lhs` and `rhs` may be different post construction of simplified expr.
790  lhs = binExpr.getLHS();
791  rhs = binExpr.getRHS();
792  auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
793  if (!rhsConst)
794  return;
795 
796  int64_t rhsConstVal = rhsConst.getValue();
797  // Undefined exprsessions aren't touched; IR can still be valid with them.
798  if (rhsConstVal <= 0)
799  return;
800 
801  // Exploit constant lower/upper bounds to simplify a floordiv or mod.
802  MLIRContext *context = expr.getContext();
803  std::optional<int64_t> lhsLbConst =
804  getLowerBound(lhs, numDims, numSymbols, operands);
805  std::optional<int64_t> lhsUbConst =
806  getUpperBound(lhs, numDims, numSymbols, operands);
807  if (lhsLbConst && lhsUbConst) {
808  int64_t lhsLbConstVal = *lhsLbConst;
809  int64_t lhsUbConstVal = *lhsUbConst;
810  // lhs floordiv c is a single value lhs is bounded in a range `c` that has
811  // the same quotient.
812  if (binExpr.getKind() == AffineExprKind::FloorDiv &&
813  floorDiv(lhsLbConstVal, rhsConstVal) ==
814  floorDiv(lhsUbConstVal, rhsConstVal)) {
815  expr =
816  getAffineConstantExpr(floorDiv(lhsLbConstVal, rhsConstVal), context);
817  return;
818  }
819  // lhs ceildiv c is a single value if the entire range has the same ceil
820  // quotient.
821  if (binExpr.getKind() == AffineExprKind::CeilDiv &&
822  ceilDiv(lhsLbConstVal, rhsConstVal) ==
823  ceilDiv(lhsUbConstVal, rhsConstVal)) {
824  expr =
825  getAffineConstantExpr(ceilDiv(lhsLbConstVal, rhsConstVal), context);
826  return;
827  }
828  // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
829  if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
830  lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
831  expr = lhs;
832  return;
833  }
834  }
835 
836  // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
837  // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
838  // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
839  // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
840  AffineExpr quotientTimesDiv, rem;
841  int64_t divisor;
842  if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
843  if (rhsConstVal % divisor == 0 &&
844  binExpr.getKind() == AffineExprKind::FloorDiv) {
845  expr = quotientTimesDiv.floorDiv(rhsConst);
846  } else if (divisor % rhsConstVal == 0 &&
847  binExpr.getKind() == AffineExprKind::Mod) {
848  expr = rem % rhsConst;
849  }
850  return;
851  }
852 
853  // Handle the simple case when the LHS expression can be either upper
854  // bounded or is a known multiple of RHS constant.
855  // lhs floordiv c -> 0 if 0 <= lhs < c,
856  // lhs mod c -> 0 if lhs % c = 0.
857  if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) &&
858  binExpr.getKind() == AffineExprKind::FloorDiv) ||
859  (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 &&
860  binExpr.getKind() == AffineExprKind::Mod)) {
861  expr = getAffineConstantExpr(0, expr.getContext());
862  }
863 }
864 
865 /// Simplify the expressions in `map` while making use of lower or upper bounds
866 /// of its operands. If `isMax` is true, the map is to be treated as a max of
867 /// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
868 /// d1) can be simplified to (8) if the operands are respectively lower bounded
869 /// by 2 and 0 (the second expression can't be lower than 8).
871  ArrayRef<Value> operands,
872  bool isMax) {
873  // Can't simplify.
874  if (operands.empty())
875  return;
876 
877  // Get the upper or lower bound on an affine.for op IV using its range.
878  // Get the constant lower or upper bounds on the operands.
879  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
880  constLowerBounds.reserve(operands.size());
881  constUpperBounds.reserve(operands.size());
882  for (Value operand : operands) {
883  constLowerBounds.push_back(getLowerBound(operand));
884  constUpperBounds.push_back(getUpperBound(operand));
885  }
886 
887  // We will compute the lower and upper bounds on each of the expressions
888  // Then, we will check (depending on max or min) as to whether a specific
889  // bound is redundant by checking if its highest (in case of max) and its
890  // lowest (in the case of min) value is already lower than (or higher than)
891  // the lower bound (or upper bound in the case of min) of another bound.
892  SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
893  lowerBounds.reserve(map.getNumResults());
894  upperBounds.reserve(map.getNumResults());
895  for (AffineExpr e : map.getResults()) {
896  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
897  lowerBounds.push_back(constExpr.getValue());
898  upperBounds.push_back(constExpr.getValue());
899  } else {
900  lowerBounds.push_back(
902  constLowerBounds, constUpperBounds,
903  /*isUpper=*/false));
904  upperBounds.push_back(
906  constLowerBounds, constUpperBounds,
907  /*isUpper=*/true));
908  }
909  }
910 
911  // Collect expressions that are not redundant.
912  SmallVector<AffineExpr, 4> irredundantExprs;
913  for (auto exprEn : llvm::enumerate(map.getResults())) {
914  AffineExpr e = exprEn.value();
915  unsigned i = exprEn.index();
916  // Some expressions can be turned into constants.
917  if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
918  e = getAffineConstantExpr(*lowerBounds[i], e.getContext());
919 
920  // Check if the expression is redundant.
921  if (isMax) {
922  if (!upperBounds[i]) {
923  irredundantExprs.push_back(e);
924  continue;
925  }
926  // If there exists another expression such that its lower bound is greater
927  // than this expression's upper bound, it's redundant.
928  if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
929  auto otherLowerBound = en.value();
930  unsigned pos = en.index();
931  if (pos == i || !otherLowerBound)
932  return false;
933  if (*otherLowerBound > *upperBounds[i])
934  return true;
935  if (*otherLowerBound < *upperBounds[i])
936  return false;
937  // Equality case. When both expressions are considered redundant, we
938  // don't want to get both of them. We keep the one that appears
939  // first.
940  if (upperBounds[pos] && lowerBounds[i] &&
941  lowerBounds[i] == upperBounds[i] &&
942  otherLowerBound == *upperBounds[pos] && i < pos)
943  return false;
944  return true;
945  }))
946  irredundantExprs.push_back(e);
947  } else {
948  if (!lowerBounds[i]) {
949  irredundantExprs.push_back(e);
950  continue;
951  }
952  // Likewise for the `min` case. Use the complement of the condition above.
953  if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
954  auto otherUpperBound = en.value();
955  unsigned pos = en.index();
956  if (pos == i || !otherUpperBound)
957  return false;
958  if (*otherUpperBound < *lowerBounds[i])
959  return true;
960  if (*otherUpperBound > *lowerBounds[i])
961  return false;
962  if (lowerBounds[pos] && upperBounds[i] &&
963  lowerBounds[i] == upperBounds[i] &&
964  otherUpperBound == lowerBounds[pos] && i < pos)
965  return false;
966  return true;
967  }))
968  irredundantExprs.push_back(e);
969  }
970  }
971 
972  // Create the map without the redundant expressions.
973  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs,
974  map.getContext());
975 }
976 
977 /// Simplify the map while exploiting information on the values in `operands`.
978 // Use "unused attribute" marker to silence warning stemming from the inability
979 // to see through the template expansion.
980 static void LLVM_ATTRIBUTE_UNUSED
982  assert(map.getNumInputs() == operands.size() && "invalid operands for map");
983  SmallVector<AffineExpr> newResults;
984  newResults.reserve(map.getNumResults());
985  for (AffineExpr expr : map.getResults()) {
987  operands);
988  newResults.push_back(expr);
989  }
990  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
991  map.getContext());
992 }
993 
994 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
995 /// defining AffineApplyOp expression and operands.
996 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
997 /// When `dimOrSymbolPosition >= dims.size()`,
998 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
999 /// Mutate `map`,`dims` and `syms` in place as follows:
1000 /// 1. `dims` and `syms` are only appended to.
1001 /// 2. `map` dim and symbols are gradually shifted to higher positions.
1002 /// 3. Old `dim` and `sym` entries are replaced by nullptr
1003 /// This avoids the need for any bookkeeping.
1005  unsigned dimOrSymbolPosition,
1006  SmallVectorImpl<Value> &dims,
1007  SmallVectorImpl<Value> &syms) {
1008  MLIRContext *ctx = map->getContext();
1009  bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1010  unsigned pos = isDimReplacement ? dimOrSymbolPosition
1011  : dimOrSymbolPosition - dims.size();
1012  Value &v = isDimReplacement ? dims[pos] : syms[pos];
1013  if (!v)
1014  return failure();
1015 
1016  auto affineApply = v.getDefiningOp<AffineApplyOp>();
1017  if (!affineApply)
1018  return failure();
1019 
1020  // At this point we will perform a replacement of `v`, set the entry in `dim`
1021  // or `sym` to nullptr immediately.
1022  v = nullptr;
1023 
1024  // Compute the map, dims and symbols coming from the AffineApplyOp.
1025  AffineMap composeMap = affineApply.getAffineMap();
1026  assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1027  SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1028  affineApply.getMapOperands().end());
1029  // Canonicalize the map to promote dims to symbols when possible. This is to
1030  // avoid generating invalid maps.
1031  canonicalizeMapAndOperands(&composeMap, &composeOperands);
1032  AffineExpr replacementExpr =
1033  composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
1034  ValueRange composeDims =
1035  ArrayRef<Value>(composeOperands).take_front(composeMap.getNumDims());
1036  ValueRange composeSyms =
1037  ArrayRef<Value>(composeOperands).take_back(composeMap.getNumSymbols());
1038  AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
1039  : getAffineSymbolExpr(pos, ctx);
1040 
1041  // Append the dims and symbols where relevant and perform the replacement.
1042  dims.append(composeDims.begin(), composeDims.end());
1043  syms.append(composeSyms.begin(), composeSyms.end());
1044  *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());
1045 
1046  return success();
1047 }
1048 
1049 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1050 /// iteratively. Perform canonicalization of map and operands as well as
1051 /// AffineMap simplification. `map` and `operands` are mutated in place.
1053  SmallVectorImpl<Value> *operands) {
1054  if (map->getNumResults() == 0) {
1055  canonicalizeMapAndOperands(map, operands);
1056  *map = simplifyAffineMap(*map);
1057  return;
1058  }
1059 
1060  MLIRContext *ctx = map->getContext();
1061  SmallVector<Value, 4> dims(operands->begin(),
1062  operands->begin() + map->getNumDims());
1063  SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1064  operands->end());
1065 
1066  // Iterate over dims and symbols coming from AffineApplyOp and replace until
1067  // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1068  // and `syms` can only increase by construction.
1069  // The implementation uses a `while` loop to support the case of symbols
1070  // that may be constructed from dims ;this may be overkill.
1071  while (true) {
1072  bool changed = false;
1073  for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1074  if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
1075  break;
1076  if (!changed)
1077  break;
1078  }
1079 
1080  // Clear operands so we can fill them anew.
1081  operands->clear();
1082 
1083  // At this point we may have introduced null operands, prune them out before
1084  // canonicalizing map and operands.
1085  unsigned nDims = 0, nSyms = 0;
1086  SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1087  dimReplacements.reserve(dims.size());
1088  symReplacements.reserve(syms.size());
1089  for (auto *container : {&dims, &syms}) {
1090  bool isDim = (container == &dims);
1091  auto &repls = isDim ? dimReplacements : symReplacements;
1092  for (const auto &en : llvm::enumerate(*container)) {
1093  Value v = en.value();
1094  if (!v) {
1095  assert(isDim ? !map->isFunctionOfDim(en.index())
1096  : !map->isFunctionOfSymbol(en.index()) &&
1097  "map is function of unexpected expr@pos");
1098  repls.push_back(getAffineConstantExpr(0, ctx));
1099  continue;
1100  }
1101  repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
1102  : getAffineSymbolExpr(nSyms++, ctx));
1103  operands->push_back(v);
1104  }
1105  }
1106  *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
1107  nSyms);
1108 
1109  // Canonicalize and simplify before returning.
1110  canonicalizeMapAndOperands(map, operands);
1111  *map = simplifyAffineMap(*map);
1112 }
1113 
1115  AffineMap *map, SmallVectorImpl<Value> *operands) {
1116  while (llvm::any_of(*operands, [](Value v) {
1117  return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
1118  })) {
1119  composeAffineMapAndOperands(map, operands);
1120  }
1121 }
1122 
1123 AffineApplyOp
1125  ArrayRef<OpFoldResult> operands) {
1126  SmallVector<Value> valueOperands;
1127  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1128  composeAffineMapAndOperands(&map, &valueOperands);
1129  assert(map);
1130  return b.create<AffineApplyOp>(loc, map, valueOperands);
1131 }
1132 
1133 AffineApplyOp
1135  ArrayRef<OpFoldResult> operands) {
1136  return makeComposedAffineApply(
1138  operands);
1139 }
1140 
1141 /// Composes the given affine map with the given list of operands, pulling in
1142 /// the maps from any affine.apply operations that supply the operands.
1144  SmallVectorImpl<Value> &operands) {
1145  // Compose and canonicalize each expression in the map individually because
1146  // composition only applies to single-result maps, collecting potentially
1147  // duplicate operands in a single list with shifted dimensions and symbols.
1148  SmallVector<Value> dims, symbols;
1150  for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
1151  SmallVector<Value> submapOperands(operands.begin(), operands.end());
1152  AffineMap submap = map.getSubMap({i});
1153  fullyComposeAffineMapAndOperands(&submap, &submapOperands);
1154  canonicalizeMapAndOperands(&submap, &submapOperands);
1155  unsigned numNewDims = submap.getNumDims();
1156  submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
1157  llvm::append_range(dims,
1158  ArrayRef<Value>(submapOperands).take_front(numNewDims));
1159  llvm::append_range(symbols,
1160  ArrayRef<Value>(submapOperands).drop_front(numNewDims));
1161  exprs.push_back(submap.getResult(0));
1162  }
1163 
1164  // Canonicalize the map created from composed expressions to deduplicate the
1165  // dimension and symbol operands.
1166  operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1167  map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
1168  canonicalizeMapAndOperands(&map, &operands);
1169 }
1170 
1173  AffineMap map,
1174  ArrayRef<OpFoldResult> operands) {
1175  assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1176 
1177  // Create new builder without a listener, so that no notification is
1178  // triggered if the op is folded.
1179  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1180  // workaround is no longer needed.
1181  OpBuilder newBuilder(b.getContext());
1183 
1184  // Create op.
1185  AffineApplyOp applyOp =
1186  makeComposedAffineApply(newBuilder, loc, map, operands);
1187 
1188  // Get constant operands.
1189  SmallVector<Attribute> constOperands(applyOp->getNumOperands());
1190  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1191  matchPattern(applyOp->getOperand(i), m_Constant(&constOperands[i]));
1192 
1193  // Try to fold the operation.
1194  SmallVector<OpFoldResult> foldResults;
1195  if (failed(applyOp->fold(constOperands, foldResults)) ||
1196  foldResults.empty()) {
1197  if (OpBuilder::Listener *listener = b.getListener())
1198  listener->notifyOperationInserted(applyOp);
1199  return applyOp.getResult();
1200  }
1201 
1202  applyOp->erase();
1203  assert(foldResults.size() == 1 && "expected 1 folded result");
1204  return foldResults.front();
1205 }
1206 
1209  AffineExpr expr,
1210  ArrayRef<OpFoldResult> operands) {
1212  b, loc, AffineMap::inferFromExprList(ArrayRef<AffineExpr>{expr}).front(),
1213  operands);
1214 }
1215 
1218  OpBuilder &b, Location loc, AffineMap map,
1219  ArrayRef<OpFoldResult> operands) {
1220  return llvm::map_to_vector(llvm::seq<unsigned>(0, map.getNumResults()),
1221  [&](unsigned i) {
1222  return makeComposedFoldedAffineApply(
1223  b, loc, map.getSubMap({i}), operands);
1224  });
1225 }
1226 
1227 template <typename OpTy>
1229  ArrayRef<OpFoldResult> operands) {
1230  SmallVector<Value> valueOperands;
1231  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1232  composeMultiResultAffineMap(map, valueOperands);
1233  return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands);
1234 }
1235 
1236 AffineMinOp
1238  ArrayRef<OpFoldResult> operands) {
1239  return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1240 }
1241 
1242 template <typename OpTy>
1244  AffineMap map,
1245  ArrayRef<OpFoldResult> operands) {
1246  // Create new builder without a listener, so that no notification is
1247  // triggered if the op is folded.
1248  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1249  // workaround is no longer needed.
1250  OpBuilder newBuilder(b.getContext());
1252 
1253  // Create op.
1254  auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1255 
1256  // Get constant operands.
1257  SmallVector<Attribute> constOperands(minMaxOp->getNumOperands());
1258  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1259  matchPattern(minMaxOp->getOperand(i), m_Constant(&constOperands[i]));
1260 
1261  // Try to fold the operation.
1262  SmallVector<OpFoldResult> foldResults;
1263  if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1264  foldResults.empty()) {
1265  if (OpBuilder::Listener *listener = b.getListener())
1266  listener->notifyOperationInserted(minMaxOp);
1267  return minMaxOp.getResult();
1268  }
1269 
1270  minMaxOp->erase();
1271  assert(foldResults.size() == 1 && "expected 1 folded result");
1272  return foldResults.front();
1273 }
1274 
1277  AffineMap map,
1278  ArrayRef<OpFoldResult> operands) {
1279  return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1280 }
1281 
1284  AffineMap map,
1285  ArrayRef<OpFoldResult> operands) {
1286  return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1287 }
1288 
1289 // A symbol may appear as a dim in affine.apply operations. This function
1290 // canonicalizes dims that are valid symbols into actual symbols.
1291 template <class MapOrSet>
1292 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1293  SmallVectorImpl<Value> *operands) {
1294  if (!mapOrSet || operands->empty())
1295  return;
1296 
1297  assert(mapOrSet->getNumInputs() == operands->size() &&
1298  "map/set inputs must match number of operands");
1299 
1300  auto *context = mapOrSet->getContext();
1301  SmallVector<Value, 8> resultOperands;
1302  resultOperands.reserve(operands->size());
1303  SmallVector<Value, 8> remappedSymbols;
1304  remappedSymbols.reserve(operands->size());
1305  unsigned nextDim = 0;
1306  unsigned nextSym = 0;
1307  unsigned oldNumSyms = mapOrSet->getNumSymbols();
1308  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1309  for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1310  if (i < mapOrSet->getNumDims()) {
1311  if (isValidSymbol((*operands)[i])) {
1312  // This is a valid symbol that appears as a dim, canonicalize it.
1313  dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1314  remappedSymbols.push_back((*operands)[i]);
1315  } else {
1316  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1317  resultOperands.push_back((*operands)[i]);
1318  }
1319  } else {
1320  resultOperands.push_back((*operands)[i]);
1321  }
1322  }
1323 
1324  resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1325  *operands = resultOperands;
1326  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1327  oldNumSyms + nextSym);
1328 
1329  assert(mapOrSet->getNumInputs() == operands->size() &&
1330  "map/set inputs must match number of operands");
1331 }
1332 
1333 // Works for either an affine map or an integer set.
1334 template <class MapOrSet>
1335 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1336  SmallVectorImpl<Value> *operands) {
1337  static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1338  "Argument must be either of AffineMap or IntegerSet type");
1339 
1340  if (!mapOrSet || operands->empty())
1341  return;
1342 
1343  assert(mapOrSet->getNumInputs() == operands->size() &&
1344  "map/set inputs must match number of operands");
1345 
1346  canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1347 
1348  // Check to see what dims are used.
1349  llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1350  llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1351  mapOrSet->walkExprs([&](AffineExpr expr) {
1352  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1353  usedDims[dimExpr.getPosition()] = true;
1354  else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1355  usedSyms[symExpr.getPosition()] = true;
1356  });
1357 
1358  auto *context = mapOrSet->getContext();
1359 
1360  SmallVector<Value, 8> resultOperands;
1361  resultOperands.reserve(operands->size());
1362 
1363  llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1364  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1365  unsigned nextDim = 0;
1366  for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1367  if (usedDims[i]) {
1368  // Remap dim positions for duplicate operands.
1369  auto it = seenDims.find((*operands)[i]);
1370  if (it == seenDims.end()) {
1371  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1372  resultOperands.push_back((*operands)[i]);
1373  seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1374  } else {
1375  dimRemapping[i] = it->second;
1376  }
1377  }
1378  }
1379  llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1380  SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1381  unsigned nextSym = 0;
1382  for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1383  if (!usedSyms[i])
1384  continue;
1385  // Handle constant operands (only needed for symbolic operands since
1386  // constant operands in dimensional positions would have already been
1387  // promoted to symbolic positions above).
1388  IntegerAttr operandCst;
1389  if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1390  m_Constant(&operandCst))) {
1391  symRemapping[i] =
1392  getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1393  continue;
1394  }
1395  // Remap symbol positions for duplicate operands.
1396  auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1397  if (it == seenSymbols.end()) {
1398  symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1399  resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1400  seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1401  symRemapping[i]));
1402  } else {
1403  symRemapping[i] = it->second;
1404  }
1405  }
1406  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1407  nextDim, nextSym);
1408  *operands = resultOperands;
1409 }
1410 
1412  AffineMap *map, SmallVectorImpl<Value> *operands) {
1413  canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1414 }
1415 
1417  IntegerSet *set, SmallVectorImpl<Value> *operands) {
1418  canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1419 }
1420 
1421 namespace {
1422 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1423 /// maps that supply results into them.
1424 ///
1425 template <typename AffineOpTy>
1426 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1428 
1429  /// Replace the affine op with another instance of it with the supplied
1430  /// map and mapOperands.
1431  void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1432  AffineMap map, ArrayRef<Value> mapOperands) const;
1433 
1434  LogicalResult matchAndRewrite(AffineOpTy affineOp,
1435  PatternRewriter &rewriter) const override {
1436  static_assert(
1437  llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1438  AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1439  AffineVectorStoreOp, AffineVectorLoadOp>::value,
1440  "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1441  "expected");
1442  auto map = affineOp.getAffineMap();
1443  AffineMap oldMap = map;
1444  auto oldOperands = affineOp.getMapOperands();
1445  SmallVector<Value, 8> resultOperands(oldOperands);
1446  composeAffineMapAndOperands(&map, &resultOperands);
1447  canonicalizeMapAndOperands(&map, &resultOperands);
1448  simplifyMapWithOperands(map, resultOperands);
1449  if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1450  resultOperands.begin()))
1451  return failure();
1452 
1453  replaceAffineOp(rewriter, affineOp, map, resultOperands);
1454  return success();
1455  }
1456 };
1457 
1458 // Specialize the template to account for the different build signatures for
1459 // affine load, store, and apply ops.
1460 template <>
1461 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1462  PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1463  ArrayRef<Value> mapOperands) const {
1464  rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1465  mapOperands);
1466 }
1467 template <>
1468 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1469  PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1470  ArrayRef<Value> mapOperands) const {
1471  rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1472  prefetch, prefetch.getMemref(), map, mapOperands,
1473  prefetch.getLocalityHint(), prefetch.getIsWrite(),
1474  prefetch.getIsDataCache());
1475 }
1476 template <>
1477 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1478  PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1479  ArrayRef<Value> mapOperands) const {
1480  rewriter.replaceOpWithNewOp<AffineStoreOp>(
1481  store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1482 }
1483 template <>
1484 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1485  PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1486  ArrayRef<Value> mapOperands) const {
1487  rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1488  vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1489  mapOperands);
1490 }
1491 template <>
1492 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1493  PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1494  ArrayRef<Value> mapOperands) const {
1495  rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1496  vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1497  mapOperands);
1498 }
1499 
1500 // Generic version for ops that don't have extra operands.
1501 template <typename AffineOpTy>
1502 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1503  PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1504  ArrayRef<Value> mapOperands) const {
1505  rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1506 }
1507 } // namespace
1508 
1509 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1510  MLIRContext *context) {
1511  results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1512 }
1513 
1514 //===----------------------------------------------------------------------===//
1515 // AffineDmaStartOp
1516 //===----------------------------------------------------------------------===//
1517 
1518 // TODO: Check that map operands are loop IVs or symbols.
1519 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1520  Value srcMemRef, AffineMap srcMap,
1521  ValueRange srcIndices, Value destMemRef,
1522  AffineMap dstMap, ValueRange destIndices,
1523  Value tagMemRef, AffineMap tagMap,
1524  ValueRange tagIndices, Value numElements,
1525  Value stride, Value elementsPerStride) {
1526  result.addOperands(srcMemRef);
1527  result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1528  result.addOperands(srcIndices);
1529  result.addOperands(destMemRef);
1530  result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1531  result.addOperands(destIndices);
1532  result.addOperands(tagMemRef);
1533  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1534  result.addOperands(tagIndices);
1535  result.addOperands(numElements);
1536  if (stride) {
1537  result.addOperands({stride, elementsPerStride});
1538  }
1539 }
1540 
1542  p << " " << getSrcMemRef() << '[';
1543  p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1544  p << "], " << getDstMemRef() << '[';
1545  p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1546  p << "], " << getTagMemRef() << '[';
1547  p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1548  p << "], " << getNumElements();
1549  if (isStrided()) {
1550  p << ", " << getStride();
1551  p << ", " << getNumElementsPerStride();
1552  }
1553  p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1554  << getTagMemRefType();
1555 }
1556 
1557 // Parse AffineDmaStartOp.
1558 // Ex:
1559 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1560 // %stride, %num_elt_per_stride
1561 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1562 //
1564  OperationState &result) {
1565  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1566  AffineMapAttr srcMapAttr;
1568  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1569  AffineMapAttr dstMapAttr;
1571  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1572  AffineMapAttr tagMapAttr;
1574  OpAsmParser::UnresolvedOperand numElementsInfo;
1576 
1577  SmallVector<Type, 3> types;
1578  auto indexType = parser.getBuilder().getIndexType();
1579 
1580  // Parse and resolve the following list of operands:
1581  // *) dst memref followed by its affine maps operands (in square brackets).
1582  // *) src memref followed by its affine map operands (in square brackets).
1583  // *) tag memref followed by its affine map operands (in square brackets).
1584  // *) number of elements transferred by DMA operation.
1585  if (parser.parseOperand(srcMemRefInfo) ||
1586  parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1587  getSrcMapAttrStrName(),
1588  result.attributes) ||
1589  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1590  parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1591  getDstMapAttrStrName(),
1592  result.attributes) ||
1593  parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1594  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1595  getTagMapAttrStrName(),
1596  result.attributes) ||
1597  parser.parseComma() || parser.parseOperand(numElementsInfo))
1598  return failure();
1599 
1600  // Parse optional stride and elements per stride.
1601  if (parser.parseTrailingOperandList(strideInfo))
1602  return failure();
1603 
1604  if (!strideInfo.empty() && strideInfo.size() != 2) {
1605  return parser.emitError(parser.getNameLoc(),
1606  "expected two stride related operands");
1607  }
1608  bool isStrided = strideInfo.size() == 2;
1609 
1610  if (parser.parseColonTypeList(types))
1611  return failure();
1612 
1613  if (types.size() != 3)
1614  return parser.emitError(parser.getNameLoc(), "expected three types");
1615 
1616  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1617  parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1618  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1619  parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1620  parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1621  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1622  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1623  return failure();
1624 
1625  if (isStrided) {
1626  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1627  return failure();
1628  }
1629 
1630  // Check that src/dst/tag operand counts match their map.numInputs.
1631  if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1632  dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1633  tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1634  return parser.emitError(parser.getNameLoc(),
1635  "memref operand count not equal to map.numInputs");
1636  return success();
1637 }
1638 
1639 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1640  if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
1641  return emitOpError("expected DMA source to be of memref type");
1642  if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
1643  return emitOpError("expected DMA destination to be of memref type");
1644  if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
1645  return emitOpError("expected DMA tag to be of memref type");
1646 
1647  unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1648  getDstMap().getNumInputs() +
1649  getTagMap().getNumInputs();
1650  if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1651  getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1652  return emitOpError("incorrect number of operands");
1653  }
1654 
1655  Region *scope = getAffineScope(*this);
1656  for (auto idx : getSrcIndices()) {
1657  if (!idx.getType().isIndex())
1658  return emitOpError("src index to dma_start must have 'index' type");
1659  if (!isValidAffineIndexOperand(idx, scope))
1660  return emitOpError(
1661  "src index must be a valid dimension or symbol identifier");
1662  }
1663  for (auto idx : getDstIndices()) {
1664  if (!idx.getType().isIndex())
1665  return emitOpError("dst index to dma_start must have 'index' type");
1666  if (!isValidAffineIndexOperand(idx, scope))
1667  return emitOpError(
1668  "dst index must be a valid dimension or symbol identifier");
1669  }
1670  for (auto idx : getTagIndices()) {
1671  if (!idx.getType().isIndex())
1672  return emitOpError("tag index to dma_start must have 'index' type");
1673  if (!isValidAffineIndexOperand(idx, scope))
1674  return emitOpError(
1675  "tag index must be a valid dimension or symbol identifier");
1676  }
1677  return success();
1678 }
1679 
1680 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1681  SmallVectorImpl<OpFoldResult> &results) {
1682  /// dma_start(memrefcast) -> dma_start
1683  return memref::foldMemRefCast(*this);
1684 }
1685 
1686 void AffineDmaStartOp::getEffects(
1688  &effects) {
1689  effects.emplace_back(MemoryEffects::Read::get(), getSrcMemRef(),
1691  effects.emplace_back(MemoryEffects::Write::get(), getDstMemRef(),
1693  effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
1695 }
1696 
1697 //===----------------------------------------------------------------------===//
1698 // AffineDmaWaitOp
1699 //===----------------------------------------------------------------------===//
1700 
1701 // TODO: Check that map operands are loop IVs or symbols.
1702 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1703  Value tagMemRef, AffineMap tagMap,
1704  ValueRange tagIndices, Value numElements) {
1705  result.addOperands(tagMemRef);
1706  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1707  result.addOperands(tagIndices);
1708  result.addOperands(numElements);
1709 }
1710 
1712  p << " " << getTagMemRef() << '[';
1713  SmallVector<Value, 2> operands(getTagIndices());
1714  p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1715  p << "], ";
1717  p << " : " << getTagMemRef().getType();
1718 }
1719 
1720 // Parse AffineDmaWaitOp.
1721 // Eg:
1722 // affine.dma_wait %tag[%index], %num_elements
1723 // : memref<1 x i32, (d0) -> (d0), 4>
1724 //
1726  OperationState &result) {
1727  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1728  AffineMapAttr tagMapAttr;
1730  Type type;
1731  auto indexType = parser.getBuilder().getIndexType();
1732  OpAsmParser::UnresolvedOperand numElementsInfo;
1733 
1734  // Parse tag memref, its map operands, and dma size.
1735  if (parser.parseOperand(tagMemRefInfo) ||
1736  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1737  getTagMapAttrStrName(),
1738  result.attributes) ||
1739  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1740  parser.parseColonType(type) ||
1741  parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1742  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1743  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1744  return failure();
1745 
1746  if (!llvm::isa<MemRefType>(type))
1747  return parser.emitError(parser.getNameLoc(),
1748  "expected tag to be of memref type");
1749 
1750  if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1751  return parser.emitError(parser.getNameLoc(),
1752  "tag memref operand count != to map.numInputs");
1753  return success();
1754 }
1755 
1756 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1757  if (!llvm::isa<MemRefType>(getOperand(0).getType()))
1758  return emitOpError("expected DMA tag to be of memref type");
1759  Region *scope = getAffineScope(*this);
1760  for (auto idx : getTagIndices()) {
1761  if (!idx.getType().isIndex())
1762  return emitOpError("index to dma_wait must have 'index' type");
1763  if (!isValidAffineIndexOperand(idx, scope))
1764  return emitOpError(
1765  "index must be a valid dimension or symbol identifier");
1766  }
1767  return success();
1768 }
1769 
1770 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1771  SmallVectorImpl<OpFoldResult> &results) {
1772  /// dma_wait(memrefcast) -> dma_wait
1773  return memref::foldMemRefCast(*this);
1774 }
1775 
1776 void AffineDmaWaitOp::getEffects(
1778  &effects) {
1779  effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
1781 }
1782 
1783 //===----------------------------------------------------------------------===//
1784 // AffineForOp
1785 //===----------------------------------------------------------------------===//
1786 
1787 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1788 /// bodyBuilder are empty/null, we include default terminator op.
1789 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1790  ValueRange lbOperands, AffineMap lbMap,
1791  ValueRange ubOperands, AffineMap ubMap, int64_t step,
1792  ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1793  assert(((!lbMap && lbOperands.empty()) ||
1794  lbOperands.size() == lbMap.getNumInputs()) &&
1795  "lower bound operand count does not match the affine map");
1796  assert(((!ubMap && ubOperands.empty()) ||
1797  ubOperands.size() == ubMap.getNumInputs()) &&
1798  "upper bound operand count does not match the affine map");
1799  assert(step > 0 && "step has to be a positive integer constant");
1800 
1801  // Set variadic segment sizes.
1802  result.addAttribute(
1803  getOperandSegmentSizeAttr(),
1804  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()),
1805  static_cast<int32_t>(ubOperands.size()),
1806  static_cast<int32_t>(iterArgs.size())}));
1807 
1808  for (Value val : iterArgs)
1809  result.addTypes(val.getType());
1810 
1811  // Add an attribute for the step.
1812  result.addAttribute(getStepAttrName(result.name),
1813  builder.getIntegerAttr(builder.getIndexType(), step));
1814 
1815  // Add the lower bound.
1816  result.addAttribute(getLowerBoundMapAttrName(result.name),
1817  AffineMapAttr::get(lbMap));
1818  result.addOperands(lbOperands);
1819 
1820  // Add the upper bound.
1821  result.addAttribute(getUpperBoundMapAttrName(result.name),
1822  AffineMapAttr::get(ubMap));
1823  result.addOperands(ubOperands);
1824 
1825  result.addOperands(iterArgs);
1826  // Create a region and a block for the body. The argument of the region is
1827  // the loop induction variable.
1828  Region *bodyRegion = result.addRegion();
1829  bodyRegion->push_back(new Block);
1830  Block &bodyBlock = bodyRegion->front();
1831  Value inductionVar =
1832  bodyBlock.addArgument(builder.getIndexType(), result.location);
1833  for (Value val : iterArgs)
1834  bodyBlock.addArgument(val.getType(), val.getLoc());
1835 
1836  // Create the default terminator if the builder is not provided and if the
1837  // iteration arguments are not provided. Otherwise, leave this to the caller
1838  // because we don't know which values to return from the loop.
1839  if (iterArgs.empty() && !bodyBuilder) {
1840  ensureTerminator(*bodyRegion, builder, result.location);
1841  } else if (bodyBuilder) {
1842  OpBuilder::InsertionGuard guard(builder);
1843  builder.setInsertionPointToStart(&bodyBlock);
1844  bodyBuilder(builder, result.location, inductionVar,
1845  bodyBlock.getArguments().drop_front());
1846  }
1847 }
1848 
1849 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1850  int64_t ub, int64_t step, ValueRange iterArgs,
1851  BodyBuilderFn bodyBuilder) {
1852  auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1853  auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1854  return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1855  bodyBuilder);
1856 }
1857 
1858 LogicalResult AffineForOp::verifyRegions() {
1859  // Check that the body defines as single block argument for the induction
1860  // variable.
1861  auto *body = getBody();
1862  if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1863  return emitOpError("expected body to have a single index argument for the "
1864  "induction variable");
1865 
1866  // Verify that the bound operands are valid dimension/symbols.
1867  /// Lower bound.
1868  if (getLowerBoundMap().getNumInputs() > 0)
1870  getLowerBoundMap().getNumDims())))
1871  return failure();
1872  /// Upper bound.
1873  if (getUpperBoundMap().getNumInputs() > 0)
1875  getUpperBoundMap().getNumDims())))
1876  return failure();
1877 
1878  unsigned opNumResults = getNumResults();
1879  if (opNumResults == 0)
1880  return success();
1881 
1882  // If ForOp defines values, check that the number and types of the defined
1883  // values match ForOp initial iter operands and backedge basic block
1884  // arguments.
1885  if (getNumIterOperands() != opNumResults)
1886  return emitOpError(
1887  "mismatch between the number of loop-carried values and results");
1888  if (getNumRegionIterArgs() != opNumResults)
1889  return emitOpError(
1890  "mismatch between the number of basic block args and results");
1891 
1892  return success();
1893 }
1894 
1895 /// Parse a for operation loop bounds.
1896 static ParseResult parseBound(bool isLower, OperationState &result,
1897  OpAsmParser &p) {
1898  // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1899  // the map has multiple results.
1900  bool failedToParsedMinMax =
1901  failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1902 
1903  auto &builder = p.getBuilder();
1904  auto boundAttrStrName =
1905  isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)
1906  : AffineForOp::getUpperBoundMapAttrName(result.name);
1907 
1908  // Parse ssa-id as identity map.
1910  if (p.parseOperandList(boundOpInfos))
1911  return failure();
1912 
1913  if (!boundOpInfos.empty()) {
1914  // Check that only one operand was parsed.
1915  if (boundOpInfos.size() > 1)
1916  return p.emitError(p.getNameLoc(),
1917  "expected only one loop bound operand");
1918 
1919  // TODO: improve error message when SSA value is not of index type.
1920  // Currently it is 'use of value ... expects different type than prior uses'
1921  if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1922  result.operands))
1923  return failure();
1924 
1925  // Create an identity map using symbol id. This representation is optimized
1926  // for storage. Analysis passes may expand it into a multi-dimensional map
1927  // if desired.
1928  AffineMap map = builder.getSymbolIdentityMap();
1929  result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
1930  return success();
1931  }
1932 
1933  // Get the attribute location.
1934  SMLoc attrLoc = p.getCurrentLocation();
1935 
1936  Attribute boundAttr;
1937  if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
1938  result.attributes))
1939  return failure();
1940 
1941  // Parse full form - affine map followed by dim and symbol list.
1942  if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
1943  unsigned currentNumOperands = result.operands.size();
1944  unsigned numDims;
1945  if (parseDimAndSymbolList(p, result.operands, numDims))
1946  return failure();
1947 
1948  auto map = affineMapAttr.getValue();
1949  if (map.getNumDims() != numDims)
1950  return p.emitError(
1951  p.getNameLoc(),
1952  "dim operand count and affine map dim count must match");
1953 
1954  unsigned numDimAndSymbolOperands =
1955  result.operands.size() - currentNumOperands;
1956  if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1957  return p.emitError(
1958  p.getNameLoc(),
1959  "symbol operand count and affine map symbol count must match");
1960 
1961  // If the map has multiple results, make sure that we parsed the min/max
1962  // prefix.
1963  if (map.getNumResults() > 1 && failedToParsedMinMax) {
1964  if (isLower) {
1965  return p.emitError(attrLoc, "lower loop bound affine map with "
1966  "multiple results requires 'max' prefix");
1967  }
1968  return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1969  "results requires 'min' prefix");
1970  }
1971  return success();
1972  }
1973 
1974  // Parse custom assembly form.
1975  if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
1976  result.attributes.pop_back();
1977  result.addAttribute(
1978  boundAttrStrName,
1979  AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
1980  return success();
1981  }
1982 
1983  return p.emitError(
1984  p.getNameLoc(),
1985  "expected valid affine map representation for loop bounds");
1986 }
1987 
1989  auto &builder = parser.getBuilder();
1990  OpAsmParser::Argument inductionVariable;
1991  inductionVariable.type = builder.getIndexType();
1992  // Parse the induction variable followed by '='.
1993  if (parser.parseArgument(inductionVariable) || parser.parseEqual())
1994  return failure();
1995 
1996  // Parse loop bounds.
1997  int64_t numOperands = result.operands.size();
1998  if (parseBound(/*isLower=*/true, result, parser))
1999  return failure();
2000  int64_t numLbOperands = result.operands.size() - numOperands;
2001  if (parser.parseKeyword("to", " between bounds"))
2002  return failure();
2003  numOperands = result.operands.size();
2004  if (parseBound(/*isLower=*/false, result, parser))
2005  return failure();
2006  int64_t numUbOperands = result.operands.size() - numOperands;
2007 
2008  // Parse the optional loop step, we default to 1 if one is not present.
2009  if (parser.parseOptionalKeyword("step")) {
2010  result.addAttribute(
2011  getStepAttrName(result.name),
2012  builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
2013  } else {
2014  SMLoc stepLoc = parser.getCurrentLocation();
2015  IntegerAttr stepAttr;
2016  if (parser.parseAttribute(stepAttr, builder.getIndexType(),
2017  getStepAttrName(result.name).data(),
2018  result.attributes))
2019  return failure();
2020 
2021  if (stepAttr.getValue().isNegative())
2022  return parser.emitError(
2023  stepLoc,
2024  "expected step to be representable as a positive signed integer");
2025  }
2026 
2027  // Parse the optional initial iteration arguments.
2030 
2031  // Induction variable.
2032  regionArgs.push_back(inductionVariable);
2033 
2034  if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
2035  // Parse assignment list and results type list.
2036  if (parser.parseAssignmentList(regionArgs, operands) ||
2037  parser.parseArrowTypeList(result.types))
2038  return failure();
2039  // Resolve input operands.
2040  for (auto argOperandType :
2041  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2042  Type type = std::get<2>(argOperandType);
2043  std::get<0>(argOperandType).type = type;
2044  if (parser.resolveOperand(std::get<1>(argOperandType), type,
2045  result.operands))
2046  return failure();
2047  }
2048  }
2049 
2050  result.addAttribute(
2051  getOperandSegmentSizeAttr(),
2052  builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands),
2053  static_cast<int32_t>(numUbOperands),
2054  static_cast<int32_t>(operands.size())}));
2055 
2056  // Parse the body region.
2057  Region *body = result.addRegion();
2058  if (regionArgs.size() != result.types.size() + 1)
2059  return parser.emitError(
2060  parser.getNameLoc(),
2061  "mismatch between the number of loop-carried values and results");
2062  if (parser.parseRegion(*body, regionArgs))
2063  return failure();
2064 
2065  AffineForOp::ensureTerminator(*body, builder, result.location);
2066 
2067  // Parse the optional attribute list.
2068  return parser.parseOptionalAttrDict(result.attributes);
2069 }
2070 
2071 static void printBound(AffineMapAttr boundMap,
2072  Operation::operand_range boundOperands,
2073  const char *prefix, OpAsmPrinter &p) {
2074  AffineMap map = boundMap.getValue();
2075 
2076  // Check if this bound should be printed using custom assembly form.
2077  // The decision to restrict printing custom assembly form to trivial cases
2078  // comes from the will to roundtrip MLIR binary -> text -> binary in a
2079  // lossless way.
2080  // Therefore, custom assembly form parsing and printing is only supported for
2081  // zero-operand constant maps and single symbol operand identity maps.
2082  if (map.getNumResults() == 1) {
2083  AffineExpr expr = map.getResult(0);
2084 
2085  // Print constant bound.
2086  if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2087  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2088  p << constExpr.getValue();
2089  return;
2090  }
2091  }
2092 
2093  // Print bound that consists of a single SSA symbol if the map is over a
2094  // single symbol.
2095  if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2096  if (dyn_cast<AffineSymbolExpr>(expr)) {
2097  p.printOperand(*boundOperands.begin());
2098  return;
2099  }
2100  }
2101  } else {
2102  // Map has multiple results. Print 'min' or 'max' prefix.
2103  p << prefix << ' ';
2104  }
2105 
2106  // Print the map and its operands.
2107  p << boundMap;
2108  printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
2109  map.getNumDims(), p);
2110 }
2111 
2112 unsigned AffineForOp::getNumIterOperands() {
2113  AffineMap lbMap = getLowerBoundMapAttr().getValue();
2114  AffineMap ubMap = getUpperBoundMapAttr().getValue();
2115 
2116  return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2117 }
2118 
2119 MutableArrayRef<OpOperand> AffineForOp::getYieldedValuesMutable() {
2120  return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2121 }
2122 
2124  p << ' ';
2125  p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
2126  /*omitType=*/true);
2127  p << " = ";
2128  printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
2129  p << " to ";
2130  printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
2131 
2132  if (getStepAsInt() != 1)
2133  p << " step " << getStepAsInt();
2134 
2135  bool printBlockTerminators = false;
2136  if (getNumIterOperands() > 0) {
2137  p << " iter_args(";
2138  auto regionArgs = getRegionIterArgs();
2139  auto operands = getInits();
2140 
2141  llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2142  p << std::get<0>(it) << " = " << std::get<1>(it);
2143  });
2144  p << ") -> (" << getResultTypes() << ")";
2145  printBlockTerminators = true;
2146  }
2147 
2148  p << ' ';
2149  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2150  printBlockTerminators);
2152  (*this)->getAttrs(),
2153  /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()),
2154  getUpperBoundMapAttrName(getOperation()->getName()),
2155  getStepAttrName(getOperation()->getName()),
2156  getOperandSegmentSizeAttr()});
2157 }
2158 
2159 /// Fold the constant bounds of a loop.
2160 static LogicalResult foldLoopBounds(AffineForOp forOp) {
2161  auto foldLowerOrUpperBound = [&forOp](bool lower) {
2162  // Check to see if each of the operands is the result of a constant. If
2163  // so, get the value. If not, ignore it.
2164  SmallVector<Attribute, 8> operandConstants;
2165  auto boundOperands =
2166  lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2167  for (auto operand : boundOperands) {
2168  Attribute operandCst;
2169  matchPattern(operand, m_Constant(&operandCst));
2170  operandConstants.push_back(operandCst);
2171  }
2172 
2173  AffineMap boundMap =
2174  lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2175  assert(boundMap.getNumResults() >= 1 &&
2176  "bound maps should have at least one result");
2177  SmallVector<Attribute, 4> foldedResults;
2178  if (failed(boundMap.constantFold(operandConstants, foldedResults)))
2179  return failure();
2180 
2181  // Compute the max or min as applicable over the results.
2182  assert(!foldedResults.empty() && "bounds should have at least one result");
2183  auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2184  for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2185  auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2186  maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2187  : llvm::APIntOps::smin(maxOrMin, foldedResult);
2188  }
2189  lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2190  : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2191  return success();
2192  };
2193 
2194  // Try to fold the lower bound.
2195  bool folded = false;
2196  if (!forOp.hasConstantLowerBound())
2197  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
2198 
2199  // Try to fold the upper bound.
2200  if (!forOp.hasConstantUpperBound())
2201  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
2202  return success(folded);
2203 }
2204 
2205 /// Canonicalize the bounds of the given loop.
2206 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2207  SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2208  SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2209 
2210  auto lbMap = forOp.getLowerBoundMap();
2211  auto ubMap = forOp.getUpperBoundMap();
2212  auto prevLbMap = lbMap;
2213  auto prevUbMap = ubMap;
2214 
2215  composeAffineMapAndOperands(&lbMap, &lbOperands);
2216  canonicalizeMapAndOperands(&lbMap, &lbOperands);
2217  simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
2218  simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
2219  lbMap = removeDuplicateExprs(lbMap);
2220 
2221  composeAffineMapAndOperands(&ubMap, &ubOperands);
2222  canonicalizeMapAndOperands(&ubMap, &ubOperands);
2223  ubMap = removeDuplicateExprs(ubMap);
2224 
2225  // Any canonicalization change always leads to updated map(s).
2226  if (lbMap == prevLbMap && ubMap == prevUbMap)
2227  return failure();
2228 
2229  if (lbMap != prevLbMap)
2230  forOp.setLowerBound(lbOperands, lbMap);
2231  if (ubMap != prevUbMap)
2232  forOp.setUpperBound(ubOperands, ubMap);
2233  return success();
2234 }
2235 
2236 namespace {
2237 /// Returns constant trip count in trivial cases.
2238 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2239  int64_t step = forOp.getStepAsInt();
2240  if (!forOp.hasConstantBounds() || step <= 0)
2241  return std::nullopt;
2242  int64_t lb = forOp.getConstantLowerBound();
2243  int64_t ub = forOp.getConstantUpperBound();
2244  return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2245 }
2246 
2247 /// This is a pattern to fold trivially empty loop bodies.
2248 /// TODO: This should be moved into the folding hook.
2249 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
2251 
2252  LogicalResult matchAndRewrite(AffineForOp forOp,
2253  PatternRewriter &rewriter) const override {
2254  // Check that the body only contains a yield.
2255  if (!llvm::hasSingleElement(*forOp.getBody()))
2256  return failure();
2257  if (forOp.getNumResults() == 0)
2258  return success();
2259  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2260  if (tripCount && *tripCount == 0) {
2261  // The initial values of the iteration arguments would be the op's
2262  // results.
2263  rewriter.replaceOp(forOp, forOp.getInits());
2264  return success();
2265  }
2266  SmallVector<Value, 4> replacements;
2267  auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2268  auto iterArgs = forOp.getRegionIterArgs();
2269  bool hasValDefinedOutsideLoop = false;
2270  bool iterArgsNotInOrder = false;
2271  for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2272  Value val = yieldOp.getOperand(i);
2273  auto *iterArgIt = llvm::find(iterArgs, val);
2274  if (iterArgIt == iterArgs.end()) {
2275  // `val` is defined outside of the loop.
2276  assert(forOp.isDefinedOutsideOfLoop(val) &&
2277  "must be defined outside of the loop");
2278  hasValDefinedOutsideLoop = true;
2279  replacements.push_back(val);
2280  } else {
2281  unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2282  if (pos != i)
2283  iterArgsNotInOrder = true;
2284  replacements.push_back(forOp.getInits()[pos]);
2285  }
2286  }
2287  // Bail out when the trip count is unknown and the loop returns any value
2288  // defined outside of the loop or any iterArg out of order.
2289  if (!tripCount.has_value() &&
2290  (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2291  return failure();
2292  // Bail out when the loop iterates more than once and it returns any iterArg
2293  // out of order.
2294  if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2295  return failure();
2296  rewriter.replaceOp(forOp, replacements);
2297  return success();
2298  }
2299 };
2300 } // namespace
2301 
2302 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
2303  MLIRContext *context) {
2304  results.add<AffineForEmptyLoopFolder>(context);
2305 }
2306 
2307 OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2308  assert((point.isParent() || point == getRegion()) && "invalid region point");
2309 
2310  // The initial operands map to the loop arguments after the induction
2311  // variable or are forwarded to the results when the trip count is zero.
2312  return getInits();
2313 }
2314 
2315 void AffineForOp::getSuccessorRegions(
2317  assert((point.isParent() || point == getRegion()) && "expected loop region");
2318  // The loop may typically branch back to its body or to the parent operation.
2319  // If the predecessor is the parent op and the trip count is known to be at
2320  // least one, branch into the body using the iterator arguments. And in cases
2321  // we know the trip count is zero, it can only branch back to its parent.
2322  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2323  if (point.isParent() && tripCount.has_value()) {
2324  if (tripCount.value() > 0) {
2325  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2326  return;
2327  }
2328  if (tripCount.value() == 0) {
2329  regions.push_back(RegionSuccessor(getResults()));
2330  return;
2331  }
2332  }
2333 
2334  // From the loop body, if the trip count is one, we can only branch back to
2335  // the parent.
2336  if (!point.isParent() && tripCount && *tripCount == 1) {
2337  regions.push_back(RegionSuccessor(getResults()));
2338  return;
2339  }
2340 
2341  // In all other cases, the loop may branch back to itself or the parent
2342  // operation.
2343  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2344  regions.push_back(RegionSuccessor(getResults()));
2345 }
2346 
2347 /// Returns true if the affine.for has zero iterations in trivial cases.
2348 static bool hasTrivialZeroTripCount(AffineForOp op) {
2349  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
2350  return tripCount && *tripCount == 0;
2351 }
2352 
2353 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2354  SmallVectorImpl<OpFoldResult> &results) {
2355  bool folded = succeeded(foldLoopBounds(*this));
2356  folded |= succeeded(canonicalizeLoopBounds(*this));
2357  if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
2358  // The initial values of the loop-carried variables (iter_args) are the
2359  // results of the op. But this must be avoided for an affine.for op that
2360  // does not return any results. Since ops that do not return results cannot
2361  // be folded away, we would enter an infinite loop of folds on the same
2362  // affine.for op.
2363  results.assign(getInits().begin(), getInits().end());
2364  folded = true;
2365  }
2366  return success(folded);
2367 }
2368 
2370  return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
2371 }
2372 
2374  return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap());
2375 }
2376 
2377 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2378  assert(lbOperands.size() == map.getNumInputs());
2379  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2380  getLowerBoundOperandsMutable().assign(lbOperands);
2381  setLowerBoundMap(map);
2382 }
2383 
2384 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2385  assert(ubOperands.size() == map.getNumInputs());
2386  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2387  getUpperBoundOperandsMutable().assign(ubOperands);
2388  setUpperBoundMap(map);
2389 }
2390 
2391 bool AffineForOp::hasConstantLowerBound() {
2392  return getLowerBoundMap().isSingleConstant();
2393 }
2394 
2395 bool AffineForOp::hasConstantUpperBound() {
2396  return getUpperBoundMap().isSingleConstant();
2397 }
2398 
2399 int64_t AffineForOp::getConstantLowerBound() {
2400  return getLowerBoundMap().getSingleConstantResult();
2401 }
2402 
2403 int64_t AffineForOp::getConstantUpperBound() {
2404  return getUpperBoundMap().getSingleConstantResult();
2405 }
2406 
2407 void AffineForOp::setConstantLowerBound(int64_t value) {
2408  setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2409 }
2410 
2411 void AffineForOp::setConstantUpperBound(int64_t value) {
2412  setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2413 }
2414 
2415 AffineForOp::operand_range AffineForOp::getControlOperands() {
2416  return {operand_begin(), operand_begin() + getLowerBoundOperands().size() +
2417  getUpperBoundOperands().size()};
2418 }
2419 
2420 bool AffineForOp::matchingBoundOperandList() {
2421  auto lbMap = getLowerBoundMap();
2422  auto ubMap = getUpperBoundMap();
2423  if (lbMap.getNumDims() != ubMap.getNumDims() ||
2424  lbMap.getNumSymbols() != ubMap.getNumSymbols())
2425  return false;
2426 
2427  unsigned numOperands = lbMap.getNumInputs();
2428  for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2429  // Compare Value 's.
2430  if (getOperand(i) != getOperand(numOperands + i))
2431  return false;
2432  }
2433  return true;
2434 }
2435 
2436 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
2437 
2438 std::optional<Value> AffineForOp::getSingleInductionVar() {
2439  return getInductionVar();
2440 }
2441 
2442 std::optional<OpFoldResult> AffineForOp::getSingleLowerBound() {
2443  if (!hasConstantLowerBound())
2444  return std::nullopt;
2445  OpBuilder b(getContext());
2446  return OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()));
2447 }
2448 
2449 std::optional<OpFoldResult> AffineForOp::getSingleStep() {
2450  OpBuilder b(getContext());
2451  return OpFoldResult(b.getI64IntegerAttr(getStepAsInt()));
2452 }
2453 
2454 std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
2455  if (!hasConstantUpperBound())
2456  return std::nullopt;
2457  OpBuilder b(getContext());
2458  return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
2459 }
2460 
2461 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2462  RewriterBase &rewriter, ValueRange newInitOperands,
2463  bool replaceInitOperandUsesInLoop,
2464  const NewYieldValuesFn &newYieldValuesFn) {
2465  // Create a new loop before the existing one, with the extra operands.
2466  OpBuilder::InsertionGuard g(rewriter);
2467  rewriter.setInsertionPoint(getOperation());
2468  auto inits = llvm::to_vector(getInits());
2469  inits.append(newInitOperands.begin(), newInitOperands.end());
2470  AffineForOp newLoop = rewriter.create<AffineForOp>(
2471  getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2472  getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
2473 
2474  // Generate the new yield values and append them to the scf.yield operation.
2475  auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2476  ArrayRef<BlockArgument> newIterArgs =
2477  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2478  {
2479  OpBuilder::InsertionGuard g(rewriter);
2480  rewriter.setInsertionPoint(yieldOp);
2481  SmallVector<Value> newYieldedValues =
2482  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2483  assert(newInitOperands.size() == newYieldedValues.size() &&
2484  "expected as many new yield values as new iter operands");
2485  rewriter.updateRootInPlace(yieldOp, [&]() {
2486  yieldOp.getOperandsMutable().append(newYieldedValues);
2487  });
2488  }
2489 
2490  // Move the loop body to the new op.
2491  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2492  newLoop.getBody()->getArguments().take_front(
2493  getBody()->getNumArguments()));
2494 
2495  if (replaceInitOperandUsesInLoop) {
2496  // Replace all uses of `newInitOperands` with the corresponding basic block
2497  // arguments.
2498  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2499  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
2500  [&](OpOperand &use) {
2501  Operation *user = use.getOwner();
2502  return newLoop->isProperAncestor(user);
2503  });
2504  }
2505  }
2506 
2507  // Replace the old loop.
2508  rewriter.replaceOp(getOperation(),
2509  newLoop->getResults().take_front(getNumResults()));
2510  return cast<LoopLikeOpInterface>(newLoop.getOperation());
2511 }
2512 
2513 Speculation::Speculatability AffineForOp::getSpeculatability() {
2514  // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2515  // Start and End.
2516  //
2517  // For Step != 1, the loop may not terminate. We can add more smarts here if
2518  // needed.
2519  return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable
2521 }
2522 
2523 /// Returns true if the provided value is the induction variable of a
2524 /// AffineForOp.
2526  return getForInductionVarOwner(val) != AffineForOp();
2527 }
2528 
2530  return getAffineParallelInductionVarOwner(val) != nullptr;
2531 }
2532 
2535 }
2536 
2538  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2539  if (!ivArg || !ivArg.getOwner())
2540  return AffineForOp();
2541  auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
2542  if (auto forOp = dyn_cast<AffineForOp>(containingInst))
2543  // Check to make sure `val` is the induction variable, not an iter_arg.
2544  return forOp.getInductionVar() == val ? forOp : AffineForOp();
2545  return AffineForOp();
2546 }
2547 
2549  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2550  if (!ivArg || !ivArg.getOwner())
2551  return nullptr;
2552  Operation *containingOp = ivArg.getOwner()->getParentOp();
2553  auto parallelOp = dyn_cast<AffineParallelOp>(containingOp);
2554  if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2555  return parallelOp;
2556  return nullptr;
2557 }
2558 
2559 /// Extracts the induction variables from a list of AffineForOps and returns
2560 /// them.
2562  SmallVectorImpl<Value> *ivs) {
2563  ivs->reserve(forInsts.size());
2564  for (auto forInst : forInsts)
2565  ivs->push_back(forInst.getInductionVar());
2566 }
2567 
2570  ivs.reserve(affineOps.size());
2571  for (Operation *op : affineOps) {
2572  // Add constraints from forOp's bounds.
2573  if (auto forOp = dyn_cast<AffineForOp>(op))
2574  ivs.push_back(forOp.getInductionVar());
2575  else if (auto parallelOp = dyn_cast<AffineParallelOp>(op))
2576  for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2577  ivs.push_back(parallelOp.getBody()->getArgument(i));
2578  }
2579 }
2580 
2581 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2582 /// operations.
2583 template <typename BoundListTy, typename LoopCreatorTy>
2585  OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2586  ArrayRef<int64_t> steps,
2587  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2588  LoopCreatorTy &&loopCreatorFn) {
2589  assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2590  assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2591 
2592  // If there are no loops to be constructed, construct the body anyway.
2593  OpBuilder::InsertionGuard guard(builder);
2594  if (lbs.empty()) {
2595  if (bodyBuilderFn)
2596  bodyBuilderFn(builder, loc, ValueRange());
2597  return;
2598  }
2599 
2600  // Create the loops iteratively and store the induction variables.
2602  ivs.reserve(lbs.size());
2603  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2604  // Callback for creating the loop body, always creates the terminator.
2605  auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
2606  ValueRange iterArgs) {
2607  ivs.push_back(iv);
2608  // In the innermost loop, call the body builder.
2609  if (i == e - 1 && bodyBuilderFn) {
2610  OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
2611  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2612  }
2613  nestedBuilder.create<AffineYieldOp>(nestedLoc);
2614  };
2615 
2616  // Delegate actual loop creation to the callback in order to dispatch
2617  // between constant- and variable-bound loops.
2618  auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2619  builder.setInsertionPointToStart(loop.getBody());
2620  }
2621 }
2622 
2623 /// Creates an affine loop from the bounds known to be constants.
2624 static AffineForOp
2626  int64_t ub, int64_t step,
2627  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2628  return builder.create<AffineForOp>(loc, lb, ub, step,
2629  /*iterArgs=*/std::nullopt, bodyBuilderFn);
2630 }
2631 
2632 /// Creates an affine loop from the bounds that may or may not be constants.
2633 static AffineForOp
2635  int64_t step,
2636  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2637  std::optional<int64_t> lbConst = getConstantIntValue(lb);
2638  std::optional<int64_t> ubConst = getConstantIntValue(ub);
2639  if (lbConst && ubConst)
2640  return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
2641  ubConst.value(), step, bodyBuilderFn);
2642  return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
2643  builder.getDimIdentityMap(), step,
2644  /*iterArgs=*/std::nullopt, bodyBuilderFn);
2645 }
2646 
2648  OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
2650  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2651  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2653 }
2654 
2656  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
2657  ArrayRef<int64_t> steps,
2658  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2659  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2661 }
2662 
2663 //===----------------------------------------------------------------------===//
2664 // AffineIfOp
2665 //===----------------------------------------------------------------------===//
2666 
2667 namespace {
2668 /// Remove else blocks that have nothing other than a zero value yield.
2669 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
2671 
2672  LogicalResult matchAndRewrite(AffineIfOp ifOp,
2673  PatternRewriter &rewriter) const override {
2674  if (ifOp.getElseRegion().empty() ||
2675  !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2676  return failure();
2677 
2678  rewriter.startRootUpdate(ifOp);
2679  rewriter.eraseBlock(ifOp.getElseBlock());
2680  rewriter.finalizeRootUpdate(ifOp);
2681  return success();
2682  }
2683 };
2684 
2685 /// Removes affine.if cond if the condition is always true or false in certain
2686 /// trivial cases. Promotes the then/else block in the parent operation block.
2687 struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
2689 
2690  LogicalResult matchAndRewrite(AffineIfOp op,
2691  PatternRewriter &rewriter) const override {
2692 
2693  auto isTriviallyFalse = [](IntegerSet iSet) {
2694  return iSet.isEmptyIntegerSet();
2695  };
2696 
2697  auto isTriviallyTrue = [](IntegerSet iSet) {
2698  return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2699  iSet.getConstraint(0) == 0);
2700  };
2701 
2702  IntegerSet affineIfConditions = op.getIntegerSet();
2703  Block *blockToMove;
2704  if (isTriviallyFalse(affineIfConditions)) {
2705  // The absence, or equivalently, the emptiness of the else region need not
2706  // be checked when affine.if is returning results because if an affine.if
2707  // operation is returning results, it always has a non-empty else region.
2708  if (op.getNumResults() == 0 && !op.hasElse()) {
2709  // If the else region is absent, or equivalently, empty, remove the
2710  // affine.if operation (which is not returning any results).
2711  rewriter.eraseOp(op);
2712  return success();
2713  }
2714  blockToMove = op.getElseBlock();
2715  } else if (isTriviallyTrue(affineIfConditions)) {
2716  blockToMove = op.getThenBlock();
2717  } else {
2718  return failure();
2719  }
2720  Operation *blockToMoveTerminator = blockToMove->getTerminator();
2721  // Promote the "blockToMove" block to the parent operation block between the
2722  // prologue and epilogue of "op".
2723  rewriter.inlineBlockBefore(blockToMove, op);
2724  // Replace the "op" operation with the operands of the
2725  // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2726  // the affine.yield operation present in the "blockToMove" block. It has no
2727  // operands when affine.if is not returning results and therefore, in that
2728  // case, replaceOp just erases "op". When affine.if is not returning
2729  // results, the affine.yield operation can be omitted. It gets inserted
2730  // implicitly.
2731  rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
2732  // Erase the "blockToMoveTerminator" operation since it is now in the parent
2733  // operation block, which already has its own terminator.
2734  rewriter.eraseOp(blockToMoveTerminator);
2735  return success();
2736  }
2737 };
2738 } // namespace
2739 
2740 /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
2741 /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
2742 void AffineIfOp::getSuccessorRegions(
2744  // If the predecessor is an AffineIfOp, then branching into both `then` and
2745  // `else` region is valid.
2746  if (point.isParent()) {
2747  regions.reserve(2);
2748  regions.push_back(
2749  RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
2750  // If the "else" region is empty, branch bach into parent.
2751  if (getElseRegion().empty()) {
2752  regions.push_back(getResults());
2753  } else {
2754  regions.push_back(
2755  RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
2756  }
2757  return;
2758  }
2759 
2760  // If the predecessor is the `else`/`then` region, then branching into parent
2761  // op is valid.
2762  regions.push_back(RegionSuccessor(getResults()));
2763 }
2764 
2766  // Verify that we have a condition attribute.
2767  // FIXME: This should be specified in the arguments list in ODS.
2768  auto conditionAttr =
2769  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2770  if (!conditionAttr)
2771  return emitOpError("requires an integer set attribute named 'condition'");
2772 
2773  // Verify that there are enough operands for the condition.
2774  IntegerSet condition = conditionAttr.getValue();
2775  if (getNumOperands() != condition.getNumInputs())
2776  return emitOpError("operand count and condition integer set dimension and "
2777  "symbol count must match");
2778 
2779  // Verify that the operands are valid dimension/symbols.
2780  if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
2781  condition.getNumDims())))
2782  return failure();
2783 
2784  return success();
2785 }
2786 
2788  // Parse the condition attribute set.
2789  IntegerSetAttr conditionAttr;
2790  unsigned numDims;
2791  if (parser.parseAttribute(conditionAttr,
2792  AffineIfOp::getConditionAttrStrName(),
2793  result.attributes) ||
2794  parseDimAndSymbolList(parser, result.operands, numDims))
2795  return failure();
2796 
2797  // Verify the condition operands.
2798  auto set = conditionAttr.getValue();
2799  if (set.getNumDims() != numDims)
2800  return parser.emitError(
2801  parser.getNameLoc(),
2802  "dim operand count and integer set dim count must match");
2803  if (numDims + set.getNumSymbols() != result.operands.size())
2804  return parser.emitError(
2805  parser.getNameLoc(),
2806  "symbol operand count and integer set symbol count must match");
2807 
2808  if (parser.parseOptionalArrowTypeList(result.types))
2809  return failure();
2810 
2811  // Create the regions for 'then' and 'else'. The latter must be created even
2812  // if it remains empty for the validity of the operation.
2813  result.regions.reserve(2);
2814  Region *thenRegion = result.addRegion();
2815  Region *elseRegion = result.addRegion();
2816 
2817  // Parse the 'then' region.
2818  if (parser.parseRegion(*thenRegion, {}, {}))
2819  return failure();
2820  AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
2821  result.location);
2822 
2823  // If we find an 'else' keyword then parse the 'else' region.
2824  if (!parser.parseOptionalKeyword("else")) {
2825  if (parser.parseRegion(*elseRegion, {}, {}))
2826  return failure();
2827  AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
2828  result.location);
2829  }
2830 
2831  // Parse the optional attribute list.
2832  if (parser.parseOptionalAttrDict(result.attributes))
2833  return failure();
2834 
2835  return success();
2836 }
2837 
2839  auto conditionAttr =
2840  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2841  p << " " << conditionAttr;
2842  printDimAndSymbolList(operand_begin(), operand_end(),
2843  conditionAttr.getValue().getNumDims(), p);
2844  p.printOptionalArrowTypeList(getResultTypes());
2845  p << ' ';
2846  p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
2847  /*printBlockTerminators=*/getNumResults());
2848 
2849  // Print the 'else' regions if it has any blocks.
2850  auto &elseRegion = this->getElseRegion();
2851  if (!elseRegion.empty()) {
2852  p << " else ";
2853  p.printRegion(elseRegion,
2854  /*printEntryBlockArgs=*/false,
2855  /*printBlockTerminators=*/getNumResults());
2856  }
2857 
2858  // Print the attribute list.
2859  p.printOptionalAttrDict((*this)->getAttrs(),
2860  /*elidedAttrs=*/getConditionAttrStrName());
2861 }
2862 
2863 IntegerSet AffineIfOp::getIntegerSet() {
2864  return (*this)
2865  ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2866  .getValue();
2867 }
2868 
2869 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
2870  (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
2871 }
2872 
2873 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
2874  setIntegerSet(set);
2875  (*this)->setOperands(operands);
2876 }
2877 
2878 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2879  TypeRange resultTypes, IntegerSet set, ValueRange args,
2880  bool withElseRegion) {
2881  assert(resultTypes.empty() || withElseRegion);
2882  result.addTypes(resultTypes);
2883  result.addOperands(args);
2884  result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
2885 
2886  Region *thenRegion = result.addRegion();
2887  thenRegion->push_back(new Block());
2888  if (resultTypes.empty())
2889  AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
2890 
2891  Region *elseRegion = result.addRegion();
2892  if (withElseRegion) {
2893  elseRegion->push_back(new Block());
2894  if (resultTypes.empty())
2895  AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
2896  }
2897 }
2898 
2899 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2900  IntegerSet set, ValueRange args, bool withElseRegion) {
2901  AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
2902  withElseRegion);
2903 }
2904 
2905 /// Compose any affine.apply ops feeding into `operands` of the integer set
2906 /// `set` by composing the maps of such affine.apply ops with the integer
2907 /// set constraints.
2909  SmallVectorImpl<Value> &operands) {
2910  // We will simply reuse the API of the map composition by viewing the LHSs of
2911  // the equalities and inequalities of `set` as the affine exprs of an affine
2912  // map. Convert to equivalent map, compose, and convert back to set.
2913  auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(),
2914  set.getConstraints(), set.getContext());
2915  // Check if any composition is possible.
2916  if (llvm::none_of(operands,
2917  [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
2918  return;
2919 
2920  composeAffineMapAndOperands(&map, &operands);
2921  set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
2922  set.getEqFlags());
2923 }
2924 
2925 /// Canonicalize an affine if op's conditional (integer set + operands).
2926 LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2927  auto set = getIntegerSet();
2928  SmallVector<Value, 4> operands(getOperands());
2929  composeSetAndOperands(set, operands);
2930  canonicalizeSetAndOperands(&set, &operands);
2931 
2932  // Check if the canonicalization or composition led to any change.
2933  if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
2934  return failure();
2935 
2936  setConditional(set, operands);
2937  return success();
2938 }
2939 
2940 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2941  MLIRContext *context) {
2942  results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
2943 }
2944 
2945 //===----------------------------------------------------------------------===//
2946 // AffineLoadOp
2947 //===----------------------------------------------------------------------===//
2948 
2949 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2950  AffineMap map, ValueRange operands) {
2951  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2952  result.addOperands(operands);
2953  if (map)
2954  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2955  auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
2956  result.types.push_back(memrefType.getElementType());
2957 }
2958 
2959 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2960  Value memref, AffineMap map, ValueRange mapOperands) {
2961  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2962  result.addOperands(memref);
2963  result.addOperands(mapOperands);
2964  auto memrefType = llvm::cast<MemRefType>(memref.getType());
2965  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
2966  result.types.push_back(memrefType.getElementType());
2967 }
2968 
2969 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2970  Value memref, ValueRange indices) {
2971  auto memrefType = llvm::cast<MemRefType>(memref.getType());
2972  int64_t rank = memrefType.getRank();
2973  // Create identity map for memrefs with at least one dimension or () -> ()
2974  // for zero-dimensional memrefs.
2975  auto map =
2976  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2977  build(builder, result, memref, map, indices);
2978 }
2979 
2981  auto &builder = parser.getBuilder();
2982  auto indexTy = builder.getIndexType();
2983 
2984  MemRefType type;
2985  OpAsmParser::UnresolvedOperand memrefInfo;
2986  AffineMapAttr mapAttr;
2988  return failure(
2989  parser.parseOperand(memrefInfo) ||
2990  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2991  AffineLoadOp::getMapAttrStrName(),
2992  result.attributes) ||
2993  parser.parseOptionalAttrDict(result.attributes) ||
2994  parser.parseColonType(type) ||
2995  parser.resolveOperand(memrefInfo, type, result.operands) ||
2996  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
2997  parser.addTypeToList(type.getElementType(), result.types));
2998 }
2999 
3001  p << " " << getMemRef() << '[';
3002  if (AffineMapAttr mapAttr =
3003  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3004  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3005  p << ']';
3006  p.printOptionalAttrDict((*this)->getAttrs(),
3007  /*elidedAttrs=*/{getMapAttrStrName()});
3008  p << " : " << getMemRefType();
3009 }
3010 
3011 /// Verify common indexing invariants of affine.load, affine.store,
3012 /// affine.vector_load and affine.vector_store.
3013 static LogicalResult
3014 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
3015  Operation::operand_range mapOperands,
3016  MemRefType memrefType, unsigned numIndexOperands) {
3017  AffineMap map = mapAttr.getValue();
3018  if (map.getNumResults() != memrefType.getRank())
3019  return op->emitOpError("affine map num results must equal memref rank");
3020  if (map.getNumInputs() != numIndexOperands)
3021  return op->emitOpError("expects as many subscripts as affine map inputs");
3022 
3023  Region *scope = getAffineScope(op);
3024  for (auto idx : mapOperands) {
3025  if (!idx.getType().isIndex())
3026  return op->emitOpError("index to load must have 'index' type");
3027  if (!isValidAffineIndexOperand(idx, scope))
3028  return op->emitOpError(
3029  "index must be a valid dimension or symbol identifier");
3030  }
3031 
3032  return success();
3033 }
3034 
3036  auto memrefType = getMemRefType();
3037  if (getType() != memrefType.getElementType())
3038  return emitOpError("result type must match element type of memref");
3039 
3041  getOperation(),
3042  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3043  getMapOperands(), memrefType,
3044  /*numIndexOperands=*/getNumOperands() - 1)))
3045  return failure();
3046 
3047  return success();
3048 }
3049 
3050 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3051  MLIRContext *context) {
3052  results.add<SimplifyAffineOp<AffineLoadOp>>(context);
3053 }
3054 
3055 OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3056  /// load(memrefcast) -> load
3057  if (succeeded(memref::foldMemRefCast(*this)))
3058  return getResult();
3059 
3060  // Fold load from a global constant memref.
3061  auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3062  if (!getGlobalOp)
3063  return {};
3064  // Get to the memref.global defining the symbol.
3065  auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3066  if (!symbolTableOp)
3067  return {};
3068  auto global = dyn_cast_or_null<memref::GlobalOp>(
3069  SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
3070  if (!global)
3071  return {};
3072 
3073  // Check if the global memref is a constant.
3074  auto cstAttr =
3075  llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3076  if (!cstAttr)
3077  return {};
3078  // If it's a splat constant, we can fold irrespective of indices.
3079  if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3080  return splatAttr.getSplatValue<Attribute>();
3081  // Otherwise, we can fold only if we know the indices.
3082  if (!getAffineMap().isConstant())
3083  return {};
3084  auto indices = llvm::to_vector<4>(
3085  llvm::map_range(getAffineMap().getConstantResults(),
3086  [](int64_t v) -> uint64_t { return v; }));
3087  return cstAttr.getValues<Attribute>()[indices];
3088 }
3089 
3090 //===----------------------------------------------------------------------===//
3091 // AffineStoreOp
3092 //===----------------------------------------------------------------------===//
3093 
3094 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3095  Value valueToStore, Value memref, AffineMap map,
3096  ValueRange mapOperands) {
3097  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3098  result.addOperands(valueToStore);
3099  result.addOperands(memref);
3100  result.addOperands(mapOperands);
3101  result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3102 }
3103 
3104 // Use identity map.
3105 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3106  Value valueToStore, Value memref,
3107  ValueRange indices) {
3108  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3109  int64_t rank = memrefType.getRank();
3110  // Create identity map for memrefs with at least one dimension or () -> ()
3111  // for zero-dimensional memrefs.
3112  auto map =
3113  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3114  build(builder, result, valueToStore, memref, map, indices);
3115 }
3116 
3118  auto indexTy = parser.getBuilder().getIndexType();
3119 
3120  MemRefType type;
3121  OpAsmParser::UnresolvedOperand storeValueInfo;
3122  OpAsmParser::UnresolvedOperand memrefInfo;
3123  AffineMapAttr mapAttr;
3125  return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3126  parser.parseOperand(memrefInfo) ||
3127  parser.parseAffineMapOfSSAIds(
3128  mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3129  result.attributes) ||
3130  parser.parseOptionalAttrDict(result.attributes) ||
3131  parser.parseColonType(type) ||
3132  parser.resolveOperand(storeValueInfo, type.getElementType(),
3133  result.operands) ||
3134  parser.resolveOperand(memrefInfo, type, result.operands) ||
3135  parser.resolveOperands(mapOperands, indexTy, result.operands));
3136 }
3137 
3139  p << " " << getValueToStore();
3140  p << ", " << getMemRef() << '[';
3141  if (AffineMapAttr mapAttr =
3142  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3143  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3144  p << ']';
3145  p.printOptionalAttrDict((*this)->getAttrs(),
3146  /*elidedAttrs=*/{getMapAttrStrName()});
3147  p << " : " << getMemRefType();
3148 }
3149 
3151  // The value to store must have the same type as memref element type.
3152  auto memrefType = getMemRefType();
3153  if (getValueToStore().getType() != memrefType.getElementType())
3154  return emitOpError(
3155  "value to store must have the same type as memref element type");
3156 
3158  getOperation(),
3159  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3160  getMapOperands(), memrefType,
3161  /*numIndexOperands=*/getNumOperands() - 2)))
3162  return failure();
3163 
3164  return success();
3165 }
3166 
3167 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3168  MLIRContext *context) {
3169  results.add<SimplifyAffineOp<AffineStoreOp>>(context);
3170 }
3171 
3172 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3173  SmallVectorImpl<OpFoldResult> &results) {
3174  /// store(memrefcast) -> store
3175  return memref::foldMemRefCast(*this, getValueToStore());
3176 }
3177 
3178 //===----------------------------------------------------------------------===//
3179 // AffineMinMaxOpBase
3180 //===----------------------------------------------------------------------===//
3181 
3182 template <typename T>
3184  // Verify that operand count matches affine map dimension and symbol count.
3185  if (op.getNumOperands() !=
3186  op.getMap().getNumDims() + op.getMap().getNumSymbols())
3187  return op.emitOpError(
3188  "operand count and affine map dimension and symbol count must match");
3189  return success();
3190 }
3191 
3192 template <typename T>
3193 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3194  p << ' ' << op->getAttr(T::getMapAttrStrName());
3195  auto operands = op.getOperands();
3196  unsigned numDims = op.getMap().getNumDims();
3197  p << '(' << operands.take_front(numDims) << ')';
3198 
3199  if (operands.size() != numDims)
3200  p << '[' << operands.drop_front(numDims) << ']';
3202  /*elidedAttrs=*/{T::getMapAttrStrName()});
3203 }
3204 
3205 template <typename T>
3207  OperationState &result) {
3208  auto &builder = parser.getBuilder();
3209  auto indexType = builder.getIndexType();
3212  AffineMapAttr mapAttr;
3213  return failure(
3214  parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3215  result.attributes) ||
3216  parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) ||
3217  parser.parseOperandList(symInfos,
3219  parser.parseOptionalAttrDict(result.attributes) ||
3220  parser.resolveOperands(dimInfos, indexType, result.operands) ||
3221  parser.resolveOperands(symInfos, indexType, result.operands) ||
3222  parser.addTypeToList(indexType, result.types));
3223 }
3224 
3225 /// Fold an affine min or max operation with the given operands. The operand
3226 /// list may contain nulls, which are interpreted as the operand not being a
3227 /// constant.
3228 template <typename T>
3230  static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3231  "expected affine min or max op");
3232 
3233  // Fold the affine map.
3234  // TODO: Fold more cases:
3235  // min(some_affine, some_affine + constant, ...), etc.
3236  SmallVector<int64_t, 2> results;
3237  auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3238 
3239  if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3240  return op.getOperand(0);
3241 
3242  // If some of the map results are not constant, try changing the map in-place.
3243  if (results.empty()) {
3244  // If the map is the same, report that folding did not happen.
3245  if (foldedMap == op.getMap())
3246  return {};
3247  op->setAttr("map", AffineMapAttr::get(foldedMap));
3248  return op.getResult();
3249  }
3250 
3251  // Otherwise, completely fold the op into a constant.
3252  auto resultIt = std::is_same<T, AffineMinOp>::value
3253  ? std::min_element(results.begin(), results.end())
3254  : std::max_element(results.begin(), results.end());
3255  if (resultIt == results.end())
3256  return {};
3257  return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3258 }
3259 
3260 /// Remove duplicated expressions in affine min/max ops.
3261 template <typename T>
3264 
3266  PatternRewriter &rewriter) const override {
3267  AffineMap oldMap = affineOp.getAffineMap();
3268 
3269  SmallVector<AffineExpr, 4> newExprs;
3270  for (AffineExpr expr : oldMap.getResults()) {
3271  // This is a linear scan over newExprs, but it should be fine given that
3272  // we typically just have a few expressions per op.
3273  if (!llvm::is_contained(newExprs, expr))
3274  newExprs.push_back(expr);
3275  }
3276 
3277  if (newExprs.size() == oldMap.getNumResults())
3278  return failure();
3279 
3280  auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
3281  newExprs, rewriter.getContext());
3282  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3283 
3284  return success();
3285  }
3286 };
3287 
3288 /// Merge an affine min/max op to its consumers if its consumer is also an
3289 /// affine min/max op.
3290 ///
3291 /// This pattern requires the producer affine min/max op is bound to a
3292 /// dimension/symbol that is used as a standalone expression in the consumer
3293 /// affine op's map.
3294 ///
3295 /// For example, a pattern like the following:
3296 ///
3297 /// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3298 /// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3299 ///
3300 /// Can be turned into:
3301 ///
3302 /// %1 = affine.min affine_map<
3303 /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3304 template <typename T>
3307 
3309  PatternRewriter &rewriter) const override {
3310  AffineMap oldMap = affineOp.getAffineMap();
3311  ValueRange dimOperands =
3312  affineOp.getMapOperands().take_front(oldMap.getNumDims());
3313  ValueRange symOperands =
3314  affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3315 
3316  auto newDimOperands = llvm::to_vector<8>(dimOperands);
3317  auto newSymOperands = llvm::to_vector<8>(symOperands);
3318  SmallVector<AffineExpr, 4> newExprs;
3319  SmallVector<T, 4> producerOps;
3320 
3321  // Go over each expression to see whether it's a single dimension/symbol
3322  // with the corresponding operand which is the result of another affine
3323  // min/max op. If So it can be merged into this affine op.
3324  for (AffineExpr expr : oldMap.getResults()) {
3325  if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3326  Value symValue = symOperands[symExpr.getPosition()];
3327  if (auto producerOp = symValue.getDefiningOp<T>()) {
3328  producerOps.push_back(producerOp);
3329  continue;
3330  }
3331  } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3332  Value dimValue = dimOperands[dimExpr.getPosition()];
3333  if (auto producerOp = dimValue.getDefiningOp<T>()) {
3334  producerOps.push_back(producerOp);
3335  continue;
3336  }
3337  }
3338  // For the above cases we will remove the expression by merging the
3339  // producer affine min/max's affine expressions. Otherwise we need to
3340  // keep the existing expression.
3341  newExprs.push_back(expr);
3342  }
3343 
3344  if (producerOps.empty())
3345  return failure();
3346 
3347  unsigned numUsedDims = oldMap.getNumDims();
3348  unsigned numUsedSyms = oldMap.getNumSymbols();
3349 
3350  // Now go over all producer affine ops and merge their expressions.
3351  for (T producerOp : producerOps) {
3352  AffineMap producerMap = producerOp.getAffineMap();
3353  unsigned numProducerDims = producerMap.getNumDims();
3354  unsigned numProducerSyms = producerMap.getNumSymbols();
3355 
3356  // Collect all dimension/symbol values.
3357  ValueRange dimValues =
3358  producerOp.getMapOperands().take_front(numProducerDims);
3359  ValueRange symValues =
3360  producerOp.getMapOperands().take_back(numProducerSyms);
3361  newDimOperands.append(dimValues.begin(), dimValues.end());
3362  newSymOperands.append(symValues.begin(), symValues.end());
3363 
3364  // For expressions we need to shift to avoid overlap.
3365  for (AffineExpr expr : producerMap.getResults()) {
3366  newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3367  .shiftSymbols(numProducerSyms, numUsedSyms));
3368  }
3369 
3370  numUsedDims += numProducerDims;
3371  numUsedSyms += numProducerSyms;
3372  }
3373 
3374  auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
3375  rewriter.getContext());
3376  auto newOperands =
3377  llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3378  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3379 
3380  return success();
3381  }
3382 };
3383 
3384 /// Canonicalize the result expression order of an affine map and return success
3385 /// if the order changed.
3386 ///
3387 /// The function flattens the map's affine expressions to coefficient arrays and
3388 /// sorts them in lexicographic order. A coefficient array contains a multiplier
3389 /// for every dimension/symbol and a constant term. The canonicalization fails
3390 /// if a result expression is not pure or if the flattening requires local
3391 /// variables that, unlike dimensions and symbols, have no global order.
3393  SmallVector<SmallVector<int64_t>> flattenedExprs;
3394  for (const AffineExpr &resultExpr : map.getResults()) {
3395  // Fail if the expression is not pure.
3396  if (!resultExpr.isPureAffine())
3397  return failure();
3398 
3399  SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3400  auto flattenResult = flattener.walkPostOrder(resultExpr);
3401  if (failed(flattenResult))
3402  return failure();
3403 
3404  // Fail if the flattened expression has local variables.
3405  if (flattener.operandExprStack.back().size() !=
3406  map.getNumDims() + map.getNumSymbols() + 1)
3407  return failure();
3408 
3409  flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3410  flattener.operandExprStack.back().end());
3411  }
3412 
3413  // Fail if sorting is not necessary.
3414  if (llvm::is_sorted(flattenedExprs))
3415  return failure();
3416 
3417  // Reorder the result expressions according to their flattened form.
3418  SmallVector<unsigned> resultPermutation =
3419  llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults()));
3420  llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3421  return flattenedExprs[lhs] < flattenedExprs[rhs];
3422  });
3423  SmallVector<AffineExpr> newExprs;
3424  for (unsigned idx : resultPermutation)
3425  newExprs.push_back(map.getResult(idx));
3426 
3427  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs,
3428  map.getContext());
3429  return success();
3430 }
3431 
3432 /// Canonicalize the affine map result expression order of an affine min/max
3433 /// operation.
3434 ///
3435 /// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3436 /// expressions and replaces the operation if the order changed.
3437 ///
3438 /// For example, the following operation:
3439 ///
3440 /// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3441 ///
3442 /// Turns into:
3443 ///
3444 /// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3445 template <typename T>
3448 
3450  PatternRewriter &rewriter) const override {
3451  AffineMap map = affineOp.getAffineMap();
3453  return failure();
3454  rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3455  return success();
3456  }
3457 };
3458 
3459 template <typename T>
3462 
3464  PatternRewriter &rewriter) const override {
3465  if (affineOp.getMap().getNumResults() != 1)
3466  return failure();
3467  rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3468  affineOp.getOperands());
3469  return success();
3470  }
3471 };
3472 
3473 //===----------------------------------------------------------------------===//
3474 // AffineMinOp
3475 //===----------------------------------------------------------------------===//
3476 //
3477 // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3478 //
3479 
3480 OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3481  return foldMinMaxOp(*this, adaptor.getOperands());
3482 }
3483 
3484 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3485  MLIRContext *context) {
3488  MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3490  context);
3491 }
3492 
3494 
3496  return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3497 }
3498 
3499 void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3500 
3501 //===----------------------------------------------------------------------===//
3502 // AffineMaxOp
3503 //===----------------------------------------------------------------------===//
3504 //
3505 // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3506 //
3507 
3508 OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3509  return foldMinMaxOp(*this, adaptor.getOperands());
3510 }
3511 
3512 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3513  MLIRContext *context) {
3516  MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3518  context);
3519 }
3520 
3522 
3524  return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3525 }
3526 
3527 void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3528 
3529 //===----------------------------------------------------------------------===//
3530 // AffinePrefetchOp
3531 //===----------------------------------------------------------------------===//
3532 
3533 //
3534 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3535 //
3537  OperationState &result) {
3538  auto &builder = parser.getBuilder();
3539  auto indexTy = builder.getIndexType();
3540 
3541  MemRefType type;
3542  OpAsmParser::UnresolvedOperand memrefInfo;
3543  IntegerAttr hintInfo;
3544  auto i32Type = parser.getBuilder().getIntegerType(32);
3545  StringRef readOrWrite, cacheType;
3546 
3547  AffineMapAttr mapAttr;
3549  if (parser.parseOperand(memrefInfo) ||
3550  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3551  AffinePrefetchOp::getMapAttrStrName(),
3552  result.attributes) ||
3553  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3554  parser.parseComma() || parser.parseKeyword("locality") ||
3555  parser.parseLess() ||
3556  parser.parseAttribute(hintInfo, i32Type,
3557  AffinePrefetchOp::getLocalityHintAttrStrName(),
3558  result.attributes) ||
3559  parser.parseGreater() || parser.parseComma() ||
3560  parser.parseKeyword(&cacheType) ||
3561  parser.parseOptionalAttrDict(result.attributes) ||
3562  parser.parseColonType(type) ||
3563  parser.resolveOperand(memrefInfo, type, result.operands) ||
3564  parser.resolveOperands(mapOperands, indexTy, result.operands))
3565  return failure();
3566 
3567  if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
3568  return parser.emitError(parser.getNameLoc(),
3569  "rw specifier has to be 'read' or 'write'");
3570  result.addAttribute(
3571  AffinePrefetchOp::getIsWriteAttrStrName(),
3572  parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
3573 
3574  if (!cacheType.equals("data") && !cacheType.equals("instr"))
3575  return parser.emitError(parser.getNameLoc(),
3576  "cache type has to be 'data' or 'instr'");
3577 
3578  result.addAttribute(
3579  AffinePrefetchOp::getIsDataCacheAttrStrName(),
3580  parser.getBuilder().getBoolAttr(cacheType.equals("data")));
3581 
3582  return success();
3583 }
3584 
3586  p << " " << getMemref() << '[';
3587  AffineMapAttr mapAttr =
3588  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3589  if (mapAttr)
3590  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3591  p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3592  << "locality<" << getLocalityHint() << ">, "
3593  << (getIsDataCache() ? "data" : "instr");
3595  (*this)->getAttrs(),
3596  /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3597  getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3598  p << " : " << getMemRefType();
3599 }
3600 
3602  auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3603  if (mapAttr) {
3604  AffineMap map = mapAttr.getValue();
3605  if (map.getNumResults() != getMemRefType().getRank())
3606  return emitOpError("affine.prefetch affine map num results must equal"
3607  " memref rank");
3608  if (map.getNumInputs() + 1 != getNumOperands())
3609  return emitOpError("too few operands");
3610  } else {
3611  if (getNumOperands() != 1)
3612  return emitOpError("too few operands");
3613  }
3614 
3615  Region *scope = getAffineScope(*this);
3616  for (auto idx : getMapOperands()) {
3617  if (!isValidAffineIndexOperand(idx, scope))
3618  return emitOpError(
3619  "index must be a valid dimension or symbol identifier");
3620  }
3621  return success();
3622 }
3623 
3624 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3625  MLIRContext *context) {
3626  // prefetch(memrefcast) -> prefetch
3627  results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3628 }
3629 
3630 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3631  SmallVectorImpl<OpFoldResult> &results) {
3632  /// prefetch(memrefcast) -> prefetch
3633  return memref::foldMemRefCast(*this);
3634 }
3635 
3636 //===----------------------------------------------------------------------===//
3637 // AffineParallelOp
3638 //===----------------------------------------------------------------------===//
3639 
3640 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3641  TypeRange resultTypes,
3642  ArrayRef<arith::AtomicRMWKind> reductions,
3643  ArrayRef<int64_t> ranges) {
3644  SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
3645  auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3646  return builder.getConstantAffineMap(value);
3647  }));
3648  SmallVector<int64_t> steps(ranges.size(), 1);
3649  build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
3650  /*ubArgs=*/{}, steps);
3651 }
3652 
3653 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3654  TypeRange resultTypes,
3655  ArrayRef<arith::AtomicRMWKind> reductions,
3656  ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
3657  ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
3658  ArrayRef<int64_t> steps) {
3659  assert(llvm::all_of(lbMaps,
3660  [lbMaps](AffineMap m) {
3661  return m.getNumDims() == lbMaps[0].getNumDims() &&
3662  m.getNumSymbols() == lbMaps[0].getNumSymbols();
3663  }) &&
3664  "expected all lower bounds maps to have the same number of dimensions "
3665  "and symbols");
3666  assert(llvm::all_of(ubMaps,
3667  [ubMaps](AffineMap m) {
3668  return m.getNumDims() == ubMaps[0].getNumDims() &&
3669  m.getNumSymbols() == ubMaps[0].getNumSymbols();
3670  }) &&
3671  "expected all upper bounds maps to have the same number of dimensions "
3672  "and symbols");
3673  assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3674  "expected lower bound maps to have as many inputs as lower bound "
3675  "operands");
3676  assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3677  "expected upper bound maps to have as many inputs as upper bound "
3678  "operands");
3679 
3680  result.addTypes(resultTypes);
3681 
3682  // Convert the reductions to integer attributes.
3683  SmallVector<Attribute, 4> reductionAttrs;
3684  for (arith::AtomicRMWKind reduction : reductions)
3685  reductionAttrs.push_back(
3686  builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
3687  result.addAttribute(getReductionsAttrStrName(),
3688  builder.getArrayAttr(reductionAttrs));
3689 
3690  // Concatenates maps defined in the same input space (same dimensions and
3691  // symbols), assumes there is at least one map.
3692  auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
3693  SmallVectorImpl<int32_t> &groups) {
3694  if (maps.empty())
3695  return AffineMap::get(builder.getContext());
3697  groups.reserve(groups.size() + maps.size());
3698  exprs.reserve(maps.size());
3699  for (AffineMap m : maps) {
3700  llvm::append_range(exprs, m.getResults());
3701  groups.push_back(m.getNumResults());
3702  }
3703  return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3704  maps[0].getContext());
3705  };
3706 
3707  // Set up the bounds.
3708  SmallVector<int32_t> lbGroups, ubGroups;
3709  AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3710  AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3711  result.addAttribute(getLowerBoundsMapAttrStrName(),
3712  AffineMapAttr::get(lbMap));
3713  result.addAttribute(getLowerBoundsGroupsAttrStrName(),
3714  builder.getI32TensorAttr(lbGroups));
3715  result.addAttribute(getUpperBoundsMapAttrStrName(),
3716  AffineMapAttr::get(ubMap));
3717  result.addAttribute(getUpperBoundsGroupsAttrStrName(),
3718  builder.getI32TensorAttr(ubGroups));
3719  result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
3720  result.addOperands(lbArgs);
3721  result.addOperands(ubArgs);
3722 
3723  // Create a region and a block for the body.
3724  auto *bodyRegion = result.addRegion();
3725  auto *body = new Block();
3726  // Add all the block arguments.
3727  for (unsigned i = 0, e = steps.size(); i < e; ++i)
3728  body->addArgument(IndexType::get(builder.getContext()), result.location);
3729  bodyRegion->push_back(body);
3730  if (resultTypes.empty())
3731  ensureTerminator(*bodyRegion, builder, result.location);
3732 }
3733 
3734 SmallVector<Region *> AffineParallelOp::getLoopRegions() {
3735  return {&getRegion()};
3736 }
3737 
3738 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
3739 
3740 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3741  return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3742 }
3743 
3744 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3745  return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3746 }
3747 
3748 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
3749  auto values = getLowerBoundsGroups().getValues<int32_t>();
3750  unsigned start = 0;
3751  for (unsigned i = 0; i < pos; ++i)
3752  start += values[i];
3753  return getLowerBoundsMap().getSliceMap(start, values[pos]);
3754 }
3755 
3756 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
3757  auto values = getUpperBoundsGroups().getValues<int32_t>();
3758  unsigned start = 0;
3759  for (unsigned i = 0; i < pos; ++i)
3760  start += values[i];
3761  return getUpperBoundsMap().getSliceMap(start, values[pos]);
3762 }
3763 
3764 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
3765  return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3766 }
3767 
3768 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
3769  return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3770 }
3771 
3772 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3773  if (hasMinMaxBounds())
3774  return std::nullopt;
3775 
3776  // Try to convert all the ranges to constant expressions.
3778  AffineValueMap rangesValueMap;
3779  AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3780  &rangesValueMap);
3781  out.reserve(rangesValueMap.getNumResults());
3782  for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
3783  auto expr = rangesValueMap.getResult(i);
3784  auto cst = dyn_cast<AffineConstantExpr>(expr);
3785  if (!cst)
3786  return std::nullopt;
3787  out.push_back(cst.getValue());
3788  }
3789  return out;
3790 }
3791 
3792 Block *AffineParallelOp::getBody() { return &getRegion().front(); }
3793 
3794 OpBuilder AffineParallelOp::getBodyBuilder() {
3795  return OpBuilder(getBody(), std::prev(getBody()->end()));
3796 }
3797 
3798 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
3799  assert(lbOperands.size() == map.getNumInputs() &&
3800  "operands to map must match number of inputs");
3801 
3802  auto ubOperands = getUpperBoundsOperands();
3803 
3804  SmallVector<Value, 4> newOperands(lbOperands);
3805  newOperands.append(ubOperands.begin(), ubOperands.end());
3806  (*this)->setOperands(newOperands);
3807 
3808  setLowerBoundsMapAttr(AffineMapAttr::get(map));
3809 }
3810 
3811 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
3812  assert(ubOperands.size() == map.getNumInputs() &&
3813  "operands to map must match number of inputs");
3814 
3815  SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
3816  newOperands.append(ubOperands.begin(), ubOperands.end());
3817  (*this)->setOperands(newOperands);
3818 
3819  setUpperBoundsMapAttr(AffineMapAttr::get(map));
3820 }
3821 
3822 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
3823  setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3824 }
3825 
3826 // check whether resultType match op or not in affine.parallel
3827 static bool isResultTypeMatchAtomicRMWKind(Type resultType,
3828  arith::AtomicRMWKind op) {
3829  switch (op) {
3830  case arith::AtomicRMWKind::addf:
3831  return isa<FloatType>(resultType);
3832  case arith::AtomicRMWKind::addi:
3833  return isa<IntegerType>(resultType);
3834  case arith::AtomicRMWKind::assign:
3835  return true;
3836  case arith::AtomicRMWKind::mulf:
3837  return isa<FloatType>(resultType);
3838  case arith::AtomicRMWKind::muli:
3839  return isa<IntegerType>(resultType);
3840  case arith::AtomicRMWKind::maximumf:
3841  return isa<FloatType>(resultType);
3842  case arith::AtomicRMWKind::minimumf:
3843  return isa<FloatType>(resultType);
3844  case arith::AtomicRMWKind::maxs: {
3845  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3846  return intType && intType.isSigned();
3847  }
3848  case arith::AtomicRMWKind::mins: {
3849  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3850  return intType && intType.isSigned();
3851  }
3852  case arith::AtomicRMWKind::maxu: {
3853  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3854  return intType && intType.isUnsigned();
3855  }
3856  case arith::AtomicRMWKind::minu: {
3857  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3858  return intType && intType.isUnsigned();
3859  }
3860  case arith::AtomicRMWKind::ori:
3861  return isa<IntegerType>(resultType);
3862  case arith::AtomicRMWKind::andi:
3863  return isa<IntegerType>(resultType);
3864  default:
3865  return false;
3866  }
3867 }
3868 
3870  auto numDims = getNumDims();
3871  if (getLowerBoundsGroups().getNumElements() != numDims ||
3872  getUpperBoundsGroups().getNumElements() != numDims ||
3873  getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3874  return emitOpError() << "the number of region arguments ("
3875  << getBody()->getNumArguments()
3876  << ") and the number of map groups for lower ("
3877  << getLowerBoundsGroups().getNumElements()
3878  << ") and upper bound ("
3879  << getUpperBoundsGroups().getNumElements()
3880  << "), and the number of steps (" << getSteps().size()
3881  << ") must all match";
3882  }
3883 
3884  unsigned expectedNumLBResults = 0;
3885  for (APInt v : getLowerBoundsGroups())
3886  expectedNumLBResults += v.getZExtValue();
3887  if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
3888  return emitOpError() << "expected lower bounds map to have "
3889  << expectedNumLBResults << " results";
3890  unsigned expectedNumUBResults = 0;
3891  for (APInt v : getUpperBoundsGroups())
3892  expectedNumUBResults += v.getZExtValue();
3893  if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
3894  return emitOpError() << "expected upper bounds map to have "
3895  << expectedNumUBResults << " results";
3896 
3897  if (getReductions().size() != getNumResults())
3898  return emitOpError("a reduction must be specified for each output");
3899 
3900  // Verify reduction ops are all valid and each result type matches reduction
3901  // ops
3902  for (auto it : llvm::enumerate((getReductions()))) {
3903  Attribute attr = it.value();
3904  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
3905  if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
3906  return emitOpError("invalid reduction attribute");
3907  auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
3908  if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
3909  return emitOpError("result type cannot match reduction attribute");
3910  }
3911 
3912  // Verify that the bound operands are valid dimension/symbols.
3913  /// Lower bounds.
3914  if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
3915  getLowerBoundsMap().getNumDims())))
3916  return failure();
3917  /// Upper bounds.
3918  if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
3919  getUpperBoundsMap().getNumDims())))
3920  return failure();
3921  return success();
3922 }
3923 
3924 LogicalResult AffineValueMap::canonicalize() {
3925  SmallVector<Value, 4> newOperands{operands};
3926  auto newMap = getAffineMap();
3927  composeAffineMapAndOperands(&newMap, &newOperands);
3928  if (newMap == getAffineMap() && newOperands == operands)
3929  return failure();
3930  reset(newMap, newOperands);
3931  return success();
3932 }
3933 
3934 /// Canonicalize the bounds of the given loop.
3935 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
3936  AffineValueMap lb = op.getLowerBoundsValueMap();
3937  bool lbCanonicalized = succeeded(lb.canonicalize());
3938 
3939  AffineValueMap ub = op.getUpperBoundsValueMap();
3940  bool ubCanonicalized = succeeded(ub.canonicalize());
3941 
3942  // Any canonicalization change always leads to updated map(s).
3943  if (!lbCanonicalized && !ubCanonicalized)
3944  return failure();
3945 
3946  if (lbCanonicalized)
3947  op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
3948  if (ubCanonicalized)
3949  op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
3950 
3951  return success();
3952 }
3953 
3954 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
3955  SmallVectorImpl<OpFoldResult> &results) {
3956  return canonicalizeLoopBounds(*this);
3957 }
3958 
3959 /// Prints a lower(upper) bound of an affine parallel loop with max(min)
3960 /// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
3961 /// identifies which of the those expressions form max/min groups. `operands`
3962 /// are the SSA values of dimensions and symbols and `keyword` is either "min"
3963 /// or "max".
3964 static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
3965  DenseIntElementsAttr group, ValueRange operands,
3966  StringRef keyword) {
3967  AffineMap map = mapAttr.getValue();
3968  unsigned numDims = map.getNumDims();
3969  ValueRange dimOperands = operands.take_front(numDims);
3970  ValueRange symOperands = operands.drop_front(numDims);
3971  unsigned start = 0;
3972  for (llvm::APInt groupSize : group) {
3973  if (start != 0)
3974  p << ", ";
3975 
3976  unsigned size = groupSize.getZExtValue();
3977  if (size == 1) {
3978  p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
3979  ++start;
3980  } else {
3981  p << keyword << '(';
3982  AffineMap submap = map.getSliceMap(start, size);
3983  p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
3984  p << ')';
3985  start += size;
3986  }
3987  }
3988 }
3989 
3991  p << " (" << getBody()->getArguments() << ") = (";
3992  printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
3993  getLowerBoundsOperands(), "max");
3994  p << ") to (";
3995  printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
3996  getUpperBoundsOperands(), "min");
3997  p << ')';
3998  SmallVector<int64_t, 8> steps = getSteps();
3999  bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4000  if (!elideSteps) {
4001  p << " step (";
4002  llvm::interleaveComma(steps, p);
4003  p << ')';
4004  }
4005  if (getNumResults()) {
4006  p << " reduce (";
4007  llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4008  arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4009  llvm::cast<IntegerAttr>(attr).getInt());
4010  p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4011  });
4012  p << ") -> (" << getResultTypes() << ")";
4013  }
4014 
4015  p << ' ';
4016  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4017  /*printBlockTerminators=*/getNumResults());
4019  (*this)->getAttrs(),
4020  /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4021  AffineParallelOp::getLowerBoundsMapAttrStrName(),
4022  AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4023  AffineParallelOp::getUpperBoundsMapAttrStrName(),
4024  AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4025  AffineParallelOp::getStepsAttrStrName()});
4026 }
4027 
4028 /// Given a list of lists of parsed operands, populates `uniqueOperands` with
4029 /// unique operands. Also populates `replacements with affine expressions of
4030 /// `kind` that can be used to update affine maps previously accepting a
4031 /// `operands` to accept `uniqueOperands` instead.
4033  OpAsmParser &parser,
4035  SmallVectorImpl<Value> &uniqueOperands,
4036  SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
4037  assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
4038  "expected operands to be dim or symbol expression");
4039 
4040  Type indexType = parser.getBuilder().getIndexType();
4041  for (const auto &list : operands) {
4042  SmallVector<Value> valueOperands;
4043  if (parser.resolveOperands(list, indexType, valueOperands))
4044  return failure();
4045  for (Value operand : valueOperands) {
4046  unsigned pos = std::distance(uniqueOperands.begin(),
4047  llvm::find(uniqueOperands, operand));
4048  if (pos == uniqueOperands.size())
4049  uniqueOperands.push_back(operand);
4050  replacements.push_back(
4051  kind == AffineExprKind::DimId
4052  ? getAffineDimExpr(pos, parser.getContext())
4053  : getAffineSymbolExpr(pos, parser.getContext()));
4054  }
4055  }
4056  return success();
4057 }
4058 
4059 namespace {
4060 enum class MinMaxKind { Min, Max };
4061 } // namespace
4062 
4063 /// Parses an affine map that can contain a min/max for groups of its results,
4064 /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4065 /// `result` attributes with the map (flat list of expressions) and the grouping
4066 /// (list of integers that specify how many expressions to put into each
4067 /// min/max) attributes. Deduplicates repeated operands.
4068 ///
4069 /// parallel-bound ::= `(` parallel-group-list `)`
4070 /// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
4071 /// parallel-group ::= simple-group | min-max-group
4072 /// simple-group ::= expr-of-ssa-ids
4073 /// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4074 /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4075 ///
4076 /// Examples:
4077 /// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4078 /// (%0, max(%1 - 2 * %2))
4080  OperationState &result,
4081  MinMaxKind kind) {
4082  // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4083  // see: https://reviews.llvm.org/D134227#3821753
4084  const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4085 
4086  StringRef mapName = kind == MinMaxKind::Min
4087  ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4088  : AffineParallelOp::getLowerBoundsMapAttrStrName();
4089  StringRef groupsName =
4090  kind == MinMaxKind::Min
4091  ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4092  : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4093 
4094  if (failed(parser.parseLParen()))
4095  return failure();
4096 
4097  if (succeeded(parser.parseOptionalRParen())) {
4098  result.addAttribute(
4099  mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4100  result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
4101  return success();
4102  }
4103 
4104  SmallVector<AffineExpr> flatExprs;
4107  SmallVector<int32_t> numMapsPerGroup;
4109  auto parseOperands = [&]() {
4110  if (succeeded(parser.parseOptionalKeyword(
4111  kind == MinMaxKind::Min ? "min" : "max"))) {
4112  mapOperands.clear();
4113  AffineMapAttr map;
4114  if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName,
4115  result.attributes,
4117  return failure();
4118  result.attributes.erase(tmpAttrStrName);
4119  llvm::append_range(flatExprs, map.getValue().getResults());
4120  auto operandsRef = llvm::ArrayRef(mapOperands);
4121  auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4122  SmallVector<OpAsmParser::UnresolvedOperand> dims(dimsRef.begin(),
4123  dimsRef.end());
4124  auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4125  SmallVector<OpAsmParser::UnresolvedOperand> syms(symsRef.begin(),
4126  symsRef.end());
4127  flatDimOperands.append(map.getValue().getNumResults(), dims);
4128  flatSymOperands.append(map.getValue().getNumResults(), syms);
4129  numMapsPerGroup.push_back(map.getValue().getNumResults());
4130  } else {
4131  if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
4132  flatSymOperands.emplace_back(),
4133  flatExprs.emplace_back())))
4134  return failure();
4135  numMapsPerGroup.push_back(1);
4136  }
4137  return success();
4138  };
4139  if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
4140  return failure();
4141 
4142  unsigned totalNumDims = 0;
4143  unsigned totalNumSyms = 0;
4144  for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4145  unsigned numDims = flatDimOperands[i].size();
4146  unsigned numSyms = flatSymOperands[i].size();
4147  flatExprs[i] = flatExprs[i]
4148  .shiftDims(numDims, totalNumDims)
4149  .shiftSymbols(numSyms, totalNumSyms);
4150  totalNumDims += numDims;
4151  totalNumSyms += numSyms;
4152  }
4153 
4154  // Deduplicate map operands.
4155  SmallVector<Value> dimOperands, symOperands;
4156  SmallVector<AffineExpr> dimRplacements, symRepacements;
4157  if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
4158  dimRplacements, AffineExprKind::DimId) ||
4159  deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
4160  symRepacements, AffineExprKind::SymbolId))
4161  return failure();
4162 
4163  result.operands.append(dimOperands.begin(), dimOperands.end());
4164  result.operands.append(symOperands.begin(), symOperands.end());
4165 
4166  Builder &builder = parser.getBuilder();
4167  auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4168  parser.getContext());
4169  flatMap = flatMap.replaceDimsAndSymbols(
4170  dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4171 
4172  result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4173  result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
4174  return success();
4175 }
4176 
4177 //
4178 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4179 // `to` parallel-bound steps? region attr-dict?
4180 // steps ::= `steps` `(` integer-literals `)`
4181 //
4183  OperationState &result) {
4184  auto &builder = parser.getBuilder();
4185  auto indexType = builder.getIndexType();
4188  parser.parseEqual() ||
4189  parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
4190  parser.parseKeyword("to") ||
4191  parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
4192  return failure();
4193 
4194  AffineMapAttr stepsMapAttr;
4195  NamedAttrList stepsAttrs;
4197  if (failed(parser.parseOptionalKeyword("step"))) {
4198  SmallVector<int64_t, 4> steps(ivs.size(), 1);
4199  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4200  builder.getI64ArrayAttr(steps));
4201  } else {
4202  if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
4203  AffineParallelOp::getStepsAttrStrName(),
4204  stepsAttrs,
4206  return failure();
4207 
4208  // Convert steps from an AffineMap into an I64ArrayAttr.
4210  auto stepsMap = stepsMapAttr.getValue();
4211  for (const auto &result : stepsMap.getResults()) {
4212  auto constExpr = dyn_cast<AffineConstantExpr>(result);
4213  if (!constExpr)
4214  return parser.emitError(parser.getNameLoc(),
4215  "steps must be constant integers");
4216  steps.push_back(constExpr.getValue());
4217  }
4218  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4219  builder.getI64ArrayAttr(steps));
4220  }
4221 
4222  // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4223  // quoted strings are a member of the enum AtomicRMWKind.
4224  SmallVector<Attribute, 4> reductions;
4225  if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4226  if (parser.parseLParen())
4227  return failure();
4228  auto parseAttributes = [&]() -> ParseResult {
4229  // Parse a single quoted string via the attribute parsing, and then
4230  // verify it is a member of the enum and convert to it's integer
4231  // representation.
4232  StringAttr attrVal;
4233  NamedAttrList attrStorage;
4234  auto loc = parser.getCurrentLocation();
4235  if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
4236  attrStorage))
4237  return failure();
4238  std::optional<arith::AtomicRMWKind> reduction =
4239  arith::symbolizeAtomicRMWKind(attrVal.getValue());
4240  if (!reduction)
4241  return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4242  reductions.push_back(
4243  builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4244  // While we keep getting commas, keep parsing.
4245  return success();
4246  };
4247  if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
4248  return failure();
4249  }
4250  result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4251  builder.getArrayAttr(reductions));
4252 
4253  // Parse return types of reductions (if any)
4254  if (parser.parseOptionalArrowTypeList(result.types))
4255  return failure();
4256 
4257  // Now parse the body.
4258  Region *body = result.addRegion();
4259  for (auto &iv : ivs)
4260  iv.type = indexType;
4261  if (parser.parseRegion(*body, ivs) ||
4262  parser.parseOptionalAttrDict(result.attributes))
4263  return failure();
4264 
4265  // Add a terminator if none was parsed.
4266  AffineParallelOp::ensureTerminator(*body, builder, result.location);
4267  return success();
4268 }
4269 
4270 //===----------------------------------------------------------------------===//
4271 // AffineYieldOp
4272 //===----------------------------------------------------------------------===//
4273 
4275  auto *parentOp = (*this)->getParentOp();
4276  auto results = parentOp->getResults();
4277  auto operands = getOperands();
4278 
4279  if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4280  return emitOpError() << "only terminates affine.if/for/parallel regions";
4281  if (parentOp->getNumResults() != getNumOperands())
4282  return emitOpError() << "parent of yield must have same number of "
4283  "results as the yield operands";
4284  for (auto it : llvm::zip(results, operands)) {
4285  if (std::get<0>(it).getType() != std::get<1>(it).getType())
4286  return emitOpError() << "types mismatch between yield op and its parent";
4287  }
4288 
4289  return success();
4290 }
4291 
4292 //===----------------------------------------------------------------------===//
4293 // AffineVectorLoadOp
4294 //===----------------------------------------------------------------------===//
4295 
4296 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4297  VectorType resultType, AffineMap map,
4298  ValueRange operands) {
4299  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4300  result.addOperands(operands);
4301  if (map)
4302  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4303  result.types.push_back(resultType);
4304 }
4305 
4306 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4307  VectorType resultType, Value memref,
4308  AffineMap map, ValueRange mapOperands) {
4309  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4310  result.addOperands(memref);
4311  result.addOperands(mapOperands);
4312  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4313  result.types.push_back(resultType);
4314 }
4315 
4316 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4317  VectorType resultType, Value memref,
4318  ValueRange indices) {
4319  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4320  int64_t rank = memrefType.getRank();
4321  // Create identity map for memrefs with at least one dimension or () -> ()
4322  // for zero-dimensional memrefs.
4323  auto map =
4324  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4325  build(builder, result, resultType, memref, map, indices);
4326 }
4327 
4328 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4329  MLIRContext *context) {
4330  results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4331 }
4332 
4334  OperationState &result) {
4335  auto &builder = parser.getBuilder();
4336  auto indexTy = builder.getIndexType();
4337 
4338  MemRefType memrefType;
4339  VectorType resultType;
4340  OpAsmParser::UnresolvedOperand memrefInfo;
4341  AffineMapAttr mapAttr;
4343  return failure(
4344  parser.parseOperand(memrefInfo) ||
4345  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4346  AffineVectorLoadOp::getMapAttrStrName(),
4347  result.attributes) ||
4348  parser.parseOptionalAttrDict(result.attributes) ||
4349  parser.parseColonType(memrefType) || parser.parseComma() ||
4350  parser.parseType(resultType) ||
4351  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4352  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
4353  parser.addTypeToList(resultType, result.types));
4354 }
4355 
4357  p << " " << getMemRef() << '[';
4358  if (AffineMapAttr mapAttr =
4359  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4360  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4361  p << ']';
4362  p.printOptionalAttrDict((*this)->getAttrs(),
4363  /*elidedAttrs=*/{getMapAttrStrName()});
4364  p << " : " << getMemRefType() << ", " << getType();
4365 }
4366 
4367 /// Verify common invariants of affine.vector_load and affine.vector_store.
4368 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4369  VectorType vectorType) {
4370  // Check that memref and vector element types match.
4371  if (memrefType.getElementType() != vectorType.getElementType())
4372  return op->emitOpError(
4373  "requires memref and vector types of the same elemental type");
4374  return success();
4375 }
4376 
4378  MemRefType memrefType = getMemRefType();
4380  getOperation(),
4381  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4382  getMapOperands(), memrefType,
4383  /*numIndexOperands=*/getNumOperands() - 1)))
4384  return failure();
4385 
4386  if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
4387  return failure();
4388 
4389  return success();
4390 }
4391 
4392 //===----------------------------------------------------------------------===//
4393 // AffineVectorStoreOp
4394 //===----------------------------------------------------------------------===//
4395 
4396 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4397  Value valueToStore, Value memref, AffineMap map,
4398  ValueRange mapOperands) {
4399  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4400  result.addOperands(valueToStore);
4401  result.addOperands(memref);
4402  result.addOperands(mapOperands);
4403  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4404 }
4405 
4406 // Use identity map.
4407 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4408  Value valueToStore, Value memref,
4409  ValueRange indices) {
4410  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4411  int64_t rank = memrefType.getRank();
4412  // Create identity map for memrefs with at least one dimension or () -> ()
4413  // for zero-dimensional memrefs.
4414  auto map =
4415  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4416  build(builder, result, valueToStore, memref, map, indices);
4417 }
4418 void AffineVectorStoreOp::getCanonicalizationPatterns(
4419  RewritePatternSet &results, MLIRContext *context) {
4420  results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4421 }
4422 
4424  OperationState &result) {
4425  auto indexTy = parser.getBuilder().getIndexType();
4426 
4427  MemRefType memrefType;
4428  VectorType resultType;
4429  OpAsmParser::UnresolvedOperand storeValueInfo;
4430  OpAsmParser::UnresolvedOperand memrefInfo;
4431  AffineMapAttr mapAttr;
4433  return failure(
4434  parser.parseOperand(storeValueInfo) || parser.parseComma() ||
4435  parser.parseOperand(memrefInfo) ||
4436  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4437  AffineVectorStoreOp::getMapAttrStrName(),
4438  result.attributes) ||
4439  parser.parseOptionalAttrDict(result.attributes) ||
4440  parser.parseColonType(memrefType) || parser.parseComma() ||
4441  parser.parseType(resultType) ||
4442  parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
4443  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4444  parser.resolveOperands(mapOperands, indexTy, result.operands));
4445 }
4446 
4448  p << " " << getValueToStore();
4449  p << ", " << getMemRef() << '[';
4450  if (AffineMapAttr mapAttr =
4451  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4452  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4453  p << ']';
4454  p.printOptionalAttrDict((*this)->getAttrs(),
4455  /*elidedAttrs=*/{getMapAttrStrName()});
4456  p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4457 }
4458 
4460  MemRefType memrefType = getMemRefType();
4462  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4463  getMapOperands(), memrefType,
4464  /*numIndexOperands=*/getNumOperands() - 2)))
4465  return failure();
4466 
4467  if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4468  return failure();
4469 
4470  return success();
4471 }
4472 
4473 //===----------------------------------------------------------------------===//
4474 // DelinearizeIndexOp
4475 //===----------------------------------------------------------------------===//
4476 
4477 void AffineDelinearizeIndexOp::build(OpBuilder &builder, OperationState &result,
4478  Value linearIndex,
4479  ArrayRef<OpFoldResult> basis) {
4480  result.addTypes(SmallVector<Type>(basis.size(), builder.getIndexType()));
4481  result.addOperands(linearIndex);
4482  SmallVector<Value> basisValues =
4483  llvm::map_to_vector(basis, [&](OpFoldResult ofr) -> Value {
4484  std::optional<int64_t> staticDim = getConstantIntValue(ofr);
4485  if (staticDim.has_value())
4486  return builder.create<arith::ConstantIndexOp>(result.location,
4487  *staticDim);
4488  return llvm::dyn_cast_if_present<Value>(ofr);
4489  });
4490  result.addOperands(basisValues);
4491 }
4492 
4494  if (getBasis().empty())
4495  return emitOpError("basis should not be empty");
4496  if (getNumResults() != getBasis().size())
4497  return emitOpError("should return an index for each basis element");
4498  return success();
4499 }
4500 
4501 //===----------------------------------------------------------------------===//
4502 // TableGen'd op method definitions
4503 //===----------------------------------------------------------------------===//
4504 
4505 #define GET_OP_CLASSES
4506 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
Definition: AffineOps.cpp:2625
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
Definition: AffineOps.cpp:2348
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
Definition: AffineOps.cpp:1143
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
Definition: AffineOps.cpp:3193
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
Definition: AffineOps.cpp:3827
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
Definition: AffineOps.cpp:53
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
Definition: AffineOps.cpp:3964
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
Definition: AffineOps.cpp:981
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
Definition: AffineOps.cpp:3229
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
Definition: AffineOps.cpp:2206
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
Definition: AffineOps.cpp:767
static bool isValidAffineIndexOperand(Value value, Region *region)
Definition: AffineOps.cpp:446
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1335
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
Definition: AffineOps.cpp:1052
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:702
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
Definition: AffineOps.cpp:1896
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:694
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
Definition: AffineOps.cpp:2908
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
Definition: AffineOps.cpp:1004
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1292
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
Definition: AffineOps.cpp:4368
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
Definition: AffineOps.cpp:870
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
Definition: AffineOps.cpp:3206
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
Definition: AffineOps.cpp:320
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
Definition: AffineOps.cpp:642
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
Definition: AffineOps.cpp:4079
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
Definition: AffineOps.cpp:2634
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
Definition: AffineOps.cpp:451
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
Definition: AffineOps.cpp:604
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1228
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
Definition: AffineOps.cpp:2584
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
Definition: AffineOps.cpp:2160
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
Definition: AffineOps.cpp:483
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1243
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
Definition: AffineOps.cpp:339
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr &quotientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
Definition: AffineOps.cpp:670
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
Definition: AffineOps.cpp:4032
static LogicalResult verifyAffineMinMaxOp(T op)
Definition: AffineOps.cpp:3183
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
Definition: AffineOps.cpp:2071
static LogicalResult verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
Definition: AffineOps.cpp:3014
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
Definition: AffineOps.cpp:3392
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:76
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:81
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
static int64_t getNumElements(ShapedType type)
Definition: TensorOps.cpp:1539
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:867
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:27
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
Definition: AffineExpr.cpp:220
MLIRContext * getContext() const
Definition: AffineExpr.cpp:25
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:619
MLIRContext * getContext() const
Definition: AffineMap.cpp:321
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
Definition: AffineMap.h:212
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:258
unsigned getNumSymbols() const
Definition: AffineMap.cpp:378
unsigned getNumDims() const
Definition: AffineMap.cpp:374
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:387
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
Definition: AffineMap.h:219
unsigned getNumResults() const
Definition: AffineMap.cpp:382
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineMap.cpp:480
unsigned getNumInputs() const
Definition: AffineMap.cpp:383
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
Definition: AffineMap.h:271
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:292
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:391
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:495
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
Definition: AffineMap.cpp:125
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:611
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
Definition: AffineMap.cpp:414
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:68
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:30
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:238
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:147
BlockArgListType getArguments()
Definition: Block.h:80
Operation & front()
Definition: Block.h:146
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:179
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
AffineMap getDimIdentityMap()
Definition: Builders.cpp:372
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:376
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
Definition: Builders.cpp:195
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:128
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
NoneType getNoneType()
Definition: Builders.cpp:104
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:116
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
Definition: Builders.cpp:365
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:367
MLIRContext * getContext() const
Definition: Builders.h:55
AffineMap getSymbolIdentityMap()
Definition: Builders.cpp:385
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:273
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
IndexType getIndexType()
Definition: Builders.cpp:71
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition: IntegerSet.h:44
unsigned getNumDims() const
Definition: IntegerSet.cpp:15
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
Definition: IntegerSet.cpp:57
unsigned getNumInputs() const
Definition: IntegerSet.cpp:17
ArrayRef< AffineExpr > getConstraints() const
Definition: IntegerSet.cpp:41
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
Definition: IntegerSet.cpp:51
unsigned getNumSymbols() const
Definition: IntegerSet.cpp:16
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:416
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:305
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:427
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
This class represents an operand of an operation.
Definition: Value.h:263
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:728
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumOperands()
Definition: Operation.h:341
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:486
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:560
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
operand_range::iterator operand_iterator
Definition: Operation.h:367
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:640
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
void push_back(Block *block)
Definition: Region.h:61
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Region.h:98
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:606
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void finalizeRootUpdate(Operation *op)
This method is used to signal the end of a root update on the given operation.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor)
Find uses of from and replace them with to if the functor returns true.
virtual void startRootUpdate(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:591
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:539
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
AffineBound represents a lower or upper bound in the for operation.
Definition: AffineOps.h:497
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
Definition: AffineOps.h:94
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
Definition: AffineOps.h:291
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
Definition: AffineOps.cpp:3924
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
unsigned getNumResults() const
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
Definition: AffineOps.cpp:2647
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
Definition: AffineOps.cpp:1114
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
Definition: AffineOps.cpp:2561
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
Definition: AffineOps.cpp:269
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1217
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
Definition: AffineOps.cpp:2533
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: AffineOps.cpp:2537
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1124
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1411
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1283
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
Definition: AffineOps.cpp:2525
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1276
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Definition: AffineOps.cpp:239
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
Definition: AffineOps.cpp:1416
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
Definition: AffineOps.cpp:2568
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
Definition: AffineOps.cpp:372
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1172
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Definition: AffineOps.cpp:2548
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
Definition: AffineOps.cpp:254
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
Definition: AffineOps.cpp:461
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
Definition: AffineOps.cpp:2529
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
Definition: AffineOps.cpp:1237
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt gcd(const MPInt &a, const MPInt &b)
Definition: MPInt.h:399
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:19
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:730
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:401
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
Definition: AffineMap.cpp:740
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
int64_t floorDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's floordiv operation on constants.
Definition: MathExtras.h:33
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
Definition: MathExtras.h:23
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
AffineExprKind
Definition: AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:47
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:608
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool isStrided(MemRefType t)
Return "true" if the layout for t is compatible with strided semantics.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:310
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:584
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
Definition: AffineMap.cpp:702
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:594
Canonicalize the affine map result expression order of an affine min/max operation.
Definition: AffineOps.cpp:3446
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3449
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3463
Remove duplicated expressions in affine min/max ops.
Definition: AffineOps.cpp:3262
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3265
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
Definition: AffineOps.cpp:3305
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3308
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:283
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.