MLIR  21.0.0git
AffineOps.cpp
Go to the documentation of this file.
1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/IRMapping.h"
17 #include "mlir/IR/IntegerSet.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/OpDefinition.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/Value.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallBitVector.h"
27 #include "llvm/ADT/SmallVectorExtras.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/LogicalResult.h"
31 #include "llvm/Support/MathExtras.h"
32 #include <numeric>
33 #include <optional>
34 
35 using namespace mlir;
36 using namespace mlir::affine;
37 
38 using llvm::divideCeilSigned;
39 using llvm::divideFloorSigned;
40 using llvm::mod;
41 
42 #define DEBUG_TYPE "affine-ops"
43 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
44 
45 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
46 
47 /// A utility function to check if a value is defined at the top level of
48 /// `region` or is an argument of `region`. A value of index type defined at the
49 /// top level of a `AffineScope` region is always a valid symbol for all
50 /// uses in that region.
52  if (auto arg = llvm::dyn_cast<BlockArgument>(value))
53  return arg.getParentRegion() == region;
54  return value.getDefiningOp()->getParentRegion() == region;
55 }
56 
57 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
58 /// region remains legal if the operation that uses it is inlined into `dest`
59 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
60 /// `isValidSymbol`, depending on the value being required to remain a valid
61 /// dimension or symbol.
62 static bool
64  const IRMapping &mapping,
65  function_ref<bool(Value, Region *)> legalityCheck) {
66  // If the value is a valid dimension for any other reason than being
67  // a top-level value, it will remain valid: constants get inlined
68  // with the function, transitive affine applies also get inlined and
69  // will be checked themselves, etc.
70  if (!isTopLevelValue(value, src))
71  return true;
72 
73  // If it's a top-level value because it's a block operand, i.e. a
74  // function argument, check whether the value replacing it after
75  // inlining is a valid dimension in the new region.
76  if (llvm::isa<BlockArgument>(value))
77  return legalityCheck(mapping.lookup(value), dest);
78 
79  // If it's a top-level value because it's defined in the region,
80  // it can only be inlined if the defining op is a constant or a
81  // `dim`, which can appear anywhere and be valid, since the defining
82  // op won't be top-level anymore after inlining.
83  Attribute operandCst;
84  bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp());
85  return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
86  isDimLikeOp;
87 }
88 
89 /// Checks if all values known to be legal affine dimensions or symbols in `src`
90 /// remain so if their respective users are inlined into `dest`.
91 static bool
93  const IRMapping &mapping,
94  function_ref<bool(Value, Region *)> legalityCheck) {
95  return llvm::all_of(values, [&](Value v) {
96  return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
97  });
98 }
99 
100 /// Checks if an affine read or write operation remains legal after inlining
101 /// from `src` to `dest`.
102 template <typename OpTy>
103 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
104  const IRMapping &mapping) {
105  static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
106  AffineWriteOpInterface>::value,
107  "only ops with affine read/write interface are supported");
108 
109  AffineMap map = op.getAffineMap();
110  ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
111  ValueRange symbolOperands =
112  op.getMapOperands().take_back(map.getNumSymbols());
114  dimOperands, src, dest, mapping,
115  static_cast<bool (*)(Value, Region *)>(isValidDim)))
116  return false;
118  symbolOperands, src, dest, mapping,
119  static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
120  return false;
121  return true;
122 }
123 
124 /// Checks if an affine apply operation remains legal after inlining from `src`
125 /// to `dest`.
126 // Use "unused attribute" marker to silence clang-tidy warning stemming from
127 // the inability to see through "llvm::TypeSwitch".
128 template <>
129 bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op,
130  Region *src, Region *dest,
131  const IRMapping &mapping) {
132  // If it's a valid dimension, we need to check that it remains so.
133  if (isValidDim(op.getResult(), src))
135  op.getMapOperands(), src, dest, mapping,
136  static_cast<bool (*)(Value, Region *)>(isValidDim));
137 
138  // Otherwise it must be a valid symbol, check that it remains so.
140  op.getMapOperands(), src, dest, mapping,
141  static_cast<bool (*)(Value, Region *)>(isValidSymbol));
142 }
143 
144 //===----------------------------------------------------------------------===//
145 // AffineDialect Interfaces
146 //===----------------------------------------------------------------------===//
147 
148 namespace {
149 /// This class defines the interface for handling inlining with affine
150 /// operations.
151 struct AffineInlinerInterface : public DialectInlinerInterface {
153 
154  //===--------------------------------------------------------------------===//
155  // Analysis Hooks
156  //===--------------------------------------------------------------------===//
157 
158  /// Returns true if the given region 'src' can be inlined into the region
159  /// 'dest' that is attached to an operation registered to the current dialect.
160  /// 'wouldBeCloned' is set if the region is cloned into its new location
161  /// rather than moved, indicating there may be other users.
162  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
163  IRMapping &valueMapping) const final {
164  // We can inline into affine loops and conditionals if this doesn't break
165  // affine value categorization rules.
166  Operation *destOp = dest->getParentOp();
167  if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
168  return false;
169 
170  // Multi-block regions cannot be inlined into affine constructs, all of
171  // which require single-block regions.
172  if (!llvm::hasSingleElement(*src))
173  return false;
174 
175  // Side-effecting operations that the affine dialect cannot understand
176  // should not be inlined.
177  Block &srcBlock = src->front();
178  for (Operation &op : srcBlock) {
179  // Ops with no side effects are fine,
180  if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
181  if (iface.hasNoEffect())
182  continue;
183  }
184 
185  // Assuming the inlined region is valid, we only need to check if the
186  // inlining would change it.
187  bool remainsValid =
189  .Case<AffineApplyOp, AffineReadOpInterface,
190  AffineWriteOpInterface>([&](auto op) {
191  return remainsLegalAfterInline(op, src, dest, valueMapping);
192  })
193  .Default([](Operation *) {
194  // Conservatively disallow inlining ops we cannot reason about.
195  return false;
196  });
197 
198  if (!remainsValid)
199  return false;
200  }
201 
202  return true;
203  }
204 
205  /// Returns true if the given operation 'op', that is registered to this
206  /// dialect, can be inlined into the given region, false otherwise.
207  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
208  IRMapping &valueMapping) const final {
209  // Always allow inlining affine operations into a region that is marked as
210  // affine scope, or into affine loops and conditionals. There are some edge
211  // cases when inlining *into* affine structures, but that is handled in the
212  // other 'isLegalToInline' hook above.
213  Operation *parentOp = region->getParentOp();
214  return parentOp->hasTrait<OpTrait::AffineScope>() ||
215  isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
216  }
217 
218  /// Affine regions should be analyzed recursively.
219  bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
220 };
221 } // namespace
222 
223 //===----------------------------------------------------------------------===//
224 // AffineDialect
225 //===----------------------------------------------------------------------===//
226 
227 void AffineDialect::initialize() {
228  addOperations<AffineDmaStartOp, AffineDmaWaitOp,
229 #define GET_OP_LIST
230 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
231  >();
232  addInterfaces<AffineInlinerInterface>();
233  declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
234  AffineMinOp>();
235 }
236 
237 /// Materialize a single constant operation from a given attribute value with
238 /// the desired resultant type.
240  Attribute value, Type type,
241  Location loc) {
242  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
243  return builder.create<ub::PoisonOp>(loc, type, poison);
244  return arith::ConstantOp::materialize(builder, value, type, loc);
245 }
246 
247 /// A utility function to check if a value is defined at the top level of an
248 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
249 /// conservatively assume it is not top-level. A value of index type defined at
250 /// the top level is always a valid symbol.
252  if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
253  // The block owning the argument may be unlinked, e.g. when the surrounding
254  // region has not yet been attached to an Op, at which point the parent Op
255  // is null.
256  Operation *parentOp = arg.getOwner()->getParentOp();
257  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
258  }
259  // The defining Op may live in an unlinked block so its parent Op may be null.
260  Operation *parentOp = value.getDefiningOp()->getParentOp();
261  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
262 }
263 
264 /// Returns the closest region enclosing `op` that is held by an operation with
265 /// trait `AffineScope`; `nullptr` if there is no such region.
267  auto *curOp = op;
268  while (auto *parentOp = curOp->getParentOp()) {
269  if (parentOp->hasTrait<OpTrait::AffineScope>())
270  return curOp->getParentRegion();
271  curOp = parentOp;
272  }
273  return nullptr;
274 }
275 
277  Operation *curOp = op;
278  while (auto *parentOp = curOp->getParentOp()) {
279  if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))
280  return curOp->getParentRegion();
281  curOp = parentOp;
282  }
283  return nullptr;
284 }
285 
286 // A Value can be used as a dimension id iff it meets one of the following
287 // conditions:
288 // *) It is valid as a symbol.
289 // *) It is an induction variable.
290 // *) It is the result of affine apply operation with dimension id arguments.
292  // The value must be an index type.
293  if (!value.getType().isIndex())
294  return false;
295 
296  if (auto *defOp = value.getDefiningOp())
297  return isValidDim(value, getAffineScope(defOp));
298 
299  // This value has to be a block argument for an op that has the
300  // `AffineScope` trait or an induction var of an affine.for or
301  // affine.parallel.
302  if (isAffineInductionVar(value))
303  return true;
304  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
305  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
306 }
307 
308 // Value can be used as a dimension id iff it meets one of the following
309 // conditions:
310 // *) It is valid as a symbol.
311 // *) It is an induction variable.
312 // *) It is the result of an affine apply operation with dimension id operands.
313 // *) It is the result of a more specialized index transformation (ex.
314 // delinearize_index or linearize_index) with dimension id operands.
315 bool mlir::affine::isValidDim(Value value, Region *region) {
316  // The value must be an index type.
317  if (!value.getType().isIndex())
318  return false;
319 
320  // All valid symbols are okay.
321  if (isValidSymbol(value, region))
322  return true;
323 
324  auto *op = value.getDefiningOp();
325  if (!op) {
326  // This value has to be an induction var for an affine.for or an
327  // affine.parallel.
328  return isAffineInductionVar(value);
329  }
330 
331  // Affine apply operation is ok if all of its operands are ok.
332  if (auto applyOp = dyn_cast<AffineApplyOp>(op))
333  return applyOp.isValidDim(region);
334  // delinearize_index and linearize_index are special forms of apply
335  // and so are valid dimensions if all their arguments are valid dimensions.
336  if (isa<AffineDelinearizeIndexOp, AffineLinearizeIndexOp>(op))
337  return llvm::all_of(op->getOperands(),
338  [&](Value arg) { return ::isValidDim(arg, region); });
339  // The dim op is okay if its operand memref/tensor is defined at the top
340  // level.
341  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
342  return isTopLevelValue(dimOp.getShapedValue());
343  return false;
344 }
345 
346 /// Returns true if the 'index' dimension of the `memref` defined by
347 /// `memrefDefOp` is a statically shaped one or defined using a valid symbol
348 /// for `region`.
349 template <typename AnyMemRefDefOp>
350 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
351  Region *region) {
352  MemRefType memRefType = memrefDefOp.getType();
353 
354  // Dimension index is out of bounds.
355  if (index >= memRefType.getRank()) {
356  return false;
357  }
358 
359  // Statically shaped.
360  if (!memRefType.isDynamicDim(index))
361  return true;
362  // Get the position of the dimension among dynamic dimensions;
363  unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
364  return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
365  region);
366 }
367 
368 /// Returns true if the result of the dim op is a valid symbol for `region`.
369 static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
370  // The dim op is okay if its source is defined at the top level.
371  if (isTopLevelValue(dimOp.getShapedValue()))
372  return true;
373 
374  // Conservatively handle remaining BlockArguments as non-valid symbols.
375  // E.g. scf.for iterArgs.
376  if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
377  return false;
378 
379  // The dim op is also okay if its operand memref is a view/subview whose
380  // corresponding size is a valid symbol.
381  std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
382 
383  // Be conservative if we can't understand the dimension.
384  if (!index.has_value())
385  return false;
386 
387  // Skip over all memref.cast ops (if any).
388  Operation *op = dimOp.getShapedValue().getDefiningOp();
389  while (auto castOp = dyn_cast<memref::CastOp>(op)) {
390  // Bail on unranked memrefs.
391  if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
392  return false;
393  op = castOp.getSource().getDefiningOp();
394  if (!op)
395  return false;
396  }
397 
398  int64_t i = index.value();
400  .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
401  [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
402  .Default([](Operation *) { return false; });
403 }
404 
405 // A value can be used as a symbol (at all its use sites) iff it meets one of
406 // the following conditions:
407 // *) It is a constant.
408 // *) Its defining op or block arg appearance is immediately enclosed by an op
409 // with `AffineScope` trait.
410 // *) It is the result of an affine.apply operation with symbol operands.
411 // *) It is a result of the dim op on a memref whose corresponding size is a
412 // valid symbol.
414  if (!value)
415  return false;
416 
417  // The value must be an index type.
418  if (!value.getType().isIndex())
419  return false;
420 
421  // Check that the value is a top level value.
422  if (isTopLevelValue(value))
423  return true;
424 
425  if (auto *defOp = value.getDefiningOp())
426  return isValidSymbol(value, getAffineScope(defOp));
427 
428  return false;
429 }
430 
431 /// A value can be used as a symbol for `region` iff it meets one of the
432 /// following conditions:
433 /// *) It is a constant.
434 /// *) It is a result of a `Pure` operation whose operands are valid symbolic
435 /// *) identifiers.
436 /// *) It is a result of the dim op on a memref whose corresponding size is
437 /// a valid symbol.
438 /// *) It is defined at the top level of 'region' or is its argument.
439 /// *) It dominates `region`'s parent op.
440 /// If `region` is null, conservatively assume the symbol definition scope does
441 /// not exist and only accept the values that would be symbols regardless of
442 /// the surrounding region structure, i.e. the first three cases above.
444  // The value must be an index type.
445  if (!value.getType().isIndex())
446  return false;
447 
448  // A top-level value is a valid symbol.
449  if (region && ::isTopLevelValue(value, region))
450  return true;
451 
452  auto *defOp = value.getDefiningOp();
453  if (!defOp) {
454  // A block argument that is not a top-level value is a valid symbol if it
455  // dominates region's parent op.
456  Operation *regionOp = region ? region->getParentOp() : nullptr;
457  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
458  if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
459  return isValidSymbol(value, parentOpRegion);
460  return false;
461  }
462 
463  // Constant operation is ok.
464  Attribute operandCst;
465  if (matchPattern(defOp, m_Constant(&operandCst)))
466  return true;
467 
468  // `Pure` operation that whose operands are valid symbolic identifiers.
469  if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
470  return affine::isValidSymbol(operand, region);
471  })) {
472  return true;
473  }
474 
475  // Dim op results could be valid symbols at any level.
476  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
477  return isDimOpValidSymbol(dimOp, region);
478 
479  // Check for values dominating `region`'s parent op.
480  Operation *regionOp = region ? region->getParentOp() : nullptr;
481  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
482  if (auto *parentRegion = region->getParentOp()->getParentRegion())
483  return isValidSymbol(value, parentRegion);
484 
485  return false;
486 }
487 
488 // Returns true if 'value' is a valid index to an affine operation (e.g.
489 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
490 // `region` provides the polyhedral symbol scope. Returns false otherwise.
491 static bool isValidAffineIndexOperand(Value value, Region *region) {
492  return isValidDim(value, region) || isValidSymbol(value, region);
493 }
494 
495 /// Prints dimension and symbol list.
498  unsigned numDims, OpAsmPrinter &printer) {
499  OperandRange operands(begin, end);
500  printer << '(' << operands.take_front(numDims) << ')';
501  if (operands.size() > numDims)
502  printer << '[' << operands.drop_front(numDims) << ']';
503 }
504 
505 /// Parses dimension and symbol list and returns true if parsing failed.
507  OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
509  if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
510  return failure();
511  // Store number of dimensions for validation by caller.
512  numDims = opInfos.size();
513 
514  // Parse the optional symbol operands.
515  auto indexTy = parser.getBuilder().getIndexType();
516  return failure(parser.parseOperandList(
518  parser.resolveOperands(opInfos, indexTy, operands));
519 }
520 
521 /// Utility function to verify that a set of operands are valid dimension and
522 /// symbol identifiers. The operands should be laid out such that the dimension
523 /// operands are before the symbol operands. This function returns failure if
524 /// there was an invalid operand. An operation is provided to emit any necessary
525 /// errors.
526 template <typename OpTy>
527 static LogicalResult
529  unsigned numDims) {
530  unsigned opIt = 0;
531  for (auto operand : operands) {
532  if (opIt++ < numDims) {
533  if (!isValidDim(operand, getAffineScope(op)))
534  return op.emitOpError("operand cannot be used as a dimension id");
535  } else if (!isValidSymbol(operand, getAffineScope(op))) {
536  return op.emitOpError("operand cannot be used as a symbol");
537  }
538  }
539  return success();
540 }
541 
542 //===----------------------------------------------------------------------===//
543 // AffineApplyOp
544 //===----------------------------------------------------------------------===//
545 
546 AffineValueMap AffineApplyOp::getAffineValueMap() {
547  return AffineValueMap(getAffineMap(), getOperands(), getResult());
548 }
549 
550 ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
551  auto &builder = parser.getBuilder();
552  auto indexTy = builder.getIndexType();
553 
554  AffineMapAttr mapAttr;
555  unsigned numDims;
556  if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
557  parseDimAndSymbolList(parser, result.operands, numDims) ||
558  parser.parseOptionalAttrDict(result.attributes))
559  return failure();
560  auto map = mapAttr.getValue();
561 
562  if (map.getNumDims() != numDims ||
563  numDims + map.getNumSymbols() != result.operands.size()) {
564  return parser.emitError(parser.getNameLoc(),
565  "dimension or symbol index mismatch");
566  }
567 
568  result.types.append(map.getNumResults(), indexTy);
569  return success();
570 }
571 
573  p << " " << getMapAttr();
574  printDimAndSymbolList(operand_begin(), operand_end(),
575  getAffineMap().getNumDims(), p);
576  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
577 }
578 
579 LogicalResult AffineApplyOp::verify() {
580  // Check input and output dimensions match.
581  AffineMap affineMap = getMap();
582 
583  // Verify that operand count matches affine map dimension and symbol count.
584  if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
585  return emitOpError(
586  "operand count and affine map dimension and symbol count must match");
587 
588  // Verify that the map only produces one result.
589  if (affineMap.getNumResults() != 1)
590  return emitOpError("mapping must produce one value");
591 
592  // Do not allow valid dims to be used in symbol positions. We do allow
593  // affine.apply to use operands for values that may neither qualify as affine
594  // dims or affine symbols due to usage outside of affine ops, analyses, etc.
595  Region *region = getAffineScope(*this);
596  for (Value operand : getMapOperands().drop_front(affineMap.getNumDims())) {
597  if (::isValidDim(operand, region) && !::isValidSymbol(operand, region))
598  return emitError("dimensional operand cannot be used as a symbol");
599  }
600 
601  return success();
602 }
603 
604 // The result of the affine apply operation can be used as a dimension id if all
605 // its operands are valid dimension ids.
607  return llvm::all_of(getOperands(),
608  [](Value op) { return affine::isValidDim(op); });
609 }
610 
611 // The result of the affine apply operation can be used as a dimension id if all
612 // its operands are valid dimension ids with the parent operation of `region`
613 // defining the polyhedral scope for symbols.
614 bool AffineApplyOp::isValidDim(Region *region) {
615  return llvm::all_of(getOperands(),
616  [&](Value op) { return ::isValidDim(op, region); });
617 }
618 
619 // The result of the affine apply operation can be used as a symbol if all its
620 // operands are symbols.
622  return llvm::all_of(getOperands(),
623  [](Value op) { return affine::isValidSymbol(op); });
624 }
625 
626 // The result of the affine apply operation can be used as a symbol in `region`
627 // if all its operands are symbols in `region`.
628 bool AffineApplyOp::isValidSymbol(Region *region) {
629  return llvm::all_of(getOperands(), [&](Value operand) {
630  return affine::isValidSymbol(operand, region);
631  });
632 }
633 
634 OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
635  auto map = getAffineMap();
636 
637  // Fold dims and symbols to existing values.
638  auto expr = map.getResult(0);
639  if (auto dim = dyn_cast<AffineDimExpr>(expr))
640  return getOperand(dim.getPosition());
641  if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
642  return getOperand(map.getNumDims() + sym.getPosition());
643 
644  // Otherwise, default to folding the map.
646  bool hasPoison = false;
647  auto foldResult =
648  map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
649  if (hasPoison)
651  if (failed(foldResult))
652  return {};
653  return result[0];
654 }
655 
656 /// Returns the largest known divisor of `e`. Exploits information from the
657 /// values in `operands`.
658 static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
659  // This method isn't aware of `operands`.
660  int64_t div = e.getLargestKnownDivisor();
661 
662  // We now make use of operands for the case `e` is a dim expression.
663  // TODO: More powerful simplification would have to modify
664  // getLargestKnownDivisor to take `operands` and exploit that information as
665  // well for dim/sym expressions, but in that case, getLargestKnownDivisor
666  // can't be part of the IR library but of the `Analysis` library. The IR
667  // library can only really depend on simple O(1) checks.
668  auto dimExpr = dyn_cast<AffineDimExpr>(e);
669  // If it's not a dim expr, `div` is the best we have.
670  if (!dimExpr)
671  return div;
672 
673  // We simply exploit information from loop IVs.
674  // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
675  // desired simplifications are expected to be part of other
676  // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
677  // LoopAnalysis library.
678  Value operand = operands[dimExpr.getPosition()];
679  int64_t operandDivisor = 1;
680  // TODO: With the right accessors, this can be extended to
681  // LoopLikeOpInterface.
682  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
683  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
684  operandDivisor = forOp.getStepAsInt();
685  } else {
686  uint64_t lbLargestKnownDivisor =
687  forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
688  operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
689  }
690  }
691  return operandDivisor;
692 }
693 
694 /// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
695 /// being an affine dim expression or a constant.
697  int64_t k) {
698  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
699  int64_t constVal = constExpr.getValue();
700  return constVal >= 0 && constVal < k;
701  }
702  auto dimExpr = dyn_cast<AffineDimExpr>(e);
703  if (!dimExpr)
704  return false;
705  Value operand = operands[dimExpr.getPosition()];
706  // TODO: With the right accessors, this can be extended to
707  // LoopLikeOpInterface.
708  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
709  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
710  forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
711  return true;
712  }
713  }
714 
715  // We don't consider other cases like `operand` being defined by a constant or
716  // an affine.apply op since such cases will already be handled by other
717  // patterns and propagation of loop IVs or constant would happen.
718  return false;
719 }
720 
721 /// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
722 /// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
723 /// expression is in that form.
724 static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
725  AffineExpr &quotientTimesDiv, AffineExpr &rem) {
726  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
727  if (!bin || bin.getKind() != AffineExprKind::Add)
728  return false;
729 
730  AffineExpr llhs = bin.getLHS();
731  AffineExpr rlhs = bin.getRHS();
732  div = getLargestKnownDivisor(llhs, operands);
733  if (isNonNegativeBoundedBy(rlhs, operands, div)) {
734  quotientTimesDiv = llhs;
735  rem = rlhs;
736  return true;
737  }
738  div = getLargestKnownDivisor(rlhs, operands);
739  if (isNonNegativeBoundedBy(llhs, operands, div)) {
740  quotientTimesDiv = rlhs;
741  rem = llhs;
742  return true;
743  }
744  return false;
745 }
746 
747 /// Gets the constant lower bound on an `iv`.
748 static std::optional<int64_t> getLowerBound(Value iv) {
749  AffineForOp forOp = getForInductionVarOwner(iv);
750  if (forOp && forOp.hasConstantLowerBound())
751  return forOp.getConstantLowerBound();
752  return std::nullopt;
753 }
754 
755 /// Gets the constant upper bound on an affine.for `iv`.
756 static std::optional<int64_t> getUpperBound(Value iv) {
757  AffineForOp forOp = getForInductionVarOwner(iv);
758  if (!forOp || !forOp.hasConstantUpperBound())
759  return std::nullopt;
760 
761  // If its lower bound is also known, we can get a more precise bound
762  // whenever the step is not one.
763  if (forOp.hasConstantLowerBound()) {
764  return forOp.getConstantUpperBound() - 1 -
765  (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
766  forOp.getStepAsInt();
767  }
768  return forOp.getConstantUpperBound() - 1;
769 }
770 
771 /// Determine a constant upper bound for `expr` if one exists while exploiting
772 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
773 /// is guaranteed to be less than or equal to it.
774 static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
775  unsigned numSymbols,
776  ArrayRef<Value> operands) {
777  // Get the constant lower or upper bounds on the operands.
778  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
779  constLowerBounds.reserve(operands.size());
780  constUpperBounds.reserve(operands.size());
781  for (Value operand : operands) {
782  constLowerBounds.push_back(getLowerBound(operand));
783  constUpperBounds.push_back(getUpperBound(operand));
784  }
785 
786  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
787  return constExpr.getValue();
788 
789  return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
790  constUpperBounds,
791  /*isUpper=*/true);
792 }
793 
794 /// Determine a constant lower bound for `expr` if one exists while exploiting
795 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
796 /// is guaranteed to be less than or equal to it.
797 static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
798  unsigned numSymbols,
799  ArrayRef<Value> operands) {
800  // Get the constant lower or upper bounds on the operands.
801  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
802  constLowerBounds.reserve(operands.size());
803  constUpperBounds.reserve(operands.size());
804  for (Value operand : operands) {
805  constLowerBounds.push_back(getLowerBound(operand));
806  constUpperBounds.push_back(getUpperBound(operand));
807  }
808 
809  std::optional<int64_t> lowerBound;
810  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
811  lowerBound = constExpr.getValue();
812  } else {
813  lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
814  constLowerBounds, constUpperBounds,
815  /*isUpper=*/false);
816  }
817  return lowerBound;
818 }
819 
820 /// Simplify `expr` while exploiting information from the values in `operands`.
821 static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
822  unsigned numSymbols,
823  ArrayRef<Value> operands) {
824  // We do this only for certain floordiv/mod expressions.
825  auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
826  if (!binExpr)
827  return;
828 
829  // Simplify the child expressions first.
830  AffineExpr lhs = binExpr.getLHS();
831  AffineExpr rhs = binExpr.getRHS();
832  simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
833  simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
834  expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
835 
836  binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
837  if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
838  expr.getKind() != AffineExprKind::CeilDiv &&
839  expr.getKind() != AffineExprKind::Mod)) {
840  return;
841  }
842 
843  // The `lhs` and `rhs` may be different post construction of simplified expr.
844  lhs = binExpr.getLHS();
845  rhs = binExpr.getRHS();
846  auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
847  if (!rhsConst)
848  return;
849 
850  int64_t rhsConstVal = rhsConst.getValue();
851  // Undefined exprsessions aren't touched; IR can still be valid with them.
852  if (rhsConstVal <= 0)
853  return;
854 
855  // Exploit constant lower/upper bounds to simplify a floordiv or mod.
856  MLIRContext *context = expr.getContext();
857  std::optional<int64_t> lhsLbConst =
858  getLowerBound(lhs, numDims, numSymbols, operands);
859  std::optional<int64_t> lhsUbConst =
860  getUpperBound(lhs, numDims, numSymbols, operands);
861  if (lhsLbConst && lhsUbConst) {
862  int64_t lhsLbConstVal = *lhsLbConst;
863  int64_t lhsUbConstVal = *lhsUbConst;
864  // lhs floordiv c is a single value lhs is bounded in a range `c` that has
865  // the same quotient.
866  if (binExpr.getKind() == AffineExprKind::FloorDiv &&
867  divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
868  divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
869  expr = getAffineConstantExpr(
870  divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
871  return;
872  }
873  // lhs ceildiv c is a single value if the entire range has the same ceil
874  // quotient.
875  if (binExpr.getKind() == AffineExprKind::CeilDiv &&
876  divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
877  divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
878  expr = getAffineConstantExpr(divideCeilSigned(lhsLbConstVal, rhsConstVal),
879  context);
880  return;
881  }
882  // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
883  if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
884  lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
885  expr = lhs;
886  return;
887  }
888  }
889 
890  // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
891  // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
892  // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
893  // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
894  AffineExpr quotientTimesDiv, rem;
895  int64_t divisor;
896  if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
897  if (rhsConstVal % divisor == 0 &&
898  binExpr.getKind() == AffineExprKind::FloorDiv) {
899  expr = quotientTimesDiv.floorDiv(rhsConst);
900  } else if (divisor % rhsConstVal == 0 &&
901  binExpr.getKind() == AffineExprKind::Mod) {
902  expr = rem % rhsConst;
903  }
904  return;
905  }
906 
907  // Handle the simple case when the LHS expression can be either upper
908  // bounded or is a known multiple of RHS constant.
909  // lhs floordiv c -> 0 if 0 <= lhs < c,
910  // lhs mod c -> 0 if lhs % c = 0.
911  if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) &&
912  binExpr.getKind() == AffineExprKind::FloorDiv) ||
913  (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 &&
914  binExpr.getKind() == AffineExprKind::Mod)) {
915  expr = getAffineConstantExpr(0, expr.getContext());
916  }
917 }
918 
919 /// Simplify the expressions in `map` while making use of lower or upper bounds
920 /// of its operands. If `isMax` is true, the map is to be treated as a max of
921 /// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
922 /// d1) can be simplified to (8) if the operands are respectively lower bounded
923 /// by 2 and 0 (the second expression can't be lower than 8).
925  ArrayRef<Value> operands,
926  bool isMax) {
927  // Can't simplify.
928  if (operands.empty())
929  return;
930 
931  // Get the upper or lower bound on an affine.for op IV using its range.
932  // Get the constant lower or upper bounds on the operands.
933  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
934  constLowerBounds.reserve(operands.size());
935  constUpperBounds.reserve(operands.size());
936  for (Value operand : operands) {
937  constLowerBounds.push_back(getLowerBound(operand));
938  constUpperBounds.push_back(getUpperBound(operand));
939  }
940 
941  // We will compute the lower and upper bounds on each of the expressions
942  // Then, we will check (depending on max or min) as to whether a specific
943  // bound is redundant by checking if its highest (in case of max) and its
944  // lowest (in the case of min) value is already lower than (or higher than)
945  // the lower bound (or upper bound in the case of min) of another bound.
946  SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
947  lowerBounds.reserve(map.getNumResults());
948  upperBounds.reserve(map.getNumResults());
949  for (AffineExpr e : map.getResults()) {
950  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
951  lowerBounds.push_back(constExpr.getValue());
952  upperBounds.push_back(constExpr.getValue());
953  } else {
954  lowerBounds.push_back(
956  constLowerBounds, constUpperBounds,
957  /*isUpper=*/false));
958  upperBounds.push_back(
960  constLowerBounds, constUpperBounds,
961  /*isUpper=*/true));
962  }
963  }
964 
965  // Collect expressions that are not redundant.
966  SmallVector<AffineExpr, 4> irredundantExprs;
967  for (auto exprEn : llvm::enumerate(map.getResults())) {
968  AffineExpr e = exprEn.value();
969  unsigned i = exprEn.index();
970  // Some expressions can be turned into constants.
971  if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
972  e = getAffineConstantExpr(*lowerBounds[i], e.getContext());
973 
974  // Check if the expression is redundant.
975  if (isMax) {
976  if (!upperBounds[i]) {
977  irredundantExprs.push_back(e);
978  continue;
979  }
980  // If there exists another expression such that its lower bound is greater
981  // than this expression's upper bound, it's redundant.
982  if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
983  auto otherLowerBound = en.value();
984  unsigned pos = en.index();
985  if (pos == i || !otherLowerBound)
986  return false;
987  if (*otherLowerBound > *upperBounds[i])
988  return true;
989  if (*otherLowerBound < *upperBounds[i])
990  return false;
991  // Equality case. When both expressions are considered redundant, we
992  // don't want to get both of them. We keep the one that appears
993  // first.
994  if (upperBounds[pos] && lowerBounds[i] &&
995  lowerBounds[i] == upperBounds[i] &&
996  otherLowerBound == *upperBounds[pos] && i < pos)
997  return false;
998  return true;
999  }))
1000  irredundantExprs.push_back(e);
1001  } else {
1002  if (!lowerBounds[i]) {
1003  irredundantExprs.push_back(e);
1004  continue;
1005  }
1006  // Likewise for the `min` case. Use the complement of the condition above.
1007  if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
1008  auto otherUpperBound = en.value();
1009  unsigned pos = en.index();
1010  if (pos == i || !otherUpperBound)
1011  return false;
1012  if (*otherUpperBound < *lowerBounds[i])
1013  return true;
1014  if (*otherUpperBound > *lowerBounds[i])
1015  return false;
1016  if (lowerBounds[pos] && upperBounds[i] &&
1017  lowerBounds[i] == upperBounds[i] &&
1018  otherUpperBound == lowerBounds[pos] && i < pos)
1019  return false;
1020  return true;
1021  }))
1022  irredundantExprs.push_back(e);
1023  }
1024  }
1025 
1026  // Create the map without the redundant expressions.
1027  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs,
1028  map.getContext());
1029 }
1030 
1031 /// Simplify the map while exploiting information on the values in `operands`.
1032 // Use "unused attribute" marker to silence warning stemming from the inability
1033 // to see through the template expansion.
1034 static void LLVM_ATTRIBUTE_UNUSED
1036  assert(map.getNumInputs() == operands.size() && "invalid operands for map");
1037  SmallVector<AffineExpr> newResults;
1038  newResults.reserve(map.getNumResults());
1039  for (AffineExpr expr : map.getResults()) {
1041  operands);
1042  newResults.push_back(expr);
1043  }
1044  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
1045  map.getContext());
1046 }
1047 
1048 /// Assuming `dimOrSym` is a quantity in the apply op map `map` and defined by
1049 /// `minOp = affine_min(x_1, ..., x_n)`. This function checks that:
1050 /// `0 < affine_min(x_1, ..., x_n)` and proceeds with replacing the patterns:
1051 /// ```
1052 /// dimOrSym.ceildiv(x_k)
1053 /// (dimOrSym + x_k - 1).floordiv(x_k)
1054 /// ```
1055 /// by `1` for all `k` in `1, ..., n`. This is possible because `x / x_k <= 1`.
1056 ///
1057 ///
1058 /// Warning: ValueBoundsConstraintSet::computeConstantBound is needed to check
1059 /// `minOp` is positive.
1060 static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
1061  AffineExpr dimOrSym,
1062  AffineMap *map,
1063  ValueRange dims,
1064  ValueRange syms) {
1065  AffineMap affineMinMap = minOp.getAffineMap();
1066 
1067  LLVM_DEBUG({
1068  DBGS() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`\n";
1069  });
1070 
1071  // Check the value is positive.
1072  for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
1073  // Compare each expression in the minimum against 0.
1075  getAsIndexOpFoldResult(minOp.getContext(), 0),
1076  ValueBoundsConstraintSet::ComparisonOperator::LT,
1078  minOp.getOperands())))
1079  return failure();
1080  }
1081 
1082  /// Convert affine symbols and dimensions in minOp to symbols or dimensions in
1083  /// the apply op affine map.
1084  DenseMap<AffineExpr, AffineExpr> dimSymConversionTable;
1085  SmallVector<unsigned> unmappedDims, unmappedSyms;
1086  for (auto [i, dim] : llvm::enumerate(minOp.getDimOperands())) {
1087  auto it = llvm::find(dims, dim);
1088  if (it == dims.end()) {
1089  unmappedDims.push_back(i);
1090  continue;
1091  }
1092  dimSymConversionTable[getAffineDimExpr(i, minOp.getContext())] =
1093  getAffineDimExpr(it.getIndex(), minOp.getContext());
1094  }
1095  for (auto [i, sym] : llvm::enumerate(minOp.getSymbolOperands())) {
1096  auto it = llvm::find(syms, sym);
1097  if (it == syms.end()) {
1098  unmappedSyms.push_back(i);
1099  continue;
1100  }
1101  dimSymConversionTable[getAffineSymbolExpr(i, minOp.getContext())] =
1102  getAffineSymbolExpr(it.getIndex(), minOp.getContext());
1103  }
1104 
1105  // Create the replacement map.
1107  AffineExpr c1 = getAffineConstantExpr(1, minOp.getContext());
1108  for (AffineExpr expr : affineMinMap.getResults()) {
1109  // If we cannot express the result in terms of the apply map symbols and
1110  // sims then continue.
1111  if (llvm::any_of(unmappedDims,
1112  [&](unsigned i) { return expr.isFunctionOfDim(i); }) ||
1113  llvm::any_of(unmappedSyms,
1114  [&](unsigned i) { return expr.isFunctionOfSymbol(i); }))
1115  continue;
1116 
1117  AffineExpr convertedExpr = expr.replace(dimSymConversionTable);
1118 
1119  // dimOrSym.ceilDiv(expr) -> 1
1120  repl[dimOrSym.ceilDiv(convertedExpr)] = c1;
1121  // (dimOrSym + expr - 1).floorDiv(expr) -> 1
1122  repl[(dimOrSym + convertedExpr - 1).floorDiv(convertedExpr)] = c1;
1123  }
1124  AffineMap initialMap = *map;
1125  *map = initialMap.replace(repl, initialMap.getNumDims(),
1126  initialMap.getNumSymbols());
1127  return success(*map != initialMap);
1128 }
1129 
1130 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
1131 /// defining AffineApplyOp expression and operands.
1132 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
1133 /// When `dimOrSymbolPosition >= dims.size()`,
1134 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
1135 /// Mutate `map`,`dims` and `syms` in place as follows:
1136 /// 1. `dims` and `syms` are only appended to.
1137 /// 2. `map` dim and symbols are gradually shifted to higher positions.
1138 /// 3. Old `dim` and `sym` entries are replaced by nullptr
1139 /// This avoids the need for any bookkeeping.
1140 /// If `replaceAffineMin` is set to true, additionally triggers more expensive
1141 /// replacements involving affine_min operations.
1142 static LogicalResult replaceDimOrSym(AffineMap *map,
1143  unsigned dimOrSymbolPosition,
1144  SmallVectorImpl<Value> &dims,
1145  SmallVectorImpl<Value> &syms,
1146  bool replaceAffineMin) {
1147  MLIRContext *ctx = map->getContext();
1148  bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1149  unsigned pos = isDimReplacement ? dimOrSymbolPosition
1150  : dimOrSymbolPosition - dims.size();
1151  Value &v = isDimReplacement ? dims[pos] : syms[pos];
1152  if (!v)
1153  return failure();
1154 
1155  if (auto minOp = v.getDefiningOp<AffineMinOp>(); minOp && replaceAffineMin) {
1156  AffineExpr dimOrSym = isDimReplacement ? getAffineDimExpr(pos, ctx)
1157  : getAffineSymbolExpr(pos, ctx);
1158  return replaceAffineMinBoundingBoxExpression(minOp, dimOrSym, map, dims,
1159  syms);
1160  }
1161 
1162  auto affineApply = v.getDefiningOp<AffineApplyOp>();
1163  if (!affineApply)
1164  return failure();
1165 
1166  // At this point we will perform a replacement of `v`, set the entry in `dim`
1167  // or `sym` to nullptr immediately.
1168  v = nullptr;
1169 
1170  // Compute the map, dims and symbols coming from the AffineApplyOp.
1171  AffineMap composeMap = affineApply.getAffineMap();
1172  assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1173  SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1174  affineApply.getMapOperands().end());
1175  // Canonicalize the map to promote dims to symbols when possible. This is to
1176  // avoid generating invalid maps.
1177  canonicalizeMapAndOperands(&composeMap, &composeOperands);
1178  AffineExpr replacementExpr =
1179  composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
1180  ValueRange composeDims =
1181  ArrayRef<Value>(composeOperands).take_front(composeMap.getNumDims());
1182  ValueRange composeSyms =
1183  ArrayRef<Value>(composeOperands).take_back(composeMap.getNumSymbols());
1184  AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
1185  : getAffineSymbolExpr(pos, ctx);
1186 
1187  // Append the dims and symbols where relevant and perform the replacement.
1188  dims.append(composeDims.begin(), composeDims.end());
1189  syms.append(composeSyms.begin(), composeSyms.end());
1190  *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());
1191 
1192  return success();
1193 }
1194 
1195 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1196 /// iteratively. Perform canonicalization of map and operands as well as
1197 /// AffineMap simplification. `map` and `operands` are mutated in place.
1199  SmallVectorImpl<Value> *operands,
1200  bool composeAffineMin = false) {
1201  if (map->getNumResults() == 0) {
1202  canonicalizeMapAndOperands(map, operands);
1203  *map = simplifyAffineMap(*map);
1204  return;
1205  }
1206 
1207  MLIRContext *ctx = map->getContext();
1208  SmallVector<Value, 4> dims(operands->begin(),
1209  operands->begin() + map->getNumDims());
1210  SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1211  operands->end());
1212 
1213  // Iterate over dims and symbols coming from AffineApplyOp and replace until
1214  // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1215  // and `syms` can only increase by construction.
1216  // The implementation uses a `while` loop to support the case of symbols
1217  // that may be constructed from dims ;this may be overkill.
1218  while (true) {
1219  bool changed = false;
1220  for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1221  if ((changed |=
1222  succeeded(replaceDimOrSym(map, pos, dims, syms, composeAffineMin))))
1223  break;
1224  if (!changed)
1225  break;
1226  }
1227 
1228  // Clear operands so we can fill them anew.
1229  operands->clear();
1230 
1231  // At this point we may have introduced null operands, prune them out before
1232  // canonicalizing map and operands.
1233  unsigned nDims = 0, nSyms = 0;
1234  SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1235  dimReplacements.reserve(dims.size());
1236  symReplacements.reserve(syms.size());
1237  for (auto *container : {&dims, &syms}) {
1238  bool isDim = (container == &dims);
1239  auto &repls = isDim ? dimReplacements : symReplacements;
1240  for (const auto &en : llvm::enumerate(*container)) {
1241  Value v = en.value();
1242  if (!v) {
1243  assert(isDim ? !map->isFunctionOfDim(en.index())
1244  : !map->isFunctionOfSymbol(en.index()) &&
1245  "map is function of unexpected expr@pos");
1246  repls.push_back(getAffineConstantExpr(0, ctx));
1247  continue;
1248  }
1249  repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
1250  : getAffineSymbolExpr(nSyms++, ctx));
1251  operands->push_back(v);
1252  }
1253  }
1254  *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
1255  nSyms);
1256 
1257  // Canonicalize and simplify before returning.
1258  canonicalizeMapAndOperands(map, operands);
1259  *map = simplifyAffineMap(*map);
1260 }
1261 
1263  AffineMap *map, SmallVectorImpl<Value> *operands, bool composeAffineMin) {
1264  while (llvm::any_of(*operands, [](Value v) {
1265  return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
1266  })) {
1267  composeAffineMapAndOperands(map, operands, composeAffineMin);
1268  }
1269  // Additional trailing step for AffineMinOps in case no chains of AffineApply.
1270  if (composeAffineMin && llvm::any_of(*operands, [](Value v) {
1271  return isa_and_nonnull<AffineMinOp>(v.getDefiningOp());
1272  })) {
1273  composeAffineMapAndOperands(map, operands, composeAffineMin);
1274  }
1275 }
1276 
1277 AffineApplyOp
1279  ArrayRef<OpFoldResult> operands,
1280  bool composeAffineMin) {
1281  SmallVector<Value> valueOperands;
1282  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1283  composeAffineMapAndOperands(&map, &valueOperands, composeAffineMin);
1284  assert(map);
1285  return b.create<AffineApplyOp>(loc, map, valueOperands);
1286 }
1287 
1288 AffineApplyOp
1290  ArrayRef<OpFoldResult> operands,
1291  bool composeAffineMin) {
1292  return makeComposedAffineApply(
1293  b, loc,
1295  .front(),
1296  operands, composeAffineMin);
1297 }
1298 
1299 /// Composes the given affine map with the given list of operands, pulling in
1300 /// the maps from any affine.apply operations that supply the operands.
1302  SmallVectorImpl<Value> &operands,
1303  bool composeAffineMin = false) {
1304  // Compose and canonicalize each expression in the map individually because
1305  // composition only applies to single-result maps, collecting potentially
1306  // duplicate operands in a single list with shifted dimensions and symbols.
1307  SmallVector<Value> dims, symbols;
1309  for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
1310  SmallVector<Value> submapOperands(operands.begin(), operands.end());
1311  AffineMap submap = map.getSubMap({i});
1312  fullyComposeAffineMapAndOperands(&submap, &submapOperands,
1313  composeAffineMin);
1314  canonicalizeMapAndOperands(&submap, &submapOperands);
1315  unsigned numNewDims = submap.getNumDims();
1316  submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
1317  llvm::append_range(dims,
1318  ArrayRef<Value>(submapOperands).take_front(numNewDims));
1319  llvm::append_range(symbols,
1320  ArrayRef<Value>(submapOperands).drop_front(numNewDims));
1321  exprs.push_back(submap.getResult(0));
1322  }
1323 
1324  // Canonicalize the map created from composed expressions to deduplicate the
1325  // dimension and symbol operands.
1326  operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1327  map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
1328  canonicalizeMapAndOperands(&map, &operands);
1329 }
1330 
1332  OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands,
1333  bool composeAffineMin) {
1334  assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1335 
1336  // Create new builder without a listener, so that no notification is
1337  // triggered if the op is folded.
1338  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1339  // workaround is no longer needed.
1340  OpBuilder newBuilder(b.getContext());
1342 
1343  // Create op.
1344  AffineApplyOp applyOp =
1345  makeComposedAffineApply(newBuilder, loc, map, operands, composeAffineMin);
1346 
1347  // Get constant operands.
1348  SmallVector<Attribute> constOperands(applyOp->getNumOperands());
1349  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1350  matchPattern(applyOp->getOperand(i), m_Constant(&constOperands[i]));
1351 
1352  // Try to fold the operation.
1353  SmallVector<OpFoldResult> foldResults;
1354  if (failed(applyOp->fold(constOperands, foldResults)) ||
1355  foldResults.empty()) {
1356  if (OpBuilder::Listener *listener = b.getListener())
1357  listener->notifyOperationInserted(applyOp, /*previous=*/{});
1358  return applyOp.getResult();
1359  }
1360 
1361  applyOp->erase();
1362  return llvm::getSingleElement(foldResults);
1363 }
1364 
1366  OpBuilder &b, Location loc, AffineExpr expr,
1367  ArrayRef<OpFoldResult> operands, bool composeAffineMin) {
1369  b, loc,
1371  .front(),
1372  operands, composeAffineMin);
1373 }
1374 
1377  OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands,
1378  bool composeAffineMin) {
1379  return llvm::map_to_vector(
1380  llvm::seq<unsigned>(0, map.getNumResults()), [&](unsigned i) {
1381  return makeComposedFoldedAffineApply(b, loc, map.getSubMap({i}),
1382  operands, composeAffineMin);
1383  });
1384 }
1385 
1386 template <typename OpTy>
1388  ArrayRef<OpFoldResult> operands) {
1389  SmallVector<Value> valueOperands;
1390  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1391  composeMultiResultAffineMap(map, valueOperands);
1392  return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands);
1393 }
1394 
1395 AffineMinOp
1397  ArrayRef<OpFoldResult> operands) {
1398  return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1399 }
1400 
1401 template <typename OpTy>
1403  AffineMap map,
1404  ArrayRef<OpFoldResult> operands) {
1405  // Create new builder without a listener, so that no notification is
1406  // triggered if the op is folded.
1407  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1408  // workaround is no longer needed.
1409  OpBuilder newBuilder(b.getContext());
1411 
1412  // Create op.
1413  auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1414 
1415  // Get constant operands.
1416  SmallVector<Attribute> constOperands(minMaxOp->getNumOperands());
1417  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1418  matchPattern(minMaxOp->getOperand(i), m_Constant(&constOperands[i]));
1419 
1420  // Try to fold the operation.
1421  SmallVector<OpFoldResult> foldResults;
1422  if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1423  foldResults.empty()) {
1424  if (OpBuilder::Listener *listener = b.getListener())
1425  listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
1426  return minMaxOp.getResult();
1427  }
1428 
1429  minMaxOp->erase();
1430  return llvm::getSingleElement(foldResults);
1431 }
1432 
1435  AffineMap map,
1436  ArrayRef<OpFoldResult> operands) {
1437  return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1438 }
1439 
1442  AffineMap map,
1443  ArrayRef<OpFoldResult> operands) {
1444  return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1445 }
1446 
1447 // A symbol may appear as a dim in affine.apply operations. This function
1448 // canonicalizes dims that are valid symbols into actual symbols.
1449 template <class MapOrSet>
1450 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1451  SmallVectorImpl<Value> *operands) {
1452  if (!mapOrSet || operands->empty())
1453  return;
1454 
1455  assert(mapOrSet->getNumInputs() == operands->size() &&
1456  "map/set inputs must match number of operands");
1457 
1458  auto *context = mapOrSet->getContext();
1459  SmallVector<Value, 8> resultOperands;
1460  resultOperands.reserve(operands->size());
1461  SmallVector<Value, 8> remappedSymbols;
1462  remappedSymbols.reserve(operands->size());
1463  unsigned nextDim = 0;
1464  unsigned nextSym = 0;
1465  unsigned oldNumSyms = mapOrSet->getNumSymbols();
1466  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1467  for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1468  if (i < mapOrSet->getNumDims()) {
1469  if (isValidSymbol((*operands)[i])) {
1470  // This is a valid symbol that appears as a dim, canonicalize it.
1471  dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1472  remappedSymbols.push_back((*operands)[i]);
1473  } else {
1474  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1475  resultOperands.push_back((*operands)[i]);
1476  }
1477  } else {
1478  resultOperands.push_back((*operands)[i]);
1479  }
1480  }
1481 
1482  resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1483  *operands = resultOperands;
1484  *mapOrSet = mapOrSet->replaceDimsAndSymbols(
1485  dimRemapping, /*symReplacements=*/{}, nextDim, oldNumSyms + nextSym);
1486 
1487  assert(mapOrSet->getNumInputs() == operands->size() &&
1488  "map/set inputs must match number of operands");
1489 }
1490 
1491 /// A valid affine dimension may appear as a symbol in affine.apply operations.
1492 /// Given an application of `operands` to an affine map or integer set
1493 /// `mapOrSet`, this function canonicalizes symbols of `mapOrSet` that are valid
1494 /// dims, but not valid symbols into actual dims. Without such a legalization,
1495 /// the affine.apply will be invalid. This method is the exact inverse of
1496 /// canonicalizePromotedSymbols.
1497 template <class MapOrSet>
1498 static void legalizeDemotedDims(MapOrSet &mapOrSet,
1499  SmallVectorImpl<Value> &operands) {
1500  if (!mapOrSet || operands.empty())
1501  return;
1502 
1503  unsigned numOperands = operands.size();
1504 
1505  assert(mapOrSet.getNumInputs() == numOperands &&
1506  "map/set inputs must match number of operands");
1507 
1508  auto *context = mapOrSet.getContext();
1509  SmallVector<Value, 8> resultOperands;
1510  resultOperands.reserve(numOperands);
1511  SmallVector<Value, 8> remappedDims;
1512  remappedDims.reserve(numOperands);
1513  SmallVector<Value, 8> symOperands;
1514  symOperands.reserve(mapOrSet.getNumSymbols());
1515  unsigned nextSym = 0;
1516  unsigned nextDim = 0;
1517  unsigned oldNumDims = mapOrSet.getNumDims();
1518  SmallVector<AffineExpr, 8> symRemapping(mapOrSet.getNumSymbols());
1519  resultOperands.assign(operands.begin(), operands.begin() + oldNumDims);
1520  for (unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
1521  if (operands[i] && isValidDim(operands[i]) && !isValidSymbol(operands[i])) {
1522  // This is a valid dim that appears as a symbol, legalize it.
1523  symRemapping[i - oldNumDims] =
1524  getAffineDimExpr(oldNumDims + nextDim++, context);
1525  remappedDims.push_back(operands[i]);
1526  } else {
1527  symRemapping[i - oldNumDims] = getAffineSymbolExpr(nextSym++, context);
1528  symOperands.push_back(operands[i]);
1529  }
1530  }
1531 
1532  append_range(resultOperands, remappedDims);
1533  append_range(resultOperands, symOperands);
1534  operands = resultOperands;
1535  mapOrSet = mapOrSet.replaceDimsAndSymbols(
1536  /*dimReplacements=*/{}, symRemapping, oldNumDims + nextDim, nextSym);
1537 
1538  assert(mapOrSet.getNumInputs() == operands.size() &&
1539  "map/set inputs must match number of operands");
1540 }
1541 
1542 // Works for either an affine map or an integer set.
1543 template <class MapOrSet>
1544 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1545  SmallVectorImpl<Value> *operands) {
1546  static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1547  "Argument must be either of AffineMap or IntegerSet type");
1548 
1549  if (!mapOrSet || operands->empty())
1550  return;
1551 
1552  assert(mapOrSet->getNumInputs() == operands->size() &&
1553  "map/set inputs must match number of operands");
1554 
1555  canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1556  legalizeDemotedDims<MapOrSet>(*mapOrSet, *operands);
1557 
1558  // Check to see what dims are used.
1559  llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1560  llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1561  mapOrSet->walkExprs([&](AffineExpr expr) {
1562  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1563  usedDims[dimExpr.getPosition()] = true;
1564  else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1565  usedSyms[symExpr.getPosition()] = true;
1566  });
1567 
1568  auto *context = mapOrSet->getContext();
1569 
1570  SmallVector<Value, 8> resultOperands;
1571  resultOperands.reserve(operands->size());
1572 
1573  llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1574  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1575  unsigned nextDim = 0;
1576  for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1577  if (usedDims[i]) {
1578  // Remap dim positions for duplicate operands.
1579  auto it = seenDims.find((*operands)[i]);
1580  if (it == seenDims.end()) {
1581  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1582  resultOperands.push_back((*operands)[i]);
1583  seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1584  } else {
1585  dimRemapping[i] = it->second;
1586  }
1587  }
1588  }
1589  llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1590  SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1591  unsigned nextSym = 0;
1592  for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1593  if (!usedSyms[i])
1594  continue;
1595  // Handle constant operands (only needed for symbolic operands since
1596  // constant operands in dimensional positions would have already been
1597  // promoted to symbolic positions above).
1598  IntegerAttr operandCst;
1599  if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1600  m_Constant(&operandCst))) {
1601  symRemapping[i] =
1602  getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1603  continue;
1604  }
1605  // Remap symbol positions for duplicate operands.
1606  auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1607  if (it == seenSymbols.end()) {
1608  symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1609  resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1610  seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1611  symRemapping[i]));
1612  } else {
1613  symRemapping[i] = it->second;
1614  }
1615  }
1616  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1617  nextDim, nextSym);
1618  *operands = resultOperands;
1619 }
1620 
1622  AffineMap *map, SmallVectorImpl<Value> *operands) {
1623  canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1624 }
1625 
1627  IntegerSet *set, SmallVectorImpl<Value> *operands) {
1628  canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1629 }
1630 
1631 namespace {
1632 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1633 /// maps that supply results into them.
1634 ///
1635 template <typename AffineOpTy>
1636 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1638 
1639  /// Replace the affine op with another instance of it with the supplied
1640  /// map and mapOperands.
1641  void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1642  AffineMap map, ArrayRef<Value> mapOperands) const;
1643 
1644  LogicalResult matchAndRewrite(AffineOpTy affineOp,
1645  PatternRewriter &rewriter) const override {
1646  static_assert(
1647  llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1648  AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1649  AffineVectorStoreOp, AffineVectorLoadOp>::value,
1650  "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1651  "expected");
1652  auto map = affineOp.getAffineMap();
1653  AffineMap oldMap = map;
1654  auto oldOperands = affineOp.getMapOperands();
1655  SmallVector<Value, 8> resultOperands(oldOperands);
1656  composeAffineMapAndOperands(&map, &resultOperands);
1657  canonicalizeMapAndOperands(&map, &resultOperands);
1658  simplifyMapWithOperands(map, resultOperands);
1659  if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1660  resultOperands.begin()))
1661  return failure();
1662 
1663  replaceAffineOp(rewriter, affineOp, map, resultOperands);
1664  return success();
1665  }
1666 };
1667 
1668 // Specialize the template to account for the different build signatures for
1669 // affine load, store, and apply ops.
1670 template <>
1671 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1672  PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1673  ArrayRef<Value> mapOperands) const {
1674  rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1675  mapOperands);
1676 }
1677 template <>
1678 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1679  PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1680  ArrayRef<Value> mapOperands) const {
1681  rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1682  prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1683  prefetch.getLocalityHint(), prefetch.getIsDataCache());
1684 }
1685 template <>
1686 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1687  PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1688  ArrayRef<Value> mapOperands) const {
1689  rewriter.replaceOpWithNewOp<AffineStoreOp>(
1690  store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1691 }
1692 template <>
1693 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1694  PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1695  ArrayRef<Value> mapOperands) const {
1696  rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1697  vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1698  mapOperands);
1699 }
1700 template <>
1701 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1702  PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1703  ArrayRef<Value> mapOperands) const {
1704  rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1705  vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1706  mapOperands);
1707 }
1708 
1709 // Generic version for ops that don't have extra operands.
1710 template <typename AffineOpTy>
1711 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1712  PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1713  ArrayRef<Value> mapOperands) const {
1714  rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1715 }
1716 } // namespace
1717 
1718 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1719  MLIRContext *context) {
1720  results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1721 }
1722 
1723 //===----------------------------------------------------------------------===//
1724 // AffineDmaStartOp
1725 //===----------------------------------------------------------------------===//
1726 
1727 // TODO: Check that map operands are loop IVs or symbols.
1728 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1729  Value srcMemRef, AffineMap srcMap,
1730  ValueRange srcIndices, Value destMemRef,
1731  AffineMap dstMap, ValueRange destIndices,
1732  Value tagMemRef, AffineMap tagMap,
1733  ValueRange tagIndices, Value numElements,
1734  Value stride, Value elementsPerStride) {
1735  result.addOperands(srcMemRef);
1736  result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1737  result.addOperands(srcIndices);
1738  result.addOperands(destMemRef);
1739  result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1740  result.addOperands(destIndices);
1741  result.addOperands(tagMemRef);
1742  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1743  result.addOperands(tagIndices);
1744  result.addOperands(numElements);
1745  if (stride) {
1746  result.addOperands({stride, elementsPerStride});
1747  }
1748 }
1749 
1751  p << " " << getSrcMemRef() << '[';
1752  p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1753  p << "], " << getDstMemRef() << '[';
1754  p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1755  p << "], " << getTagMemRef() << '[';
1756  p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1757  p << "], " << getNumElements();
1758  if (isStrided()) {
1759  p << ", " << getStride();
1760  p << ", " << getNumElementsPerStride();
1761  }
1762  p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1763  << getTagMemRefType();
1764 }
1765 
1766 // Parse AffineDmaStartOp.
1767 // Ex:
1768 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1769 // %stride, %num_elt_per_stride
1770 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1771 //
1773  OperationState &result) {
1774  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1775  AffineMapAttr srcMapAttr;
1777  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1778  AffineMapAttr dstMapAttr;
1780  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1781  AffineMapAttr tagMapAttr;
1783  OpAsmParser::UnresolvedOperand numElementsInfo;
1785 
1786  SmallVector<Type, 3> types;
1787  auto indexType = parser.getBuilder().getIndexType();
1788 
1789  // Parse and resolve the following list of operands:
1790  // *) dst memref followed by its affine maps operands (in square brackets).
1791  // *) src memref followed by its affine map operands (in square brackets).
1792  // *) tag memref followed by its affine map operands (in square brackets).
1793  // *) number of elements transferred by DMA operation.
1794  if (parser.parseOperand(srcMemRefInfo) ||
1795  parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1796  getSrcMapAttrStrName(),
1797  result.attributes) ||
1798  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1799  parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1800  getDstMapAttrStrName(),
1801  result.attributes) ||
1802  parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1803  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1804  getTagMapAttrStrName(),
1805  result.attributes) ||
1806  parser.parseComma() || parser.parseOperand(numElementsInfo))
1807  return failure();
1808 
1809  // Parse optional stride and elements per stride.
1810  if (parser.parseTrailingOperandList(strideInfo))
1811  return failure();
1812 
1813  if (!strideInfo.empty() && strideInfo.size() != 2) {
1814  return parser.emitError(parser.getNameLoc(),
1815  "expected two stride related operands");
1816  }
1817  bool isStrided = strideInfo.size() == 2;
1818 
1819  if (parser.parseColonTypeList(types))
1820  return failure();
1821 
1822  if (types.size() != 3)
1823  return parser.emitError(parser.getNameLoc(), "expected three types");
1824 
1825  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1826  parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1827  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1828  parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1829  parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1830  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1831  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1832  return failure();
1833 
1834  if (isStrided) {
1835  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1836  return failure();
1837  }
1838 
1839  // Check that src/dst/tag operand counts match their map.numInputs.
1840  if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1841  dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1842  tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1843  return parser.emitError(parser.getNameLoc(),
1844  "memref operand count not equal to map.numInputs");
1845  return success();
1846 }
1847 
1848 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1849  if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
1850  return emitOpError("expected DMA source to be of memref type");
1851  if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
1852  return emitOpError("expected DMA destination to be of memref type");
1853  if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
1854  return emitOpError("expected DMA tag to be of memref type");
1855 
1856  unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1857  getDstMap().getNumInputs() +
1858  getTagMap().getNumInputs();
1859  if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1860  getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1861  return emitOpError("incorrect number of operands");
1862  }
1863 
1864  Region *scope = getAffineScope(*this);
1865  for (auto idx : getSrcIndices()) {
1866  if (!idx.getType().isIndex())
1867  return emitOpError("src index to dma_start must have 'index' type");
1868  if (!isValidAffineIndexOperand(idx, scope))
1869  return emitOpError(
1870  "src index must be a valid dimension or symbol identifier");
1871  }
1872  for (auto idx : getDstIndices()) {
1873  if (!idx.getType().isIndex())
1874  return emitOpError("dst index to dma_start must have 'index' type");
1875  if (!isValidAffineIndexOperand(idx, scope))
1876  return emitOpError(
1877  "dst index must be a valid dimension or symbol identifier");
1878  }
1879  for (auto idx : getTagIndices()) {
1880  if (!idx.getType().isIndex())
1881  return emitOpError("tag index to dma_start must have 'index' type");
1882  if (!isValidAffineIndexOperand(idx, scope))
1883  return emitOpError(
1884  "tag index must be a valid dimension or symbol identifier");
1885  }
1886  return success();
1887 }
1888 
1889 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1890  SmallVectorImpl<OpFoldResult> &results) {
1891  /// dma_start(memrefcast) -> dma_start
1892  return memref::foldMemRefCast(*this);
1893 }
1894 
1895 void AffineDmaStartOp::getEffects(
1897  &effects) {
1898  effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
1900  effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
1902  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
1904 }
1905 
1906 //===----------------------------------------------------------------------===//
1907 // AffineDmaWaitOp
1908 //===----------------------------------------------------------------------===//
1909 
1910 // TODO: Check that map operands are loop IVs or symbols.
1911 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1912  Value tagMemRef, AffineMap tagMap,
1913  ValueRange tagIndices, Value numElements) {
1914  result.addOperands(tagMemRef);
1915  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1916  result.addOperands(tagIndices);
1917  result.addOperands(numElements);
1918 }
1919 
1921  p << " " << getTagMemRef() << '[';
1922  SmallVector<Value, 2> operands(getTagIndices());
1923  p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1924  p << "], ";
1926  p << " : " << getTagMemRef().getType();
1927 }
1928 
1929 // Parse AffineDmaWaitOp.
1930 // Eg:
1931 // affine.dma_wait %tag[%index], %num_elements
1932 // : memref<1 x i32, (d0) -> (d0), 4>
1933 //
1935  OperationState &result) {
1936  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1937  AffineMapAttr tagMapAttr;
1939  Type type;
1940  auto indexType = parser.getBuilder().getIndexType();
1941  OpAsmParser::UnresolvedOperand numElementsInfo;
1942 
1943  // Parse tag memref, its map operands, and dma size.
1944  if (parser.parseOperand(tagMemRefInfo) ||
1945  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1946  getTagMapAttrStrName(),
1947  result.attributes) ||
1948  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1949  parser.parseColonType(type) ||
1950  parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1951  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1952  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1953  return failure();
1954 
1955  if (!llvm::isa<MemRefType>(type))
1956  return parser.emitError(parser.getNameLoc(),
1957  "expected tag to be of memref type");
1958 
1959  if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1960  return parser.emitError(parser.getNameLoc(),
1961  "tag memref operand count != to map.numInputs");
1962  return success();
1963 }
1964 
1965 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1966  if (!llvm::isa<MemRefType>(getOperand(0).getType()))
1967  return emitOpError("expected DMA tag to be of memref type");
1968  Region *scope = getAffineScope(*this);
1969  for (auto idx : getTagIndices()) {
1970  if (!idx.getType().isIndex())
1971  return emitOpError("index to dma_wait must have 'index' type");
1972  if (!isValidAffineIndexOperand(idx, scope))
1973  return emitOpError(
1974  "index must be a valid dimension or symbol identifier");
1975  }
1976  return success();
1977 }
1978 
1979 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1980  SmallVectorImpl<OpFoldResult> &results) {
1981  /// dma_wait(memrefcast) -> dma_wait
1982  return memref::foldMemRefCast(*this);
1983 }
1984 
1985 void AffineDmaWaitOp::getEffects(
1987  &effects) {
1988  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
1990 }
1991 
1992 //===----------------------------------------------------------------------===//
1993 // AffineForOp
1994 //===----------------------------------------------------------------------===//
1995 
1996 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1997 /// bodyBuilder are empty/null, we include default terminator op.
1998 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1999  ValueRange lbOperands, AffineMap lbMap,
2000  ValueRange ubOperands, AffineMap ubMap, int64_t step,
2001  ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
2002  assert(((!lbMap && lbOperands.empty()) ||
2003  lbOperands.size() == lbMap.getNumInputs()) &&
2004  "lower bound operand count does not match the affine map");
2005  assert(((!ubMap && ubOperands.empty()) ||
2006  ubOperands.size() == ubMap.getNumInputs()) &&
2007  "upper bound operand count does not match the affine map");
2008  assert(step > 0 && "step has to be a positive integer constant");
2009 
2010  OpBuilder::InsertionGuard guard(builder);
2011 
2012  // Set variadic segment sizes.
2013  result.addAttribute(
2014  getOperandSegmentSizeAttr(),
2015  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()),
2016  static_cast<int32_t>(ubOperands.size()),
2017  static_cast<int32_t>(iterArgs.size())}));
2018 
2019  for (Value val : iterArgs)
2020  result.addTypes(val.getType());
2021 
2022  // Add an attribute for the step.
2023  result.addAttribute(getStepAttrName(result.name),
2024  builder.getIntegerAttr(builder.getIndexType(), step));
2025 
2026  // Add the lower bound.
2027  result.addAttribute(getLowerBoundMapAttrName(result.name),
2028  AffineMapAttr::get(lbMap));
2029  result.addOperands(lbOperands);
2030 
2031  // Add the upper bound.
2032  result.addAttribute(getUpperBoundMapAttrName(result.name),
2033  AffineMapAttr::get(ubMap));
2034  result.addOperands(ubOperands);
2035 
2036  result.addOperands(iterArgs);
2037  // Create a region and a block for the body. The argument of the region is
2038  // the loop induction variable.
2039  Region *bodyRegion = result.addRegion();
2040  Block *bodyBlock = builder.createBlock(bodyRegion);
2041  Value inductionVar =
2042  bodyBlock->addArgument(builder.getIndexType(), result.location);
2043  for (Value val : iterArgs)
2044  bodyBlock->addArgument(val.getType(), val.getLoc());
2045 
2046  // Create the default terminator if the builder is not provided and if the
2047  // iteration arguments are not provided. Otherwise, leave this to the caller
2048  // because we don't know which values to return from the loop.
2049  if (iterArgs.empty() && !bodyBuilder) {
2050  ensureTerminator(*bodyRegion, builder, result.location);
2051  } else if (bodyBuilder) {
2052  OpBuilder::InsertionGuard guard(builder);
2053  builder.setInsertionPointToStart(bodyBlock);
2054  bodyBuilder(builder, result.location, inductionVar,
2055  bodyBlock->getArguments().drop_front());
2056  }
2057 }
2058 
2059 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
2060  int64_t ub, int64_t step, ValueRange iterArgs,
2061  BodyBuilderFn bodyBuilder) {
2062  auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
2063  auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
2064  return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
2065  bodyBuilder);
2066 }
2067 
2068 LogicalResult AffineForOp::verifyRegions() {
2069  // Check that the body defines as single block argument for the induction
2070  // variable.
2071  auto *body = getBody();
2072  if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
2073  return emitOpError("expected body to have a single index argument for the "
2074  "induction variable");
2075 
2076  // Verify that the bound operands are valid dimension/symbols.
2077  /// Lower bound.
2078  if (getLowerBoundMap().getNumInputs() > 0)
2080  getLowerBoundMap().getNumDims())))
2081  return failure();
2082  /// Upper bound.
2083  if (getUpperBoundMap().getNumInputs() > 0)
2085  getUpperBoundMap().getNumDims())))
2086  return failure();
2087  if (getLowerBoundMap().getNumResults() < 1)
2088  return emitOpError("expected lower bound map to have at least one result");
2089  if (getUpperBoundMap().getNumResults() < 1)
2090  return emitOpError("expected upper bound map to have at least one result");
2091 
2092  unsigned opNumResults = getNumResults();
2093  if (opNumResults == 0)
2094  return success();
2095 
2096  // If ForOp defines values, check that the number and types of the defined
2097  // values match ForOp initial iter operands and backedge basic block
2098  // arguments.
2099  if (getNumIterOperands() != opNumResults)
2100  return emitOpError(
2101  "mismatch between the number of loop-carried values and results");
2102  if (getNumRegionIterArgs() != opNumResults)
2103  return emitOpError(
2104  "mismatch between the number of basic block args and results");
2105 
2106  return success();
2107 }
2108 
2109 /// Parse a for operation loop bounds.
2110 static ParseResult parseBound(bool isLower, OperationState &result,
2111  OpAsmParser &p) {
2112  // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
2113  // the map has multiple results.
2114  bool failedToParsedMinMax =
2115  failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
2116 
2117  auto &builder = p.getBuilder();
2118  auto boundAttrStrName =
2119  isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)
2120  : AffineForOp::getUpperBoundMapAttrName(result.name);
2121 
2122  // Parse ssa-id as identity map.
2124  if (p.parseOperandList(boundOpInfos))
2125  return failure();
2126 
2127  if (!boundOpInfos.empty()) {
2128  // Check that only one operand was parsed.
2129  if (boundOpInfos.size() > 1)
2130  return p.emitError(p.getNameLoc(),
2131  "expected only one loop bound operand");
2132 
2133  // TODO: improve error message when SSA value is not of index type.
2134  // Currently it is 'use of value ... expects different type than prior uses'
2135  if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
2136  result.operands))
2137  return failure();
2138 
2139  // Create an identity map using symbol id. This representation is optimized
2140  // for storage. Analysis passes may expand it into a multi-dimensional map
2141  // if desired.
2142  AffineMap map = builder.getSymbolIdentityMap();
2143  result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
2144  return success();
2145  }
2146 
2147  // Get the attribute location.
2148  SMLoc attrLoc = p.getCurrentLocation();
2149 
2150  Attribute boundAttr;
2151  if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
2152  result.attributes))
2153  return failure();
2154 
2155  // Parse full form - affine map followed by dim and symbol list.
2156  if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
2157  unsigned currentNumOperands = result.operands.size();
2158  unsigned numDims;
2159  if (parseDimAndSymbolList(p, result.operands, numDims))
2160  return failure();
2161 
2162  auto map = affineMapAttr.getValue();
2163  if (map.getNumDims() != numDims)
2164  return p.emitError(
2165  p.getNameLoc(),
2166  "dim operand count and affine map dim count must match");
2167 
2168  unsigned numDimAndSymbolOperands =
2169  result.operands.size() - currentNumOperands;
2170  if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
2171  return p.emitError(
2172  p.getNameLoc(),
2173  "symbol operand count and affine map symbol count must match");
2174 
2175  // If the map has multiple results, make sure that we parsed the min/max
2176  // prefix.
2177  if (map.getNumResults() > 1 && failedToParsedMinMax) {
2178  if (isLower) {
2179  return p.emitError(attrLoc, "lower loop bound affine map with "
2180  "multiple results requires 'max' prefix");
2181  }
2182  return p.emitError(attrLoc, "upper loop bound affine map with multiple "
2183  "results requires 'min' prefix");
2184  }
2185  return success();
2186  }
2187 
2188  // Parse custom assembly form.
2189  if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
2190  result.attributes.pop_back();
2191  result.addAttribute(
2192  boundAttrStrName,
2193  AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
2194  return success();
2195  }
2196 
2197  return p.emitError(
2198  p.getNameLoc(),
2199  "expected valid affine map representation for loop bounds");
2200 }
2201 
2202 ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
2203  auto &builder = parser.getBuilder();
2204  OpAsmParser::Argument inductionVariable;
2205  inductionVariable.type = builder.getIndexType();
2206  // Parse the induction variable followed by '='.
2207  if (parser.parseArgument(inductionVariable) || parser.parseEqual())
2208  return failure();
2209 
2210  // Parse loop bounds.
2211  int64_t numOperands = result.operands.size();
2212  if (parseBound(/*isLower=*/true, result, parser))
2213  return failure();
2214  int64_t numLbOperands = result.operands.size() - numOperands;
2215  if (parser.parseKeyword("to", " between bounds"))
2216  return failure();
2217  numOperands = result.operands.size();
2218  if (parseBound(/*isLower=*/false, result, parser))
2219  return failure();
2220  int64_t numUbOperands = result.operands.size() - numOperands;
2221 
2222  // Parse the optional loop step, we default to 1 if one is not present.
2223  if (parser.parseOptionalKeyword("step")) {
2224  result.addAttribute(
2225  getStepAttrName(result.name),
2226  builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
2227  } else {
2228  SMLoc stepLoc = parser.getCurrentLocation();
2229  IntegerAttr stepAttr;
2230  if (parser.parseAttribute(stepAttr, builder.getIndexType(),
2231  getStepAttrName(result.name).data(),
2232  result.attributes))
2233  return failure();
2234 
2235  if (stepAttr.getValue().isNegative())
2236  return parser.emitError(
2237  stepLoc,
2238  "expected step to be representable as a positive signed integer");
2239  }
2240 
2241  // Parse the optional initial iteration arguments.
2244 
2245  // Induction variable.
2246  regionArgs.push_back(inductionVariable);
2247 
2248  if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
2249  // Parse assignment list and results type list.
2250  if (parser.parseAssignmentList(regionArgs, operands) ||
2251  parser.parseArrowTypeList(result.types))
2252  return failure();
2253  // Resolve input operands.
2254  for (auto argOperandType :
2255  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2256  Type type = std::get<2>(argOperandType);
2257  std::get<0>(argOperandType).type = type;
2258  if (parser.resolveOperand(std::get<1>(argOperandType), type,
2259  result.operands))
2260  return failure();
2261  }
2262  }
2263 
2264  result.addAttribute(
2265  getOperandSegmentSizeAttr(),
2266  builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands),
2267  static_cast<int32_t>(numUbOperands),
2268  static_cast<int32_t>(operands.size())}));
2269 
2270  // Parse the body region.
2271  Region *body = result.addRegion();
2272  if (regionArgs.size() != result.types.size() + 1)
2273  return parser.emitError(
2274  parser.getNameLoc(),
2275  "mismatch between the number of loop-carried values and results");
2276  if (parser.parseRegion(*body, regionArgs))
2277  return failure();
2278 
2279  AffineForOp::ensureTerminator(*body, builder, result.location);
2280 
2281  // Parse the optional attribute list.
2282  return parser.parseOptionalAttrDict(result.attributes);
2283 }
2284 
2285 static void printBound(AffineMapAttr boundMap,
2286  Operation::operand_range boundOperands,
2287  const char *prefix, OpAsmPrinter &p) {
2288  AffineMap map = boundMap.getValue();
2289 
2290  // Check if this bound should be printed using custom assembly form.
2291  // The decision to restrict printing custom assembly form to trivial cases
2292  // comes from the will to roundtrip MLIR binary -> text -> binary in a
2293  // lossless way.
2294  // Therefore, custom assembly form parsing and printing is only supported for
2295  // zero-operand constant maps and single symbol operand identity maps.
2296  if (map.getNumResults() == 1) {
2297  AffineExpr expr = map.getResult(0);
2298 
2299  // Print constant bound.
2300  if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2301  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2302  p << constExpr.getValue();
2303  return;
2304  }
2305  }
2306 
2307  // Print bound that consists of a single SSA symbol if the map is over a
2308  // single symbol.
2309  if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2310  if (isa<AffineSymbolExpr>(expr)) {
2311  p.printOperand(*boundOperands.begin());
2312  return;
2313  }
2314  }
2315  } else {
2316  // Map has multiple results. Print 'min' or 'max' prefix.
2317  p << prefix << ' ';
2318  }
2319 
2320  // Print the map and its operands.
2321  p << boundMap;
2322  printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
2323  map.getNumDims(), p);
2324 }
2325 
2326 unsigned AffineForOp::getNumIterOperands() {
2327  AffineMap lbMap = getLowerBoundMapAttr().getValue();
2328  AffineMap ubMap = getUpperBoundMapAttr().getValue();
2329 
2330  return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2331 }
2332 
2333 std::optional<MutableArrayRef<OpOperand>>
2334 AffineForOp::getYieldedValuesMutable() {
2335  return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2336 }
2337 
2339  p << ' ';
2340  p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
2341  /*omitType=*/true);
2342  p << " = ";
2343  printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
2344  p << " to ";
2345  printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
2346 
2347  if (getStepAsInt() != 1)
2348  p << " step " << getStepAsInt();
2349 
2350  bool printBlockTerminators = false;
2351  if (getNumIterOperands() > 0) {
2352  p << " iter_args(";
2353  auto regionArgs = getRegionIterArgs();
2354  auto operands = getInits();
2355 
2356  llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2357  p << std::get<0>(it) << " = " << std::get<1>(it);
2358  });
2359  p << ") -> (" << getResultTypes() << ")";
2360  printBlockTerminators = true;
2361  }
2362 
2363  p << ' ';
2364  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2365  printBlockTerminators);
2367  (*this)->getAttrs(),
2368  /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()),
2369  getUpperBoundMapAttrName(getOperation()->getName()),
2370  getStepAttrName(getOperation()->getName()),
2371  getOperandSegmentSizeAttr()});
2372 }
2373 
2374 /// Fold the constant bounds of a loop.
2375 static LogicalResult foldLoopBounds(AffineForOp forOp) {
2376  auto foldLowerOrUpperBound = [&forOp](bool lower) {
2377  // Check to see if each of the operands is the result of a constant. If
2378  // so, get the value. If not, ignore it.
2379  SmallVector<Attribute, 8> operandConstants;
2380  auto boundOperands =
2381  lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2382  for (auto operand : boundOperands) {
2383  Attribute operandCst;
2384  matchPattern(operand, m_Constant(&operandCst));
2385  operandConstants.push_back(operandCst);
2386  }
2387 
2388  AffineMap boundMap =
2389  lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2390  assert(boundMap.getNumResults() >= 1 &&
2391  "bound maps should have at least one result");
2392  SmallVector<Attribute, 4> foldedResults;
2393  if (failed(boundMap.constantFold(operandConstants, foldedResults)))
2394  return failure();
2395 
2396  // Compute the max or min as applicable over the results.
2397  assert(!foldedResults.empty() && "bounds should have at least one result");
2398  auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2399  for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2400  auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2401  maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2402  : llvm::APIntOps::smin(maxOrMin, foldedResult);
2403  }
2404  lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2405  : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2406  return success();
2407  };
2408 
2409  // Try to fold the lower bound.
2410  bool folded = false;
2411  if (!forOp.hasConstantLowerBound())
2412  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
2413 
2414  // Try to fold the upper bound.
2415  if (!forOp.hasConstantUpperBound())
2416  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
2417  return success(folded);
2418 }
2419 
2420 /// Canonicalize the bounds of the given loop.
2421 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2422  SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2423  SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2424 
2425  auto lbMap = forOp.getLowerBoundMap();
2426  auto ubMap = forOp.getUpperBoundMap();
2427  auto prevLbMap = lbMap;
2428  auto prevUbMap = ubMap;
2429 
2430  composeAffineMapAndOperands(&lbMap, &lbOperands);
2431  canonicalizeMapAndOperands(&lbMap, &lbOperands);
2432  simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
2433  simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
2434  lbMap = removeDuplicateExprs(lbMap);
2435 
2436  composeAffineMapAndOperands(&ubMap, &ubOperands);
2437  canonicalizeMapAndOperands(&ubMap, &ubOperands);
2438  ubMap = removeDuplicateExprs(ubMap);
2439 
2440  // Any canonicalization change always leads to updated map(s).
2441  if (lbMap == prevLbMap && ubMap == prevUbMap)
2442  return failure();
2443 
2444  if (lbMap != prevLbMap)
2445  forOp.setLowerBound(lbOperands, lbMap);
2446  if (ubMap != prevUbMap)
2447  forOp.setUpperBound(ubOperands, ubMap);
2448  return success();
2449 }
2450 
2451 namespace {
2452 /// Returns constant trip count in trivial cases.
2453 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2454  int64_t step = forOp.getStepAsInt();
2455  if (!forOp.hasConstantBounds() || step <= 0)
2456  return std::nullopt;
2457  int64_t lb = forOp.getConstantLowerBound();
2458  int64_t ub = forOp.getConstantUpperBound();
2459  return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2460 }
2461 
2462 /// This is a pattern to fold trivially empty loop bodies.
2463 /// TODO: This should be moved into the folding hook.
2464 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
2466 
2467  LogicalResult matchAndRewrite(AffineForOp forOp,
2468  PatternRewriter &rewriter) const override {
2469  // Check that the body only contains a yield.
2470  if (!llvm::hasSingleElement(*forOp.getBody()))
2471  return failure();
2472  if (forOp.getNumResults() == 0)
2473  return success();
2474  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2475  if (tripCount == 0) {
2476  // The initial values of the iteration arguments would be the op's
2477  // results.
2478  rewriter.replaceOp(forOp, forOp.getInits());
2479  return success();
2480  }
2481  SmallVector<Value, 4> replacements;
2482  auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2483  auto iterArgs = forOp.getRegionIterArgs();
2484  bool hasValDefinedOutsideLoop = false;
2485  bool iterArgsNotInOrder = false;
2486  for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2487  Value val = yieldOp.getOperand(i);
2488  auto *iterArgIt = llvm::find(iterArgs, val);
2489  // TODO: It should be possible to perform a replacement by computing the
2490  // last value of the IV based on the bounds and the step.
2491  if (val == forOp.getInductionVar())
2492  return failure();
2493  if (iterArgIt == iterArgs.end()) {
2494  // `val` is defined outside of the loop.
2495  assert(forOp.isDefinedOutsideOfLoop(val) &&
2496  "must be defined outside of the loop");
2497  hasValDefinedOutsideLoop = true;
2498  replacements.push_back(val);
2499  } else {
2500  unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2501  if (pos != i)
2502  iterArgsNotInOrder = true;
2503  replacements.push_back(forOp.getInits()[pos]);
2504  }
2505  }
2506  // Bail out when the trip count is unknown and the loop returns any value
2507  // defined outside of the loop or any iterArg out of order.
2508  if (!tripCount.has_value() &&
2509  (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2510  return failure();
2511  // Bail out when the loop iterates more than once and it returns any iterArg
2512  // out of order.
2513  if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2514  return failure();
2515  rewriter.replaceOp(forOp, replacements);
2516  return success();
2517  }
2518 };
2519 } // namespace
2520 
2521 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
2522  MLIRContext *context) {
2523  results.add<AffineForEmptyLoopFolder>(context);
2524 }
2525 
2526 OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2527  assert((point.isParent() || point == getRegion()) && "invalid region point");
2528 
2529  // The initial operands map to the loop arguments after the induction
2530  // variable or are forwarded to the results when the trip count is zero.
2531  return getInits();
2532 }
2533 
2534 void AffineForOp::getSuccessorRegions(
2536  assert((point.isParent() || point == getRegion()) && "expected loop region");
2537  // The loop may typically branch back to its body or to the parent operation.
2538  // If the predecessor is the parent op and the trip count is known to be at
2539  // least one, branch into the body using the iterator arguments. And in cases
2540  // we know the trip count is zero, it can only branch back to its parent.
2541  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2542  if (point.isParent() && tripCount.has_value()) {
2543  if (tripCount.value() > 0) {
2544  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2545  return;
2546  }
2547  if (tripCount.value() == 0) {
2548  regions.push_back(RegionSuccessor(getResults()));
2549  return;
2550  }
2551  }
2552 
2553  // From the loop body, if the trip count is one, we can only branch back to
2554  // the parent.
2555  if (!point.isParent() && tripCount == 1) {
2556  regions.push_back(RegionSuccessor(getResults()));
2557  return;
2558  }
2559 
2560  // In all other cases, the loop may branch back to itself or the parent
2561  // operation.
2562  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2563  regions.push_back(RegionSuccessor(getResults()));
2564 }
2565 
2566 /// Returns true if the affine.for has zero iterations in trivial cases.
2567 static bool hasTrivialZeroTripCount(AffineForOp op) {
2568  return getTrivialConstantTripCount(op) == 0;
2569 }
2570 
2571 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2572  SmallVectorImpl<OpFoldResult> &results) {
2573  bool folded = succeeded(foldLoopBounds(*this));
2574  folded |= succeeded(canonicalizeLoopBounds(*this));
2575  if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
2576  // The initial values of the loop-carried variables (iter_args) are the
2577  // results of the op. But this must be avoided for an affine.for op that
2578  // does not return any results. Since ops that do not return results cannot
2579  // be folded away, we would enter an infinite loop of folds on the same
2580  // affine.for op.
2581  results.assign(getInits().begin(), getInits().end());
2582  folded = true;
2583  }
2584  return success(folded);
2585 }
2586 
2588  return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
2589 }
2590 
2592  return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap());
2593 }
2594 
2595 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2596  assert(lbOperands.size() == map.getNumInputs());
2597  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2598  getLowerBoundOperandsMutable().assign(lbOperands);
2599  setLowerBoundMap(map);
2600 }
2601 
2602 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2603  assert(ubOperands.size() == map.getNumInputs());
2604  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2605  getUpperBoundOperandsMutable().assign(ubOperands);
2606  setUpperBoundMap(map);
2607 }
2608 
2609 bool AffineForOp::hasConstantLowerBound() {
2610  return getLowerBoundMap().isSingleConstant();
2611 }
2612 
2613 bool AffineForOp::hasConstantUpperBound() {
2614  return getUpperBoundMap().isSingleConstant();
2615 }
2616 
2617 int64_t AffineForOp::getConstantLowerBound() {
2618  return getLowerBoundMap().getSingleConstantResult();
2619 }
2620 
2621 int64_t AffineForOp::getConstantUpperBound() {
2622  return getUpperBoundMap().getSingleConstantResult();
2623 }
2624 
2625 void AffineForOp::setConstantLowerBound(int64_t value) {
2626  setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2627 }
2628 
2629 void AffineForOp::setConstantUpperBound(int64_t value) {
2630  setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2631 }
2632 
2633 AffineForOp::operand_range AffineForOp::getControlOperands() {
2634  return {operand_begin(), operand_begin() + getLowerBoundOperands().size() +
2635  getUpperBoundOperands().size()};
2636 }
2637 
2638 bool AffineForOp::matchingBoundOperandList() {
2639  auto lbMap = getLowerBoundMap();
2640  auto ubMap = getUpperBoundMap();
2641  if (lbMap.getNumDims() != ubMap.getNumDims() ||
2642  lbMap.getNumSymbols() != ubMap.getNumSymbols())
2643  return false;
2644 
2645  unsigned numOperands = lbMap.getNumInputs();
2646  for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2647  // Compare Value 's.
2648  if (getOperand(i) != getOperand(numOperands + i))
2649  return false;
2650  }
2651  return true;
2652 }
2653 
2654 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
2655 
2656 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2657  return SmallVector<Value>{getInductionVar()};
2658 }
2659 
2660 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2661  if (!hasConstantLowerBound())
2662  return std::nullopt;
2663  OpBuilder b(getContext());
2665  OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2666 }
2667 
2668 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2669  OpBuilder b(getContext());
2671  OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
2672 }
2673 
2674 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2675  if (!hasConstantUpperBound())
2676  return {};
2677  OpBuilder b(getContext());
2679  OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2680 }
2681 
2682 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2683  RewriterBase &rewriter, ValueRange newInitOperands,
2684  bool replaceInitOperandUsesInLoop,
2685  const NewYieldValuesFn &newYieldValuesFn) {
2686  // Create a new loop before the existing one, with the extra operands.
2687  OpBuilder::InsertionGuard g(rewriter);
2688  rewriter.setInsertionPoint(getOperation());
2689  auto inits = llvm::to_vector(getInits());
2690  inits.append(newInitOperands.begin(), newInitOperands.end());
2691  AffineForOp newLoop = rewriter.create<AffineForOp>(
2692  getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2693  getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
2694 
2695  // Generate the new yield values and append them to the scf.yield operation.
2696  auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2697  ArrayRef<BlockArgument> newIterArgs =
2698  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2699  {
2700  OpBuilder::InsertionGuard g(rewriter);
2701  rewriter.setInsertionPoint(yieldOp);
2702  SmallVector<Value> newYieldedValues =
2703  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2704  assert(newInitOperands.size() == newYieldedValues.size() &&
2705  "expected as many new yield values as new iter operands");
2706  rewriter.modifyOpInPlace(yieldOp, [&]() {
2707  yieldOp.getOperandsMutable().append(newYieldedValues);
2708  });
2709  }
2710 
2711  // Move the loop body to the new op.
2712  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2713  newLoop.getBody()->getArguments().take_front(
2714  getBody()->getNumArguments()));
2715 
2716  if (replaceInitOperandUsesInLoop) {
2717  // Replace all uses of `newInitOperands` with the corresponding basic block
2718  // arguments.
2719  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2720  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
2721  [&](OpOperand &use) {
2722  Operation *user = use.getOwner();
2723  return newLoop->isProperAncestor(user);
2724  });
2725  }
2726  }
2727 
2728  // Replace the old loop.
2729  rewriter.replaceOp(getOperation(),
2730  newLoop->getResults().take_front(getNumResults()));
2731  return cast<LoopLikeOpInterface>(newLoop.getOperation());
2732 }
2733 
2734 Speculation::Speculatability AffineForOp::getSpeculatability() {
2735  // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2736  // Start and End.
2737  //
2738  // For Step != 1, the loop may not terminate. We can add more smarts here if
2739  // needed.
2740  return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable
2742 }
2743 
2744 /// Returns true if the provided value is the induction variable of a
2745 /// AffineForOp.
2747  return getForInductionVarOwner(val) != AffineForOp();
2748 }
2749 
2751  return getAffineParallelInductionVarOwner(val) != nullptr;
2752 }
2753 
2756 }
2757 
2759  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2760  if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2761  return AffineForOp();
2762  if (auto forOp =
2763  ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2764  // Check to make sure `val` is the induction variable, not an iter_arg.
2765  return forOp.getInductionVar() == val ? forOp : AffineForOp();
2766  return AffineForOp();
2767 }
2768 
2770  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2771  if (!ivArg || !ivArg.getOwner())
2772  return nullptr;
2773  Operation *containingOp = ivArg.getOwner()->getParentOp();
2774  auto parallelOp = dyn_cast_if_present<AffineParallelOp>(containingOp);
2775  if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2776  return parallelOp;
2777  return nullptr;
2778 }
2779 
2780 /// Extracts the induction variables from a list of AffineForOps and returns
2781 /// them.
2783  SmallVectorImpl<Value> *ivs) {
2784  ivs->reserve(forInsts.size());
2785  for (auto forInst : forInsts)
2786  ivs->push_back(forInst.getInductionVar());
2787 }
2788 
2791  ivs.reserve(affineOps.size());
2792  for (Operation *op : affineOps) {
2793  // Add constraints from forOp's bounds.
2794  if (auto forOp = dyn_cast<AffineForOp>(op))
2795  ivs.push_back(forOp.getInductionVar());
2796  else if (auto parallelOp = dyn_cast<AffineParallelOp>(op))
2797  for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2798  ivs.push_back(parallelOp.getBody()->getArgument(i));
2799  }
2800 }
2801 
2802 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2803 /// operations.
2804 template <typename BoundListTy, typename LoopCreatorTy>
2806  OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2807  ArrayRef<int64_t> steps,
2808  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2809  LoopCreatorTy &&loopCreatorFn) {
2810  assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2811  assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2812 
2813  // If there are no loops to be constructed, construct the body anyway.
2814  OpBuilder::InsertionGuard guard(builder);
2815  if (lbs.empty()) {
2816  if (bodyBuilderFn)
2817  bodyBuilderFn(builder, loc, ValueRange());
2818  return;
2819  }
2820 
2821  // Create the loops iteratively and store the induction variables.
2823  ivs.reserve(lbs.size());
2824  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2825  // Callback for creating the loop body, always creates the terminator.
2826  auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
2827  ValueRange iterArgs) {
2828  ivs.push_back(iv);
2829  // In the innermost loop, call the body builder.
2830  if (i == e - 1 && bodyBuilderFn) {
2831  OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
2832  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2833  }
2834  nestedBuilder.create<AffineYieldOp>(nestedLoc);
2835  };
2836 
2837  // Delegate actual loop creation to the callback in order to dispatch
2838  // between constant- and variable-bound loops.
2839  auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2840  builder.setInsertionPointToStart(loop.getBody());
2841  }
2842 }
2843 
2844 /// Creates an affine loop from the bounds known to be constants.
2845 static AffineForOp
2847  int64_t ub, int64_t step,
2848  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2849  return builder.create<AffineForOp>(loc, lb, ub, step,
2850  /*iterArgs=*/ValueRange(), bodyBuilderFn);
2851 }
2852 
2853 /// Creates an affine loop from the bounds that may or may not be constants.
2854 static AffineForOp
2856  int64_t step,
2857  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2858  std::optional<int64_t> lbConst = getConstantIntValue(lb);
2859  std::optional<int64_t> ubConst = getConstantIntValue(ub);
2860  if (lbConst && ubConst)
2861  return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
2862  ubConst.value(), step, bodyBuilderFn);
2863  return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
2864  builder.getDimIdentityMap(), step,
2865  /*iterArgs=*/ValueRange(), bodyBuilderFn);
2866 }
2867 
2869  OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
2871  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2872  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2874 }
2875 
2877  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
2878  ArrayRef<int64_t> steps,
2879  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2880  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2882 }
2883 
2884 //===----------------------------------------------------------------------===//
2885 // AffineIfOp
2886 //===----------------------------------------------------------------------===//
2887 
2888 namespace {
2889 /// Remove else blocks that have nothing other than a zero value yield.
2890 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
2892 
2893  LogicalResult matchAndRewrite(AffineIfOp ifOp,
2894  PatternRewriter &rewriter) const override {
2895  if (ifOp.getElseRegion().empty() ||
2896  !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2897  return failure();
2898 
2899  rewriter.startOpModification(ifOp);
2900  rewriter.eraseBlock(ifOp.getElseBlock());
2901  rewriter.finalizeOpModification(ifOp);
2902  return success();
2903  }
2904 };
2905 
2906 /// Removes affine.if cond if the condition is always true or false in certain
2907 /// trivial cases. Promotes the then/else block in the parent operation block.
2908 struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
2910 
2911  LogicalResult matchAndRewrite(AffineIfOp op,
2912  PatternRewriter &rewriter) const override {
2913 
2914  auto isTriviallyFalse = [](IntegerSet iSet) {
2915  return iSet.isEmptyIntegerSet();
2916  };
2917 
2918  auto isTriviallyTrue = [](IntegerSet iSet) {
2919  return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2920  iSet.getConstraint(0) == 0);
2921  };
2922 
2923  IntegerSet affineIfConditions = op.getIntegerSet();
2924  Block *blockToMove;
2925  if (isTriviallyFalse(affineIfConditions)) {
2926  // The absence, or equivalently, the emptiness of the else region need not
2927  // be checked when affine.if is returning results because if an affine.if
2928  // operation is returning results, it always has a non-empty else region.
2929  if (op.getNumResults() == 0 && !op.hasElse()) {
2930  // If the else region is absent, or equivalently, empty, remove the
2931  // affine.if operation (which is not returning any results).
2932  rewriter.eraseOp(op);
2933  return success();
2934  }
2935  blockToMove = op.getElseBlock();
2936  } else if (isTriviallyTrue(affineIfConditions)) {
2937  blockToMove = op.getThenBlock();
2938  } else {
2939  return failure();
2940  }
2941  Operation *blockToMoveTerminator = blockToMove->getTerminator();
2942  // Promote the "blockToMove" block to the parent operation block between the
2943  // prologue and epilogue of "op".
2944  rewriter.inlineBlockBefore(blockToMove, op);
2945  // Replace the "op" operation with the operands of the
2946  // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2947  // the affine.yield operation present in the "blockToMove" block. It has no
2948  // operands when affine.if is not returning results and therefore, in that
2949  // case, replaceOp just erases "op". When affine.if is not returning
2950  // results, the affine.yield operation can be omitted. It gets inserted
2951  // implicitly.
2952  rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
2953  // Erase the "blockToMoveTerminator" operation since it is now in the parent
2954  // operation block, which already has its own terminator.
2955  rewriter.eraseOp(blockToMoveTerminator);
2956  return success();
2957  }
2958 };
2959 } // namespace
2960 
2961 /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
2962 /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
2963 void AffineIfOp::getSuccessorRegions(
2965  // If the predecessor is an AffineIfOp, then branching into both `then` and
2966  // `else` region is valid.
2967  if (point.isParent()) {
2968  regions.reserve(2);
2969  regions.push_back(
2970  RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
2971  // If the "else" region is empty, branch bach into parent.
2972  if (getElseRegion().empty()) {
2973  regions.push_back(getResults());
2974  } else {
2975  regions.push_back(
2976  RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
2977  }
2978  return;
2979  }
2980 
2981  // If the predecessor is the `else`/`then` region, then branching into parent
2982  // op is valid.
2983  regions.push_back(RegionSuccessor(getResults()));
2984 }
2985 
2986 LogicalResult AffineIfOp::verify() {
2987  // Verify that we have a condition attribute.
2988  // FIXME: This should be specified in the arguments list in ODS.
2989  auto conditionAttr =
2990  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2991  if (!conditionAttr)
2992  return emitOpError("requires an integer set attribute named 'condition'");
2993 
2994  // Verify that there are enough operands for the condition.
2995  IntegerSet condition = conditionAttr.getValue();
2996  if (getNumOperands() != condition.getNumInputs())
2997  return emitOpError("operand count and condition integer set dimension and "
2998  "symbol count must match");
2999 
3000  // Verify that the operands are valid dimension/symbols.
3001  if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
3002  condition.getNumDims())))
3003  return failure();
3004 
3005  return success();
3006 }
3007 
3008 ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
3009  // Parse the condition attribute set.
3010  IntegerSetAttr conditionAttr;
3011  unsigned numDims;
3012  if (parser.parseAttribute(conditionAttr,
3013  AffineIfOp::getConditionAttrStrName(),
3014  result.attributes) ||
3015  parseDimAndSymbolList(parser, result.operands, numDims))
3016  return failure();
3017 
3018  // Verify the condition operands.
3019  auto set = conditionAttr.getValue();
3020  if (set.getNumDims() != numDims)
3021  return parser.emitError(
3022  parser.getNameLoc(),
3023  "dim operand count and integer set dim count must match");
3024  if (numDims + set.getNumSymbols() != result.operands.size())
3025  return parser.emitError(
3026  parser.getNameLoc(),
3027  "symbol operand count and integer set symbol count must match");
3028 
3029  if (parser.parseOptionalArrowTypeList(result.types))
3030  return failure();
3031 
3032  // Create the regions for 'then' and 'else'. The latter must be created even
3033  // if it remains empty for the validity of the operation.
3034  result.regions.reserve(2);
3035  Region *thenRegion = result.addRegion();
3036  Region *elseRegion = result.addRegion();
3037 
3038  // Parse the 'then' region.
3039  if (parser.parseRegion(*thenRegion, {}, {}))
3040  return failure();
3041  AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
3042  result.location);
3043 
3044  // If we find an 'else' keyword then parse the 'else' region.
3045  if (!parser.parseOptionalKeyword("else")) {
3046  if (parser.parseRegion(*elseRegion, {}, {}))
3047  return failure();
3048  AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
3049  result.location);
3050  }
3051 
3052  // Parse the optional attribute list.
3053  if (parser.parseOptionalAttrDict(result.attributes))
3054  return failure();
3055 
3056  return success();
3057 }
3058 
3060  auto conditionAttr =
3061  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3062  p << " " << conditionAttr;
3063  printDimAndSymbolList(operand_begin(), operand_end(),
3064  conditionAttr.getValue().getNumDims(), p);
3065  p.printOptionalArrowTypeList(getResultTypes());
3066  p << ' ';
3067  p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
3068  /*printBlockTerminators=*/getNumResults());
3069 
3070  // Print the 'else' regions if it has any blocks.
3071  auto &elseRegion = this->getElseRegion();
3072  if (!elseRegion.empty()) {
3073  p << " else ";
3074  p.printRegion(elseRegion,
3075  /*printEntryBlockArgs=*/false,
3076  /*printBlockTerminators=*/getNumResults());
3077  }
3078 
3079  // Print the attribute list.
3080  p.printOptionalAttrDict((*this)->getAttrs(),
3081  /*elidedAttrs=*/getConditionAttrStrName());
3082 }
3083 
3084 IntegerSet AffineIfOp::getIntegerSet() {
3085  return (*this)
3086  ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
3087  .getValue();
3088 }
3089 
3090 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
3091  (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
3092 }
3093 
3094 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
3095  setIntegerSet(set);
3096  (*this)->setOperands(operands);
3097 }
3098 
3099 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
3100  TypeRange resultTypes, IntegerSet set, ValueRange args,
3101  bool withElseRegion) {
3102  assert(resultTypes.empty() || withElseRegion);
3103  OpBuilder::InsertionGuard guard(builder);
3104 
3105  result.addTypes(resultTypes);
3106  result.addOperands(args);
3107  result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
3108 
3109  Region *thenRegion = result.addRegion();
3110  builder.createBlock(thenRegion);
3111  if (resultTypes.empty())
3112  AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
3113 
3114  Region *elseRegion = result.addRegion();
3115  if (withElseRegion) {
3116  builder.createBlock(elseRegion);
3117  if (resultTypes.empty())
3118  AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
3119  }
3120 }
3121 
3122 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
3123  IntegerSet set, ValueRange args, bool withElseRegion) {
3124  AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
3125  withElseRegion);
3126 }
3127 
3128 /// Compose any affine.apply ops feeding into `operands` of the integer set
3129 /// `set` by composing the maps of such affine.apply ops with the integer
3130 /// set constraints.
3132  SmallVectorImpl<Value> &operands,
3133  bool composeAffineMin = false) {
3134  // We will simply reuse the API of the map composition by viewing the LHSs of
3135  // the equalities and inequalities of `set` as the affine exprs of an affine
3136  // map. Convert to equivalent map, compose, and convert back to set.
3137  auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(),
3138  set.getConstraints(), set.getContext());
3139  // Check if any composition is possible.
3140  if (llvm::none_of(operands,
3141  [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
3142  return;
3143 
3144  composeAffineMapAndOperands(&map, &operands, composeAffineMin);
3145  set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
3146  set.getEqFlags());
3147 }
3148 
3149 /// Canonicalize an affine if op's conditional (integer set + operands).
3150 LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3151  auto set = getIntegerSet();
3152  SmallVector<Value, 4> operands(getOperands());
3153  composeSetAndOperands(set, operands);
3154  canonicalizeSetAndOperands(&set, &operands);
3155 
3156  // Check if the canonicalization or composition led to any change.
3157  if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
3158  return failure();
3159 
3160  setConditional(set, operands);
3161  return success();
3162 }
3163 
3164 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
3165  MLIRContext *context) {
3166  results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
3167 }
3168 
3169 //===----------------------------------------------------------------------===//
3170 // AffineLoadOp
3171 //===----------------------------------------------------------------------===//
3172 
3173 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3174  AffineMap map, ValueRange operands) {
3175  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
3176  result.addOperands(operands);
3177  if (map)
3178  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3179  auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
3180  result.types.push_back(memrefType.getElementType());
3181 }
3182 
3183 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3184  Value memref, AffineMap map, ValueRange mapOperands) {
3185  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3186  result.addOperands(memref);
3187  result.addOperands(mapOperands);
3188  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3189  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3190  result.types.push_back(memrefType.getElementType());
3191 }
3192 
3193 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3194  Value memref, ValueRange indices) {
3195  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3196  int64_t rank = memrefType.getRank();
3197  // Create identity map for memrefs with at least one dimension or () -> ()
3198  // for zero-dimensional memrefs.
3199  auto map =
3200  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3201  build(builder, result, memref, map, indices);
3202 }
3203 
3204 ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
3205  auto &builder = parser.getBuilder();
3206  auto indexTy = builder.getIndexType();
3207 
3208  MemRefType type;
3209  OpAsmParser::UnresolvedOperand memrefInfo;
3210  AffineMapAttr mapAttr;
3212  return failure(
3213  parser.parseOperand(memrefInfo) ||
3214  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3215  AffineLoadOp::getMapAttrStrName(),
3216  result.attributes) ||
3217  parser.parseOptionalAttrDict(result.attributes) ||
3218  parser.parseColonType(type) ||
3219  parser.resolveOperand(memrefInfo, type, result.operands) ||
3220  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3221  parser.addTypeToList(type.getElementType(), result.types));
3222 }
3223 
3225  p << " " << getMemRef() << '[';
3226  if (AffineMapAttr mapAttr =
3227  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3228  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3229  p << ']';
3230  p.printOptionalAttrDict((*this)->getAttrs(),
3231  /*elidedAttrs=*/{getMapAttrStrName()});
3232  p << " : " << getMemRefType();
3233 }
3234 
3235 /// Verify common indexing invariants of affine.load, affine.store,
3236 /// affine.vector_load and affine.vector_store.
3237 template <typename AffineMemOpTy>
3238 static LogicalResult
3239 verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr,
3240  Operation::operand_range mapOperands,
3241  MemRefType memrefType, unsigned numIndexOperands) {
3242  AffineMap map = mapAttr.getValue();
3243  if (map.getNumResults() != memrefType.getRank())
3244  return op->emitOpError("affine map num results must equal memref rank");
3245  if (map.getNumInputs() != numIndexOperands)
3246  return op->emitOpError("expects as many subscripts as affine map inputs");
3247 
3248  for (auto idx : mapOperands) {
3249  if (!idx.getType().isIndex())
3250  return op->emitOpError("index to load must have 'index' type");
3251  }
3252  if (failed(verifyDimAndSymbolIdentifiers(op, mapOperands, map.getNumDims())))
3253  return failure();
3254 
3255  return success();
3256 }
3257 
3258 LogicalResult AffineLoadOp::verify() {
3259  auto memrefType = getMemRefType();
3260  if (getType() != memrefType.getElementType())
3261  return emitOpError("result type must match element type of memref");
3262 
3263  if (failed(verifyMemoryOpIndexing(
3264  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3265  getMapOperands(), memrefType,
3266  /*numIndexOperands=*/getNumOperands() - 1)))
3267  return failure();
3268 
3269  return success();
3270 }
3271 
3272 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3273  MLIRContext *context) {
3274  results.add<SimplifyAffineOp<AffineLoadOp>>(context);
3275 }
3276 
3277 OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3278  /// load(memrefcast) -> load
3279  if (succeeded(memref::foldMemRefCast(*this)))
3280  return getResult();
3281 
3282  // Fold load from a global constant memref.
3283  auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3284  if (!getGlobalOp)
3285  return {};
3286  // Get to the memref.global defining the symbol.
3287  auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3288  if (!symbolTableOp)
3289  return {};
3290  auto global = dyn_cast_or_null<memref::GlobalOp>(
3291  SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
3292  if (!global)
3293  return {};
3294 
3295  // Check if the global memref is a constant.
3296  auto cstAttr =
3297  llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3298  if (!cstAttr)
3299  return {};
3300  // If it's a splat constant, we can fold irrespective of indices.
3301  if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3302  return splatAttr.getSplatValue<Attribute>();
3303  // Otherwise, we can fold only if we know the indices.
3304  if (!getAffineMap().isConstant())
3305  return {};
3306  auto indices = llvm::to_vector<4>(
3307  llvm::map_range(getAffineMap().getConstantResults(),
3308  [](int64_t v) -> uint64_t { return v; }));
3309  return cstAttr.getValues<Attribute>()[indices];
3310 }
3311 
3312 //===----------------------------------------------------------------------===//
3313 // AffineStoreOp
3314 //===----------------------------------------------------------------------===//
3315 
3316 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3317  Value valueToStore, Value memref, AffineMap map,
3318  ValueRange mapOperands) {
3319  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3320  result.addOperands(valueToStore);
3321  result.addOperands(memref);
3322  result.addOperands(mapOperands);
3323  result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3324 }
3325 
3326 // Use identity map.
3327 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3328  Value valueToStore, Value memref,
3329  ValueRange indices) {
3330  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3331  int64_t rank = memrefType.getRank();
3332  // Create identity map for memrefs with at least one dimension or () -> ()
3333  // for zero-dimensional memrefs.
3334  auto map =
3335  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3336  build(builder, result, valueToStore, memref, map, indices);
3337 }
3338 
3339 ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
3340  auto indexTy = parser.getBuilder().getIndexType();
3341 
3342  MemRefType type;
3343  OpAsmParser::UnresolvedOperand storeValueInfo;
3344  OpAsmParser::UnresolvedOperand memrefInfo;
3345  AffineMapAttr mapAttr;
3347  return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3348  parser.parseOperand(memrefInfo) ||
3349  parser.parseAffineMapOfSSAIds(
3350  mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3351  result.attributes) ||
3352  parser.parseOptionalAttrDict(result.attributes) ||
3353  parser.parseColonType(type) ||
3354  parser.resolveOperand(storeValueInfo, type.getElementType(),
3355  result.operands) ||
3356  parser.resolveOperand(memrefInfo, type, result.operands) ||
3357  parser.resolveOperands(mapOperands, indexTy, result.operands));
3358 }
3359 
3361  p << " " << getValueToStore();
3362  p << ", " << getMemRef() << '[';
3363  if (AffineMapAttr mapAttr =
3364  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3365  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3366  p << ']';
3367  p.printOptionalAttrDict((*this)->getAttrs(),
3368  /*elidedAttrs=*/{getMapAttrStrName()});
3369  p << " : " << getMemRefType();
3370 }
3371 
3372 LogicalResult AffineStoreOp::verify() {
3373  // The value to store must have the same type as memref element type.
3374  auto memrefType = getMemRefType();
3375  if (getValueToStore().getType() != memrefType.getElementType())
3376  return emitOpError(
3377  "value to store must have the same type as memref element type");
3378 
3379  if (failed(verifyMemoryOpIndexing(
3380  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3381  getMapOperands(), memrefType,
3382  /*numIndexOperands=*/getNumOperands() - 2)))
3383  return failure();
3384 
3385  return success();
3386 }
3387 
3388 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3389  MLIRContext *context) {
3390  results.add<SimplifyAffineOp<AffineStoreOp>>(context);
3391 }
3392 
3393 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3394  SmallVectorImpl<OpFoldResult> &results) {
3395  /// store(memrefcast) -> store
3396  return memref::foldMemRefCast(*this, getValueToStore());
3397 }
3398 
3399 //===----------------------------------------------------------------------===//
3400 // AffineMinMaxOpBase
3401 //===----------------------------------------------------------------------===//
3402 
3403 template <typename T>
3404 static LogicalResult verifyAffineMinMaxOp(T op) {
3405  // Verify that operand count matches affine map dimension and symbol count.
3406  if (op.getNumOperands() !=
3407  op.getMap().getNumDims() + op.getMap().getNumSymbols())
3408  return op.emitOpError(
3409  "operand count and affine map dimension and symbol count must match");
3410 
3411  if (op.getMap().getNumResults() == 0)
3412  return op.emitOpError("affine map expect at least one result");
3413  return success();
3414 }
3415 
3416 template <typename T>
3417 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3418  p << ' ' << op->getAttr(T::getMapAttrStrName());
3419  auto operands = op.getOperands();
3420  unsigned numDims = op.getMap().getNumDims();
3421  p << '(' << operands.take_front(numDims) << ')';
3422 
3423  if (operands.size() != numDims)
3424  p << '[' << operands.drop_front(numDims) << ']';
3425  p.printOptionalAttrDict(op->getAttrs(),
3426  /*elidedAttrs=*/{T::getMapAttrStrName()});
3427 }
3428 
3429 template <typename T>
3430 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
3431  OperationState &result) {
3432  auto &builder = parser.getBuilder();
3433  auto indexType = builder.getIndexType();
3436  AffineMapAttr mapAttr;
3437  return failure(
3438  parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3439  result.attributes) ||
3440  parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) ||
3441  parser.parseOperandList(symInfos,
3443  parser.parseOptionalAttrDict(result.attributes) ||
3444  parser.resolveOperands(dimInfos, indexType, result.operands) ||
3445  parser.resolveOperands(symInfos, indexType, result.operands) ||
3446  parser.addTypeToList(indexType, result.types));
3447 }
3448 
3449 /// Fold an affine min or max operation with the given operands. The operand
3450 /// list may contain nulls, which are interpreted as the operand not being a
3451 /// constant.
3452 template <typename T>
3454  static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3455  "expected affine min or max op");
3456 
3457  // Fold the affine map.
3458  // TODO: Fold more cases:
3459  // min(some_affine, some_affine + constant, ...), etc.
3460  SmallVector<int64_t, 2> results;
3461  auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3462 
3463  if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3464  return op.getOperand(0);
3465 
3466  // If some of the map results are not constant, try changing the map in-place.
3467  if (results.empty()) {
3468  // If the map is the same, report that folding did not happen.
3469  if (foldedMap == op.getMap())
3470  return {};
3471  op->setAttr("map", AffineMapAttr::get(foldedMap));
3472  return op.getResult();
3473  }
3474 
3475  // Otherwise, completely fold the op into a constant.
3476  auto resultIt = std::is_same<T, AffineMinOp>::value
3477  ? llvm::min_element(results)
3478  : llvm::max_element(results);
3479  if (resultIt == results.end())
3480  return {};
3481  return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3482 }
3483 
3484 /// Remove duplicated expressions in affine min/max ops.
3485 template <typename T>
3488 
3489  LogicalResult matchAndRewrite(T affineOp,
3490  PatternRewriter &rewriter) const override {
3491  AffineMap oldMap = affineOp.getAffineMap();
3492 
3493  SmallVector<AffineExpr, 4> newExprs;
3494  for (AffineExpr expr : oldMap.getResults()) {
3495  // This is a linear scan over newExprs, but it should be fine given that
3496  // we typically just have a few expressions per op.
3497  if (!llvm::is_contained(newExprs, expr))
3498  newExprs.push_back(expr);
3499  }
3500 
3501  if (newExprs.size() == oldMap.getNumResults())
3502  return failure();
3503 
3504  auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
3505  newExprs, rewriter.getContext());
3506  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3507 
3508  return success();
3509  }
3510 };
3511 
3512 /// Merge an affine min/max op to its consumers if its consumer is also an
3513 /// affine min/max op.
3514 ///
3515 /// This pattern requires the producer affine min/max op is bound to a
3516 /// dimension/symbol that is used as a standalone expression in the consumer
3517 /// affine op's map.
3518 ///
3519 /// For example, a pattern like the following:
3520 ///
3521 /// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3522 /// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3523 ///
3524 /// Can be turned into:
3525 ///
3526 /// %1 = affine.min affine_map<
3527 /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3528 template <typename T>
3531 
3532  LogicalResult matchAndRewrite(T affineOp,
3533  PatternRewriter &rewriter) const override {
3534  AffineMap oldMap = affineOp.getAffineMap();
3535  ValueRange dimOperands =
3536  affineOp.getMapOperands().take_front(oldMap.getNumDims());
3537  ValueRange symOperands =
3538  affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3539 
3540  auto newDimOperands = llvm::to_vector<8>(dimOperands);
3541  auto newSymOperands = llvm::to_vector<8>(symOperands);
3542  SmallVector<AffineExpr, 4> newExprs;
3543  SmallVector<T, 4> producerOps;
3544 
3545  // Go over each expression to see whether it's a single dimension/symbol
3546  // with the corresponding operand which is the result of another affine
3547  // min/max op. If So it can be merged into this affine op.
3548  for (AffineExpr expr : oldMap.getResults()) {
3549  if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3550  Value symValue = symOperands[symExpr.getPosition()];
3551  if (auto producerOp = symValue.getDefiningOp<T>()) {
3552  producerOps.push_back(producerOp);
3553  continue;
3554  }
3555  } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3556  Value dimValue = dimOperands[dimExpr.getPosition()];
3557  if (auto producerOp = dimValue.getDefiningOp<T>()) {
3558  producerOps.push_back(producerOp);
3559  continue;
3560  }
3561  }
3562  // For the above cases we will remove the expression by merging the
3563  // producer affine min/max's affine expressions. Otherwise we need to
3564  // keep the existing expression.
3565  newExprs.push_back(expr);
3566  }
3567 
3568  if (producerOps.empty())
3569  return failure();
3570 
3571  unsigned numUsedDims = oldMap.getNumDims();
3572  unsigned numUsedSyms = oldMap.getNumSymbols();
3573 
3574  // Now go over all producer affine ops and merge their expressions.
3575  for (T producerOp : producerOps) {
3576  AffineMap producerMap = producerOp.getAffineMap();
3577  unsigned numProducerDims = producerMap.getNumDims();
3578  unsigned numProducerSyms = producerMap.getNumSymbols();
3579 
3580  // Collect all dimension/symbol values.
3581  ValueRange dimValues =
3582  producerOp.getMapOperands().take_front(numProducerDims);
3583  ValueRange symValues =
3584  producerOp.getMapOperands().take_back(numProducerSyms);
3585  newDimOperands.append(dimValues.begin(), dimValues.end());
3586  newSymOperands.append(symValues.begin(), symValues.end());
3587 
3588  // For expressions we need to shift to avoid overlap.
3589  for (AffineExpr expr : producerMap.getResults()) {
3590  newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3591  .shiftSymbols(numProducerSyms, numUsedSyms));
3592  }
3593 
3594  numUsedDims += numProducerDims;
3595  numUsedSyms += numProducerSyms;
3596  }
3597 
3598  auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
3599  rewriter.getContext());
3600  auto newOperands =
3601  llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3602  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3603 
3604  return success();
3605  }
3606 };
3607 
3608 /// Canonicalize the result expression order of an affine map and return success
3609 /// if the order changed.
3610 ///
3611 /// The function flattens the map's affine expressions to coefficient arrays and
3612 /// sorts them in lexicographic order. A coefficient array contains a multiplier
3613 /// for every dimension/symbol and a constant term. The canonicalization fails
3614 /// if a result expression is not pure or if the flattening requires local
3615 /// variables that, unlike dimensions and symbols, have no global order.
3616 static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
3617  SmallVector<SmallVector<int64_t>> flattenedExprs;
3618  for (const AffineExpr &resultExpr : map.getResults()) {
3619  // Fail if the expression is not pure.
3620  if (!resultExpr.isPureAffine())
3621  return failure();
3622 
3623  SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3624  auto flattenResult = flattener.walkPostOrder(resultExpr);
3625  if (failed(flattenResult))
3626  return failure();
3627 
3628  // Fail if the flattened expression has local variables.
3629  if (flattener.operandExprStack.back().size() !=
3630  map.getNumDims() + map.getNumSymbols() + 1)
3631  return failure();
3632 
3633  flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3634  flattener.operandExprStack.back().end());
3635  }
3636 
3637  // Fail if sorting is not necessary.
3638  if (llvm::is_sorted(flattenedExprs))
3639  return failure();
3640 
3641  // Reorder the result expressions according to their flattened form.
3642  SmallVector<unsigned> resultPermutation =
3643  llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults()));
3644  llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3645  return flattenedExprs[lhs] < flattenedExprs[rhs];
3646  });
3647  SmallVector<AffineExpr> newExprs;
3648  for (unsigned idx : resultPermutation)
3649  newExprs.push_back(map.getResult(idx));
3650 
3651  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs,
3652  map.getContext());
3653  return success();
3654 }
3655 
3656 /// Canonicalize the affine map result expression order of an affine min/max
3657 /// operation.
3658 ///
3659 /// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3660 /// expressions and replaces the operation if the order changed.
3661 ///
3662 /// For example, the following operation:
3663 ///
3664 /// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3665 ///
3666 /// Turns into:
3667 ///
3668 /// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3669 template <typename T>
3672 
3673  LogicalResult matchAndRewrite(T affineOp,
3674  PatternRewriter &rewriter) const override {
3675  AffineMap map = affineOp.getAffineMap();
3676  if (failed(canonicalizeMapExprAndTermOrder(map)))
3677  return failure();
3678  rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3679  return success();
3680  }
3681 };
3682 
3683 template <typename T>
3686 
3687  LogicalResult matchAndRewrite(T affineOp,
3688  PatternRewriter &rewriter) const override {
3689  if (affineOp.getMap().getNumResults() != 1)
3690  return failure();
3691  rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3692  affineOp.getOperands());
3693  return success();
3694  }
3695 };
3696 
3697 //===----------------------------------------------------------------------===//
3698 // AffineMinOp
3699 //===----------------------------------------------------------------------===//
3700 //
3701 // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3702 //
3703 
3704 OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3705  return foldMinMaxOp(*this, adaptor.getOperands());
3706 }
3707 
3708 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3709  MLIRContext *context) {
3712  MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3714  context);
3715 }
3716 
3717 LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); }
3718 
3719 ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3720  return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3721 }
3722 
3723 void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3724 
3725 //===----------------------------------------------------------------------===//
3726 // AffineMaxOp
3727 //===----------------------------------------------------------------------===//
3728 //
3729 // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3730 //
3731 
3732 OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3733  return foldMinMaxOp(*this, adaptor.getOperands());
3734 }
3735 
3736 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3737  MLIRContext *context) {
3740  MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3742  context);
3743 }
3744 
3745 LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); }
3746 
3747 ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3748  return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3749 }
3750 
3751 void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3752 
3753 //===----------------------------------------------------------------------===//
3754 // AffinePrefetchOp
3755 //===----------------------------------------------------------------------===//
3756 
3757 //
3758 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3759 //
3760 ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3761  OperationState &result) {
3762  auto &builder = parser.getBuilder();
3763  auto indexTy = builder.getIndexType();
3764 
3765  MemRefType type;
3766  OpAsmParser::UnresolvedOperand memrefInfo;
3767  IntegerAttr hintInfo;
3768  auto i32Type = parser.getBuilder().getIntegerType(32);
3769  StringRef readOrWrite, cacheType;
3770 
3771  AffineMapAttr mapAttr;
3773  if (parser.parseOperand(memrefInfo) ||
3774  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3775  AffinePrefetchOp::getMapAttrStrName(),
3776  result.attributes) ||
3777  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3778  parser.parseComma() || parser.parseKeyword("locality") ||
3779  parser.parseLess() ||
3780  parser.parseAttribute(hintInfo, i32Type,
3781  AffinePrefetchOp::getLocalityHintAttrStrName(),
3782  result.attributes) ||
3783  parser.parseGreater() || parser.parseComma() ||
3784  parser.parseKeyword(&cacheType) ||
3785  parser.parseOptionalAttrDict(result.attributes) ||
3786  parser.parseColonType(type) ||
3787  parser.resolveOperand(memrefInfo, type, result.operands) ||
3788  parser.resolveOperands(mapOperands, indexTy, result.operands))
3789  return failure();
3790 
3791  if (readOrWrite != "read" && readOrWrite != "write")
3792  return parser.emitError(parser.getNameLoc(),
3793  "rw specifier has to be 'read' or 'write'");
3794  result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3795  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
3796 
3797  if (cacheType != "data" && cacheType != "instr")
3798  return parser.emitError(parser.getNameLoc(),
3799  "cache type has to be 'data' or 'instr'");
3800 
3801  result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3802  parser.getBuilder().getBoolAttr(cacheType == "data"));
3803 
3804  return success();
3805 }
3806 
3808  p << " " << getMemref() << '[';
3809  AffineMapAttr mapAttr =
3810  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3811  if (mapAttr)
3812  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3813  p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3814  << "locality<" << getLocalityHint() << ">, "
3815  << (getIsDataCache() ? "data" : "instr");
3817  (*this)->getAttrs(),
3818  /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3819  getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3820  p << " : " << getMemRefType();
3821 }
3822 
3823 LogicalResult AffinePrefetchOp::verify() {
3824  auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3825  if (mapAttr) {
3826  AffineMap map = mapAttr.getValue();
3827  if (map.getNumResults() != getMemRefType().getRank())
3828  return emitOpError("affine.prefetch affine map num results must equal"
3829  " memref rank");
3830  if (map.getNumInputs() + 1 != getNumOperands())
3831  return emitOpError("too few operands");
3832  } else {
3833  if (getNumOperands() != 1)
3834  return emitOpError("too few operands");
3835  }
3836 
3837  Region *scope = getAffineScope(*this);
3838  for (auto idx : getMapOperands()) {
3839  if (!isValidAffineIndexOperand(idx, scope))
3840  return emitOpError(
3841  "index must be a valid dimension or symbol identifier");
3842  }
3843  return success();
3844 }
3845 
3846 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3847  MLIRContext *context) {
3848  // prefetch(memrefcast) -> prefetch
3849  results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3850 }
3851 
3852 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3853  SmallVectorImpl<OpFoldResult> &results) {
3854  /// prefetch(memrefcast) -> prefetch
3855  return memref::foldMemRefCast(*this);
3856 }
3857 
3858 //===----------------------------------------------------------------------===//
3859 // AffineParallelOp
3860 //===----------------------------------------------------------------------===//
3861 
3862 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3863  TypeRange resultTypes,
3864  ArrayRef<arith::AtomicRMWKind> reductions,
3865  ArrayRef<int64_t> ranges) {
3866  SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
3867  auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3868  return builder.getConstantAffineMap(value);
3869  }));
3870  SmallVector<int64_t> steps(ranges.size(), 1);
3871  build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
3872  /*ubArgs=*/{}, steps);
3873 }
3874 
3875 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3876  TypeRange resultTypes,
3877  ArrayRef<arith::AtomicRMWKind> reductions,
3878  ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
3879  ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
3880  ArrayRef<int64_t> steps) {
3881  assert(llvm::all_of(lbMaps,
3882  [lbMaps](AffineMap m) {
3883  return m.getNumDims() == lbMaps[0].getNumDims() &&
3884  m.getNumSymbols() == lbMaps[0].getNumSymbols();
3885  }) &&
3886  "expected all lower bounds maps to have the same number of dimensions "
3887  "and symbols");
3888  assert(llvm::all_of(ubMaps,
3889  [ubMaps](AffineMap m) {
3890  return m.getNumDims() == ubMaps[0].getNumDims() &&
3891  m.getNumSymbols() == ubMaps[0].getNumSymbols();
3892  }) &&
3893  "expected all upper bounds maps to have the same number of dimensions "
3894  "and symbols");
3895  assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3896  "expected lower bound maps to have as many inputs as lower bound "
3897  "operands");
3898  assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3899  "expected upper bound maps to have as many inputs as upper bound "
3900  "operands");
3901 
3902  OpBuilder::InsertionGuard guard(builder);
3903  result.addTypes(resultTypes);
3904 
3905  // Convert the reductions to integer attributes.
3906  SmallVector<Attribute, 4> reductionAttrs;
3907  for (arith::AtomicRMWKind reduction : reductions)
3908  reductionAttrs.push_back(
3909  builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
3910  result.addAttribute(getReductionsAttrStrName(),
3911  builder.getArrayAttr(reductionAttrs));
3912 
3913  // Concatenates maps defined in the same input space (same dimensions and
3914  // symbols), assumes there is at least one map.
3915  auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
3916  SmallVectorImpl<int32_t> &groups) {
3917  if (maps.empty())
3918  return AffineMap::get(builder.getContext());
3920  groups.reserve(groups.size() + maps.size());
3921  exprs.reserve(maps.size());
3922  for (AffineMap m : maps) {
3923  llvm::append_range(exprs, m.getResults());
3924  groups.push_back(m.getNumResults());
3925  }
3926  return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3927  maps[0].getContext());
3928  };
3929 
3930  // Set up the bounds.
3931  SmallVector<int32_t> lbGroups, ubGroups;
3932  AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3933  AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3934  result.addAttribute(getLowerBoundsMapAttrStrName(),
3935  AffineMapAttr::get(lbMap));
3936  result.addAttribute(getLowerBoundsGroupsAttrStrName(),
3937  builder.getI32TensorAttr(lbGroups));
3938  result.addAttribute(getUpperBoundsMapAttrStrName(),
3939  AffineMapAttr::get(ubMap));
3940  result.addAttribute(getUpperBoundsGroupsAttrStrName(),
3941  builder.getI32TensorAttr(ubGroups));
3942  result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
3943  result.addOperands(lbArgs);
3944  result.addOperands(ubArgs);
3945 
3946  // Create a region and a block for the body.
3947  auto *bodyRegion = result.addRegion();
3948  Block *body = builder.createBlock(bodyRegion);
3949 
3950  // Add all the block arguments.
3951  for (unsigned i = 0, e = steps.size(); i < e; ++i)
3952  body->addArgument(IndexType::get(builder.getContext()), result.location);
3953  if (resultTypes.empty())
3954  ensureTerminator(*bodyRegion, builder, result.location);
3955 }
3956 
3957 SmallVector<Region *> AffineParallelOp::getLoopRegions() {
3958  return {&getRegion()};
3959 }
3960 
3961 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
3962 
3963 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3964  return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3965 }
3966 
3967 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3968  return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3969 }
3970 
3971 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
3972  auto values = getLowerBoundsGroups().getValues<int32_t>();
3973  unsigned start = 0;
3974  for (unsigned i = 0; i < pos; ++i)
3975  start += values[i];
3976  return getLowerBoundsMap().getSliceMap(start, values[pos]);
3977 }
3978 
3979 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
3980  auto values = getUpperBoundsGroups().getValues<int32_t>();
3981  unsigned start = 0;
3982  for (unsigned i = 0; i < pos; ++i)
3983  start += values[i];
3984  return getUpperBoundsMap().getSliceMap(start, values[pos]);
3985 }
3986 
3987 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
3988  return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3989 }
3990 
3991 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
3992  return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3993 }
3994 
3995 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3996  if (hasMinMaxBounds())
3997  return std::nullopt;
3998 
3999  // Try to convert all the ranges to constant expressions.
4001  AffineValueMap rangesValueMap;
4002  AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
4003  &rangesValueMap);
4004  out.reserve(rangesValueMap.getNumResults());
4005  for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
4006  auto expr = rangesValueMap.getResult(i);
4007  auto cst = dyn_cast<AffineConstantExpr>(expr);
4008  if (!cst)
4009  return std::nullopt;
4010  out.push_back(cst.getValue());
4011  }
4012  return out;
4013 }
4014 
4015 Block *AffineParallelOp::getBody() { return &getRegion().front(); }
4016 
4017 OpBuilder AffineParallelOp::getBodyBuilder() {
4018  return OpBuilder(getBody(), std::prev(getBody()->end()));
4019 }
4020 
4021 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
4022  assert(lbOperands.size() == map.getNumInputs() &&
4023  "operands to map must match number of inputs");
4024 
4025  auto ubOperands = getUpperBoundsOperands();
4026 
4027  SmallVector<Value, 4> newOperands(lbOperands);
4028  newOperands.append(ubOperands.begin(), ubOperands.end());
4029  (*this)->setOperands(newOperands);
4030 
4031  setLowerBoundsMapAttr(AffineMapAttr::get(map));
4032 }
4033 
4034 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
4035  assert(ubOperands.size() == map.getNumInputs() &&
4036  "operands to map must match number of inputs");
4037 
4038  SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
4039  newOperands.append(ubOperands.begin(), ubOperands.end());
4040  (*this)->setOperands(newOperands);
4041 
4042  setUpperBoundsMapAttr(AffineMapAttr::get(map));
4043 }
4044 
4045 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
4046  setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
4047 }
4048 
4049 // check whether resultType match op or not in affine.parallel
4050 static bool isResultTypeMatchAtomicRMWKind(Type resultType,
4051  arith::AtomicRMWKind op) {
4052  switch (op) {
4053  case arith::AtomicRMWKind::addf:
4054  return isa<FloatType>(resultType);
4055  case arith::AtomicRMWKind::addi:
4056  return isa<IntegerType>(resultType);
4057  case arith::AtomicRMWKind::assign:
4058  return true;
4059  case arith::AtomicRMWKind::mulf:
4060  return isa<FloatType>(resultType);
4061  case arith::AtomicRMWKind::muli:
4062  return isa<IntegerType>(resultType);
4063  case arith::AtomicRMWKind::maximumf:
4064  return isa<FloatType>(resultType);
4065  case arith::AtomicRMWKind::minimumf:
4066  return isa<FloatType>(resultType);
4067  case arith::AtomicRMWKind::maxs: {
4068  auto intType = llvm::dyn_cast<IntegerType>(resultType);
4069  return intType && intType.isSigned();
4070  }
4071  case arith::AtomicRMWKind::mins: {
4072  auto intType = llvm::dyn_cast<IntegerType>(resultType);
4073  return intType && intType.isSigned();
4074  }
4075  case arith::AtomicRMWKind::maxu: {
4076  auto intType = llvm::dyn_cast<IntegerType>(resultType);
4077  return intType && intType.isUnsigned();
4078  }
4079  case arith::AtomicRMWKind::minu: {
4080  auto intType = llvm::dyn_cast<IntegerType>(resultType);
4081  return intType && intType.isUnsigned();
4082  }
4083  case arith::AtomicRMWKind::ori:
4084  return isa<IntegerType>(resultType);
4085  case arith::AtomicRMWKind::andi:
4086  return isa<IntegerType>(resultType);
4087  default:
4088  return false;
4089  }
4090 }
4091 
4092 LogicalResult AffineParallelOp::verify() {
4093  auto numDims = getNumDims();
4094  if (getLowerBoundsGroups().getNumElements() != numDims ||
4095  getUpperBoundsGroups().getNumElements() != numDims ||
4096  getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
4097  return emitOpError() << "the number of region arguments ("
4098  << getBody()->getNumArguments()
4099  << ") and the number of map groups for lower ("
4100  << getLowerBoundsGroups().getNumElements()
4101  << ") and upper bound ("
4102  << getUpperBoundsGroups().getNumElements()
4103  << "), and the number of steps (" << getSteps().size()
4104  << ") must all match";
4105  }
4106 
4107  unsigned expectedNumLBResults = 0;
4108  for (APInt v : getLowerBoundsGroups()) {
4109  unsigned results = v.getZExtValue();
4110  if (results == 0)
4111  return emitOpError()
4112  << "expected lower bound map to have at least one result";
4113  expectedNumLBResults += results;
4114  }
4115  if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
4116  return emitOpError() << "expected lower bounds map to have "
4117  << expectedNumLBResults << " results";
4118  unsigned expectedNumUBResults = 0;
4119  for (APInt v : getUpperBoundsGroups()) {
4120  unsigned results = v.getZExtValue();
4121  if (results == 0)
4122  return emitOpError()
4123  << "expected upper bound map to have at least one result";
4124  expectedNumUBResults += results;
4125  }
4126  if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
4127  return emitOpError() << "expected upper bounds map to have "
4128  << expectedNumUBResults << " results";
4129 
4130  if (getReductions().size() != getNumResults())
4131  return emitOpError("a reduction must be specified for each output");
4132 
4133  // Verify reduction ops are all valid and each result type matches reduction
4134  // ops
4135  for (auto it : llvm::enumerate((getReductions()))) {
4136  Attribute attr = it.value();
4137  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
4138  if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
4139  return emitOpError("invalid reduction attribute");
4140  auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
4141  if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
4142  return emitOpError("result type cannot match reduction attribute");
4143  }
4144 
4145  // Verify that the bound operands are valid dimension/symbols.
4146  /// Lower bounds.
4147  if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
4148  getLowerBoundsMap().getNumDims())))
4149  return failure();
4150  /// Upper bounds.
4151  if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
4152  getUpperBoundsMap().getNumDims())))
4153  return failure();
4154  return success();
4155 }
4156 
4157 LogicalResult AffineValueMap::canonicalize() {
4158  SmallVector<Value, 4> newOperands{operands};
4159  auto newMap = getAffineMap();
4160  composeAffineMapAndOperands(&newMap, &newOperands);
4161  if (newMap == getAffineMap() && newOperands == operands)
4162  return failure();
4163  reset(newMap, newOperands);
4164  return success();
4165 }
4166 
4167 /// Canonicalize the bounds of the given loop.
4168 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
4169  AffineValueMap lb = op.getLowerBoundsValueMap();
4170  bool lbCanonicalized = succeeded(lb.canonicalize());
4171 
4172  AffineValueMap ub = op.getUpperBoundsValueMap();
4173  bool ubCanonicalized = succeeded(ub.canonicalize());
4174 
4175  // Any canonicalization change always leads to updated map(s).
4176  if (!lbCanonicalized && !ubCanonicalized)
4177  return failure();
4178 
4179  if (lbCanonicalized)
4180  op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
4181  if (ubCanonicalized)
4182  op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
4183 
4184  return success();
4185 }
4186 
4187 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4188  SmallVectorImpl<OpFoldResult> &results) {
4189  return canonicalizeLoopBounds(*this);
4190 }
4191 
4192 /// Prints a lower(upper) bound of an affine parallel loop with max(min)
4193 /// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
4194 /// identifies which of the those expressions form max/min groups. `operands`
4195 /// are the SSA values of dimensions and symbols and `keyword` is either "min"
4196 /// or "max".
4197 static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
4198  DenseIntElementsAttr group, ValueRange operands,
4199  StringRef keyword) {
4200  AffineMap map = mapAttr.getValue();
4201  unsigned numDims = map.getNumDims();
4202  ValueRange dimOperands = operands.take_front(numDims);
4203  ValueRange symOperands = operands.drop_front(numDims);
4204  unsigned start = 0;
4205  for (llvm::APInt groupSize : group) {
4206  if (start != 0)
4207  p << ", ";
4208 
4209  unsigned size = groupSize.getZExtValue();
4210  if (size == 1) {
4211  p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
4212  ++start;
4213  } else {
4214  p << keyword << '(';
4215  AffineMap submap = map.getSliceMap(start, size);
4216  p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
4217  p << ')';
4218  start += size;
4219  }
4220  }
4221 }
4222 
4224  p << " (" << getBody()->getArguments() << ") = (";
4225  printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
4226  getLowerBoundsOperands(), "max");
4227  p << ") to (";
4228  printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
4229  getUpperBoundsOperands(), "min");
4230  p << ')';
4231  SmallVector<int64_t, 8> steps = getSteps();
4232  bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4233  if (!elideSteps) {
4234  p << " step (";
4235  llvm::interleaveComma(steps, p);
4236  p << ')';
4237  }
4238  if (getNumResults()) {
4239  p << " reduce (";
4240  llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4241  arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4242  llvm::cast<IntegerAttr>(attr).getInt());
4243  p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4244  });
4245  p << ") -> (" << getResultTypes() << ")";
4246  }
4247 
4248  p << ' ';
4249  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4250  /*printBlockTerminators=*/getNumResults());
4252  (*this)->getAttrs(),
4253  /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4254  AffineParallelOp::getLowerBoundsMapAttrStrName(),
4255  AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4256  AffineParallelOp::getUpperBoundsMapAttrStrName(),
4257  AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4258  AffineParallelOp::getStepsAttrStrName()});
4259 }
4260 
4261 /// Given a list of lists of parsed operands, populates `uniqueOperands` with
4262 /// unique operands. Also populates `replacements with affine expressions of
4263 /// `kind` that can be used to update affine maps previously accepting a
4264 /// `operands` to accept `uniqueOperands` instead.
4266  OpAsmParser &parser,
4268  SmallVectorImpl<Value> &uniqueOperands,
4271  "expected operands to be dim or symbol expression");
4272 
4273  Type indexType = parser.getBuilder().getIndexType();
4274  for (const auto &list : operands) {
4275  SmallVector<Value> valueOperands;
4276  if (parser.resolveOperands(list, indexType, valueOperands))
4277  return failure();
4278  for (Value operand : valueOperands) {
4279  unsigned pos = std::distance(uniqueOperands.begin(),
4280  llvm::find(uniqueOperands, operand));
4281  if (pos == uniqueOperands.size())
4282  uniqueOperands.push_back(operand);
4283  replacements.push_back(
4285  ? getAffineDimExpr(pos, parser.getContext())
4286  : getAffineSymbolExpr(pos, parser.getContext()));
4287  }
4288  }
4289  return success();
4290 }
4291 
4292 namespace {
4293 enum class MinMaxKind { Min, Max };
4294 } // namespace
4295 
4296 /// Parses an affine map that can contain a min/max for groups of its results,
4297 /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4298 /// `result` attributes with the map (flat list of expressions) and the grouping
4299 /// (list of integers that specify how many expressions to put into each
4300 /// min/max) attributes. Deduplicates repeated operands.
4301 ///
4302 /// parallel-bound ::= `(` parallel-group-list `)`
4303 /// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
4304 /// parallel-group ::= simple-group | min-max-group
4305 /// simple-group ::= expr-of-ssa-ids
4306 /// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4307 /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4308 ///
4309 /// Examples:
4310 /// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4311 /// (%0, max(%1 - 2 * %2))
4312 static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
4313  OperationState &result,
4314  MinMaxKind kind) {
4315  // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4316  // see: https://reviews.llvm.org/D134227#3821753
4317  const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4318 
4319  StringRef mapName = kind == MinMaxKind::Min
4320  ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4321  : AffineParallelOp::getLowerBoundsMapAttrStrName();
4322  StringRef groupsName =
4323  kind == MinMaxKind::Min
4324  ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4325  : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4326 
4327  if (failed(parser.parseLParen()))
4328  return failure();
4329 
4330  if (succeeded(parser.parseOptionalRParen())) {
4331  result.addAttribute(
4332  mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4333  result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
4334  return success();
4335  }
4336 
4337  SmallVector<AffineExpr> flatExprs;
4340  SmallVector<int32_t> numMapsPerGroup;
4342  auto parseOperands = [&]() {
4343  if (succeeded(parser.parseOptionalKeyword(
4344  kind == MinMaxKind::Min ? "min" : "max"))) {
4345  mapOperands.clear();
4346  AffineMapAttr map;
4347  if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName,
4348  result.attributes,
4350  return failure();
4351  result.attributes.erase(tmpAttrStrName);
4352  llvm::append_range(flatExprs, map.getValue().getResults());
4353  auto operandsRef = llvm::ArrayRef(mapOperands);
4354  auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4356  auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4358  flatDimOperands.append(map.getValue().getNumResults(), dims);
4359  flatSymOperands.append(map.getValue().getNumResults(), syms);
4360  numMapsPerGroup.push_back(map.getValue().getNumResults());
4361  } else {
4362  if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
4363  flatSymOperands.emplace_back(),
4364  flatExprs.emplace_back())))
4365  return failure();
4366  numMapsPerGroup.push_back(1);
4367  }
4368  return success();
4369  };
4370  if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
4371  return failure();
4372 
4373  unsigned totalNumDims = 0;
4374  unsigned totalNumSyms = 0;
4375  for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4376  unsigned numDims = flatDimOperands[i].size();
4377  unsigned numSyms = flatSymOperands[i].size();
4378  flatExprs[i] = flatExprs[i]
4379  .shiftDims(numDims, totalNumDims)
4380  .shiftSymbols(numSyms, totalNumSyms);
4381  totalNumDims += numDims;
4382  totalNumSyms += numSyms;
4383  }
4384 
4385  // Deduplicate map operands.
4386  SmallVector<Value> dimOperands, symOperands;
4387  SmallVector<AffineExpr> dimRplacements, symRepacements;
4388  if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
4389  dimRplacements, AffineExprKind::DimId) ||
4390  deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
4391  symRepacements, AffineExprKind::SymbolId))
4392  return failure();
4393 
4394  result.operands.append(dimOperands.begin(), dimOperands.end());
4395  result.operands.append(symOperands.begin(), symOperands.end());
4396 
4397  Builder &builder = parser.getBuilder();
4398  auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4399  parser.getContext());
4400  flatMap = flatMap.replaceDimsAndSymbols(
4401  dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4402 
4403  result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4404  result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
4405  return success();
4406 }
4407 
4408 //
4409 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4410 // `to` parallel-bound steps? region attr-dict?
4411 // steps ::= `steps` `(` integer-literals `)`
4412 //
4413 ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4414  OperationState &result) {
4415  auto &builder = parser.getBuilder();
4416  auto indexType = builder.getIndexType();
4419  parser.parseEqual() ||
4420  parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
4421  parser.parseKeyword("to") ||
4422  parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
4423  return failure();
4424 
4425  AffineMapAttr stepsMapAttr;
4426  NamedAttrList stepsAttrs;
4428  if (failed(parser.parseOptionalKeyword("step"))) {
4429  SmallVector<int64_t, 4> steps(ivs.size(), 1);
4430  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4431  builder.getI64ArrayAttr(steps));
4432  } else {
4433  if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
4434  AffineParallelOp::getStepsAttrStrName(),
4435  stepsAttrs,
4437  return failure();
4438 
4439  // Convert steps from an AffineMap into an I64ArrayAttr.
4441  auto stepsMap = stepsMapAttr.getValue();
4442  for (const auto &result : stepsMap.getResults()) {
4443  auto constExpr = dyn_cast<AffineConstantExpr>(result);
4444  if (!constExpr)
4445  return parser.emitError(parser.getNameLoc(),
4446  "steps must be constant integers");
4447  steps.push_back(constExpr.getValue());
4448  }
4449  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4450  builder.getI64ArrayAttr(steps));
4451  }
4452 
4453  // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4454  // quoted strings are a member of the enum AtomicRMWKind.
4455  SmallVector<Attribute, 4> reductions;
4456  if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4457  if (parser.parseLParen())
4458  return failure();
4459  auto parseAttributes = [&]() -> ParseResult {
4460  // Parse a single quoted string via the attribute parsing, and then
4461  // verify it is a member of the enum and convert to it's integer
4462  // representation.
4463  StringAttr attrVal;
4464  NamedAttrList attrStorage;
4465  auto loc = parser.getCurrentLocation();
4466  if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
4467  attrStorage))
4468  return failure();
4469  std::optional<arith::AtomicRMWKind> reduction =
4470  arith::symbolizeAtomicRMWKind(attrVal.getValue());
4471  if (!reduction)
4472  return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4473  reductions.push_back(
4474  builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4475  // While we keep getting commas, keep parsing.
4476  return success();
4477  };
4478  if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
4479  return failure();
4480  }
4481  result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4482  builder.getArrayAttr(reductions));
4483 
4484  // Parse return types of reductions (if any)
4485  if (parser.parseOptionalArrowTypeList(result.types))
4486  return failure();
4487 
4488  // Now parse the body.
4489  Region *body = result.addRegion();
4490  for (auto &iv : ivs)
4491  iv.type = indexType;
4492  if (parser.parseRegion(*body, ivs) ||
4493  parser.parseOptionalAttrDict(result.attributes))
4494  return failure();
4495 
4496  // Add a terminator if none was parsed.
4497  AffineParallelOp::ensureTerminator(*body, builder, result.location);
4498  return success();
4499 }
4500 
4501 //===----------------------------------------------------------------------===//
4502 // AffineYieldOp
4503 //===----------------------------------------------------------------------===//
4504 
4505 LogicalResult AffineYieldOp::verify() {
4506  auto *parentOp = (*this)->getParentOp();
4507  auto results = parentOp->getResults();
4508  auto operands = getOperands();
4509 
4510  if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4511  return emitOpError() << "only terminates affine.if/for/parallel regions";
4512  if (parentOp->getNumResults() != getNumOperands())
4513  return emitOpError() << "parent of yield must have same number of "
4514  "results as the yield operands";
4515  for (auto it : llvm::zip(results, operands)) {
4516  if (std::get<0>(it).getType() != std::get<1>(it).getType())
4517  return emitOpError() << "types mismatch between yield op and its parent";
4518  }
4519 
4520  return success();
4521 }
4522 
4523 //===----------------------------------------------------------------------===//
4524 // AffineVectorLoadOp
4525 //===----------------------------------------------------------------------===//
4526 
4527 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4528  VectorType resultType, AffineMap map,
4529  ValueRange operands) {
4530  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4531  result.addOperands(operands);
4532  if (map)
4533  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4534  result.types.push_back(resultType);
4535 }
4536 
4537 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4538  VectorType resultType, Value memref,
4539  AffineMap map, ValueRange mapOperands) {
4540  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4541  result.addOperands(memref);
4542  result.addOperands(mapOperands);
4543  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4544  result.types.push_back(resultType);
4545 }
4546 
4547 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4548  VectorType resultType, Value memref,
4549  ValueRange indices) {
4550  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4551  int64_t rank = memrefType.getRank();
4552  // Create identity map for memrefs with at least one dimension or () -> ()
4553  // for zero-dimensional memrefs.
4554  auto map =
4555  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4556  build(builder, result, resultType, memref, map, indices);
4557 }
4558 
4559 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4560  MLIRContext *context) {
4561  results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4562 }
4563 
4564 ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4565  OperationState &result) {
4566  auto &builder = parser.getBuilder();
4567  auto indexTy = builder.getIndexType();
4568 
4569  MemRefType memrefType;
4570  VectorType resultType;
4571  OpAsmParser::UnresolvedOperand memrefInfo;
4572  AffineMapAttr mapAttr;
4574  return failure(
4575  parser.parseOperand(memrefInfo) ||
4576  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4577  AffineVectorLoadOp::getMapAttrStrName(),
4578  result.attributes) ||
4579  parser.parseOptionalAttrDict(result.attributes) ||
4580  parser.parseColonType(memrefType) || parser.parseComma() ||
4581  parser.parseType(resultType) ||
4582  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4583  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
4584  parser.addTypeToList(resultType, result.types));
4585 }
4586 
4588  p << " " << getMemRef() << '[';
4589  if (AffineMapAttr mapAttr =
4590  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4591  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4592  p << ']';
4593  p.printOptionalAttrDict((*this)->getAttrs(),
4594  /*elidedAttrs=*/{getMapAttrStrName()});
4595  p << " : " << getMemRefType() << ", " << getType();
4596 }
4597 
4598 /// Verify common invariants of affine.vector_load and affine.vector_store.
4599 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4600  VectorType vectorType) {
4601  // Check that memref and vector element types match.
4602  if (memrefType.getElementType() != vectorType.getElementType())
4603  return op->emitOpError(
4604  "requires memref and vector types of the same elemental type");
4605  return success();
4606 }
4607 
4608 LogicalResult AffineVectorLoadOp::verify() {
4609  MemRefType memrefType = getMemRefType();
4610  if (failed(verifyMemoryOpIndexing(
4611  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4612  getMapOperands(), memrefType,
4613  /*numIndexOperands=*/getNumOperands() - 1)))
4614  return failure();
4615 
4616  if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
4617  return failure();
4618 
4619  return success();
4620 }
4621 
4622 //===----------------------------------------------------------------------===//
4623 // AffineVectorStoreOp
4624 //===----------------------------------------------------------------------===//
4625 
4626 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4627  Value valueToStore, Value memref, AffineMap map,
4628  ValueRange mapOperands) {
4629  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4630  result.addOperands(valueToStore);
4631  result.addOperands(memref);
4632  result.addOperands(mapOperands);
4633  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4634 }
4635 
4636 // Use identity map.
4637 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4638  Value valueToStore, Value memref,
4639  ValueRange indices) {
4640  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4641  int64_t rank = memrefType.getRank();
4642  // Create identity map for memrefs with at least one dimension or () -> ()
4643  // for zero-dimensional memrefs.
4644  auto map =
4645  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4646  build(builder, result, valueToStore, memref, map, indices);
4647 }
4648 void AffineVectorStoreOp::getCanonicalizationPatterns(
4649  RewritePatternSet &results, MLIRContext *context) {
4650  results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4651 }
4652 
4653 ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4654  OperationState &result) {
4655  auto indexTy = parser.getBuilder().getIndexType();
4656 
4657  MemRefType memrefType;
4658  VectorType resultType;
4659  OpAsmParser::UnresolvedOperand storeValueInfo;
4660  OpAsmParser::UnresolvedOperand memrefInfo;
4661  AffineMapAttr mapAttr;
4663  return failure(
4664  parser.parseOperand(storeValueInfo) || parser.parseComma() ||
4665  parser.parseOperand(memrefInfo) ||
4666  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4667  AffineVectorStoreOp::getMapAttrStrName(),
4668  result.attributes) ||
4669  parser.parseOptionalAttrDict(result.attributes) ||
4670  parser.parseColonType(memrefType) || parser.parseComma() ||
4671  parser.parseType(resultType) ||
4672  parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
4673  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4674  parser.resolveOperands(mapOperands, indexTy, result.operands));
4675 }
4676 
4678  p << " " << getValueToStore();
4679  p << ", " << getMemRef() << '[';
4680  if (AffineMapAttr mapAttr =
4681  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4682  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4683  p << ']';
4684  p.printOptionalAttrDict((*this)->getAttrs(),
4685  /*elidedAttrs=*/{getMapAttrStrName()});
4686  p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4687 }
4688 
4689 LogicalResult AffineVectorStoreOp::verify() {
4690  MemRefType memrefType = getMemRefType();
4691  if (failed(verifyMemoryOpIndexing(
4692  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4693  getMapOperands(), memrefType,
4694  /*numIndexOperands=*/getNumOperands() - 2)))
4695  return failure();
4696 
4697  if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4698  return failure();
4699 
4700  return success();
4701 }
4702 
4703 //===----------------------------------------------------------------------===//
4704 // DelinearizeIndexOp
4705 //===----------------------------------------------------------------------===//
4706 
4707 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4708  OperationState &odsState,
4709  Value linearIndex, ValueRange dynamicBasis,
4710  ArrayRef<int64_t> staticBasis,
4711  bool hasOuterBound) {
4712  SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
4713  : staticBasis.size() + 1,
4714  linearIndex.getType());
4715  build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4716  staticBasis);
4717 }
4718 
4719 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4720  OperationState &odsState,
4721  Value linearIndex, ValueRange basis,
4722  bool hasOuterBound) {
4723  if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
4724  hasOuterBound = false;
4725  basis = basis.drop_front();
4726  }
4727  SmallVector<Value> dynamicBasis;
4728  SmallVector<int64_t> staticBasis;
4729  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4730  staticBasis);
4731  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4732  hasOuterBound);
4733 }
4734 
4735 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4736  OperationState &odsState,
4737  Value linearIndex,
4738  ArrayRef<OpFoldResult> basis,
4739  bool hasOuterBound) {
4740  if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
4741  hasOuterBound = false;
4742  basis = basis.drop_front();
4743  }
4744  SmallVector<Value> dynamicBasis;
4745  SmallVector<int64_t> staticBasis;
4746  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
4747  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4748  hasOuterBound);
4749 }
4750 
4751 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4752  OperationState &odsState,
4753  Value linearIndex, ArrayRef<int64_t> basis,
4754  bool hasOuterBound) {
4755  build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound);
4756 }
4757 
4758 LogicalResult AffineDelinearizeIndexOp::verify() {
4759  ArrayRef<int64_t> staticBasis = getStaticBasis();
4760  if (getNumResults() != staticBasis.size() &&
4761  getNumResults() != staticBasis.size() + 1)
4762  return emitOpError("should return an index for each basis element and up "
4763  "to one extra index");
4764 
4765  auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4766  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4767  return emitOpError(
4768  "mismatch between dynamic and static basis (kDynamic marker but no "
4769  "corresponding dynamic basis entry) -- this can only happen due to an "
4770  "incorrect fold/rewrite");
4771 
4772  if (!llvm::all_of(staticBasis, [](int64_t v) {
4773  return v > 0 || ShapedType::isDynamic(v);
4774  }))
4775  return emitOpError("no basis element may be statically non-positive");
4776 
4777  return success();
4778 }
4779 
4780 /// Given mixed basis of affine.delinearize_index/linearize_index replace
4781 /// constant SSA values with the constant integer value and return the new
4782 /// static basis. In case no such candidate for replacement exists, this utility
4783 /// returns std::nullopt.
4784 static std::optional<SmallVector<int64_t>>
4786  MutableOperandRange mutableDynamicBasis,
4787  ArrayRef<Attribute> dynamicBasis) {
4788  uint64_t dynamicBasisIndex = 0;
4789  for (OpFoldResult basis : dynamicBasis) {
4790  if (basis) {
4791  mutableDynamicBasis.erase(dynamicBasisIndex);
4792  } else {
4793  ++dynamicBasisIndex;
4794  }
4795  }
4796 
4797  // No constant SSA value exists.
4798  if (dynamicBasisIndex == dynamicBasis.size())
4799  return std::nullopt;
4800 
4801  SmallVector<int64_t> staticBasis;
4802  for (OpFoldResult basis : mixedBasis) {
4803  std::optional<int64_t> basisVal = getConstantIntValue(basis);
4804  if (!basisVal)
4805  staticBasis.push_back(ShapedType::kDynamic);
4806  else
4807  staticBasis.push_back(*basisVal);
4808  }
4809 
4810  return staticBasis;
4811 }
4812 
4813 LogicalResult
4814 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4816  std::optional<SmallVector<int64_t>> maybeStaticBasis =
4817  foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
4818  adaptor.getDynamicBasis());
4819  if (maybeStaticBasis) {
4820  setStaticBasis(*maybeStaticBasis);
4821  return success();
4822  }
4823  // If we won't be doing any division or modulo (no basis or the one basis
4824  // element is purely advisory), simply return the input value.
4825  if (getNumResults() == 1) {
4826  result.push_back(getLinearIndex());
4827  return success();
4828  }
4829 
4830  if (adaptor.getLinearIndex() == nullptr)
4831  return failure();
4832 
4833  if (!adaptor.getDynamicBasis().empty())
4834  return failure();
4835 
4836  int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
4837  Type attrType = getLinearIndex().getType();
4838 
4839  ArrayRef<int64_t> staticBasis = getStaticBasis();
4840  if (hasOuterBound())
4841  staticBasis = staticBasis.drop_front();
4842  for (int64_t modulus : llvm::reverse(staticBasis)) {
4843  result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
4844  highPart = llvm::divideFloorSigned(highPart, modulus);
4845  }
4846  result.push_back(IntegerAttr::get(attrType, highPart));
4847  std::reverse(result.begin(), result.end());
4848  return success();
4849 }
4850 
4851 SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
4852  OpBuilder builder(getContext());
4853  if (hasOuterBound()) {
4854  if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4855  return getMixedValues(getStaticBasis().drop_front(),
4856  getDynamicBasis().drop_front(), builder);
4857 
4858  return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4859  builder);
4860  }
4861 
4862  return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4863 }
4864 
4865 SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
4866  SmallVector<OpFoldResult> ret = getMixedBasis();
4867  if (!hasOuterBound())
4868  ret.insert(ret.begin(), OpFoldResult());
4869  return ret;
4870 }
4871 
4872 namespace {
4873 
4874 // Drops delinearization indices that correspond to unit-extent basis
4875 struct DropUnitExtentBasis
4876  : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4878 
4879  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4880  PatternRewriter &rewriter) const override {
4881  SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
4882  std::optional<Value> zero = std::nullopt;
4883  Location loc = delinearizeOp->getLoc();
4884  auto getZero = [&]() -> Value {
4885  if (!zero)
4886  zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
4887  return zero.value();
4888  };
4889 
4890  // Replace all indices corresponding to unit-extent basis with 0.
4891  // Remaining basis can be used to get a new `affine.delinearize_index` op.
4892  SmallVector<OpFoldResult> newBasis;
4893  for (auto [index, basis] :
4894  llvm::enumerate(delinearizeOp.getPaddedBasis())) {
4895  std::optional<int64_t> basisVal =
4896  basis ? getConstantIntValue(basis) : std::nullopt;
4897  if (basisVal == 1)
4898  replacements[index] = getZero();
4899  else
4900  newBasis.push_back(basis);
4901  }
4902 
4903  if (newBasis.size() == delinearizeOp.getNumResults())
4904  return rewriter.notifyMatchFailure(delinearizeOp,
4905  "no unit basis elements");
4906 
4907  if (!newBasis.empty()) {
4908  // Will drop the leading nullptr from `basis` if there was no outer bound.
4909  auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
4910  loc, delinearizeOp.getLinearIndex(), newBasis);
4911  int newIndex = 0;
4912  // Map back the new delinearized indices to the values they replace.
4913  for (auto &replacement : replacements) {
4914  if (replacement)
4915  continue;
4916  replacement = newDelinearizeOp->getResult(newIndex++);
4917  }
4918  }
4919 
4920  rewriter.replaceOp(delinearizeOp, replacements);
4921  return success();
4922  }
4923 };
4924 
4925 /// If a `affine.delinearize_index`'s input is a `affine.linearize_index
4926 /// disjoint` and the two operations end with the same basis elements,
4927 /// cancel those parts of the operations out because they are inverses
4928 /// of each other.
4929 ///
4930 /// If the operations have the same basis, cancel them entirely.
4931 ///
4932 /// The `disjoint` flag is needed on the `affine.linearize_index` because
4933 /// otherwise, there is no guarantee that the inputs to the linearization are
4934 /// in-bounds the way the outputs of the delinearization would be.
4935 struct CancelDelinearizeOfLinearizeDisjointExactTail
4936  : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4938 
4939  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4940  PatternRewriter &rewriter) const override {
4941  auto linearizeOp = delinearizeOp.getLinearIndex()
4942  .getDefiningOp<affine::AffineLinearizeIndexOp>();
4943  if (!linearizeOp)
4944  return rewriter.notifyMatchFailure(delinearizeOp,
4945  "index doesn't come from linearize");
4946 
4947  if (!linearizeOp.getDisjoint())
4948  return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
4949 
4950  ValueRange linearizeIns = linearizeOp.getMultiIndex();
4951  // Note: we use the full basis so we don't lose outer bounds later.
4952  SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
4953  SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
4954  size_t numMatches = 0;
4955  for (auto [linSize, delinSize] : llvm::zip(
4956  llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
4957  if (linSize != delinSize)
4958  break;
4959  ++numMatches;
4960  }
4961 
4962  if (numMatches == 0)
4963  return rewriter.notifyMatchFailure(
4964  delinearizeOp, "final basis element doesn't match linearize");
4965 
4966  // The easy case: everything lines up and the basis match sup completely.
4967  if (numMatches == linearizeBasis.size() &&
4968  numMatches == delinearizeBasis.size() &&
4969  linearizeIns.size() == delinearizeOp.getNumResults()) {
4970  rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
4971  return success();
4972  }
4973 
4974  Value newLinearize = rewriter.create<affine::AffineLinearizeIndexOp>(
4975  linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
4976  ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
4977  linearizeOp.getDisjoint());
4978  auto newDelinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
4979  delinearizeOp.getLoc(), newLinearize,
4980  ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
4981  delinearizeOp.hasOuterBound());
4982  SmallVector<Value> mergedResults(newDelinearize.getResults());
4983  mergedResults.append(linearizeIns.take_back(numMatches).begin(),
4984  linearizeIns.take_back(numMatches).end());
4985  rewriter.replaceOp(delinearizeOp, mergedResults);
4986  return success();
4987  }
4988 };
4989 
4990 /// If the input to a delinearization is a disjoint linearization, and the
4991 /// last k > 1 components of the delinearization basis multiply to the
4992 /// last component of the linearization basis, break the linearization and
4993 /// delinearization into two parts, peeling off the last input to linearization.
4994 ///
4995 /// For example:
4996 /// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
4997 /// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
4998 /// becomes
4999 /// %0 = affine.linearize_index [%z, %y] by (3, 2) : index
5000 /// %1:2 = affine.delinearize_index %0 by (2, 3) : index
5001 /// %2:2 = affine.delinearize_index %x by (8, 4) : index
5002 /// where the original %1:4 is replaced by %1:2 ++ %2:2
5003 struct SplitDelinearizeSpanningLastLinearizeArg final
5004  : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5006 
5007  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5008  PatternRewriter &rewriter) const override {
5009  auto linearizeOp = delinearizeOp.getLinearIndex()
5010  .getDefiningOp<affine::AffineLinearizeIndexOp>();
5011  if (!linearizeOp)
5012  return rewriter.notifyMatchFailure(delinearizeOp,
5013  "index doesn't come from linearize");
5014 
5015  if (!linearizeOp.getDisjoint())
5016  return rewriter.notifyMatchFailure(linearizeOp,
5017  "linearize isn't disjoint");
5018 
5019  int64_t target = linearizeOp.getStaticBasis().back();
5020  if (ShapedType::isDynamic(target))
5021  return rewriter.notifyMatchFailure(
5022  linearizeOp, "linearize ends with dynamic basis value");
5023 
5024  int64_t sizeToSplit = 1;
5025  size_t elemsToSplit = 0;
5026  ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
5027  for (int64_t basisElem : llvm::reverse(basis)) {
5028  if (ShapedType::isDynamic(basisElem))
5029  return rewriter.notifyMatchFailure(
5030  delinearizeOp, "dynamic basis element while scanning for split");
5031  sizeToSplit *= basisElem;
5032  elemsToSplit += 1;
5033 
5034  if (sizeToSplit > target)
5035  return rewriter.notifyMatchFailure(delinearizeOp,
5036  "overshot last argument size");
5037  if (sizeToSplit == target)
5038  break;
5039  }
5040 
5041  if (sizeToSplit < target)
5042  return rewriter.notifyMatchFailure(
5043  delinearizeOp, "product of known basis elements doesn't exceed last "
5044  "linearize argument");
5045 
5046  if (elemsToSplit < 2)
5047  return rewriter.notifyMatchFailure(
5048  delinearizeOp,
5049  "need at least two elements to form the basis product");
5050 
5051  Value linearizeWithoutBack =
5052  rewriter.create<affine::AffineLinearizeIndexOp>(
5053  linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
5054  linearizeOp.getDynamicBasis(),
5055  linearizeOp.getStaticBasis().drop_back(),
5056  linearizeOp.getDisjoint());
5057  auto delinearizeWithoutSplitPart =
5058  rewriter.create<affine::AffineDelinearizeIndexOp>(
5059  delinearizeOp.getLoc(), linearizeWithoutBack,
5060  delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
5061  delinearizeOp.hasOuterBound());
5062  auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
5063  delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
5064  basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
5065  SmallVector<Value> results = llvm::to_vector(
5066  llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
5067  delinearizeBack.getResults()));
5068  rewriter.replaceOp(delinearizeOp, results);
5069 
5070  return success();
5071  }
5072 };
5073 } // namespace
5074 
5075 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
5076  RewritePatternSet &patterns, MLIRContext *context) {
5077  patterns
5078  .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
5079  DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
5080  context);
5081 }
5082 
5083 //===----------------------------------------------------------------------===//
5084 // LinearizeIndexOp
5085 //===----------------------------------------------------------------------===//
5086 
5087 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5088  OperationState &odsState,
5089  ValueRange multiIndex, ValueRange basis,
5090  bool disjoint) {
5091  if (!basis.empty() && basis.front() == Value())
5092  basis = basis.drop_front();
5093  SmallVector<Value> dynamicBasis;
5094  SmallVector<int64_t> staticBasis;
5095  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
5096  staticBasis);
5097  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5098 }
5099 
5100 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5101  OperationState &odsState,
5102  ValueRange multiIndex,
5103  ArrayRef<OpFoldResult> basis,
5104  bool disjoint) {
5105  if (!basis.empty() && basis.front() == OpFoldResult())
5106  basis = basis.drop_front();
5107  SmallVector<Value> dynamicBasis;
5108  SmallVector<int64_t> staticBasis;
5109  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
5110  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5111 }
5112 
5113 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5114  OperationState &odsState,
5115  ValueRange multiIndex,
5116  ArrayRef<int64_t> basis, bool disjoint) {
5117  build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
5118 }
5119 
5120 LogicalResult AffineLinearizeIndexOp::verify() {
5121  size_t numIndexes = getMultiIndex().size();
5122  size_t numBasisElems = getStaticBasis().size();
5123  if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
5124  return emitOpError("should be passed a basis element for each index except "
5125  "possibly the first");
5126 
5127  auto dynamicMarkersCount =
5128  llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
5129  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
5130  return emitOpError(
5131  "mismatch between dynamic and static basis (kDynamic marker but no "
5132  "corresponding dynamic basis entry) -- this can only happen due to an "
5133  "incorrect fold/rewrite");
5134 
5135  return success();
5136 }
5137 
5138 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
5139  std::optional<SmallVector<int64_t>> maybeStaticBasis =
5140  foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
5141  adaptor.getDynamicBasis());
5142  if (maybeStaticBasis) {
5143  setStaticBasis(*maybeStaticBasis);
5144  return getResult();
5145  }
5146  // No indices linearizes to zero.
5147  if (getMultiIndex().empty())
5148  return IntegerAttr::get(getResult().getType(), 0);
5149 
5150  // One single index linearizes to itself.
5151  if (getMultiIndex().size() == 1)
5152  return getMultiIndex().front();
5153 
5154  if (llvm::is_contained(adaptor.getMultiIndex(), nullptr))
5155  return nullptr;
5156 
5157  if (!adaptor.getDynamicBasis().empty())
5158  return nullptr;
5159 
5160  int64_t result = 0;
5161  int64_t stride = 1;
5162  for (auto [length, indexAttr] :
5163  llvm::zip_first(llvm::reverse(getStaticBasis()),
5164  llvm::reverse(adaptor.getMultiIndex()))) {
5165  result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
5166  stride = stride * length;
5167  }
5168  // Handle the index element with no basis element.
5169  if (!hasOuterBound())
5170  result =
5171  result +
5172  cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
5173 
5174  return IntegerAttr::get(getResult().getType(), result);
5175 }
5176 
5177 SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
5178  OpBuilder builder(getContext());
5179  if (hasOuterBound()) {
5180  if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5181  return getMixedValues(getStaticBasis().drop_front(),
5182  getDynamicBasis().drop_front(), builder);
5183 
5184  return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5185  builder);
5186  }
5187 
5188  return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5189 }
5190 
5191 SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
5192  SmallVector<OpFoldResult> ret = getMixedBasis();
5193  if (!hasOuterBound())
5194  ret.insert(ret.begin(), OpFoldResult());
5195  return ret;
5196 }
5197 
5198 namespace {
5199 /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
5200 /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
5201 /// %...d)`.
5202 
5203 /// Note that `disjoint` is required here, because, without it, we could have
5204 /// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
5205 /// is a valid operation where the `%c64` cannot be trivially dropped.
5206 ///
5207 /// Alternatively, if `%x` in the above is a known constant 0, remove it even if
5208 /// the operation isn't asserted to be `disjoint`.
5209 struct DropLinearizeUnitComponentsIfDisjointOrZero final
5210  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5212 
5213  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5214  PatternRewriter &rewriter) const override {
5215  ValueRange multiIndex = op.getMultiIndex();
5216  size_t numIndices = multiIndex.size();
5217  SmallVector<Value> newIndices;
5218  newIndices.reserve(numIndices);
5219  SmallVector<OpFoldResult> newBasis;
5220  newBasis.reserve(numIndices);
5221 
5222  if (!op.hasOuterBound()) {
5223  newIndices.push_back(multiIndex.front());
5224  multiIndex = multiIndex.drop_front();
5225  }
5226 
5227  SmallVector<OpFoldResult> basis = op.getMixedBasis();
5228  for (auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5229  std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
5230  if (!basisEntry || *basisEntry != 1) {
5231  newIndices.push_back(index);
5232  newBasis.push_back(basisElem);
5233  continue;
5234  }
5235 
5236  std::optional<int64_t> indexValue = getConstantIntValue(index);
5237  if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5238  newIndices.push_back(index);
5239  newBasis.push_back(basisElem);
5240  continue;
5241  }
5242  }
5243  if (newIndices.size() == numIndices)
5244  return rewriter.notifyMatchFailure(op,
5245  "no unit basis entries to replace");
5246 
5247  if (newIndices.size() == 0) {
5248  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
5249  return success();
5250  }
5251  rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5252  op, newIndices, newBasis, op.getDisjoint());
5253  return success();
5254  }
5255 };
5256 
5258  ArrayRef<OpFoldResult> terms) {
5259  int64_t nDynamic = 0;
5260  SmallVector<Value> dynamicPart;
5261  AffineExpr result = builder.getAffineConstantExpr(1);
5262  for (OpFoldResult term : terms) {
5263  if (!term)
5264  return term;
5265  std::optional<int64_t> maybeConst = getConstantIntValue(term);
5266  if (maybeConst) {
5267  result = result * builder.getAffineConstantExpr(*maybeConst);
5268  } else {
5269  dynamicPart.push_back(cast<Value>(term));
5270  result = result * builder.getAffineSymbolExpr(nDynamic++);
5271  }
5272  }
5273  if (auto constant = dyn_cast<AffineConstantExpr>(result))
5274  return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
5275  return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
5276 }
5277 
5278 /// If conseceutive outputs of a delinearize_index are linearized with the same
5279 /// bounds, canonicalize away the redundant arithmetic.
5280 ///
5281 /// That is, if we have
5282 /// ```
5283 /// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
5284 /// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
5285 /// by (...e, B1, B2, ..., BK, ...f)
5286 /// ```
5287 ///
5288 /// We can rewrite this to
5289 /// ```
5290 /// B = B1 * B2 ... BK
5291 /// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
5292 /// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
5293 /// ```
5294 /// where we replace all results of %s unaffected by the change with results
5295 /// from %sMerged.
5296 ///
5297 /// As a special case, if all results of the delinearize are merged in this way
5298 /// we can replace those usages with %x, thus cancelling the delinearization
5299 /// entirely, as in
5300 /// ```
5301 /// %s:3 = affine.delinearize_index %x into (2, 4, 8)
5302 /// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
5303 /// ```
5304 /// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
5305 struct CancelLinearizeOfDelinearizePortion final
5306  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5308 
5309 private:
5310  // Struct representing a case where the cancellation pattern
5311  // applies. A `Match` means that `length` inputs to the linearize operation
5312  // starting at `linStart` can be cancelled with `length` outputs of
5313  // `delinearize`, starting from `delinStart`.
5314  struct Match {
5315  AffineDelinearizeIndexOp delinearize;
5316  unsigned linStart = 0;
5317  unsigned delinStart = 0;
5318  unsigned length = 0;
5319  };
5320 
5321 public:
5322  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5323  PatternRewriter &rewriter) const override {
5324  SmallVector<Match> matches;
5325 
5326  const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
5327  ArrayRef<OpFoldResult> linBasisRef = linBasis;
5328 
5329  ValueRange multiIndex = linearizeOp.getMultiIndex();
5330  unsigned numLinArgs = multiIndex.size();
5331  unsigned linArgIdx = 0;
5332  // We only want to replace one run from the same delinearize op per
5333  // pattern invocation lest we run into invalidation issues.
5334  llvm::SmallPtrSet<Operation *, 2> alreadyMatchedDelinearize;
5335  while (linArgIdx < numLinArgs) {
5336  auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5337  if (!asResult) {
5338  linArgIdx++;
5339  continue;
5340  }
5341 
5342  auto delinearizeOp =
5343  dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5344  if (!delinearizeOp) {
5345  linArgIdx++;
5346  continue;
5347  }
5348 
5349  /// Result 0 of the delinearize and argument 0 of the linearize can
5350  /// leave their maximum value unspecified. However, even if this happens
5351  /// we can still sometimes start the match process. Specifically, if
5352  /// - The argument we're matching is result 0 and argument 0 (so the
5353  /// bounds don't matter). For example,
5354  ///
5355  /// %0:2 = affine.delinearize_index %x into (8) : index, index
5356  /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
5357  /// allows cancellation
5358  /// - The delinearization doesn't specify a bound, but the linearization
5359  /// is `disjoint`, which asserts that the bound on the linearization is
5360  /// correct.
5361  unsigned delinArgIdx = asResult.getResultNumber();
5362  SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
5363  OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5364  OpFoldResult firstLinBound = linBasis[linArgIdx];
5365  bool boundsMatch = firstDelinBound == firstLinBound;
5366  bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5367  bool knownByDisjoint =
5368  linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5369  if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5370  linArgIdx++;
5371  continue;
5372  }
5373 
5374  unsigned j = 1;
5375  unsigned numDelinOuts = delinearizeOp.getNumResults();
5376  for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
5377  ++j) {
5378  if (multiIndex[linArgIdx + j] !=
5379  delinearizeOp.getResult(delinArgIdx + j))
5380  break;
5381  if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
5382  break;
5383  }
5384  // If there're multiple matches against the same delinearize_index,
5385  // only rewrite the first one we find to prevent invalidations. The next
5386  // ones will be taken care of by subsequent pattern invocations.
5387  if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5388  linArgIdx++;
5389  continue;
5390  }
5391  matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx, j});
5392  linArgIdx += j;
5393  }
5394 
5395  if (matches.empty())
5396  return rewriter.notifyMatchFailure(
5397  linearizeOp, "no run of delinearize outputs to deal with");
5398 
5399  // Record all the delinearize replacements so we can do them after creating
5400  // the new linearization operation, since the new operation might use
5401  // outputs of something we're replacing.
5402  SmallVector<SmallVector<Value>> delinearizeReplacements;
5403 
5404  SmallVector<Value> newIndex;
5405  newIndex.reserve(numLinArgs);
5406  SmallVector<OpFoldResult> newBasis;
5407  newBasis.reserve(numLinArgs);
5408  unsigned prevMatchEnd = 0;
5409  for (Match m : matches) {
5410  unsigned gap = m.linStart - prevMatchEnd;
5411  llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5412  llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5413  // Update here so we don't forget this during early continues
5414  prevMatchEnd = m.linStart + m.length;
5415 
5416  PatternRewriter::InsertionGuard g(rewriter);
5417  rewriter.setInsertionPoint(m.delinearize);
5418 
5419  ArrayRef<OpFoldResult> basisToMerge =
5420  linBasisRef.slice(m.linStart, m.length);
5421  // We use the slice from the linearize's basis above because of the
5422  // "bounds inferred from `disjoint`" case above.
5423  OpFoldResult newSize =
5424  computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
5425 
5426  // Trivial case where we can just skip past the delinearize all together
5427  if (m.length == m.delinearize.getNumResults()) {
5428  newIndex.push_back(m.delinearize.getLinearIndex());
5429  newBasis.push_back(newSize);
5430  // Pad out set of replacements so we don't do anything with this one.
5431  delinearizeReplacements.push_back(SmallVector<Value>());
5432  continue;
5433  }
5434 
5435  SmallVector<Value> newDelinResults;
5436  SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
5437  newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5438  newDelinBasis.begin() + m.delinStart + m.length);
5439  newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5440  auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
5441  m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5442  newDelinBasis);
5443 
5444  // Since there may be other uses of the indices we just merged together,
5445  // create a residual affine.delinearize_index that delinearizes the
5446  // merged output into its component parts.
5447  Value combinedElem = newDelinearize.getResult(m.delinStart);
5448  auto residualDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
5449  m.delinearize.getLoc(), combinedElem, basisToMerge);
5450 
5451  // Swap all the uses of the unaffected delinearize outputs to the new
5452  // delinearization so that the old code can be removed if this
5453  // linearize_index is the only user of the merged results.
5454  llvm::append_range(newDelinResults,
5455  newDelinearize.getResults().take_front(m.delinStart));
5456  llvm::append_range(newDelinResults, residualDelinearize.getResults());
5457  llvm::append_range(
5458  newDelinResults,
5459  newDelinearize.getResults().drop_front(m.delinStart + 1));
5460 
5461  delinearizeReplacements.push_back(newDelinResults);
5462  newIndex.push_back(combinedElem);
5463  newBasis.push_back(newSize);
5464  }
5465  llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5466  llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5467  rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
5468  linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5469 
5470  for (auto [m, newResults] :
5471  llvm::zip_equal(matches, delinearizeReplacements)) {
5472  if (newResults.empty())
5473  continue;
5474  rewriter.replaceOp(m.delinearize, newResults);
5475  }
5476 
5477  return success();
5478  }
5479 };
5480 
5481 /// Strip leading zero from affine.linearize_index.
5482 ///
5483 /// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
5484 /// to `affine.linearize_index [...a] by (...b)` in all cases.
5485 struct DropLinearizeLeadingZero final
5486  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5488 
5489  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5490  PatternRewriter &rewriter) const override {
5491  Value leadingIdx = op.getMultiIndex().front();
5492  if (!matchPattern(leadingIdx, m_Zero()))
5493  return failure();
5494 
5495  if (op.getMultiIndex().size() == 1) {
5496  rewriter.replaceOp(op, leadingIdx);
5497  return success();
5498  }
5499 
5500  SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
5501  ArrayRef<OpFoldResult> newMixedBasis = mixedBasis;
5502  if (op.hasOuterBound())
5503  newMixedBasis = newMixedBasis.drop_front();
5504 
5505  rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5506  op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5507  return success();
5508  }
5509 };
5510 } // namespace
5511 
5512 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5513  RewritePatternSet &patterns, MLIRContext *context) {
5514  patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5515  DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5516 }
5517 
5518 //===----------------------------------------------------------------------===//
5519 // TableGen'd op method definitions
5520 //===----------------------------------------------------------------------===//
5521 
5522 #define GET_OP_CLASSES
5523 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
Definition: AffineOps.cpp:2846
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
Definition: AffineOps.cpp:2567
static LogicalResult verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
Definition: AffineOps.cpp:3239
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
Definition: AffineOps.cpp:3417
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
Definition: AffineOps.cpp:4050
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
Definition: AffineOps.cpp:63
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
Definition: AffineOps.cpp:4197
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
Definition: AffineOps.cpp:1035
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
Definition: AffineOps.cpp:3453
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
Definition: AffineOps.cpp:2421
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
Definition: AffineOps.cpp:821
static bool isValidAffineIndexOperand(Value value, Region *region)
Definition: AffineOps.cpp:491
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1544
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:756
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
Definition: AffineOps.cpp:2110
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:748
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1450
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
Definition: AffineOps.cpp:4599
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
Definition: AffineOps.cpp:924
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
Definition: AffineOps.cpp:3430
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
Definition: AffineOps.cpp:3131
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
Definition: AffineOps.cpp:350
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
Definition: AffineOps.cpp:696
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
Definition: AffineOps.cpp:4312
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
Definition: AffineOps.cpp:2855
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
Definition: AffineOps.cpp:496
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
Definition: AffineOps.cpp:658
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
Definition: AffineOps.cpp:1198
static void legalizeDemotedDims(MapOrSet &mapOrSet, SmallVectorImpl< Value > &operands)
A valid affine dimension may appear as a symbol in affine.apply operations.
Definition: AffineOps.cpp:1498
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1387
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
Definition: AffineOps.cpp:2805
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
Definition: AffineOps.cpp:2375
static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, AffineExpr dimOrSym, AffineMap *map, ValueRange dims, ValueRange syms)
Assuming dimOrSym is a quantity in the apply op map map and defined by minOp = affine_min(x_1,...
Definition: AffineOps.cpp:1060
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
Definition: AffineOps.cpp:528
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1402
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
Definition: AffineOps.cpp:369
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr &quotientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
Definition: AffineOps.cpp:724
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
Definition: AffineOps.cpp:4265
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms, bool replaceAffineMin)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
Definition: AffineOps.cpp:1142
#define DBGS()
Definition: AffineOps.cpp:43
static LogicalResult verifyAffineMinMaxOp(T op)
Definition: AffineOps.cpp:3404
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
Definition: AffineOps.cpp:2285
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
Definition: AffineOps.cpp:1301
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
Definition: AffineOps.cpp:4785
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
Definition: AffineOps.cpp:3616
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
union mlir::linalg::@1221::ArityGroupAndKind::Kind kind
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:75
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:80
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:961
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:35
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
Definition: AffineExpr.cpp:243
MLIRContext * getContext() const
Definition: AffineExpr.cpp:33
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
Definition: AffineExpr.cpp:181
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:655
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
Definition: AffineMap.h:221
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:267
unsigned getNumSymbols() const
Definition: AffineMap.cpp:394
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
Definition: AffineMap.h:228
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineMap.cpp:496
unsigned getNumInputs() const
Definition: AffineMap.cpp:399
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
Definition: AffineMap.h:280
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:511
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
Definition: AffineMap.cpp:124
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:647
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
Definition: AffineMap.cpp:430
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:158
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
AffineMap getDimIdentityMap()
Definition: Builders.cpp:378
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:382
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:363
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:367
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
Definition: Builders.cpp:174
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
NoneType getNoneType()
Definition: Builders.cpp:83
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:95
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
Definition: Builders.cpp:371
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:373
MLIRContext * getContext() const
Definition: Builders.h:55
AffineMap getSymbolIdentityMap()
Definition: Builders.cpp:391
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:276
IndexType getIndexType()
Definition: Builders.cpp:50
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition: IntegerSet.h:44
unsigned getNumDims() const
Definition: IntegerSet.cpp:15
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
Definition: IntegerSet.cpp:57
unsigned getNumInputs() const
Definition: IntegerSet.cpp:17
ArrayRef< AffineExpr > getConstraints() const
Definition: IntegerSet.cpp:41
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
Definition: IntegerSet.cpp:51
unsigned getNumSymbols() const
Definition: IntegerSet.cpp:16
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:257
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
operand_range::iterator operand_iterator
Definition: Operation.h:372
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:748
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:810
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:681
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:593
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:577
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
A variable that can be added to the constraint set as a "column".
static bool compare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Return "true" if "lhs cmp rhs" was proven to hold.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
AffineBound represents a lower or upper bound in the for operation.
Definition: AffineOps.h:529
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
Definition: AffineOps.h:106
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
Definition: AffineOps.h:315
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
Definition: AffineOps.cpp:4157
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
unsigned getNumResults() const
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
Definition: AffineOps.cpp:2868
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1278
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
Definition: AffineOps.cpp:2782
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
Definition: AffineOps.cpp:291
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
Definition: AffineOps.cpp:2754
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1376
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: AffineOps.cpp:2758
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1621
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1441
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
Definition: AffineOps.cpp:2746
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1331
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1434
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Definition: AffineOps.cpp:251
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
Definition: AffineOps.cpp:276
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
Definition: AffineOps.cpp:1262
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
Definition: AffineOps.cpp:1626
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
Definition: AffineOps.cpp:2789
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
Definition: AffineOps.cpp:413
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Definition: AffineOps.cpp:2769
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
Definition: AffineOps.cpp:266
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
Definition: AffineOps.cpp:506
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
Definition: AffineOps.cpp:2750
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
Definition: AffineOps.cpp:1396
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:766
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
Definition: AffineMap.cpp:776
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
AffineExprKind
Definition: AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:70
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
Definition: AffineMap.cpp:738
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:631
Canonicalize the affine map result expression order of an affine min/max operation.
Definition: AffineOps.cpp:3670
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3673
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3687
Remove duplicated expressions in affine min/max ops.
Definition: AffineOps.cpp:3486
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3489
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
Definition: AffineOps.cpp:3529
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3532
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:283
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.