MLIR  21.0.0git
AffineOps.cpp
Go to the documentation of this file.
1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/IntegerSet.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/OpDefinition.h"
19 #include "mlir/IR/PatternMatch.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/ScopeExit.h"
25 #include "llvm/ADT/SmallBitVector.h"
26 #include "llvm/ADT/SmallVectorExtras.h"
27 #include "llvm/ADT/TypeSwitch.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/MathExtras.h"
30 #include <numeric>
31 #include <optional>
32 
33 using namespace mlir;
34 using namespace mlir::affine;
35 
36 using llvm::divideCeilSigned;
37 using llvm::divideFloorSigned;
38 using llvm::mod;
39 
40 #define DEBUG_TYPE "affine-ops"
41 
42 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
43 
44 /// A utility function to check if a value is defined at the top level of
45 /// `region` or is an argument of `region`. A value of index type defined at the
46 /// top level of a `AffineScope` region is always a valid symbol for all
47 /// uses in that region.
49  if (auto arg = llvm::dyn_cast<BlockArgument>(value))
50  return arg.getParentRegion() == region;
51  return value.getDefiningOp()->getParentRegion() == region;
52 }
53 
54 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
55 /// region remains legal if the operation that uses it is inlined into `dest`
56 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
57 /// `isValidSymbol`, depending on the value being required to remain a valid
58 /// dimension or symbol.
59 static bool
61  const IRMapping &mapping,
62  function_ref<bool(Value, Region *)> legalityCheck) {
63  // If the value is a valid dimension for any other reason than being
64  // a top-level value, it will remain valid: constants get inlined
65  // with the function, transitive affine applies also get inlined and
66  // will be checked themselves, etc.
67  if (!isTopLevelValue(value, src))
68  return true;
69 
70  // If it's a top-level value because it's a block operand, i.e. a
71  // function argument, check whether the value replacing it after
72  // inlining is a valid dimension in the new region.
73  if (llvm::isa<BlockArgument>(value))
74  return legalityCheck(mapping.lookup(value), dest);
75 
76  // If it's a top-level value because it's defined in the region,
77  // it can only be inlined if the defining op is a constant or a
78  // `dim`, which can appear anywhere and be valid, since the defining
79  // op won't be top-level anymore after inlining.
80  Attribute operandCst;
81  bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp());
82  return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
83  isDimLikeOp;
84 }
85 
86 /// Checks if all values known to be legal affine dimensions or symbols in `src`
87 /// remain so if their respective users are inlined into `dest`.
88 static bool
90  const IRMapping &mapping,
91  function_ref<bool(Value, Region *)> legalityCheck) {
92  return llvm::all_of(values, [&](Value v) {
93  return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
94  });
95 }
96 
97 /// Checks if an affine read or write operation remains legal after inlining
98 /// from `src` to `dest`.
99 template <typename OpTy>
100 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
101  const IRMapping &mapping) {
102  static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
103  AffineWriteOpInterface>::value,
104  "only ops with affine read/write interface are supported");
105 
106  AffineMap map = op.getAffineMap();
107  ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
108  ValueRange symbolOperands =
109  op.getMapOperands().take_back(map.getNumSymbols());
111  dimOperands, src, dest, mapping,
112  static_cast<bool (*)(Value, Region *)>(isValidDim)))
113  return false;
115  symbolOperands, src, dest, mapping,
116  static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
117  return false;
118  return true;
119 }
120 
121 /// Checks if an affine apply operation remains legal after inlining from `src`
122 /// to `dest`.
123 // Use "unused attribute" marker to silence clang-tidy warning stemming from
124 // the inability to see through "llvm::TypeSwitch".
125 template <>
126 bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op,
127  Region *src, Region *dest,
128  const IRMapping &mapping) {
129  // If it's a valid dimension, we need to check that it remains so.
130  if (isValidDim(op.getResult(), src))
132  op.getMapOperands(), src, dest, mapping,
133  static_cast<bool (*)(Value, Region *)>(isValidDim));
134 
135  // Otherwise it must be a valid symbol, check that it remains so.
137  op.getMapOperands(), src, dest, mapping,
138  static_cast<bool (*)(Value, Region *)>(isValidSymbol));
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // AffineDialect Interfaces
143 //===----------------------------------------------------------------------===//
144 
145 namespace {
146 /// This class defines the interface for handling inlining with affine
147 /// operations.
148 struct AffineInlinerInterface : public DialectInlinerInterface {
150 
151  //===--------------------------------------------------------------------===//
152  // Analysis Hooks
153  //===--------------------------------------------------------------------===//
154 
155  /// Returns true if the given region 'src' can be inlined into the region
156  /// 'dest' that is attached to an operation registered to the current dialect.
157  /// 'wouldBeCloned' is set if the region is cloned into its new location
158  /// rather than moved, indicating there may be other users.
159  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
160  IRMapping &valueMapping) const final {
161  // We can inline into affine loops and conditionals if this doesn't break
162  // affine value categorization rules.
163  Operation *destOp = dest->getParentOp();
164  if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
165  return false;
166 
167  // Multi-block regions cannot be inlined into affine constructs, all of
168  // which require single-block regions.
169  if (!llvm::hasSingleElement(*src))
170  return false;
171 
172  // Side-effecting operations that the affine dialect cannot understand
173  // should not be inlined.
174  Block &srcBlock = src->front();
175  for (Operation &op : srcBlock) {
176  // Ops with no side effects are fine,
177  if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
178  if (iface.hasNoEffect())
179  continue;
180  }
181 
182  // Assuming the inlined region is valid, we only need to check if the
183  // inlining would change it.
184  bool remainsValid =
186  .Case<AffineApplyOp, AffineReadOpInterface,
187  AffineWriteOpInterface>([&](auto op) {
188  return remainsLegalAfterInline(op, src, dest, valueMapping);
189  })
190  .Default([](Operation *) {
191  // Conservatively disallow inlining ops we cannot reason about.
192  return false;
193  });
194 
195  if (!remainsValid)
196  return false;
197  }
198 
199  return true;
200  }
201 
202  /// Returns true if the given operation 'op', that is registered to this
203  /// dialect, can be inlined into the given region, false otherwise.
204  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
205  IRMapping &valueMapping) const final {
206  // Always allow inlining affine operations into a region that is marked as
207  // affine scope, or into affine loops and conditionals. There are some edge
208  // cases when inlining *into* affine structures, but that is handled in the
209  // other 'isLegalToInline' hook above.
210  Operation *parentOp = region->getParentOp();
211  return parentOp->hasTrait<OpTrait::AffineScope>() ||
212  isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
213  }
214 
215  /// Affine regions should be analyzed recursively.
216  bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
217 };
218 } // namespace
219 
220 //===----------------------------------------------------------------------===//
221 // AffineDialect
222 //===----------------------------------------------------------------------===//
223 
224 void AffineDialect::initialize() {
225  addOperations<AffineDmaStartOp, AffineDmaWaitOp,
226 #define GET_OP_LIST
227 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
228  >();
229  addInterfaces<AffineInlinerInterface>();
230  declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
231  AffineMinOp>();
232 }
233 
234 /// Materialize a single constant operation from a given attribute value with
235 /// the desired resultant type.
237  Attribute value, Type type,
238  Location loc) {
239  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
240  return builder.create<ub::PoisonOp>(loc, type, poison);
241  return arith::ConstantOp::materialize(builder, value, type, loc);
242 }
243 
244 /// A utility function to check if a value is defined at the top level of an
245 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
246 /// conservatively assume it is not top-level. A value of index type defined at
247 /// the top level is always a valid symbol.
249  if (auto arg = llvm::dyn_cast<BlockArgument>(value)) {
250  // The block owning the argument may be unlinked, e.g. when the surrounding
251  // region has not yet been attached to an Op, at which point the parent Op
252  // is null.
253  Operation *parentOp = arg.getOwner()->getParentOp();
254  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
255  }
256  // The defining Op may live in an unlinked block so its parent Op may be null.
257  Operation *parentOp = value.getDefiningOp()->getParentOp();
258  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
259 }
260 
261 /// Returns the closest region enclosing `op` that is held by an operation with
262 /// trait `AffineScope`; `nullptr` if there is no such region.
264  auto *curOp = op;
265  while (auto *parentOp = curOp->getParentOp()) {
266  if (parentOp->hasTrait<OpTrait::AffineScope>())
267  return curOp->getParentRegion();
268  curOp = parentOp;
269  }
270  return nullptr;
271 }
272 
274  Operation *curOp = op;
275  while (auto *parentOp = curOp->getParentOp()) {
276  if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))
277  return curOp->getParentRegion();
278  curOp = parentOp;
279  }
280  return nullptr;
281 }
282 
283 // A Value can be used as a dimension id iff it meets one of the following
284 // conditions:
285 // *) It is valid as a symbol.
286 // *) It is an induction variable.
287 // *) It is the result of affine apply operation with dimension id arguments.
289  // The value must be an index type.
290  if (!value.getType().isIndex())
291  return false;
292 
293  if (auto *defOp = value.getDefiningOp())
294  return isValidDim(value, getAffineScope(defOp));
295 
296  // This value has to be a block argument for an op that has the
297  // `AffineScope` trait or for an affine.for or affine.parallel.
298  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
299  return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
300  isa<AffineForOp, AffineParallelOp>(parentOp));
301 }
302 
303 // Value can be used as a dimension id iff it meets one of the following
304 // conditions:
305 // *) It is valid as a symbol.
306 // *) It is an induction variable.
307 // *) It is the result of an affine apply operation with dimension id operands.
308 bool mlir::affine::isValidDim(Value value, Region *region) {
309  // The value must be an index type.
310  if (!value.getType().isIndex())
311  return false;
312 
313  // All valid symbols are okay.
314  if (isValidSymbol(value, region))
315  return true;
316 
317  auto *op = value.getDefiningOp();
318  if (!op) {
319  // This value has to be a block argument for an affine.for or an
320  // affine.parallel.
321  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
322  return isa<AffineForOp, AffineParallelOp>(parentOp);
323  }
324 
325  // Affine apply operation is ok if all of its operands are ok.
326  if (auto applyOp = dyn_cast<AffineApplyOp>(op))
327  return applyOp.isValidDim(region);
328  // The dim op is okay if its operand memref/tensor is defined at the top
329  // level.
330  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
331  return isTopLevelValue(dimOp.getShapedValue());
332  return false;
333 }
334 
335 /// Returns true if the 'index' dimension of the `memref` defined by
336 /// `memrefDefOp` is a statically shaped one or defined using a valid symbol
337 /// for `region`.
338 template <typename AnyMemRefDefOp>
339 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
340  Region *region) {
341  MemRefType memRefType = memrefDefOp.getType();
342 
343  // Dimension index is out of bounds.
344  if (index >= memRefType.getRank()) {
345  return false;
346  }
347 
348  // Statically shaped.
349  if (!memRefType.isDynamicDim(index))
350  return true;
351  // Get the position of the dimension among dynamic dimensions;
352  unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
353  return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
354  region);
355 }
356 
357 /// Returns true if the result of the dim op is a valid symbol for `region`.
358 static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
359  // The dim op is okay if its source is defined at the top level.
360  if (isTopLevelValue(dimOp.getShapedValue()))
361  return true;
362 
363  // Conservatively handle remaining BlockArguments as non-valid symbols.
364  // E.g. scf.for iterArgs.
365  if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
366  return false;
367 
368  // The dim op is also okay if its operand memref is a view/subview whose
369  // corresponding size is a valid symbol.
370  std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
371 
372  // Be conservative if we can't understand the dimension.
373  if (!index.has_value())
374  return false;
375 
376  // Skip over all memref.cast ops (if any).
377  Operation *op = dimOp.getShapedValue().getDefiningOp();
378  while (auto castOp = dyn_cast<memref::CastOp>(op)) {
379  // Bail on unranked memrefs.
380  if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
381  return false;
382  op = castOp.getSource().getDefiningOp();
383  if (!op)
384  return false;
385  }
386 
387  int64_t i = index.value();
389  .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
390  [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
391  .Default([](Operation *) { return false; });
392 }
393 
394 // A value can be used as a symbol (at all its use sites) iff it meets one of
395 // the following conditions:
396 // *) It is a constant.
397 // *) Its defining op or block arg appearance is immediately enclosed by an op
398 // with `AffineScope` trait.
399 // *) It is the result of an affine.apply operation with symbol operands.
400 // *) It is a result of the dim op on a memref whose corresponding size is a
401 // valid symbol.
403  if (!value)
404  return false;
405 
406  // The value must be an index type.
407  if (!value.getType().isIndex())
408  return false;
409 
410  // Check that the value is a top level value.
411  if (isTopLevelValue(value))
412  return true;
413 
414  if (auto *defOp = value.getDefiningOp())
415  return isValidSymbol(value, getAffineScope(defOp));
416 
417  return false;
418 }
419 
420 /// A value can be used as a symbol for `region` iff it meets one of the
421 /// following conditions:
422 /// *) It is a constant.
423 /// *) It is a result of a `Pure` operation whose operands are valid symbolic
424 /// *) identifiers.
425 /// *) It is a result of the dim op on a memref whose corresponding size is
426 /// a valid symbol.
427 /// *) It is defined at the top level of 'region' or is its argument.
428 /// *) It dominates `region`'s parent op.
429 /// If `region` is null, conservatively assume the symbol definition scope does
430 /// not exist and only accept the values that would be symbols regardless of
431 /// the surrounding region structure, i.e. the first three cases above.
433  // The value must be an index type.
434  if (!value.getType().isIndex())
435  return false;
436 
437  // A top-level value is a valid symbol.
438  if (region && ::isTopLevelValue(value, region))
439  return true;
440 
441  auto *defOp = value.getDefiningOp();
442  if (!defOp) {
443  // A block argument that is not a top-level value is a valid symbol if it
444  // dominates region's parent op.
445  Operation *regionOp = region ? region->getParentOp() : nullptr;
446  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
447  if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
448  return isValidSymbol(value, parentOpRegion);
449  return false;
450  }
451 
452  // Constant operation is ok.
453  Attribute operandCst;
454  if (matchPattern(defOp, m_Constant(&operandCst)))
455  return true;
456 
457  // `Pure` operation that whose operands are valid symbolic identifiers.
458  if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
459  return affine::isValidSymbol(operand, region);
460  })) {
461  return true;
462  }
463 
464  // Dim op results could be valid symbols at any level.
465  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
466  return isDimOpValidSymbol(dimOp, region);
467 
468  // Check for values dominating `region`'s parent op.
469  Operation *regionOp = region ? region->getParentOp() : nullptr;
470  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
471  if (auto *parentRegion = region->getParentOp()->getParentRegion())
472  return isValidSymbol(value, parentRegion);
473 
474  return false;
475 }
476 
477 // Returns true if 'value' is a valid index to an affine operation (e.g.
478 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
479 // `region` provides the polyhedral symbol scope. Returns false otherwise.
480 static bool isValidAffineIndexOperand(Value value, Region *region) {
481  return isValidDim(value, region) || isValidSymbol(value, region);
482 }
483 
484 /// Prints dimension and symbol list.
487  unsigned numDims, OpAsmPrinter &printer) {
488  OperandRange operands(begin, end);
489  printer << '(' << operands.take_front(numDims) << ')';
490  if (operands.size() > numDims)
491  printer << '[' << operands.drop_front(numDims) << ']';
492 }
493 
494 /// Parses dimension and symbol list and returns true if parsing failed.
496  OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
498  if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
499  return failure();
500  // Store number of dimensions for validation by caller.
501  numDims = opInfos.size();
502 
503  // Parse the optional symbol operands.
504  auto indexTy = parser.getBuilder().getIndexType();
505  return failure(parser.parseOperandList(
507  parser.resolveOperands(opInfos, indexTy, operands));
508 }
509 
510 /// Utility function to verify that a set of operands are valid dimension and
511 /// symbol identifiers. The operands should be laid out such that the dimension
512 /// operands are before the symbol operands. This function returns failure if
513 /// there was an invalid operand. An operation is provided to emit any necessary
514 /// errors.
515 template <typename OpTy>
516 static LogicalResult
518  unsigned numDims) {
519  unsigned opIt = 0;
520  for (auto operand : operands) {
521  if (opIt++ < numDims) {
522  if (!isValidDim(operand, getAffineScope(op)))
523  return op.emitOpError("operand cannot be used as a dimension id");
524  } else if (!isValidSymbol(operand, getAffineScope(op))) {
525  return op.emitOpError("operand cannot be used as a symbol");
526  }
527  }
528  return success();
529 }
530 
531 //===----------------------------------------------------------------------===//
532 // AffineApplyOp
533 //===----------------------------------------------------------------------===//
534 
535 AffineValueMap AffineApplyOp::getAffineValueMap() {
536  return AffineValueMap(getAffineMap(), getOperands(), getResult());
537 }
538 
539 ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
540  auto &builder = parser.getBuilder();
541  auto indexTy = builder.getIndexType();
542 
543  AffineMapAttr mapAttr;
544  unsigned numDims;
545  if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
546  parseDimAndSymbolList(parser, result.operands, numDims) ||
547  parser.parseOptionalAttrDict(result.attributes))
548  return failure();
549  auto map = mapAttr.getValue();
550 
551  if (map.getNumDims() != numDims ||
552  numDims + map.getNumSymbols() != result.operands.size()) {
553  return parser.emitError(parser.getNameLoc(),
554  "dimension or symbol index mismatch");
555  }
556 
557  result.types.append(map.getNumResults(), indexTy);
558  return success();
559 }
560 
562  p << " " << getMapAttr();
563  printDimAndSymbolList(operand_begin(), operand_end(),
564  getAffineMap().getNumDims(), p);
565  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
566 }
567 
568 LogicalResult AffineApplyOp::verify() {
569  // Check input and output dimensions match.
570  AffineMap affineMap = getMap();
571 
572  // Verify that operand count matches affine map dimension and symbol count.
573  if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
574  return emitOpError(
575  "operand count and affine map dimension and symbol count must match");
576 
577  // Verify that the map only produces one result.
578  if (affineMap.getNumResults() != 1)
579  return emitOpError("mapping must produce one value");
580 
581  return success();
582 }
583 
584 // The result of the affine apply operation can be used as a dimension id if all
585 // its operands are valid dimension ids.
587  return llvm::all_of(getOperands(),
588  [](Value op) { return affine::isValidDim(op); });
589 }
590 
591 // The result of the affine apply operation can be used as a dimension id if all
592 // its operands are valid dimension ids with the parent operation of `region`
593 // defining the polyhedral scope for symbols.
594 bool AffineApplyOp::isValidDim(Region *region) {
595  return llvm::all_of(getOperands(),
596  [&](Value op) { return ::isValidDim(op, region); });
597 }
598 
599 // The result of the affine apply operation can be used as a symbol if all its
600 // operands are symbols.
602  return llvm::all_of(getOperands(),
603  [](Value op) { return affine::isValidSymbol(op); });
604 }
605 
606 // The result of the affine apply operation can be used as a symbol in `region`
607 // if all its operands are symbols in `region`.
608 bool AffineApplyOp::isValidSymbol(Region *region) {
609  return llvm::all_of(getOperands(), [&](Value operand) {
610  return affine::isValidSymbol(operand, region);
611  });
612 }
613 
614 OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
615  auto map = getAffineMap();
616 
617  // Fold dims and symbols to existing values.
618  auto expr = map.getResult(0);
619  if (auto dim = dyn_cast<AffineDimExpr>(expr))
620  return getOperand(dim.getPosition());
621  if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
622  return getOperand(map.getNumDims() + sym.getPosition());
623 
624  // Otherwise, default to folding the map.
626  bool hasPoison = false;
627  auto foldResult =
628  map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
629  if (hasPoison)
631  if (failed(foldResult))
632  return {};
633  return result[0];
634 }
635 
636 /// Returns the largest known divisor of `e`. Exploits information from the
637 /// values in `operands`.
638 static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
639  // This method isn't aware of `operands`.
640  int64_t div = e.getLargestKnownDivisor();
641 
642  // We now make use of operands for the case `e` is a dim expression.
643  // TODO: More powerful simplification would have to modify
644  // getLargestKnownDivisor to take `operands` and exploit that information as
645  // well for dim/sym expressions, but in that case, getLargestKnownDivisor
646  // can't be part of the IR library but of the `Analysis` library. The IR
647  // library can only really depend on simple O(1) checks.
648  auto dimExpr = dyn_cast<AffineDimExpr>(e);
649  // If it's not a dim expr, `div` is the best we have.
650  if (!dimExpr)
651  return div;
652 
653  // We simply exploit information from loop IVs.
654  // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
655  // desired simplifications are expected to be part of other
656  // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
657  // LoopAnalysis library.
658  Value operand = operands[dimExpr.getPosition()];
659  int64_t operandDivisor = 1;
660  // TODO: With the right accessors, this can be extended to
661  // LoopLikeOpInterface.
662  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
663  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
664  operandDivisor = forOp.getStepAsInt();
665  } else {
666  uint64_t lbLargestKnownDivisor =
667  forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
668  operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
669  }
670  }
671  return operandDivisor;
672 }
673 
674 /// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
675 /// being an affine dim expression or a constant.
677  int64_t k) {
678  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
679  int64_t constVal = constExpr.getValue();
680  return constVal >= 0 && constVal < k;
681  }
682  auto dimExpr = dyn_cast<AffineDimExpr>(e);
683  if (!dimExpr)
684  return false;
685  Value operand = operands[dimExpr.getPosition()];
686  // TODO: With the right accessors, this can be extended to
687  // LoopLikeOpInterface.
688  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
689  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
690  forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
691  return true;
692  }
693  }
694 
695  // We don't consider other cases like `operand` being defined by a constant or
696  // an affine.apply op since such cases will already be handled by other
697  // patterns and propagation of loop IVs or constant would happen.
698  return false;
699 }
700 
701 /// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
702 /// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
703 /// expression is in that form.
704 static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
705  AffineExpr &quotientTimesDiv, AffineExpr &rem) {
706  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
707  if (!bin || bin.getKind() != AffineExprKind::Add)
708  return false;
709 
710  AffineExpr llhs = bin.getLHS();
711  AffineExpr rlhs = bin.getRHS();
712  div = getLargestKnownDivisor(llhs, operands);
713  if (isNonNegativeBoundedBy(rlhs, operands, div)) {
714  quotientTimesDiv = llhs;
715  rem = rlhs;
716  return true;
717  }
718  div = getLargestKnownDivisor(rlhs, operands);
719  if (isNonNegativeBoundedBy(llhs, operands, div)) {
720  quotientTimesDiv = rlhs;
721  rem = llhs;
722  return true;
723  }
724  return false;
725 }
726 
727 /// Gets the constant lower bound on an `iv`.
728 static std::optional<int64_t> getLowerBound(Value iv) {
729  AffineForOp forOp = getForInductionVarOwner(iv);
730  if (forOp && forOp.hasConstantLowerBound())
731  return forOp.getConstantLowerBound();
732  return std::nullopt;
733 }
734 
735 /// Gets the constant upper bound on an affine.for `iv`.
736 static std::optional<int64_t> getUpperBound(Value iv) {
737  AffineForOp forOp = getForInductionVarOwner(iv);
738  if (!forOp || !forOp.hasConstantUpperBound())
739  return std::nullopt;
740 
741  // If its lower bound is also known, we can get a more precise bound
742  // whenever the step is not one.
743  if (forOp.hasConstantLowerBound()) {
744  return forOp.getConstantUpperBound() - 1 -
745  (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
746  forOp.getStepAsInt();
747  }
748  return forOp.getConstantUpperBound() - 1;
749 }
750 
751 /// Determine a constant upper bound for `expr` if one exists while exploiting
752 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
753 /// is guaranteed to be less than or equal to it.
754 static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
755  unsigned numSymbols,
756  ArrayRef<Value> operands) {
757  // Get the constant lower or upper bounds on the operands.
758  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
759  constLowerBounds.reserve(operands.size());
760  constUpperBounds.reserve(operands.size());
761  for (Value operand : operands) {
762  constLowerBounds.push_back(getLowerBound(operand));
763  constUpperBounds.push_back(getUpperBound(operand));
764  }
765 
766  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
767  return constExpr.getValue();
768 
769  return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
770  constUpperBounds,
771  /*isUpper=*/true);
772 }
773 
774 /// Determine a constant lower bound for `expr` if one exists while exploiting
775 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
776 /// is guaranteed to be less than or equal to it.
777 static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
778  unsigned numSymbols,
779  ArrayRef<Value> operands) {
780  // Get the constant lower or upper bounds on the operands.
781  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
782  constLowerBounds.reserve(operands.size());
783  constUpperBounds.reserve(operands.size());
784  for (Value operand : operands) {
785  constLowerBounds.push_back(getLowerBound(operand));
786  constUpperBounds.push_back(getUpperBound(operand));
787  }
788 
789  std::optional<int64_t> lowerBound;
790  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
791  lowerBound = constExpr.getValue();
792  } else {
793  lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
794  constLowerBounds, constUpperBounds,
795  /*isUpper=*/false);
796  }
797  return lowerBound;
798 }
799 
800 /// Simplify `expr` while exploiting information from the values in `operands`.
801 static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
802  unsigned numSymbols,
803  ArrayRef<Value> operands) {
804  // We do this only for certain floordiv/mod expressions.
805  auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
806  if (!binExpr)
807  return;
808 
809  // Simplify the child expressions first.
810  AffineExpr lhs = binExpr.getLHS();
811  AffineExpr rhs = binExpr.getRHS();
812  simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
813  simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
814  expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
815 
816  binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
817  if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
818  expr.getKind() != AffineExprKind::CeilDiv &&
819  expr.getKind() != AffineExprKind::Mod)) {
820  return;
821  }
822 
823  // The `lhs` and `rhs` may be different post construction of simplified expr.
824  lhs = binExpr.getLHS();
825  rhs = binExpr.getRHS();
826  auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
827  if (!rhsConst)
828  return;
829 
830  int64_t rhsConstVal = rhsConst.getValue();
831  // Undefined exprsessions aren't touched; IR can still be valid with them.
832  if (rhsConstVal <= 0)
833  return;
834 
835  // Exploit constant lower/upper bounds to simplify a floordiv or mod.
836  MLIRContext *context = expr.getContext();
837  std::optional<int64_t> lhsLbConst =
838  getLowerBound(lhs, numDims, numSymbols, operands);
839  std::optional<int64_t> lhsUbConst =
840  getUpperBound(lhs, numDims, numSymbols, operands);
841  if (lhsLbConst && lhsUbConst) {
842  int64_t lhsLbConstVal = *lhsLbConst;
843  int64_t lhsUbConstVal = *lhsUbConst;
844  // lhs floordiv c is a single value lhs is bounded in a range `c` that has
845  // the same quotient.
846  if (binExpr.getKind() == AffineExprKind::FloorDiv &&
847  divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
848  divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
849  expr = getAffineConstantExpr(
850  divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
851  return;
852  }
853  // lhs ceildiv c is a single value if the entire range has the same ceil
854  // quotient.
855  if (binExpr.getKind() == AffineExprKind::CeilDiv &&
856  divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
857  divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
858  expr = getAffineConstantExpr(divideCeilSigned(lhsLbConstVal, rhsConstVal),
859  context);
860  return;
861  }
862  // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
863  if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
864  lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
865  expr = lhs;
866  return;
867  }
868  }
869 
870  // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
871  // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
872  // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
873  // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
874  AffineExpr quotientTimesDiv, rem;
875  int64_t divisor;
876  if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
877  if (rhsConstVal % divisor == 0 &&
878  binExpr.getKind() == AffineExprKind::FloorDiv) {
879  expr = quotientTimesDiv.floorDiv(rhsConst);
880  } else if (divisor % rhsConstVal == 0 &&
881  binExpr.getKind() == AffineExprKind::Mod) {
882  expr = rem % rhsConst;
883  }
884  return;
885  }
886 
887  // Handle the simple case when the LHS expression can be either upper
888  // bounded or is a known multiple of RHS constant.
889  // lhs floordiv c -> 0 if 0 <= lhs < c,
890  // lhs mod c -> 0 if lhs % c = 0.
891  if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) &&
892  binExpr.getKind() == AffineExprKind::FloorDiv) ||
893  (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 &&
894  binExpr.getKind() == AffineExprKind::Mod)) {
895  expr = getAffineConstantExpr(0, expr.getContext());
896  }
897 }
898 
899 /// Simplify the expressions in `map` while making use of lower or upper bounds
900 /// of its operands. If `isMax` is true, the map is to be treated as a max of
901 /// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
902 /// d1) can be simplified to (8) if the operands are respectively lower bounded
903 /// by 2 and 0 (the second expression can't be lower than 8).
905  ArrayRef<Value> operands,
906  bool isMax) {
907  // Can't simplify.
908  if (operands.empty())
909  return;
910 
911  // Get the upper or lower bound on an affine.for op IV using its range.
912  // Get the constant lower or upper bounds on the operands.
913  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
914  constLowerBounds.reserve(operands.size());
915  constUpperBounds.reserve(operands.size());
916  for (Value operand : operands) {
917  constLowerBounds.push_back(getLowerBound(operand));
918  constUpperBounds.push_back(getUpperBound(operand));
919  }
920 
921  // We will compute the lower and upper bounds on each of the expressions
922  // Then, we will check (depending on max or min) as to whether a specific
923  // bound is redundant by checking if its highest (in case of max) and its
924  // lowest (in the case of min) value is already lower than (or higher than)
925  // the lower bound (or upper bound in the case of min) of another bound.
926  SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
927  lowerBounds.reserve(map.getNumResults());
928  upperBounds.reserve(map.getNumResults());
929  for (AffineExpr e : map.getResults()) {
930  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
931  lowerBounds.push_back(constExpr.getValue());
932  upperBounds.push_back(constExpr.getValue());
933  } else {
934  lowerBounds.push_back(
936  constLowerBounds, constUpperBounds,
937  /*isUpper=*/false));
938  upperBounds.push_back(
940  constLowerBounds, constUpperBounds,
941  /*isUpper=*/true));
942  }
943  }
944 
945  // Collect expressions that are not redundant.
946  SmallVector<AffineExpr, 4> irredundantExprs;
947  for (auto exprEn : llvm::enumerate(map.getResults())) {
948  AffineExpr e = exprEn.value();
949  unsigned i = exprEn.index();
950  // Some expressions can be turned into constants.
951  if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
952  e = getAffineConstantExpr(*lowerBounds[i], e.getContext());
953 
954  // Check if the expression is redundant.
955  if (isMax) {
956  if (!upperBounds[i]) {
957  irredundantExprs.push_back(e);
958  continue;
959  }
960  // If there exists another expression such that its lower bound is greater
961  // than this expression's upper bound, it's redundant.
962  if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
963  auto otherLowerBound = en.value();
964  unsigned pos = en.index();
965  if (pos == i || !otherLowerBound)
966  return false;
967  if (*otherLowerBound > *upperBounds[i])
968  return true;
969  if (*otherLowerBound < *upperBounds[i])
970  return false;
971  // Equality case. When both expressions are considered redundant, we
972  // don't want to get both of them. We keep the one that appears
973  // first.
974  if (upperBounds[pos] && lowerBounds[i] &&
975  lowerBounds[i] == upperBounds[i] &&
976  otherLowerBound == *upperBounds[pos] && i < pos)
977  return false;
978  return true;
979  }))
980  irredundantExprs.push_back(e);
981  } else {
982  if (!lowerBounds[i]) {
983  irredundantExprs.push_back(e);
984  continue;
985  }
986  // Likewise for the `min` case. Use the complement of the condition above.
987  if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
988  auto otherUpperBound = en.value();
989  unsigned pos = en.index();
990  if (pos == i || !otherUpperBound)
991  return false;
992  if (*otherUpperBound < *lowerBounds[i])
993  return true;
994  if (*otherUpperBound > *lowerBounds[i])
995  return false;
996  if (lowerBounds[pos] && upperBounds[i] &&
997  lowerBounds[i] == upperBounds[i] &&
998  otherUpperBound == lowerBounds[pos] && i < pos)
999  return false;
1000  return true;
1001  }))
1002  irredundantExprs.push_back(e);
1003  }
1004  }
1005 
1006  // Create the map without the redundant expressions.
1007  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs,
1008  map.getContext());
1009 }
1010 
1011 /// Simplify the map while exploiting information on the values in `operands`.
1012 // Use "unused attribute" marker to silence warning stemming from the inability
1013 // to see through the template expansion.
1014 static void LLVM_ATTRIBUTE_UNUSED
1016  assert(map.getNumInputs() == operands.size() && "invalid operands for map");
1017  SmallVector<AffineExpr> newResults;
1018  newResults.reserve(map.getNumResults());
1019  for (AffineExpr expr : map.getResults()) {
1021  operands);
1022  newResults.push_back(expr);
1023  }
1024  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
1025  map.getContext());
1026 }
1027 
1028 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
1029 /// defining AffineApplyOp expression and operands.
1030 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
1031 /// When `dimOrSymbolPosition >= dims.size()`,
1032 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
1033 /// Mutate `map`,`dims` and `syms` in place as follows:
1034 /// 1. `dims` and `syms` are only appended to.
1035 /// 2. `map` dim and symbols are gradually shifted to higher positions.
1036 /// 3. Old `dim` and `sym` entries are replaced by nullptr
1037 /// This avoids the need for any bookkeeping.
1038 static LogicalResult replaceDimOrSym(AffineMap *map,
1039  unsigned dimOrSymbolPosition,
1040  SmallVectorImpl<Value> &dims,
1041  SmallVectorImpl<Value> &syms) {
1042  MLIRContext *ctx = map->getContext();
1043  bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1044  unsigned pos = isDimReplacement ? dimOrSymbolPosition
1045  : dimOrSymbolPosition - dims.size();
1046  Value &v = isDimReplacement ? dims[pos] : syms[pos];
1047  if (!v)
1048  return failure();
1049 
1050  auto affineApply = v.getDefiningOp<AffineApplyOp>();
1051  if (!affineApply)
1052  return failure();
1053 
1054  // At this point we will perform a replacement of `v`, set the entry in `dim`
1055  // or `sym` to nullptr immediately.
1056  v = nullptr;
1057 
1058  // Compute the map, dims and symbols coming from the AffineApplyOp.
1059  AffineMap composeMap = affineApply.getAffineMap();
1060  assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1061  SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1062  affineApply.getMapOperands().end());
1063  // Canonicalize the map to promote dims to symbols when possible. This is to
1064  // avoid generating invalid maps.
1065  canonicalizeMapAndOperands(&composeMap, &composeOperands);
1066  AffineExpr replacementExpr =
1067  composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
1068  ValueRange composeDims =
1069  ArrayRef<Value>(composeOperands).take_front(composeMap.getNumDims());
1070  ValueRange composeSyms =
1071  ArrayRef<Value>(composeOperands).take_back(composeMap.getNumSymbols());
1072  AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
1073  : getAffineSymbolExpr(pos, ctx);
1074 
1075  // Append the dims and symbols where relevant and perform the replacement.
1076  dims.append(composeDims.begin(), composeDims.end());
1077  syms.append(composeSyms.begin(), composeSyms.end());
1078  *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());
1079 
1080  return success();
1081 }
1082 
1083 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1084 /// iteratively. Perform canonicalization of map and operands as well as
1085 /// AffineMap simplification. `map` and `operands` are mutated in place.
1087  SmallVectorImpl<Value> *operands) {
1088  if (map->getNumResults() == 0) {
1089  canonicalizeMapAndOperands(map, operands);
1090  *map = simplifyAffineMap(*map);
1091  return;
1092  }
1093 
1094  MLIRContext *ctx = map->getContext();
1095  SmallVector<Value, 4> dims(operands->begin(),
1096  operands->begin() + map->getNumDims());
1097  SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1098  operands->end());
1099 
1100  // Iterate over dims and symbols coming from AffineApplyOp and replace until
1101  // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1102  // and `syms` can only increase by construction.
1103  // The implementation uses a `while` loop to support the case of symbols
1104  // that may be constructed from dims ;this may be overkill.
1105  while (true) {
1106  bool changed = false;
1107  for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1108  if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
1109  break;
1110  if (!changed)
1111  break;
1112  }
1113 
1114  // Clear operands so we can fill them anew.
1115  operands->clear();
1116 
1117  // At this point we may have introduced null operands, prune them out before
1118  // canonicalizing map and operands.
1119  unsigned nDims = 0, nSyms = 0;
1120  SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1121  dimReplacements.reserve(dims.size());
1122  symReplacements.reserve(syms.size());
1123  for (auto *container : {&dims, &syms}) {
1124  bool isDim = (container == &dims);
1125  auto &repls = isDim ? dimReplacements : symReplacements;
1126  for (const auto &en : llvm::enumerate(*container)) {
1127  Value v = en.value();
1128  if (!v) {
1129  assert(isDim ? !map->isFunctionOfDim(en.index())
1130  : !map->isFunctionOfSymbol(en.index()) &&
1131  "map is function of unexpected expr@pos");
1132  repls.push_back(getAffineConstantExpr(0, ctx));
1133  continue;
1134  }
1135  repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
1136  : getAffineSymbolExpr(nSyms++, ctx));
1137  operands->push_back(v);
1138  }
1139  }
1140  *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
1141  nSyms);
1142 
1143  // Canonicalize and simplify before returning.
1144  canonicalizeMapAndOperands(map, operands);
1145  *map = simplifyAffineMap(*map);
1146 }
1147 
1149  AffineMap *map, SmallVectorImpl<Value> *operands) {
1150  while (llvm::any_of(*operands, [](Value v) {
1151  return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
1152  })) {
1153  composeAffineMapAndOperands(map, operands);
1154  }
1155 }
1156 
1157 AffineApplyOp
1159  ArrayRef<OpFoldResult> operands) {
1160  SmallVector<Value> valueOperands;
1161  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1162  composeAffineMapAndOperands(&map, &valueOperands);
1163  assert(map);
1164  return b.create<AffineApplyOp>(loc, map, valueOperands);
1165 }
1166 
1167 AffineApplyOp
1169  ArrayRef<OpFoldResult> operands) {
1170  return makeComposedAffineApply(
1171  b, loc,
1173  .front(),
1174  operands);
1175 }
1176 
1177 /// Composes the given affine map with the given list of operands, pulling in
1178 /// the maps from any affine.apply operations that supply the operands.
1180  SmallVectorImpl<Value> &operands) {
1181  // Compose and canonicalize each expression in the map individually because
1182  // composition only applies to single-result maps, collecting potentially
1183  // duplicate operands in a single list with shifted dimensions and symbols.
1184  SmallVector<Value> dims, symbols;
1186  for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
1187  SmallVector<Value> submapOperands(operands.begin(), operands.end());
1188  AffineMap submap = map.getSubMap({i});
1189  fullyComposeAffineMapAndOperands(&submap, &submapOperands);
1190  canonicalizeMapAndOperands(&submap, &submapOperands);
1191  unsigned numNewDims = submap.getNumDims();
1192  submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
1193  llvm::append_range(dims,
1194  ArrayRef<Value>(submapOperands).take_front(numNewDims));
1195  llvm::append_range(symbols,
1196  ArrayRef<Value>(submapOperands).drop_front(numNewDims));
1197  exprs.push_back(submap.getResult(0));
1198  }
1199 
1200  // Canonicalize the map created from composed expressions to deduplicate the
1201  // dimension and symbol operands.
1202  operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1203  map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
1204  canonicalizeMapAndOperands(&map, &operands);
1205 }
1206 
1209  AffineMap map,
1210  ArrayRef<OpFoldResult> operands) {
1211  assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1212 
1213  // Create new builder without a listener, so that no notification is
1214  // triggered if the op is folded.
1215  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1216  // workaround is no longer needed.
1217  OpBuilder newBuilder(b.getContext());
1219 
1220  // Create op.
1221  AffineApplyOp applyOp =
1222  makeComposedAffineApply(newBuilder, loc, map, operands);
1223 
1224  // Get constant operands.
1225  SmallVector<Attribute> constOperands(applyOp->getNumOperands());
1226  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1227  matchPattern(applyOp->getOperand(i), m_Constant(&constOperands[i]));
1228 
1229  // Try to fold the operation.
1230  SmallVector<OpFoldResult> foldResults;
1231  if (failed(applyOp->fold(constOperands, foldResults)) ||
1232  foldResults.empty()) {
1233  if (OpBuilder::Listener *listener = b.getListener())
1234  listener->notifyOperationInserted(applyOp, /*previous=*/{});
1235  return applyOp.getResult();
1236  }
1237 
1238  applyOp->erase();
1239  return llvm::getSingleElement(foldResults);
1240 }
1241 
1244  AffineExpr expr,
1245  ArrayRef<OpFoldResult> operands) {
1247  b, loc,
1249  .front(),
1250  operands);
1251 }
1252 
1255  OpBuilder &b, Location loc, AffineMap map,
1256  ArrayRef<OpFoldResult> operands) {
1257  return llvm::map_to_vector(llvm::seq<unsigned>(0, map.getNumResults()),
1258  [&](unsigned i) {
1259  return makeComposedFoldedAffineApply(
1260  b, loc, map.getSubMap({i}), operands);
1261  });
1262 }
1263 
1264 template <typename OpTy>
1266  ArrayRef<OpFoldResult> operands) {
1267  SmallVector<Value> valueOperands;
1268  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1269  composeMultiResultAffineMap(map, valueOperands);
1270  return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands);
1271 }
1272 
1273 AffineMinOp
1275  ArrayRef<OpFoldResult> operands) {
1276  return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1277 }
1278 
1279 template <typename OpTy>
1281  AffineMap map,
1282  ArrayRef<OpFoldResult> operands) {
1283  // Create new builder without a listener, so that no notification is
1284  // triggered if the op is folded.
1285  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1286  // workaround is no longer needed.
1287  OpBuilder newBuilder(b.getContext());
1289 
1290  // Create op.
1291  auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1292 
1293  // Get constant operands.
1294  SmallVector<Attribute> constOperands(minMaxOp->getNumOperands());
1295  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1296  matchPattern(minMaxOp->getOperand(i), m_Constant(&constOperands[i]));
1297 
1298  // Try to fold the operation.
1299  SmallVector<OpFoldResult> foldResults;
1300  if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1301  foldResults.empty()) {
1302  if (OpBuilder::Listener *listener = b.getListener())
1303  listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
1304  return minMaxOp.getResult();
1305  }
1306 
1307  minMaxOp->erase();
1308  return llvm::getSingleElement(foldResults);
1309 }
1310 
1313  AffineMap map,
1314  ArrayRef<OpFoldResult> operands) {
1315  return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1316 }
1317 
1320  AffineMap map,
1321  ArrayRef<OpFoldResult> operands) {
1322  return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1323 }
1324 
1325 // A symbol may appear as a dim in affine.apply operations. This function
1326 // canonicalizes dims that are valid symbols into actual symbols.
1327 template <class MapOrSet>
1328 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1329  SmallVectorImpl<Value> *operands) {
1330  if (!mapOrSet || operands->empty())
1331  return;
1332 
1333  assert(mapOrSet->getNumInputs() == operands->size() &&
1334  "map/set inputs must match number of operands");
1335 
1336  auto *context = mapOrSet->getContext();
1337  SmallVector<Value, 8> resultOperands;
1338  resultOperands.reserve(operands->size());
1339  SmallVector<Value, 8> remappedSymbols;
1340  remappedSymbols.reserve(operands->size());
1341  unsigned nextDim = 0;
1342  unsigned nextSym = 0;
1343  unsigned oldNumSyms = mapOrSet->getNumSymbols();
1344  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1345  for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1346  if (i < mapOrSet->getNumDims()) {
1347  if (isValidSymbol((*operands)[i])) {
1348  // This is a valid symbol that appears as a dim, canonicalize it.
1349  dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1350  remappedSymbols.push_back((*operands)[i]);
1351  } else {
1352  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1353  resultOperands.push_back((*operands)[i]);
1354  }
1355  } else {
1356  resultOperands.push_back((*operands)[i]);
1357  }
1358  }
1359 
1360  resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1361  *operands = resultOperands;
1362  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
1363  oldNumSyms + nextSym);
1364 
1365  assert(mapOrSet->getNumInputs() == operands->size() &&
1366  "map/set inputs must match number of operands");
1367 }
1368 
1369 // Works for either an affine map or an integer set.
1370 template <class MapOrSet>
1371 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1372  SmallVectorImpl<Value> *operands) {
1373  static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1374  "Argument must be either of AffineMap or IntegerSet type");
1375 
1376  if (!mapOrSet || operands->empty())
1377  return;
1378 
1379  assert(mapOrSet->getNumInputs() == operands->size() &&
1380  "map/set inputs must match number of operands");
1381 
1382  canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1383 
1384  // Check to see what dims are used.
1385  llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1386  llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1387  mapOrSet->walkExprs([&](AffineExpr expr) {
1388  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1389  usedDims[dimExpr.getPosition()] = true;
1390  else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1391  usedSyms[symExpr.getPosition()] = true;
1392  });
1393 
1394  auto *context = mapOrSet->getContext();
1395 
1396  SmallVector<Value, 8> resultOperands;
1397  resultOperands.reserve(operands->size());
1398 
1399  llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1400  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1401  unsigned nextDim = 0;
1402  for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1403  if (usedDims[i]) {
1404  // Remap dim positions for duplicate operands.
1405  auto it = seenDims.find((*operands)[i]);
1406  if (it == seenDims.end()) {
1407  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1408  resultOperands.push_back((*operands)[i]);
1409  seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1410  } else {
1411  dimRemapping[i] = it->second;
1412  }
1413  }
1414  }
1415  llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1416  SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1417  unsigned nextSym = 0;
1418  for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1419  if (!usedSyms[i])
1420  continue;
1421  // Handle constant operands (only needed for symbolic operands since
1422  // constant operands in dimensional positions would have already been
1423  // promoted to symbolic positions above).
1424  IntegerAttr operandCst;
1425  if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1426  m_Constant(&operandCst))) {
1427  symRemapping[i] =
1428  getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1429  continue;
1430  }
1431  // Remap symbol positions for duplicate operands.
1432  auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1433  if (it == seenSymbols.end()) {
1434  symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1435  resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1436  seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1437  symRemapping[i]));
1438  } else {
1439  symRemapping[i] = it->second;
1440  }
1441  }
1442  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1443  nextDim, nextSym);
1444  *operands = resultOperands;
1445 }
1446 
1448  AffineMap *map, SmallVectorImpl<Value> *operands) {
1449  canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1450 }
1451 
1453  IntegerSet *set, SmallVectorImpl<Value> *operands) {
1454  canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1455 }
1456 
1457 namespace {
1458 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1459 /// maps that supply results into them.
1460 ///
1461 template <typename AffineOpTy>
1462 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1464 
1465  /// Replace the affine op with another instance of it with the supplied
1466  /// map and mapOperands.
1467  void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1468  AffineMap map, ArrayRef<Value> mapOperands) const;
1469 
1470  LogicalResult matchAndRewrite(AffineOpTy affineOp,
1471  PatternRewriter &rewriter) const override {
1472  static_assert(
1473  llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1474  AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1475  AffineVectorStoreOp, AffineVectorLoadOp>::value,
1476  "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1477  "expected");
1478  auto map = affineOp.getAffineMap();
1479  AffineMap oldMap = map;
1480  auto oldOperands = affineOp.getMapOperands();
1481  SmallVector<Value, 8> resultOperands(oldOperands);
1482  composeAffineMapAndOperands(&map, &resultOperands);
1483  canonicalizeMapAndOperands(&map, &resultOperands);
1484  simplifyMapWithOperands(map, resultOperands);
1485  if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1486  resultOperands.begin()))
1487  return failure();
1488 
1489  replaceAffineOp(rewriter, affineOp, map, resultOperands);
1490  return success();
1491  }
1492 };
1493 
1494 // Specialize the template to account for the different build signatures for
1495 // affine load, store, and apply ops.
1496 template <>
1497 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1498  PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1499  ArrayRef<Value> mapOperands) const {
1500  rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1501  mapOperands);
1502 }
1503 template <>
1504 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1505  PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1506  ArrayRef<Value> mapOperands) const {
1507  rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1508  prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1509  prefetch.getLocalityHint(), prefetch.getIsDataCache());
1510 }
1511 template <>
1512 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1513  PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1514  ArrayRef<Value> mapOperands) const {
1515  rewriter.replaceOpWithNewOp<AffineStoreOp>(
1516  store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1517 }
1518 template <>
1519 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1520  PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1521  ArrayRef<Value> mapOperands) const {
1522  rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1523  vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1524  mapOperands);
1525 }
1526 template <>
1527 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1528  PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1529  ArrayRef<Value> mapOperands) const {
1530  rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1531  vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1532  mapOperands);
1533 }
1534 
1535 // Generic version for ops that don't have extra operands.
1536 template <typename AffineOpTy>
1537 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1538  PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1539  ArrayRef<Value> mapOperands) const {
1540  rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1541 }
1542 } // namespace
1543 
1544 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1545  MLIRContext *context) {
1546  results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1547 }
1548 
1549 //===----------------------------------------------------------------------===//
1550 // AffineDmaStartOp
1551 //===----------------------------------------------------------------------===//
1552 
1553 // TODO: Check that map operands are loop IVs or symbols.
1554 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1555  Value srcMemRef, AffineMap srcMap,
1556  ValueRange srcIndices, Value destMemRef,
1557  AffineMap dstMap, ValueRange destIndices,
1558  Value tagMemRef, AffineMap tagMap,
1559  ValueRange tagIndices, Value numElements,
1560  Value stride, Value elementsPerStride) {
1561  result.addOperands(srcMemRef);
1562  result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1563  result.addOperands(srcIndices);
1564  result.addOperands(destMemRef);
1565  result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1566  result.addOperands(destIndices);
1567  result.addOperands(tagMemRef);
1568  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1569  result.addOperands(tagIndices);
1570  result.addOperands(numElements);
1571  if (stride) {
1572  result.addOperands({stride, elementsPerStride});
1573  }
1574 }
1575 
1577  p << " " << getSrcMemRef() << '[';
1578  p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1579  p << "], " << getDstMemRef() << '[';
1580  p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1581  p << "], " << getTagMemRef() << '[';
1582  p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1583  p << "], " << getNumElements();
1584  if (isStrided()) {
1585  p << ", " << getStride();
1586  p << ", " << getNumElementsPerStride();
1587  }
1588  p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1589  << getTagMemRefType();
1590 }
1591 
1592 // Parse AffineDmaStartOp.
1593 // Ex:
1594 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1595 // %stride, %num_elt_per_stride
1596 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1597 //
1599  OperationState &result) {
1600  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1601  AffineMapAttr srcMapAttr;
1603  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1604  AffineMapAttr dstMapAttr;
1606  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1607  AffineMapAttr tagMapAttr;
1609  OpAsmParser::UnresolvedOperand numElementsInfo;
1611 
1612  SmallVector<Type, 3> types;
1613  auto indexType = parser.getBuilder().getIndexType();
1614 
1615  // Parse and resolve the following list of operands:
1616  // *) dst memref followed by its affine maps operands (in square brackets).
1617  // *) src memref followed by its affine map operands (in square brackets).
1618  // *) tag memref followed by its affine map operands (in square brackets).
1619  // *) number of elements transferred by DMA operation.
1620  if (parser.parseOperand(srcMemRefInfo) ||
1621  parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1622  getSrcMapAttrStrName(),
1623  result.attributes) ||
1624  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1625  parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1626  getDstMapAttrStrName(),
1627  result.attributes) ||
1628  parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1629  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1630  getTagMapAttrStrName(),
1631  result.attributes) ||
1632  parser.parseComma() || parser.parseOperand(numElementsInfo))
1633  return failure();
1634 
1635  // Parse optional stride and elements per stride.
1636  if (parser.parseTrailingOperandList(strideInfo))
1637  return failure();
1638 
1639  if (!strideInfo.empty() && strideInfo.size() != 2) {
1640  return parser.emitError(parser.getNameLoc(),
1641  "expected two stride related operands");
1642  }
1643  bool isStrided = strideInfo.size() == 2;
1644 
1645  if (parser.parseColonTypeList(types))
1646  return failure();
1647 
1648  if (types.size() != 3)
1649  return parser.emitError(parser.getNameLoc(), "expected three types");
1650 
1651  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1652  parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1653  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1654  parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1655  parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1656  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1657  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1658  return failure();
1659 
1660  if (isStrided) {
1661  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1662  return failure();
1663  }
1664 
1665  // Check that src/dst/tag operand counts match their map.numInputs.
1666  if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1667  dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1668  tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1669  return parser.emitError(parser.getNameLoc(),
1670  "memref operand count not equal to map.numInputs");
1671  return success();
1672 }
1673 
1674 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1675  if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
1676  return emitOpError("expected DMA source to be of memref type");
1677  if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
1678  return emitOpError("expected DMA destination to be of memref type");
1679  if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
1680  return emitOpError("expected DMA tag to be of memref type");
1681 
1682  unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1683  getDstMap().getNumInputs() +
1684  getTagMap().getNumInputs();
1685  if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1686  getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1687  return emitOpError("incorrect number of operands");
1688  }
1689 
1690  Region *scope = getAffineScope(*this);
1691  for (auto idx : getSrcIndices()) {
1692  if (!idx.getType().isIndex())
1693  return emitOpError("src index to dma_start must have 'index' type");
1694  if (!isValidAffineIndexOperand(idx, scope))
1695  return emitOpError(
1696  "src index must be a valid dimension or symbol identifier");
1697  }
1698  for (auto idx : getDstIndices()) {
1699  if (!idx.getType().isIndex())
1700  return emitOpError("dst index to dma_start must have 'index' type");
1701  if (!isValidAffineIndexOperand(idx, scope))
1702  return emitOpError(
1703  "dst index must be a valid dimension or symbol identifier");
1704  }
1705  for (auto idx : getTagIndices()) {
1706  if (!idx.getType().isIndex())
1707  return emitOpError("tag index to dma_start must have 'index' type");
1708  if (!isValidAffineIndexOperand(idx, scope))
1709  return emitOpError(
1710  "tag index must be a valid dimension or symbol identifier");
1711  }
1712  return success();
1713 }
1714 
1715 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1716  SmallVectorImpl<OpFoldResult> &results) {
1717  /// dma_start(memrefcast) -> dma_start
1718  return memref::foldMemRefCast(*this);
1719 }
1720 
1721 void AffineDmaStartOp::getEffects(
1723  &effects) {
1724  effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
1726  effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
1728  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
1730 }
1731 
1732 //===----------------------------------------------------------------------===//
1733 // AffineDmaWaitOp
1734 //===----------------------------------------------------------------------===//
1735 
1736 // TODO: Check that map operands are loop IVs or symbols.
1737 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1738  Value tagMemRef, AffineMap tagMap,
1739  ValueRange tagIndices, Value numElements) {
1740  result.addOperands(tagMemRef);
1741  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1742  result.addOperands(tagIndices);
1743  result.addOperands(numElements);
1744 }
1745 
1747  p << " " << getTagMemRef() << '[';
1748  SmallVector<Value, 2> operands(getTagIndices());
1749  p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1750  p << "], ";
1752  p << " : " << getTagMemRef().getType();
1753 }
1754 
1755 // Parse AffineDmaWaitOp.
1756 // Eg:
1757 // affine.dma_wait %tag[%index], %num_elements
1758 // : memref<1 x i32, (d0) -> (d0), 4>
1759 //
1761  OperationState &result) {
1762  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1763  AffineMapAttr tagMapAttr;
1765  Type type;
1766  auto indexType = parser.getBuilder().getIndexType();
1767  OpAsmParser::UnresolvedOperand numElementsInfo;
1768 
1769  // Parse tag memref, its map operands, and dma size.
1770  if (parser.parseOperand(tagMemRefInfo) ||
1771  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1772  getTagMapAttrStrName(),
1773  result.attributes) ||
1774  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1775  parser.parseColonType(type) ||
1776  parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1777  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1778  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1779  return failure();
1780 
1781  if (!llvm::isa<MemRefType>(type))
1782  return parser.emitError(parser.getNameLoc(),
1783  "expected tag to be of memref type");
1784 
1785  if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1786  return parser.emitError(parser.getNameLoc(),
1787  "tag memref operand count != to map.numInputs");
1788  return success();
1789 }
1790 
1791 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
1792  if (!llvm::isa<MemRefType>(getOperand(0).getType()))
1793  return emitOpError("expected DMA tag to be of memref type");
1794  Region *scope = getAffineScope(*this);
1795  for (auto idx : getTagIndices()) {
1796  if (!idx.getType().isIndex())
1797  return emitOpError("index to dma_wait must have 'index' type");
1798  if (!isValidAffineIndexOperand(idx, scope))
1799  return emitOpError(
1800  "index must be a valid dimension or symbol identifier");
1801  }
1802  return success();
1803 }
1804 
1805 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1806  SmallVectorImpl<OpFoldResult> &results) {
1807  /// dma_wait(memrefcast) -> dma_wait
1808  return memref::foldMemRefCast(*this);
1809 }
1810 
1811 void AffineDmaWaitOp::getEffects(
1813  &effects) {
1814  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
1816 }
1817 
1818 //===----------------------------------------------------------------------===//
1819 // AffineForOp
1820 //===----------------------------------------------------------------------===//
1821 
1822 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1823 /// bodyBuilder are empty/null, we include default terminator op.
1824 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1825  ValueRange lbOperands, AffineMap lbMap,
1826  ValueRange ubOperands, AffineMap ubMap, int64_t step,
1827  ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1828  assert(((!lbMap && lbOperands.empty()) ||
1829  lbOperands.size() == lbMap.getNumInputs()) &&
1830  "lower bound operand count does not match the affine map");
1831  assert(((!ubMap && ubOperands.empty()) ||
1832  ubOperands.size() == ubMap.getNumInputs()) &&
1833  "upper bound operand count does not match the affine map");
1834  assert(step > 0 && "step has to be a positive integer constant");
1835 
1836  OpBuilder::InsertionGuard guard(builder);
1837 
1838  // Set variadic segment sizes.
1839  result.addAttribute(
1840  getOperandSegmentSizeAttr(),
1841  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()),
1842  static_cast<int32_t>(ubOperands.size()),
1843  static_cast<int32_t>(iterArgs.size())}));
1844 
1845  for (Value val : iterArgs)
1846  result.addTypes(val.getType());
1847 
1848  // Add an attribute for the step.
1849  result.addAttribute(getStepAttrName(result.name),
1850  builder.getIntegerAttr(builder.getIndexType(), step));
1851 
1852  // Add the lower bound.
1853  result.addAttribute(getLowerBoundMapAttrName(result.name),
1854  AffineMapAttr::get(lbMap));
1855  result.addOperands(lbOperands);
1856 
1857  // Add the upper bound.
1858  result.addAttribute(getUpperBoundMapAttrName(result.name),
1859  AffineMapAttr::get(ubMap));
1860  result.addOperands(ubOperands);
1861 
1862  result.addOperands(iterArgs);
1863  // Create a region and a block for the body. The argument of the region is
1864  // the loop induction variable.
1865  Region *bodyRegion = result.addRegion();
1866  Block *bodyBlock = builder.createBlock(bodyRegion);
1867  Value inductionVar =
1868  bodyBlock->addArgument(builder.getIndexType(), result.location);
1869  for (Value val : iterArgs)
1870  bodyBlock->addArgument(val.getType(), val.getLoc());
1871 
1872  // Create the default terminator if the builder is not provided and if the
1873  // iteration arguments are not provided. Otherwise, leave this to the caller
1874  // because we don't know which values to return from the loop.
1875  if (iterArgs.empty() && !bodyBuilder) {
1876  ensureTerminator(*bodyRegion, builder, result.location);
1877  } else if (bodyBuilder) {
1878  OpBuilder::InsertionGuard guard(builder);
1879  builder.setInsertionPointToStart(bodyBlock);
1880  bodyBuilder(builder, result.location, inductionVar,
1881  bodyBlock->getArguments().drop_front());
1882  }
1883 }
1884 
1885 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1886  int64_t ub, int64_t step, ValueRange iterArgs,
1887  BodyBuilderFn bodyBuilder) {
1888  auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1889  auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1890  return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1891  bodyBuilder);
1892 }
1893 
1894 LogicalResult AffineForOp::verifyRegions() {
1895  // Check that the body defines as single block argument for the induction
1896  // variable.
1897  auto *body = getBody();
1898  if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1899  return emitOpError("expected body to have a single index argument for the "
1900  "induction variable");
1901 
1902  // Verify that the bound operands are valid dimension/symbols.
1903  /// Lower bound.
1904  if (getLowerBoundMap().getNumInputs() > 0)
1906  getLowerBoundMap().getNumDims())))
1907  return failure();
1908  /// Upper bound.
1909  if (getUpperBoundMap().getNumInputs() > 0)
1911  getUpperBoundMap().getNumDims())))
1912  return failure();
1913  if (getLowerBoundMap().getNumResults() < 1)
1914  return emitOpError("expected lower bound map to have at least one result");
1915  if (getUpperBoundMap().getNumResults() < 1)
1916  return emitOpError("expected upper bound map to have at least one result");
1917 
1918  unsigned opNumResults = getNumResults();
1919  if (opNumResults == 0)
1920  return success();
1921 
1922  // If ForOp defines values, check that the number and types of the defined
1923  // values match ForOp initial iter operands and backedge basic block
1924  // arguments.
1925  if (getNumIterOperands() != opNumResults)
1926  return emitOpError(
1927  "mismatch between the number of loop-carried values and results");
1928  if (getNumRegionIterArgs() != opNumResults)
1929  return emitOpError(
1930  "mismatch between the number of basic block args and results");
1931 
1932  return success();
1933 }
1934 
1935 /// Parse a for operation loop bounds.
1936 static ParseResult parseBound(bool isLower, OperationState &result,
1937  OpAsmParser &p) {
1938  // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1939  // the map has multiple results.
1940  bool failedToParsedMinMax =
1941  failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1942 
1943  auto &builder = p.getBuilder();
1944  auto boundAttrStrName =
1945  isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)
1946  : AffineForOp::getUpperBoundMapAttrName(result.name);
1947 
1948  // Parse ssa-id as identity map.
1950  if (p.parseOperandList(boundOpInfos))
1951  return failure();
1952 
1953  if (!boundOpInfos.empty()) {
1954  // Check that only one operand was parsed.
1955  if (boundOpInfos.size() > 1)
1956  return p.emitError(p.getNameLoc(),
1957  "expected only one loop bound operand");
1958 
1959  // TODO: improve error message when SSA value is not of index type.
1960  // Currently it is 'use of value ... expects different type than prior uses'
1961  if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1962  result.operands))
1963  return failure();
1964 
1965  // Create an identity map using symbol id. This representation is optimized
1966  // for storage. Analysis passes may expand it into a multi-dimensional map
1967  // if desired.
1968  AffineMap map = builder.getSymbolIdentityMap();
1969  result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
1970  return success();
1971  }
1972 
1973  // Get the attribute location.
1974  SMLoc attrLoc = p.getCurrentLocation();
1975 
1976  Attribute boundAttr;
1977  if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
1978  result.attributes))
1979  return failure();
1980 
1981  // Parse full form - affine map followed by dim and symbol list.
1982  if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(boundAttr)) {
1983  unsigned currentNumOperands = result.operands.size();
1984  unsigned numDims;
1985  if (parseDimAndSymbolList(p, result.operands, numDims))
1986  return failure();
1987 
1988  auto map = affineMapAttr.getValue();
1989  if (map.getNumDims() != numDims)
1990  return p.emitError(
1991  p.getNameLoc(),
1992  "dim operand count and affine map dim count must match");
1993 
1994  unsigned numDimAndSymbolOperands =
1995  result.operands.size() - currentNumOperands;
1996  if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1997  return p.emitError(
1998  p.getNameLoc(),
1999  "symbol operand count and affine map symbol count must match");
2000 
2001  // If the map has multiple results, make sure that we parsed the min/max
2002  // prefix.
2003  if (map.getNumResults() > 1 && failedToParsedMinMax) {
2004  if (isLower) {
2005  return p.emitError(attrLoc, "lower loop bound affine map with "
2006  "multiple results requires 'max' prefix");
2007  }
2008  return p.emitError(attrLoc, "upper loop bound affine map with multiple "
2009  "results requires 'min' prefix");
2010  }
2011  return success();
2012  }
2013 
2014  // Parse custom assembly form.
2015  if (auto integerAttr = llvm::dyn_cast<IntegerAttr>(boundAttr)) {
2016  result.attributes.pop_back();
2017  result.addAttribute(
2018  boundAttrStrName,
2019  AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
2020  return success();
2021  }
2022 
2023  return p.emitError(
2024  p.getNameLoc(),
2025  "expected valid affine map representation for loop bounds");
2026 }
2027 
2028 ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
2029  auto &builder = parser.getBuilder();
2030  OpAsmParser::Argument inductionVariable;
2031  inductionVariable.type = builder.getIndexType();
2032  // Parse the induction variable followed by '='.
2033  if (parser.parseArgument(inductionVariable) || parser.parseEqual())
2034  return failure();
2035 
2036  // Parse loop bounds.
2037  int64_t numOperands = result.operands.size();
2038  if (parseBound(/*isLower=*/true, result, parser))
2039  return failure();
2040  int64_t numLbOperands = result.operands.size() - numOperands;
2041  if (parser.parseKeyword("to", " between bounds"))
2042  return failure();
2043  numOperands = result.operands.size();
2044  if (parseBound(/*isLower=*/false, result, parser))
2045  return failure();
2046  int64_t numUbOperands = result.operands.size() - numOperands;
2047 
2048  // Parse the optional loop step, we default to 1 if one is not present.
2049  if (parser.parseOptionalKeyword("step")) {
2050  result.addAttribute(
2051  getStepAttrName(result.name),
2052  builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
2053  } else {
2054  SMLoc stepLoc = parser.getCurrentLocation();
2055  IntegerAttr stepAttr;
2056  if (parser.parseAttribute(stepAttr, builder.getIndexType(),
2057  getStepAttrName(result.name).data(),
2058  result.attributes))
2059  return failure();
2060 
2061  if (stepAttr.getValue().isNegative())
2062  return parser.emitError(
2063  stepLoc,
2064  "expected step to be representable as a positive signed integer");
2065  }
2066 
2067  // Parse the optional initial iteration arguments.
2070 
2071  // Induction variable.
2072  regionArgs.push_back(inductionVariable);
2073 
2074  if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
2075  // Parse assignment list and results type list.
2076  if (parser.parseAssignmentList(regionArgs, operands) ||
2077  parser.parseArrowTypeList(result.types))
2078  return failure();
2079  // Resolve input operands.
2080  for (auto argOperandType :
2081  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2082  Type type = std::get<2>(argOperandType);
2083  std::get<0>(argOperandType).type = type;
2084  if (parser.resolveOperand(std::get<1>(argOperandType), type,
2085  result.operands))
2086  return failure();
2087  }
2088  }
2089 
2090  result.addAttribute(
2091  getOperandSegmentSizeAttr(),
2092  builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands),
2093  static_cast<int32_t>(numUbOperands),
2094  static_cast<int32_t>(operands.size())}));
2095 
2096  // Parse the body region.
2097  Region *body = result.addRegion();
2098  if (regionArgs.size() != result.types.size() + 1)
2099  return parser.emitError(
2100  parser.getNameLoc(),
2101  "mismatch between the number of loop-carried values and results");
2102  if (parser.parseRegion(*body, regionArgs))
2103  return failure();
2104 
2105  AffineForOp::ensureTerminator(*body, builder, result.location);
2106 
2107  // Parse the optional attribute list.
2108  return parser.parseOptionalAttrDict(result.attributes);
2109 }
2110 
2111 static void printBound(AffineMapAttr boundMap,
2112  Operation::operand_range boundOperands,
2113  const char *prefix, OpAsmPrinter &p) {
2114  AffineMap map = boundMap.getValue();
2115 
2116  // Check if this bound should be printed using custom assembly form.
2117  // The decision to restrict printing custom assembly form to trivial cases
2118  // comes from the will to roundtrip MLIR binary -> text -> binary in a
2119  // lossless way.
2120  // Therefore, custom assembly form parsing and printing is only supported for
2121  // zero-operand constant maps and single symbol operand identity maps.
2122  if (map.getNumResults() == 1) {
2123  AffineExpr expr = map.getResult(0);
2124 
2125  // Print constant bound.
2126  if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2127  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2128  p << constExpr.getValue();
2129  return;
2130  }
2131  }
2132 
2133  // Print bound that consists of a single SSA symbol if the map is over a
2134  // single symbol.
2135  if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2136  if (dyn_cast<AffineSymbolExpr>(expr)) {
2137  p.printOperand(*boundOperands.begin());
2138  return;
2139  }
2140  }
2141  } else {
2142  // Map has multiple results. Print 'min' or 'max' prefix.
2143  p << prefix << ' ';
2144  }
2145 
2146  // Print the map and its operands.
2147  p << boundMap;
2148  printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
2149  map.getNumDims(), p);
2150 }
2151 
2152 unsigned AffineForOp::getNumIterOperands() {
2153  AffineMap lbMap = getLowerBoundMapAttr().getValue();
2154  AffineMap ubMap = getUpperBoundMapAttr().getValue();
2155 
2156  return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2157 }
2158 
2159 std::optional<MutableArrayRef<OpOperand>>
2160 AffineForOp::getYieldedValuesMutable() {
2161  return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2162 }
2163 
2165  p << ' ';
2166  p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
2167  /*omitType=*/true);
2168  p << " = ";
2169  printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
2170  p << " to ";
2171  printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
2172 
2173  if (getStepAsInt() != 1)
2174  p << " step " << getStepAsInt();
2175 
2176  bool printBlockTerminators = false;
2177  if (getNumIterOperands() > 0) {
2178  p << " iter_args(";
2179  auto regionArgs = getRegionIterArgs();
2180  auto operands = getInits();
2181 
2182  llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2183  p << std::get<0>(it) << " = " << std::get<1>(it);
2184  });
2185  p << ") -> (" << getResultTypes() << ")";
2186  printBlockTerminators = true;
2187  }
2188 
2189  p << ' ';
2190  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2191  printBlockTerminators);
2193  (*this)->getAttrs(),
2194  /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()),
2195  getUpperBoundMapAttrName(getOperation()->getName()),
2196  getStepAttrName(getOperation()->getName()),
2197  getOperandSegmentSizeAttr()});
2198 }
2199 
2200 /// Fold the constant bounds of a loop.
2201 static LogicalResult foldLoopBounds(AffineForOp forOp) {
2202  auto foldLowerOrUpperBound = [&forOp](bool lower) {
2203  // Check to see if each of the operands is the result of a constant. If
2204  // so, get the value. If not, ignore it.
2205  SmallVector<Attribute, 8> operandConstants;
2206  auto boundOperands =
2207  lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2208  for (auto operand : boundOperands) {
2209  Attribute operandCst;
2210  matchPattern(operand, m_Constant(&operandCst));
2211  operandConstants.push_back(operandCst);
2212  }
2213 
2214  AffineMap boundMap =
2215  lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2216  assert(boundMap.getNumResults() >= 1 &&
2217  "bound maps should have at least one result");
2218  SmallVector<Attribute, 4> foldedResults;
2219  if (failed(boundMap.constantFold(operandConstants, foldedResults)))
2220  return failure();
2221 
2222  // Compute the max or min as applicable over the results.
2223  assert(!foldedResults.empty() && "bounds should have at least one result");
2224  auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2225  for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2226  auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2227  maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2228  : llvm::APIntOps::smin(maxOrMin, foldedResult);
2229  }
2230  lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2231  : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2232  return success();
2233  };
2234 
2235  // Try to fold the lower bound.
2236  bool folded = false;
2237  if (!forOp.hasConstantLowerBound())
2238  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
2239 
2240  // Try to fold the upper bound.
2241  if (!forOp.hasConstantUpperBound())
2242  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
2243  return success(folded);
2244 }
2245 
2246 /// Canonicalize the bounds of the given loop.
2247 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2248  SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2249  SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2250 
2251  auto lbMap = forOp.getLowerBoundMap();
2252  auto ubMap = forOp.getUpperBoundMap();
2253  auto prevLbMap = lbMap;
2254  auto prevUbMap = ubMap;
2255 
2256  composeAffineMapAndOperands(&lbMap, &lbOperands);
2257  canonicalizeMapAndOperands(&lbMap, &lbOperands);
2258  simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
2259  simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
2260  lbMap = removeDuplicateExprs(lbMap);
2261 
2262  composeAffineMapAndOperands(&ubMap, &ubOperands);
2263  canonicalizeMapAndOperands(&ubMap, &ubOperands);
2264  ubMap = removeDuplicateExprs(ubMap);
2265 
2266  // Any canonicalization change always leads to updated map(s).
2267  if (lbMap == prevLbMap && ubMap == prevUbMap)
2268  return failure();
2269 
2270  if (lbMap != prevLbMap)
2271  forOp.setLowerBound(lbOperands, lbMap);
2272  if (ubMap != prevUbMap)
2273  forOp.setUpperBound(ubOperands, ubMap);
2274  return success();
2275 }
2276 
2277 namespace {
2278 /// Returns constant trip count in trivial cases.
2279 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2280  int64_t step = forOp.getStepAsInt();
2281  if (!forOp.hasConstantBounds() || step <= 0)
2282  return std::nullopt;
2283  int64_t lb = forOp.getConstantLowerBound();
2284  int64_t ub = forOp.getConstantUpperBound();
2285  return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2286 }
2287 
2288 /// This is a pattern to fold trivially empty loop bodies.
2289 /// TODO: This should be moved into the folding hook.
2290 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
2292 
2293  LogicalResult matchAndRewrite(AffineForOp forOp,
2294  PatternRewriter &rewriter) const override {
2295  // Check that the body only contains a yield.
2296  if (!llvm::hasSingleElement(*forOp.getBody()))
2297  return failure();
2298  if (forOp.getNumResults() == 0)
2299  return success();
2300  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2301  if (tripCount && *tripCount == 0) {
2302  // The initial values of the iteration arguments would be the op's
2303  // results.
2304  rewriter.replaceOp(forOp, forOp.getInits());
2305  return success();
2306  }
2307  SmallVector<Value, 4> replacements;
2308  auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2309  auto iterArgs = forOp.getRegionIterArgs();
2310  bool hasValDefinedOutsideLoop = false;
2311  bool iterArgsNotInOrder = false;
2312  for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2313  Value val = yieldOp.getOperand(i);
2314  auto *iterArgIt = llvm::find(iterArgs, val);
2315  // TODO: It should be possible to perform a replacement by computing the
2316  // last value of the IV based on the bounds and the step.
2317  if (val == forOp.getInductionVar())
2318  return failure();
2319  if (iterArgIt == iterArgs.end()) {
2320  // `val` is defined outside of the loop.
2321  assert(forOp.isDefinedOutsideOfLoop(val) &&
2322  "must be defined outside of the loop");
2323  hasValDefinedOutsideLoop = true;
2324  replacements.push_back(val);
2325  } else {
2326  unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2327  if (pos != i)
2328  iterArgsNotInOrder = true;
2329  replacements.push_back(forOp.getInits()[pos]);
2330  }
2331  }
2332  // Bail out when the trip count is unknown and the loop returns any value
2333  // defined outside of the loop or any iterArg out of order.
2334  if (!tripCount.has_value() &&
2335  (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2336  return failure();
2337  // Bail out when the loop iterates more than once and it returns any iterArg
2338  // out of order.
2339  if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2340  return failure();
2341  rewriter.replaceOp(forOp, replacements);
2342  return success();
2343  }
2344 };
2345 } // namespace
2346 
2347 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
2348  MLIRContext *context) {
2349  results.add<AffineForEmptyLoopFolder>(context);
2350 }
2351 
2352 OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2353  assert((point.isParent() || point == getRegion()) && "invalid region point");
2354 
2355  // The initial operands map to the loop arguments after the induction
2356  // variable or are forwarded to the results when the trip count is zero.
2357  return getInits();
2358 }
2359 
2360 void AffineForOp::getSuccessorRegions(
2362  assert((point.isParent() || point == getRegion()) && "expected loop region");
2363  // The loop may typically branch back to its body or to the parent operation.
2364  // If the predecessor is the parent op and the trip count is known to be at
2365  // least one, branch into the body using the iterator arguments. And in cases
2366  // we know the trip count is zero, it can only branch back to its parent.
2367  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2368  if (point.isParent() && tripCount.has_value()) {
2369  if (tripCount.value() > 0) {
2370  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2371  return;
2372  }
2373  if (tripCount.value() == 0) {
2374  regions.push_back(RegionSuccessor(getResults()));
2375  return;
2376  }
2377  }
2378 
2379  // From the loop body, if the trip count is one, we can only branch back to
2380  // the parent.
2381  if (!point.isParent() && tripCount && *tripCount == 1) {
2382  regions.push_back(RegionSuccessor(getResults()));
2383  return;
2384  }
2385 
2386  // In all other cases, the loop may branch back to itself or the parent
2387  // operation.
2388  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2389  regions.push_back(RegionSuccessor(getResults()));
2390 }
2391 
2392 /// Returns true if the affine.for has zero iterations in trivial cases.
2393 static bool hasTrivialZeroTripCount(AffineForOp op) {
2394  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
2395  return tripCount && *tripCount == 0;
2396 }
2397 
2398 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2399  SmallVectorImpl<OpFoldResult> &results) {
2400  bool folded = succeeded(foldLoopBounds(*this));
2401  folded |= succeeded(canonicalizeLoopBounds(*this));
2402  if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
2403  // The initial values of the loop-carried variables (iter_args) are the
2404  // results of the op. But this must be avoided for an affine.for op that
2405  // does not return any results. Since ops that do not return results cannot
2406  // be folded away, we would enter an infinite loop of folds on the same
2407  // affine.for op.
2408  results.assign(getInits().begin(), getInits().end());
2409  folded = true;
2410  }
2411  return success(folded);
2412 }
2413 
2415  return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
2416 }
2417 
2419  return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap());
2420 }
2421 
2422 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2423  assert(lbOperands.size() == map.getNumInputs());
2424  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2425  getLowerBoundOperandsMutable().assign(lbOperands);
2426  setLowerBoundMap(map);
2427 }
2428 
2429 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2430  assert(ubOperands.size() == map.getNumInputs());
2431  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2432  getUpperBoundOperandsMutable().assign(ubOperands);
2433  setUpperBoundMap(map);
2434 }
2435 
2436 bool AffineForOp::hasConstantLowerBound() {
2437  return getLowerBoundMap().isSingleConstant();
2438 }
2439 
2440 bool AffineForOp::hasConstantUpperBound() {
2441  return getUpperBoundMap().isSingleConstant();
2442 }
2443 
2444 int64_t AffineForOp::getConstantLowerBound() {
2445  return getLowerBoundMap().getSingleConstantResult();
2446 }
2447 
2448 int64_t AffineForOp::getConstantUpperBound() {
2449  return getUpperBoundMap().getSingleConstantResult();
2450 }
2451 
2452 void AffineForOp::setConstantLowerBound(int64_t value) {
2453  setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2454 }
2455 
2456 void AffineForOp::setConstantUpperBound(int64_t value) {
2457  setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2458 }
2459 
2460 AffineForOp::operand_range AffineForOp::getControlOperands() {
2461  return {operand_begin(), operand_begin() + getLowerBoundOperands().size() +
2462  getUpperBoundOperands().size()};
2463 }
2464 
2465 bool AffineForOp::matchingBoundOperandList() {
2466  auto lbMap = getLowerBoundMap();
2467  auto ubMap = getUpperBoundMap();
2468  if (lbMap.getNumDims() != ubMap.getNumDims() ||
2469  lbMap.getNumSymbols() != ubMap.getNumSymbols())
2470  return false;
2471 
2472  unsigned numOperands = lbMap.getNumInputs();
2473  for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2474  // Compare Value 's.
2475  if (getOperand(i) != getOperand(numOperands + i))
2476  return false;
2477  }
2478  return true;
2479 }
2480 
2481 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
2482 
2483 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2484  return SmallVector<Value>{getInductionVar()};
2485 }
2486 
2487 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2488  if (!hasConstantLowerBound())
2489  return std::nullopt;
2490  OpBuilder b(getContext());
2492  OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2493 }
2494 
2495 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2496  OpBuilder b(getContext());
2498  OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
2499 }
2500 
2501 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2502  if (!hasConstantUpperBound())
2503  return {};
2504  OpBuilder b(getContext());
2506  OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2507 }
2508 
2509 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2510  RewriterBase &rewriter, ValueRange newInitOperands,
2511  bool replaceInitOperandUsesInLoop,
2512  const NewYieldValuesFn &newYieldValuesFn) {
2513  // Create a new loop before the existing one, with the extra operands.
2514  OpBuilder::InsertionGuard g(rewriter);
2515  rewriter.setInsertionPoint(getOperation());
2516  auto inits = llvm::to_vector(getInits());
2517  inits.append(newInitOperands.begin(), newInitOperands.end());
2518  AffineForOp newLoop = rewriter.create<AffineForOp>(
2519  getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2520  getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
2521 
2522  // Generate the new yield values and append them to the scf.yield operation.
2523  auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2524  ArrayRef<BlockArgument> newIterArgs =
2525  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2526  {
2527  OpBuilder::InsertionGuard g(rewriter);
2528  rewriter.setInsertionPoint(yieldOp);
2529  SmallVector<Value> newYieldedValues =
2530  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2531  assert(newInitOperands.size() == newYieldedValues.size() &&
2532  "expected as many new yield values as new iter operands");
2533  rewriter.modifyOpInPlace(yieldOp, [&]() {
2534  yieldOp.getOperandsMutable().append(newYieldedValues);
2535  });
2536  }
2537 
2538  // Move the loop body to the new op.
2539  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2540  newLoop.getBody()->getArguments().take_front(
2541  getBody()->getNumArguments()));
2542 
2543  if (replaceInitOperandUsesInLoop) {
2544  // Replace all uses of `newInitOperands` with the corresponding basic block
2545  // arguments.
2546  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2547  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
2548  [&](OpOperand &use) {
2549  Operation *user = use.getOwner();
2550  return newLoop->isProperAncestor(user);
2551  });
2552  }
2553  }
2554 
2555  // Replace the old loop.
2556  rewriter.replaceOp(getOperation(),
2557  newLoop->getResults().take_front(getNumResults()));
2558  return cast<LoopLikeOpInterface>(newLoop.getOperation());
2559 }
2560 
2561 Speculation::Speculatability AffineForOp::getSpeculatability() {
2562  // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2563  // Start and End.
2564  //
2565  // For Step != 1, the loop may not terminate. We can add more smarts here if
2566  // needed.
2567  return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable
2569 }
2570 
2571 /// Returns true if the provided value is the induction variable of a
2572 /// AffineForOp.
2574  return getForInductionVarOwner(val) != AffineForOp();
2575 }
2576 
2578  return getAffineParallelInductionVarOwner(val) != nullptr;
2579 }
2580 
2583 }
2584 
2586  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2587  if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2588  return AffineForOp();
2589  if (auto forOp =
2590  ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2591  // Check to make sure `val` is the induction variable, not an iter_arg.
2592  return forOp.getInductionVar() == val ? forOp : AffineForOp();
2593  return AffineForOp();
2594 }
2595 
2597  auto ivArg = llvm::dyn_cast<BlockArgument>(val);
2598  if (!ivArg || !ivArg.getOwner())
2599  return nullptr;
2600  Operation *containingOp = ivArg.getOwner()->getParentOp();
2601  auto parallelOp = dyn_cast<AffineParallelOp>(containingOp);
2602  if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2603  return parallelOp;
2604  return nullptr;
2605 }
2606 
2607 /// Extracts the induction variables from a list of AffineForOps and returns
2608 /// them.
2610  SmallVectorImpl<Value> *ivs) {
2611  ivs->reserve(forInsts.size());
2612  for (auto forInst : forInsts)
2613  ivs->push_back(forInst.getInductionVar());
2614 }
2615 
2618  ivs.reserve(affineOps.size());
2619  for (Operation *op : affineOps) {
2620  // Add constraints from forOp's bounds.
2621  if (auto forOp = dyn_cast<AffineForOp>(op))
2622  ivs.push_back(forOp.getInductionVar());
2623  else if (auto parallelOp = dyn_cast<AffineParallelOp>(op))
2624  for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2625  ivs.push_back(parallelOp.getBody()->getArgument(i));
2626  }
2627 }
2628 
2629 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2630 /// operations.
2631 template <typename BoundListTy, typename LoopCreatorTy>
2633  OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2634  ArrayRef<int64_t> steps,
2635  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2636  LoopCreatorTy &&loopCreatorFn) {
2637  assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2638  assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2639 
2640  // If there are no loops to be constructed, construct the body anyway.
2641  OpBuilder::InsertionGuard guard(builder);
2642  if (lbs.empty()) {
2643  if (bodyBuilderFn)
2644  bodyBuilderFn(builder, loc, ValueRange());
2645  return;
2646  }
2647 
2648  // Create the loops iteratively and store the induction variables.
2650  ivs.reserve(lbs.size());
2651  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2652  // Callback for creating the loop body, always creates the terminator.
2653  auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
2654  ValueRange iterArgs) {
2655  ivs.push_back(iv);
2656  // In the innermost loop, call the body builder.
2657  if (i == e - 1 && bodyBuilderFn) {
2658  OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
2659  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2660  }
2661  nestedBuilder.create<AffineYieldOp>(nestedLoc);
2662  };
2663 
2664  // Delegate actual loop creation to the callback in order to dispatch
2665  // between constant- and variable-bound loops.
2666  auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2667  builder.setInsertionPointToStart(loop.getBody());
2668  }
2669 }
2670 
2671 /// Creates an affine loop from the bounds known to be constants.
2672 static AffineForOp
2674  int64_t ub, int64_t step,
2675  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2676  return builder.create<AffineForOp>(loc, lb, ub, step,
2677  /*iterArgs=*/std::nullopt, bodyBuilderFn);
2678 }
2679 
2680 /// Creates an affine loop from the bounds that may or may not be constants.
2681 static AffineForOp
2683  int64_t step,
2684  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2685  std::optional<int64_t> lbConst = getConstantIntValue(lb);
2686  std::optional<int64_t> ubConst = getConstantIntValue(ub);
2687  if (lbConst && ubConst)
2688  return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
2689  ubConst.value(), step, bodyBuilderFn);
2690  return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
2691  builder.getDimIdentityMap(), step,
2692  /*iterArgs=*/std::nullopt, bodyBuilderFn);
2693 }
2694 
2696  OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
2698  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2699  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2701 }
2702 
2704  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
2705  ArrayRef<int64_t> steps,
2706  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2707  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2709 }
2710 
2711 //===----------------------------------------------------------------------===//
2712 // AffineIfOp
2713 //===----------------------------------------------------------------------===//
2714 
2715 namespace {
2716 /// Remove else blocks that have nothing other than a zero value yield.
2717 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
2719 
2720  LogicalResult matchAndRewrite(AffineIfOp ifOp,
2721  PatternRewriter &rewriter) const override {
2722  if (ifOp.getElseRegion().empty() ||
2723  !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2724  return failure();
2725 
2726  rewriter.startOpModification(ifOp);
2727  rewriter.eraseBlock(ifOp.getElseBlock());
2728  rewriter.finalizeOpModification(ifOp);
2729  return success();
2730  }
2731 };
2732 
2733 /// Removes affine.if cond if the condition is always true or false in certain
2734 /// trivial cases. Promotes the then/else block in the parent operation block.
2735 struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
2737 
2738  LogicalResult matchAndRewrite(AffineIfOp op,
2739  PatternRewriter &rewriter) const override {
2740 
2741  auto isTriviallyFalse = [](IntegerSet iSet) {
2742  return iSet.isEmptyIntegerSet();
2743  };
2744 
2745  auto isTriviallyTrue = [](IntegerSet iSet) {
2746  return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2747  iSet.getConstraint(0) == 0);
2748  };
2749 
2750  IntegerSet affineIfConditions = op.getIntegerSet();
2751  Block *blockToMove;
2752  if (isTriviallyFalse(affineIfConditions)) {
2753  // The absence, or equivalently, the emptiness of the else region need not
2754  // be checked when affine.if is returning results because if an affine.if
2755  // operation is returning results, it always has a non-empty else region.
2756  if (op.getNumResults() == 0 && !op.hasElse()) {
2757  // If the else region is absent, or equivalently, empty, remove the
2758  // affine.if operation (which is not returning any results).
2759  rewriter.eraseOp(op);
2760  return success();
2761  }
2762  blockToMove = op.getElseBlock();
2763  } else if (isTriviallyTrue(affineIfConditions)) {
2764  blockToMove = op.getThenBlock();
2765  } else {
2766  return failure();
2767  }
2768  Operation *blockToMoveTerminator = blockToMove->getTerminator();
2769  // Promote the "blockToMove" block to the parent operation block between the
2770  // prologue and epilogue of "op".
2771  rewriter.inlineBlockBefore(blockToMove, op);
2772  // Replace the "op" operation with the operands of the
2773  // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2774  // the affine.yield operation present in the "blockToMove" block. It has no
2775  // operands when affine.if is not returning results and therefore, in that
2776  // case, replaceOp just erases "op". When affine.if is not returning
2777  // results, the affine.yield operation can be omitted. It gets inserted
2778  // implicitly.
2779  rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
2780  // Erase the "blockToMoveTerminator" operation since it is now in the parent
2781  // operation block, which already has its own terminator.
2782  rewriter.eraseOp(blockToMoveTerminator);
2783  return success();
2784  }
2785 };
2786 } // namespace
2787 
2788 /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
2789 /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
2790 void AffineIfOp::getSuccessorRegions(
2792  // If the predecessor is an AffineIfOp, then branching into both `then` and
2793  // `else` region is valid.
2794  if (point.isParent()) {
2795  regions.reserve(2);
2796  regions.push_back(
2797  RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
2798  // If the "else" region is empty, branch bach into parent.
2799  if (getElseRegion().empty()) {
2800  regions.push_back(getResults());
2801  } else {
2802  regions.push_back(
2803  RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
2804  }
2805  return;
2806  }
2807 
2808  // If the predecessor is the `else`/`then` region, then branching into parent
2809  // op is valid.
2810  regions.push_back(RegionSuccessor(getResults()));
2811 }
2812 
2813 LogicalResult AffineIfOp::verify() {
2814  // Verify that we have a condition attribute.
2815  // FIXME: This should be specified in the arguments list in ODS.
2816  auto conditionAttr =
2817  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2818  if (!conditionAttr)
2819  return emitOpError("requires an integer set attribute named 'condition'");
2820 
2821  // Verify that there are enough operands for the condition.
2822  IntegerSet condition = conditionAttr.getValue();
2823  if (getNumOperands() != condition.getNumInputs())
2824  return emitOpError("operand count and condition integer set dimension and "
2825  "symbol count must match");
2826 
2827  // Verify that the operands are valid dimension/symbols.
2828  if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
2829  condition.getNumDims())))
2830  return failure();
2831 
2832  return success();
2833 }
2834 
2835 ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
2836  // Parse the condition attribute set.
2837  IntegerSetAttr conditionAttr;
2838  unsigned numDims;
2839  if (parser.parseAttribute(conditionAttr,
2840  AffineIfOp::getConditionAttrStrName(),
2841  result.attributes) ||
2842  parseDimAndSymbolList(parser, result.operands, numDims))
2843  return failure();
2844 
2845  // Verify the condition operands.
2846  auto set = conditionAttr.getValue();
2847  if (set.getNumDims() != numDims)
2848  return parser.emitError(
2849  parser.getNameLoc(),
2850  "dim operand count and integer set dim count must match");
2851  if (numDims + set.getNumSymbols() != result.operands.size())
2852  return parser.emitError(
2853  parser.getNameLoc(),
2854  "symbol operand count and integer set symbol count must match");
2855 
2856  if (parser.parseOptionalArrowTypeList(result.types))
2857  return failure();
2858 
2859  // Create the regions for 'then' and 'else'. The latter must be created even
2860  // if it remains empty for the validity of the operation.
2861  result.regions.reserve(2);
2862  Region *thenRegion = result.addRegion();
2863  Region *elseRegion = result.addRegion();
2864 
2865  // Parse the 'then' region.
2866  if (parser.parseRegion(*thenRegion, {}, {}))
2867  return failure();
2868  AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
2869  result.location);
2870 
2871  // If we find an 'else' keyword then parse the 'else' region.
2872  if (!parser.parseOptionalKeyword("else")) {
2873  if (parser.parseRegion(*elseRegion, {}, {}))
2874  return failure();
2875  AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
2876  result.location);
2877  }
2878 
2879  // Parse the optional attribute list.
2880  if (parser.parseOptionalAttrDict(result.attributes))
2881  return failure();
2882 
2883  return success();
2884 }
2885 
2887  auto conditionAttr =
2888  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
2889  p << " " << conditionAttr;
2890  printDimAndSymbolList(operand_begin(), operand_end(),
2891  conditionAttr.getValue().getNumDims(), p);
2892  p.printOptionalArrowTypeList(getResultTypes());
2893  p << ' ';
2894  p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
2895  /*printBlockTerminators=*/getNumResults());
2896 
2897  // Print the 'else' regions if it has any blocks.
2898  auto &elseRegion = this->getElseRegion();
2899  if (!elseRegion.empty()) {
2900  p << " else ";
2901  p.printRegion(elseRegion,
2902  /*printEntryBlockArgs=*/false,
2903  /*printBlockTerminators=*/getNumResults());
2904  }
2905 
2906  // Print the attribute list.
2907  p.printOptionalAttrDict((*this)->getAttrs(),
2908  /*elidedAttrs=*/getConditionAttrStrName());
2909 }
2910 
2911 IntegerSet AffineIfOp::getIntegerSet() {
2912  return (*this)
2913  ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
2914  .getValue();
2915 }
2916 
2917 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
2918  (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
2919 }
2920 
2921 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
2922  setIntegerSet(set);
2923  (*this)->setOperands(operands);
2924 }
2925 
2926 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2927  TypeRange resultTypes, IntegerSet set, ValueRange args,
2928  bool withElseRegion) {
2929  assert(resultTypes.empty() || withElseRegion);
2930  OpBuilder::InsertionGuard guard(builder);
2931 
2932  result.addTypes(resultTypes);
2933  result.addOperands(args);
2934  result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
2935 
2936  Region *thenRegion = result.addRegion();
2937  builder.createBlock(thenRegion);
2938  if (resultTypes.empty())
2939  AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
2940 
2941  Region *elseRegion = result.addRegion();
2942  if (withElseRegion) {
2943  builder.createBlock(elseRegion);
2944  if (resultTypes.empty())
2945  AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
2946  }
2947 }
2948 
2949 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2950  IntegerSet set, ValueRange args, bool withElseRegion) {
2951  AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
2952  withElseRegion);
2953 }
2954 
2955 /// Compose any affine.apply ops feeding into `operands` of the integer set
2956 /// `set` by composing the maps of such affine.apply ops with the integer
2957 /// set constraints.
2959  SmallVectorImpl<Value> &operands) {
2960  // We will simply reuse the API of the map composition by viewing the LHSs of
2961  // the equalities and inequalities of `set` as the affine exprs of an affine
2962  // map. Convert to equivalent map, compose, and convert back to set.
2963  auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(),
2964  set.getConstraints(), set.getContext());
2965  // Check if any composition is possible.
2966  if (llvm::none_of(operands,
2967  [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
2968  return;
2969 
2970  composeAffineMapAndOperands(&map, &operands);
2971  set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
2972  set.getEqFlags());
2973 }
2974 
2975 /// Canonicalize an affine if op's conditional (integer set + operands).
2976 LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2977  auto set = getIntegerSet();
2978  SmallVector<Value, 4> operands(getOperands());
2979  composeSetAndOperands(set, operands);
2980  canonicalizeSetAndOperands(&set, &operands);
2981 
2982  // Check if the canonicalization or composition led to any change.
2983  if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
2984  return failure();
2985 
2986  setConditional(set, operands);
2987  return success();
2988 }
2989 
2990 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2991  MLIRContext *context) {
2992  results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
2993 }
2994 
2995 //===----------------------------------------------------------------------===//
2996 // AffineLoadOp
2997 //===----------------------------------------------------------------------===//
2998 
2999 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3000  AffineMap map, ValueRange operands) {
3001  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
3002  result.addOperands(operands);
3003  if (map)
3004  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3005  auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
3006  result.types.push_back(memrefType.getElementType());
3007 }
3008 
3009 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3010  Value memref, AffineMap map, ValueRange mapOperands) {
3011  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3012  result.addOperands(memref);
3013  result.addOperands(mapOperands);
3014  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3015  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3016  result.types.push_back(memrefType.getElementType());
3017 }
3018 
3019 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3020  Value memref, ValueRange indices) {
3021  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3022  int64_t rank = memrefType.getRank();
3023  // Create identity map for memrefs with at least one dimension or () -> ()
3024  // for zero-dimensional memrefs.
3025  auto map =
3026  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3027  build(builder, result, memref, map, indices);
3028 }
3029 
3030 ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
3031  auto &builder = parser.getBuilder();
3032  auto indexTy = builder.getIndexType();
3033 
3034  MemRefType type;
3035  OpAsmParser::UnresolvedOperand memrefInfo;
3036  AffineMapAttr mapAttr;
3038  return failure(
3039  parser.parseOperand(memrefInfo) ||
3040  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3041  AffineLoadOp::getMapAttrStrName(),
3042  result.attributes) ||
3043  parser.parseOptionalAttrDict(result.attributes) ||
3044  parser.parseColonType(type) ||
3045  parser.resolveOperand(memrefInfo, type, result.operands) ||
3046  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3047  parser.addTypeToList(type.getElementType(), result.types));
3048 }
3049 
3051  p << " " << getMemRef() << '[';
3052  if (AffineMapAttr mapAttr =
3053  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3054  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3055  p << ']';
3056  p.printOptionalAttrDict((*this)->getAttrs(),
3057  /*elidedAttrs=*/{getMapAttrStrName()});
3058  p << " : " << getMemRefType();
3059 }
3060 
3061 /// Verify common indexing invariants of affine.load, affine.store,
3062 /// affine.vector_load and affine.vector_store.
3063 template <typename AffineMemOpTy>
3064 static LogicalResult
3065 verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr,
3066  Operation::operand_range mapOperands,
3067  MemRefType memrefType, unsigned numIndexOperands) {
3068  AffineMap map = mapAttr.getValue();
3069  if (map.getNumResults() != memrefType.getRank())
3070  return op->emitOpError("affine map num results must equal memref rank");
3071  if (map.getNumInputs() != numIndexOperands)
3072  return op->emitOpError("expects as many subscripts as affine map inputs");
3073 
3074  for (auto idx : mapOperands) {
3075  if (!idx.getType().isIndex())
3076  return op->emitOpError("index to load must have 'index' type");
3077  }
3078  if (failed(verifyDimAndSymbolIdentifiers(op, mapOperands, map.getNumDims())))
3079  return failure();
3080 
3081  return success();
3082 }
3083 
3084 LogicalResult AffineLoadOp::verify() {
3085  auto memrefType = getMemRefType();
3086  if (getType() != memrefType.getElementType())
3087  return emitOpError("result type must match element type of memref");
3088 
3089  if (failed(verifyMemoryOpIndexing(
3090  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3091  getMapOperands(), memrefType,
3092  /*numIndexOperands=*/getNumOperands() - 1)))
3093  return failure();
3094 
3095  return success();
3096 }
3097 
3098 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3099  MLIRContext *context) {
3100  results.add<SimplifyAffineOp<AffineLoadOp>>(context);
3101 }
3102 
3103 OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3104  /// load(memrefcast) -> load
3105  if (succeeded(memref::foldMemRefCast(*this)))
3106  return getResult();
3107 
3108  // Fold load from a global constant memref.
3109  auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3110  if (!getGlobalOp)
3111  return {};
3112  // Get to the memref.global defining the symbol.
3113  auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3114  if (!symbolTableOp)
3115  return {};
3116  auto global = dyn_cast_or_null<memref::GlobalOp>(
3117  SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
3118  if (!global)
3119  return {};
3120 
3121  // Check if the global memref is a constant.
3122  auto cstAttr =
3123  llvm::dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3124  if (!cstAttr)
3125  return {};
3126  // If it's a splat constant, we can fold irrespective of indices.
3127  if (auto splatAttr = llvm::dyn_cast<SplatElementsAttr>(cstAttr))
3128  return splatAttr.getSplatValue<Attribute>();
3129  // Otherwise, we can fold only if we know the indices.
3130  if (!getAffineMap().isConstant())
3131  return {};
3132  auto indices = llvm::to_vector<4>(
3133  llvm::map_range(getAffineMap().getConstantResults(),
3134  [](int64_t v) -> uint64_t { return v; }));
3135  return cstAttr.getValues<Attribute>()[indices];
3136 }
3137 
3138 //===----------------------------------------------------------------------===//
3139 // AffineStoreOp
3140 //===----------------------------------------------------------------------===//
3141 
3142 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3143  Value valueToStore, Value memref, AffineMap map,
3144  ValueRange mapOperands) {
3145  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3146  result.addOperands(valueToStore);
3147  result.addOperands(memref);
3148  result.addOperands(mapOperands);
3149  result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3150 }
3151 
3152 // Use identity map.
3153 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3154  Value valueToStore, Value memref,
3155  ValueRange indices) {
3156  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3157  int64_t rank = memrefType.getRank();
3158  // Create identity map for memrefs with at least one dimension or () -> ()
3159  // for zero-dimensional memrefs.
3160  auto map =
3161  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3162  build(builder, result, valueToStore, memref, map, indices);
3163 }
3164 
3165 ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
3166  auto indexTy = parser.getBuilder().getIndexType();
3167 
3168  MemRefType type;
3169  OpAsmParser::UnresolvedOperand storeValueInfo;
3170  OpAsmParser::UnresolvedOperand memrefInfo;
3171  AffineMapAttr mapAttr;
3173  return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3174  parser.parseOperand(memrefInfo) ||
3175  parser.parseAffineMapOfSSAIds(
3176  mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3177  result.attributes) ||
3178  parser.parseOptionalAttrDict(result.attributes) ||
3179  parser.parseColonType(type) ||
3180  parser.resolveOperand(storeValueInfo, type.getElementType(),
3181  result.operands) ||
3182  parser.resolveOperand(memrefInfo, type, result.operands) ||
3183  parser.resolveOperands(mapOperands, indexTy, result.operands));
3184 }
3185 
3187  p << " " << getValueToStore();
3188  p << ", " << getMemRef() << '[';
3189  if (AffineMapAttr mapAttr =
3190  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3191  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3192  p << ']';
3193  p.printOptionalAttrDict((*this)->getAttrs(),
3194  /*elidedAttrs=*/{getMapAttrStrName()});
3195  p << " : " << getMemRefType();
3196 }
3197 
3198 LogicalResult AffineStoreOp::verify() {
3199  // The value to store must have the same type as memref element type.
3200  auto memrefType = getMemRefType();
3201  if (getValueToStore().getType() != memrefType.getElementType())
3202  return emitOpError(
3203  "value to store must have the same type as memref element type");
3204 
3205  if (failed(verifyMemoryOpIndexing(
3206  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3207  getMapOperands(), memrefType,
3208  /*numIndexOperands=*/getNumOperands() - 2)))
3209  return failure();
3210 
3211  return success();
3212 }
3213 
3214 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3215  MLIRContext *context) {
3216  results.add<SimplifyAffineOp<AffineStoreOp>>(context);
3217 }
3218 
3219 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3220  SmallVectorImpl<OpFoldResult> &results) {
3221  /// store(memrefcast) -> store
3222  return memref::foldMemRefCast(*this, getValueToStore());
3223 }
3224 
3225 //===----------------------------------------------------------------------===//
3226 // AffineMinMaxOpBase
3227 //===----------------------------------------------------------------------===//
3228 
3229 template <typename T>
3230 static LogicalResult verifyAffineMinMaxOp(T op) {
3231  // Verify that operand count matches affine map dimension and symbol count.
3232  if (op.getNumOperands() !=
3233  op.getMap().getNumDims() + op.getMap().getNumSymbols())
3234  return op.emitOpError(
3235  "operand count and affine map dimension and symbol count must match");
3236 
3237  if (op.getMap().getNumResults() == 0)
3238  return op.emitOpError("affine map expect at least one result");
3239  return success();
3240 }
3241 
3242 template <typename T>
3243 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3244  p << ' ' << op->getAttr(T::getMapAttrStrName());
3245  auto operands = op.getOperands();
3246  unsigned numDims = op.getMap().getNumDims();
3247  p << '(' << operands.take_front(numDims) << ')';
3248 
3249  if (operands.size() != numDims)
3250  p << '[' << operands.drop_front(numDims) << ']';
3251  p.printOptionalAttrDict(op->getAttrs(),
3252  /*elidedAttrs=*/{T::getMapAttrStrName()});
3253 }
3254 
3255 template <typename T>
3256 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
3257  OperationState &result) {
3258  auto &builder = parser.getBuilder();
3259  auto indexType = builder.getIndexType();
3262  AffineMapAttr mapAttr;
3263  return failure(
3264  parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3265  result.attributes) ||
3266  parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) ||
3267  parser.parseOperandList(symInfos,
3269  parser.parseOptionalAttrDict(result.attributes) ||
3270  parser.resolveOperands(dimInfos, indexType, result.operands) ||
3271  parser.resolveOperands(symInfos, indexType, result.operands) ||
3272  parser.addTypeToList(indexType, result.types));
3273 }
3274 
3275 /// Fold an affine min or max operation with the given operands. The operand
3276 /// list may contain nulls, which are interpreted as the operand not being a
3277 /// constant.
3278 template <typename T>
3280  static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3281  "expected affine min or max op");
3282 
3283  // Fold the affine map.
3284  // TODO: Fold more cases:
3285  // min(some_affine, some_affine + constant, ...), etc.
3286  SmallVector<int64_t, 2> results;
3287  auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3288 
3289  if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3290  return op.getOperand(0);
3291 
3292  // If some of the map results are not constant, try changing the map in-place.
3293  if (results.empty()) {
3294  // If the map is the same, report that folding did not happen.
3295  if (foldedMap == op.getMap())
3296  return {};
3297  op->setAttr("map", AffineMapAttr::get(foldedMap));
3298  return op.getResult();
3299  }
3300 
3301  // Otherwise, completely fold the op into a constant.
3302  auto resultIt = std::is_same<T, AffineMinOp>::value
3303  ? llvm::min_element(results)
3304  : llvm::max_element(results);
3305  if (resultIt == results.end())
3306  return {};
3307  return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3308 }
3309 
3310 /// Remove duplicated expressions in affine min/max ops.
3311 template <typename T>
3314 
3315  LogicalResult matchAndRewrite(T affineOp,
3316  PatternRewriter &rewriter) const override {
3317  AffineMap oldMap = affineOp.getAffineMap();
3318 
3319  SmallVector<AffineExpr, 4> newExprs;
3320  for (AffineExpr expr : oldMap.getResults()) {
3321  // This is a linear scan over newExprs, but it should be fine given that
3322  // we typically just have a few expressions per op.
3323  if (!llvm::is_contained(newExprs, expr))
3324  newExprs.push_back(expr);
3325  }
3326 
3327  if (newExprs.size() == oldMap.getNumResults())
3328  return failure();
3329 
3330  auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
3331  newExprs, rewriter.getContext());
3332  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3333 
3334  return success();
3335  }
3336 };
3337 
3338 /// Merge an affine min/max op to its consumers if its consumer is also an
3339 /// affine min/max op.
3340 ///
3341 /// This pattern requires the producer affine min/max op is bound to a
3342 /// dimension/symbol that is used as a standalone expression in the consumer
3343 /// affine op's map.
3344 ///
3345 /// For example, a pattern like the following:
3346 ///
3347 /// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3348 /// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3349 ///
3350 /// Can be turned into:
3351 ///
3352 /// %1 = affine.min affine_map<
3353 /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3354 template <typename T>
3357 
3358  LogicalResult matchAndRewrite(T affineOp,
3359  PatternRewriter &rewriter) const override {
3360  AffineMap oldMap = affineOp.getAffineMap();
3361  ValueRange dimOperands =
3362  affineOp.getMapOperands().take_front(oldMap.getNumDims());
3363  ValueRange symOperands =
3364  affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3365 
3366  auto newDimOperands = llvm::to_vector<8>(dimOperands);
3367  auto newSymOperands = llvm::to_vector<8>(symOperands);
3368  SmallVector<AffineExpr, 4> newExprs;
3369  SmallVector<T, 4> producerOps;
3370 
3371  // Go over each expression to see whether it's a single dimension/symbol
3372  // with the corresponding operand which is the result of another affine
3373  // min/max op. If So it can be merged into this affine op.
3374  for (AffineExpr expr : oldMap.getResults()) {
3375  if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3376  Value symValue = symOperands[symExpr.getPosition()];
3377  if (auto producerOp = symValue.getDefiningOp<T>()) {
3378  producerOps.push_back(producerOp);
3379  continue;
3380  }
3381  } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3382  Value dimValue = dimOperands[dimExpr.getPosition()];
3383  if (auto producerOp = dimValue.getDefiningOp<T>()) {
3384  producerOps.push_back(producerOp);
3385  continue;
3386  }
3387  }
3388  // For the above cases we will remove the expression by merging the
3389  // producer affine min/max's affine expressions. Otherwise we need to
3390  // keep the existing expression.
3391  newExprs.push_back(expr);
3392  }
3393 
3394  if (producerOps.empty())
3395  return failure();
3396 
3397  unsigned numUsedDims = oldMap.getNumDims();
3398  unsigned numUsedSyms = oldMap.getNumSymbols();
3399 
3400  // Now go over all producer affine ops and merge their expressions.
3401  for (T producerOp : producerOps) {
3402  AffineMap producerMap = producerOp.getAffineMap();
3403  unsigned numProducerDims = producerMap.getNumDims();
3404  unsigned numProducerSyms = producerMap.getNumSymbols();
3405 
3406  // Collect all dimension/symbol values.
3407  ValueRange dimValues =
3408  producerOp.getMapOperands().take_front(numProducerDims);
3409  ValueRange symValues =
3410  producerOp.getMapOperands().take_back(numProducerSyms);
3411  newDimOperands.append(dimValues.begin(), dimValues.end());
3412  newSymOperands.append(symValues.begin(), symValues.end());
3413 
3414  // For expressions we need to shift to avoid overlap.
3415  for (AffineExpr expr : producerMap.getResults()) {
3416  newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3417  .shiftSymbols(numProducerSyms, numUsedSyms));
3418  }
3419 
3420  numUsedDims += numProducerDims;
3421  numUsedSyms += numProducerSyms;
3422  }
3423 
3424  auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
3425  rewriter.getContext());
3426  auto newOperands =
3427  llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3428  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3429 
3430  return success();
3431  }
3432 };
3433 
3434 /// Canonicalize the result expression order of an affine map and return success
3435 /// if the order changed.
3436 ///
3437 /// The function flattens the map's affine expressions to coefficient arrays and
3438 /// sorts them in lexicographic order. A coefficient array contains a multiplier
3439 /// for every dimension/symbol and a constant term. The canonicalization fails
3440 /// if a result expression is not pure or if the flattening requires local
3441 /// variables that, unlike dimensions and symbols, have no global order.
3442 static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
3443  SmallVector<SmallVector<int64_t>> flattenedExprs;
3444  for (const AffineExpr &resultExpr : map.getResults()) {
3445  // Fail if the expression is not pure.
3446  if (!resultExpr.isPureAffine())
3447  return failure();
3448 
3449  SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3450  auto flattenResult = flattener.walkPostOrder(resultExpr);
3451  if (failed(flattenResult))
3452  return failure();
3453 
3454  // Fail if the flattened expression has local variables.
3455  if (flattener.operandExprStack.back().size() !=
3456  map.getNumDims() + map.getNumSymbols() + 1)
3457  return failure();
3458 
3459  flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3460  flattener.operandExprStack.back().end());
3461  }
3462 
3463  // Fail if sorting is not necessary.
3464  if (llvm::is_sorted(flattenedExprs))
3465  return failure();
3466 
3467  // Reorder the result expressions according to their flattened form.
3468  SmallVector<unsigned> resultPermutation =
3469  llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults()));
3470  llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3471  return flattenedExprs[lhs] < flattenedExprs[rhs];
3472  });
3473  SmallVector<AffineExpr> newExprs;
3474  for (unsigned idx : resultPermutation)
3475  newExprs.push_back(map.getResult(idx));
3476 
3477  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs,
3478  map.getContext());
3479  return success();
3480 }
3481 
3482 /// Canonicalize the affine map result expression order of an affine min/max
3483 /// operation.
3484 ///
3485 /// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3486 /// expressions and replaces the operation if the order changed.
3487 ///
3488 /// For example, the following operation:
3489 ///
3490 /// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3491 ///
3492 /// Turns into:
3493 ///
3494 /// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3495 template <typename T>
3498 
3499  LogicalResult matchAndRewrite(T affineOp,
3500  PatternRewriter &rewriter) const override {
3501  AffineMap map = affineOp.getAffineMap();
3502  if (failed(canonicalizeMapExprAndTermOrder(map)))
3503  return failure();
3504  rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3505  return success();
3506  }
3507 };
3508 
3509 template <typename T>
3512 
3513  LogicalResult matchAndRewrite(T affineOp,
3514  PatternRewriter &rewriter) const override {
3515  if (affineOp.getMap().getNumResults() != 1)
3516  return failure();
3517  rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3518  affineOp.getOperands());
3519  return success();
3520  }
3521 };
3522 
3523 //===----------------------------------------------------------------------===//
3524 // AffineMinOp
3525 //===----------------------------------------------------------------------===//
3526 //
3527 // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3528 //
3529 
3530 OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3531  return foldMinMaxOp(*this, adaptor.getOperands());
3532 }
3533 
3534 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3535  MLIRContext *context) {
3538  MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3540  context);
3541 }
3542 
3543 LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); }
3544 
3545 ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3546  return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3547 }
3548 
3549 void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3550 
3551 //===----------------------------------------------------------------------===//
3552 // AffineMaxOp
3553 //===----------------------------------------------------------------------===//
3554 //
3555 // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3556 //
3557 
3558 OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3559  return foldMinMaxOp(*this, adaptor.getOperands());
3560 }
3561 
3562 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3563  MLIRContext *context) {
3566  MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3568  context);
3569 }
3570 
3571 LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); }
3572 
3573 ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3574  return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3575 }
3576 
3577 void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3578 
3579 //===----------------------------------------------------------------------===//
3580 // AffinePrefetchOp
3581 //===----------------------------------------------------------------------===//
3582 
3583 //
3584 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3585 //
3586 ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3587  OperationState &result) {
3588  auto &builder = parser.getBuilder();
3589  auto indexTy = builder.getIndexType();
3590 
3591  MemRefType type;
3592  OpAsmParser::UnresolvedOperand memrefInfo;
3593  IntegerAttr hintInfo;
3594  auto i32Type = parser.getBuilder().getIntegerType(32);
3595  StringRef readOrWrite, cacheType;
3596 
3597  AffineMapAttr mapAttr;
3599  if (parser.parseOperand(memrefInfo) ||
3600  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3601  AffinePrefetchOp::getMapAttrStrName(),
3602  result.attributes) ||
3603  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3604  parser.parseComma() || parser.parseKeyword("locality") ||
3605  parser.parseLess() ||
3606  parser.parseAttribute(hintInfo, i32Type,
3607  AffinePrefetchOp::getLocalityHintAttrStrName(),
3608  result.attributes) ||
3609  parser.parseGreater() || parser.parseComma() ||
3610  parser.parseKeyword(&cacheType) ||
3611  parser.parseOptionalAttrDict(result.attributes) ||
3612  parser.parseColonType(type) ||
3613  parser.resolveOperand(memrefInfo, type, result.operands) ||
3614  parser.resolveOperands(mapOperands, indexTy, result.operands))
3615  return failure();
3616 
3617  if (readOrWrite != "read" && readOrWrite != "write")
3618  return parser.emitError(parser.getNameLoc(),
3619  "rw specifier has to be 'read' or 'write'");
3620  result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3621  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
3622 
3623  if (cacheType != "data" && cacheType != "instr")
3624  return parser.emitError(parser.getNameLoc(),
3625  "cache type has to be 'data' or 'instr'");
3626 
3627  result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3628  parser.getBuilder().getBoolAttr(cacheType == "data"));
3629 
3630  return success();
3631 }
3632 
3634  p << " " << getMemref() << '[';
3635  AffineMapAttr mapAttr =
3636  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3637  if (mapAttr)
3638  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3639  p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3640  << "locality<" << getLocalityHint() << ">, "
3641  << (getIsDataCache() ? "data" : "instr");
3643  (*this)->getAttrs(),
3644  /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3645  getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3646  p << " : " << getMemRefType();
3647 }
3648 
3649 LogicalResult AffinePrefetchOp::verify() {
3650  auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3651  if (mapAttr) {
3652  AffineMap map = mapAttr.getValue();
3653  if (map.getNumResults() != getMemRefType().getRank())
3654  return emitOpError("affine.prefetch affine map num results must equal"
3655  " memref rank");
3656  if (map.getNumInputs() + 1 != getNumOperands())
3657  return emitOpError("too few operands");
3658  } else {
3659  if (getNumOperands() != 1)
3660  return emitOpError("too few operands");
3661  }
3662 
3663  Region *scope = getAffineScope(*this);
3664  for (auto idx : getMapOperands()) {
3665  if (!isValidAffineIndexOperand(idx, scope))
3666  return emitOpError(
3667  "index must be a valid dimension or symbol identifier");
3668  }
3669  return success();
3670 }
3671 
3672 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3673  MLIRContext *context) {
3674  // prefetch(memrefcast) -> prefetch
3675  results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3676 }
3677 
3678 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3679  SmallVectorImpl<OpFoldResult> &results) {
3680  /// prefetch(memrefcast) -> prefetch
3681  return memref::foldMemRefCast(*this);
3682 }
3683 
3684 //===----------------------------------------------------------------------===//
3685 // AffineParallelOp
3686 //===----------------------------------------------------------------------===//
3687 
3688 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3689  TypeRange resultTypes,
3690  ArrayRef<arith::AtomicRMWKind> reductions,
3691  ArrayRef<int64_t> ranges) {
3692  SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
3693  auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3694  return builder.getConstantAffineMap(value);
3695  }));
3696  SmallVector<int64_t> steps(ranges.size(), 1);
3697  build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
3698  /*ubArgs=*/{}, steps);
3699 }
3700 
3701 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3702  TypeRange resultTypes,
3703  ArrayRef<arith::AtomicRMWKind> reductions,
3704  ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
3705  ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
3706  ArrayRef<int64_t> steps) {
3707  assert(llvm::all_of(lbMaps,
3708  [lbMaps](AffineMap m) {
3709  return m.getNumDims() == lbMaps[0].getNumDims() &&
3710  m.getNumSymbols() == lbMaps[0].getNumSymbols();
3711  }) &&
3712  "expected all lower bounds maps to have the same number of dimensions "
3713  "and symbols");
3714  assert(llvm::all_of(ubMaps,
3715  [ubMaps](AffineMap m) {
3716  return m.getNumDims() == ubMaps[0].getNumDims() &&
3717  m.getNumSymbols() == ubMaps[0].getNumSymbols();
3718  }) &&
3719  "expected all upper bounds maps to have the same number of dimensions "
3720  "and symbols");
3721  assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3722  "expected lower bound maps to have as many inputs as lower bound "
3723  "operands");
3724  assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3725  "expected upper bound maps to have as many inputs as upper bound "
3726  "operands");
3727 
3728  OpBuilder::InsertionGuard guard(builder);
3729  result.addTypes(resultTypes);
3730 
3731  // Convert the reductions to integer attributes.
3732  SmallVector<Attribute, 4> reductionAttrs;
3733  for (arith::AtomicRMWKind reduction : reductions)
3734  reductionAttrs.push_back(
3735  builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
3736  result.addAttribute(getReductionsAttrStrName(),
3737  builder.getArrayAttr(reductionAttrs));
3738 
3739  // Concatenates maps defined in the same input space (same dimensions and
3740  // symbols), assumes there is at least one map.
3741  auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
3742  SmallVectorImpl<int32_t> &groups) {
3743  if (maps.empty())
3744  return AffineMap::get(builder.getContext());
3746  groups.reserve(groups.size() + maps.size());
3747  exprs.reserve(maps.size());
3748  for (AffineMap m : maps) {
3749  llvm::append_range(exprs, m.getResults());
3750  groups.push_back(m.getNumResults());
3751  }
3752  return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3753  maps[0].getContext());
3754  };
3755 
3756  // Set up the bounds.
3757  SmallVector<int32_t> lbGroups, ubGroups;
3758  AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3759  AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3760  result.addAttribute(getLowerBoundsMapAttrStrName(),
3761  AffineMapAttr::get(lbMap));
3762  result.addAttribute(getLowerBoundsGroupsAttrStrName(),
3763  builder.getI32TensorAttr(lbGroups));
3764  result.addAttribute(getUpperBoundsMapAttrStrName(),
3765  AffineMapAttr::get(ubMap));
3766  result.addAttribute(getUpperBoundsGroupsAttrStrName(),
3767  builder.getI32TensorAttr(ubGroups));
3768  result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
3769  result.addOperands(lbArgs);
3770  result.addOperands(ubArgs);
3771 
3772  // Create a region and a block for the body.
3773  auto *bodyRegion = result.addRegion();
3774  Block *body = builder.createBlock(bodyRegion);
3775 
3776  // Add all the block arguments.
3777  for (unsigned i = 0, e = steps.size(); i < e; ++i)
3778  body->addArgument(IndexType::get(builder.getContext()), result.location);
3779  if (resultTypes.empty())
3780  ensureTerminator(*bodyRegion, builder, result.location);
3781 }
3782 
3783 SmallVector<Region *> AffineParallelOp::getLoopRegions() {
3784  return {&getRegion()};
3785 }
3786 
3787 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
3788 
3789 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
3790  return getOperands().take_front(getLowerBoundsMap().getNumInputs());
3791 }
3792 
3793 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
3794  return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
3795 }
3796 
3797 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
3798  auto values = getLowerBoundsGroups().getValues<int32_t>();
3799  unsigned start = 0;
3800  for (unsigned i = 0; i < pos; ++i)
3801  start += values[i];
3802  return getLowerBoundsMap().getSliceMap(start, values[pos]);
3803 }
3804 
3805 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
3806  auto values = getUpperBoundsGroups().getValues<int32_t>();
3807  unsigned start = 0;
3808  for (unsigned i = 0; i < pos; ++i)
3809  start += values[i];
3810  return getUpperBoundsMap().getSliceMap(start, values[pos]);
3811 }
3812 
3813 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
3814  return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
3815 }
3816 
3817 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
3818  return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
3819 }
3820 
3821 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
3822  if (hasMinMaxBounds())
3823  return std::nullopt;
3824 
3825  // Try to convert all the ranges to constant expressions.
3827  AffineValueMap rangesValueMap;
3828  AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
3829  &rangesValueMap);
3830  out.reserve(rangesValueMap.getNumResults());
3831  for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
3832  auto expr = rangesValueMap.getResult(i);
3833  auto cst = dyn_cast<AffineConstantExpr>(expr);
3834  if (!cst)
3835  return std::nullopt;
3836  out.push_back(cst.getValue());
3837  }
3838  return out;
3839 }
3840 
3841 Block *AffineParallelOp::getBody() { return &getRegion().front(); }
3842 
3843 OpBuilder AffineParallelOp::getBodyBuilder() {
3844  return OpBuilder(getBody(), std::prev(getBody()->end()));
3845 }
3846 
3847 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
3848  assert(lbOperands.size() == map.getNumInputs() &&
3849  "operands to map must match number of inputs");
3850 
3851  auto ubOperands = getUpperBoundsOperands();
3852 
3853  SmallVector<Value, 4> newOperands(lbOperands);
3854  newOperands.append(ubOperands.begin(), ubOperands.end());
3855  (*this)->setOperands(newOperands);
3856 
3857  setLowerBoundsMapAttr(AffineMapAttr::get(map));
3858 }
3859 
3860 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
3861  assert(ubOperands.size() == map.getNumInputs() &&
3862  "operands to map must match number of inputs");
3863 
3864  SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
3865  newOperands.append(ubOperands.begin(), ubOperands.end());
3866  (*this)->setOperands(newOperands);
3867 
3868  setUpperBoundsMapAttr(AffineMapAttr::get(map));
3869 }
3870 
3871 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
3872  setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
3873 }
3874 
3875 // check whether resultType match op or not in affine.parallel
3876 static bool isResultTypeMatchAtomicRMWKind(Type resultType,
3877  arith::AtomicRMWKind op) {
3878  switch (op) {
3879  case arith::AtomicRMWKind::addf:
3880  return isa<FloatType>(resultType);
3881  case arith::AtomicRMWKind::addi:
3882  return isa<IntegerType>(resultType);
3883  case arith::AtomicRMWKind::assign:
3884  return true;
3885  case arith::AtomicRMWKind::mulf:
3886  return isa<FloatType>(resultType);
3887  case arith::AtomicRMWKind::muli:
3888  return isa<IntegerType>(resultType);
3889  case arith::AtomicRMWKind::maximumf:
3890  return isa<FloatType>(resultType);
3891  case arith::AtomicRMWKind::minimumf:
3892  return isa<FloatType>(resultType);
3893  case arith::AtomicRMWKind::maxs: {
3894  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3895  return intType && intType.isSigned();
3896  }
3897  case arith::AtomicRMWKind::mins: {
3898  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3899  return intType && intType.isSigned();
3900  }
3901  case arith::AtomicRMWKind::maxu: {
3902  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3903  return intType && intType.isUnsigned();
3904  }
3905  case arith::AtomicRMWKind::minu: {
3906  auto intType = llvm::dyn_cast<IntegerType>(resultType);
3907  return intType && intType.isUnsigned();
3908  }
3909  case arith::AtomicRMWKind::ori:
3910  return isa<IntegerType>(resultType);
3911  case arith::AtomicRMWKind::andi:
3912  return isa<IntegerType>(resultType);
3913  default:
3914  return false;
3915  }
3916 }
3917 
3918 LogicalResult AffineParallelOp::verify() {
3919  auto numDims = getNumDims();
3920  if (getLowerBoundsGroups().getNumElements() != numDims ||
3921  getUpperBoundsGroups().getNumElements() != numDims ||
3922  getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
3923  return emitOpError() << "the number of region arguments ("
3924  << getBody()->getNumArguments()
3925  << ") and the number of map groups for lower ("
3926  << getLowerBoundsGroups().getNumElements()
3927  << ") and upper bound ("
3928  << getUpperBoundsGroups().getNumElements()
3929  << "), and the number of steps (" << getSteps().size()
3930  << ") must all match";
3931  }
3932 
3933  unsigned expectedNumLBResults = 0;
3934  for (APInt v : getLowerBoundsGroups()) {
3935  unsigned results = v.getZExtValue();
3936  if (results == 0)
3937  return emitOpError()
3938  << "expected lower bound map to have at least one result";
3939  expectedNumLBResults += results;
3940  }
3941  if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
3942  return emitOpError() << "expected lower bounds map to have "
3943  << expectedNumLBResults << " results";
3944  unsigned expectedNumUBResults = 0;
3945  for (APInt v : getUpperBoundsGroups()) {
3946  unsigned results = v.getZExtValue();
3947  if (results == 0)
3948  return emitOpError()
3949  << "expected upper bound map to have at least one result";
3950  expectedNumUBResults += results;
3951  }
3952  if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
3953  return emitOpError() << "expected upper bounds map to have "
3954  << expectedNumUBResults << " results";
3955 
3956  if (getReductions().size() != getNumResults())
3957  return emitOpError("a reduction must be specified for each output");
3958 
3959  // Verify reduction ops are all valid and each result type matches reduction
3960  // ops
3961  for (auto it : llvm::enumerate((getReductions()))) {
3962  Attribute attr = it.value();
3963  auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
3964  if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
3965  return emitOpError("invalid reduction attribute");
3966  auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
3967  if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
3968  return emitOpError("result type cannot match reduction attribute");
3969  }
3970 
3971  // Verify that the bound operands are valid dimension/symbols.
3972  /// Lower bounds.
3973  if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
3974  getLowerBoundsMap().getNumDims())))
3975  return failure();
3976  /// Upper bounds.
3977  if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
3978  getUpperBoundsMap().getNumDims())))
3979  return failure();
3980  return success();
3981 }
3982 
3983 LogicalResult AffineValueMap::canonicalize() {
3984  SmallVector<Value, 4> newOperands{operands};
3985  auto newMap = getAffineMap();
3986  composeAffineMapAndOperands(&newMap, &newOperands);
3987  if (newMap == getAffineMap() && newOperands == operands)
3988  return failure();
3989  reset(newMap, newOperands);
3990  return success();
3991 }
3992 
3993 /// Canonicalize the bounds of the given loop.
3994 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
3995  AffineValueMap lb = op.getLowerBoundsValueMap();
3996  bool lbCanonicalized = succeeded(lb.canonicalize());
3997 
3998  AffineValueMap ub = op.getUpperBoundsValueMap();
3999  bool ubCanonicalized = succeeded(ub.canonicalize());
4000 
4001  // Any canonicalization change always leads to updated map(s).
4002  if (!lbCanonicalized && !ubCanonicalized)
4003  return failure();
4004 
4005  if (lbCanonicalized)
4006  op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
4007  if (ubCanonicalized)
4008  op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
4009 
4010  return success();
4011 }
4012 
4013 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4014  SmallVectorImpl<OpFoldResult> &results) {
4015  return canonicalizeLoopBounds(*this);
4016 }
4017 
4018 /// Prints a lower(upper) bound of an affine parallel loop with max(min)
4019 /// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
4020 /// identifies which of the those expressions form max/min groups. `operands`
4021 /// are the SSA values of dimensions and symbols and `keyword` is either "min"
4022 /// or "max".
4023 static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
4024  DenseIntElementsAttr group, ValueRange operands,
4025  StringRef keyword) {
4026  AffineMap map = mapAttr.getValue();
4027  unsigned numDims = map.getNumDims();
4028  ValueRange dimOperands = operands.take_front(numDims);
4029  ValueRange symOperands = operands.drop_front(numDims);
4030  unsigned start = 0;
4031  for (llvm::APInt groupSize : group) {
4032  if (start != 0)
4033  p << ", ";
4034 
4035  unsigned size = groupSize.getZExtValue();
4036  if (size == 1) {
4037  p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
4038  ++start;
4039  } else {
4040  p << keyword << '(';
4041  AffineMap submap = map.getSliceMap(start, size);
4042  p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
4043  p << ')';
4044  start += size;
4045  }
4046  }
4047 }
4048 
4050  p << " (" << getBody()->getArguments() << ") = (";
4051  printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
4052  getLowerBoundsOperands(), "max");
4053  p << ") to (";
4054  printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
4055  getUpperBoundsOperands(), "min");
4056  p << ')';
4057  SmallVector<int64_t, 8> steps = getSteps();
4058  bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4059  if (!elideSteps) {
4060  p << " step (";
4061  llvm::interleaveComma(steps, p);
4062  p << ')';
4063  }
4064  if (getNumResults()) {
4065  p << " reduce (";
4066  llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4067  arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4068  llvm::cast<IntegerAttr>(attr).getInt());
4069  p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4070  });
4071  p << ") -> (" << getResultTypes() << ")";
4072  }
4073 
4074  p << ' ';
4075  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4076  /*printBlockTerminators=*/getNumResults());
4078  (*this)->getAttrs(),
4079  /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4080  AffineParallelOp::getLowerBoundsMapAttrStrName(),
4081  AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4082  AffineParallelOp::getUpperBoundsMapAttrStrName(),
4083  AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4084  AffineParallelOp::getStepsAttrStrName()});
4085 }
4086 
4087 /// Given a list of lists of parsed operands, populates `uniqueOperands` with
4088 /// unique operands. Also populates `replacements with affine expressions of
4089 /// `kind` that can be used to update affine maps previously accepting a
4090 /// `operands` to accept `uniqueOperands` instead.
4092  OpAsmParser &parser,
4094  SmallVectorImpl<Value> &uniqueOperands,
4097  "expected operands to be dim or symbol expression");
4098 
4099  Type indexType = parser.getBuilder().getIndexType();
4100  for (const auto &list : operands) {
4101  SmallVector<Value> valueOperands;
4102  if (parser.resolveOperands(list, indexType, valueOperands))
4103  return failure();
4104  for (Value operand : valueOperands) {
4105  unsigned pos = std::distance(uniqueOperands.begin(),
4106  llvm::find(uniqueOperands, operand));
4107  if (pos == uniqueOperands.size())
4108  uniqueOperands.push_back(operand);
4109  replacements.push_back(
4111  ? getAffineDimExpr(pos, parser.getContext())
4112  : getAffineSymbolExpr(pos, parser.getContext()));
4113  }
4114  }
4115  return success();
4116 }
4117 
4118 namespace {
4119 enum class MinMaxKind { Min, Max };
4120 } // namespace
4121 
4122 /// Parses an affine map that can contain a min/max for groups of its results,
4123 /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4124 /// `result` attributes with the map (flat list of expressions) and the grouping
4125 /// (list of integers that specify how many expressions to put into each
4126 /// min/max) attributes. Deduplicates repeated operands.
4127 ///
4128 /// parallel-bound ::= `(` parallel-group-list `)`
4129 /// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
4130 /// parallel-group ::= simple-group | min-max-group
4131 /// simple-group ::= expr-of-ssa-ids
4132 /// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4133 /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4134 ///
4135 /// Examples:
4136 /// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4137 /// (%0, max(%1 - 2 * %2))
4138 static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
4139  OperationState &result,
4140  MinMaxKind kind) {
4141  // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4142  // see: https://reviews.llvm.org/D134227#3821753
4143  const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4144 
4145  StringRef mapName = kind == MinMaxKind::Min
4146  ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4147  : AffineParallelOp::getLowerBoundsMapAttrStrName();
4148  StringRef groupsName =
4149  kind == MinMaxKind::Min
4150  ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4151  : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4152 
4153  if (failed(parser.parseLParen()))
4154  return failure();
4155 
4156  if (succeeded(parser.parseOptionalRParen())) {
4157  result.addAttribute(
4158  mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4159  result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
4160  return success();
4161  }
4162 
4163  SmallVector<AffineExpr> flatExprs;
4166  SmallVector<int32_t> numMapsPerGroup;
4168  auto parseOperands = [&]() {
4169  if (succeeded(parser.parseOptionalKeyword(
4170  kind == MinMaxKind::Min ? "min" : "max"))) {
4171  mapOperands.clear();
4172  AffineMapAttr map;
4173  if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName,
4174  result.attributes,
4176  return failure();
4177  result.attributes.erase(tmpAttrStrName);
4178  llvm::append_range(flatExprs, map.getValue().getResults());
4179  auto operandsRef = llvm::ArrayRef(mapOperands);
4180  auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4182  auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4184  flatDimOperands.append(map.getValue().getNumResults(), dims);
4185  flatSymOperands.append(map.getValue().getNumResults(), syms);
4186  numMapsPerGroup.push_back(map.getValue().getNumResults());
4187  } else {
4188  if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
4189  flatSymOperands.emplace_back(),
4190  flatExprs.emplace_back())))
4191  return failure();
4192  numMapsPerGroup.push_back(1);
4193  }
4194  return success();
4195  };
4196  if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
4197  return failure();
4198 
4199  unsigned totalNumDims = 0;
4200  unsigned totalNumSyms = 0;
4201  for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4202  unsigned numDims = flatDimOperands[i].size();
4203  unsigned numSyms = flatSymOperands[i].size();
4204  flatExprs[i] = flatExprs[i]
4205  .shiftDims(numDims, totalNumDims)
4206  .shiftSymbols(numSyms, totalNumSyms);
4207  totalNumDims += numDims;
4208  totalNumSyms += numSyms;
4209  }
4210 
4211  // Deduplicate map operands.
4212  SmallVector<Value> dimOperands, symOperands;
4213  SmallVector<AffineExpr> dimRplacements, symRepacements;
4214  if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
4215  dimRplacements, AffineExprKind::DimId) ||
4216  deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
4217  symRepacements, AffineExprKind::SymbolId))
4218  return failure();
4219 
4220  result.operands.append(dimOperands.begin(), dimOperands.end());
4221  result.operands.append(symOperands.begin(), symOperands.end());
4222 
4223  Builder &builder = parser.getBuilder();
4224  auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4225  parser.getContext());
4226  flatMap = flatMap.replaceDimsAndSymbols(
4227  dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4228 
4229  result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4230  result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
4231  return success();
4232 }
4233 
4234 //
4235 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4236 // `to` parallel-bound steps? region attr-dict?
4237 // steps ::= `steps` `(` integer-literals `)`
4238 //
4239 ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4240  OperationState &result) {
4241  auto &builder = parser.getBuilder();
4242  auto indexType = builder.getIndexType();
4245  parser.parseEqual() ||
4246  parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
4247  parser.parseKeyword("to") ||
4248  parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
4249  return failure();
4250 
4251  AffineMapAttr stepsMapAttr;
4252  NamedAttrList stepsAttrs;
4254  if (failed(parser.parseOptionalKeyword("step"))) {
4255  SmallVector<int64_t, 4> steps(ivs.size(), 1);
4256  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4257  builder.getI64ArrayAttr(steps));
4258  } else {
4259  if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
4260  AffineParallelOp::getStepsAttrStrName(),
4261  stepsAttrs,
4263  return failure();
4264 
4265  // Convert steps from an AffineMap into an I64ArrayAttr.
4267  auto stepsMap = stepsMapAttr.getValue();
4268  for (const auto &result : stepsMap.getResults()) {
4269  auto constExpr = dyn_cast<AffineConstantExpr>(result);
4270  if (!constExpr)
4271  return parser.emitError(parser.getNameLoc(),
4272  "steps must be constant integers");
4273  steps.push_back(constExpr.getValue());
4274  }
4275  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4276  builder.getI64ArrayAttr(steps));
4277  }
4278 
4279  // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4280  // quoted strings are a member of the enum AtomicRMWKind.
4281  SmallVector<Attribute, 4> reductions;
4282  if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4283  if (parser.parseLParen())
4284  return failure();
4285  auto parseAttributes = [&]() -> ParseResult {
4286  // Parse a single quoted string via the attribute parsing, and then
4287  // verify it is a member of the enum and convert to it's integer
4288  // representation.
4289  StringAttr attrVal;
4290  NamedAttrList attrStorage;
4291  auto loc = parser.getCurrentLocation();
4292  if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
4293  attrStorage))
4294  return failure();
4295  std::optional<arith::AtomicRMWKind> reduction =
4296  arith::symbolizeAtomicRMWKind(attrVal.getValue());
4297  if (!reduction)
4298  return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4299  reductions.push_back(
4300  builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4301  // While we keep getting commas, keep parsing.
4302  return success();
4303  };
4304  if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
4305  return failure();
4306  }
4307  result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4308  builder.getArrayAttr(reductions));
4309 
4310  // Parse return types of reductions (if any)
4311  if (parser.parseOptionalArrowTypeList(result.types))
4312  return failure();
4313 
4314  // Now parse the body.
4315  Region *body = result.addRegion();
4316  for (auto &iv : ivs)
4317  iv.type = indexType;
4318  if (parser.parseRegion(*body, ivs) ||
4319  parser.parseOptionalAttrDict(result.attributes))
4320  return failure();
4321 
4322  // Add a terminator if none was parsed.
4323  AffineParallelOp::ensureTerminator(*body, builder, result.location);
4324  return success();
4325 }
4326 
4327 //===----------------------------------------------------------------------===//
4328 // AffineYieldOp
4329 //===----------------------------------------------------------------------===//
4330 
4331 LogicalResult AffineYieldOp::verify() {
4332  auto *parentOp = (*this)->getParentOp();
4333  auto results = parentOp->getResults();
4334  auto operands = getOperands();
4335 
4336  if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4337  return emitOpError() << "only terminates affine.if/for/parallel regions";
4338  if (parentOp->getNumResults() != getNumOperands())
4339  return emitOpError() << "parent of yield must have same number of "
4340  "results as the yield operands";
4341  for (auto it : llvm::zip(results, operands)) {
4342  if (std::get<0>(it).getType() != std::get<1>(it).getType())
4343  return emitOpError() << "types mismatch between yield op and its parent";
4344  }
4345 
4346  return success();
4347 }
4348 
4349 //===----------------------------------------------------------------------===//
4350 // AffineVectorLoadOp
4351 //===----------------------------------------------------------------------===//
4352 
4353 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4354  VectorType resultType, AffineMap map,
4355  ValueRange operands) {
4356  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4357  result.addOperands(operands);
4358  if (map)
4359  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4360  result.types.push_back(resultType);
4361 }
4362 
4363 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4364  VectorType resultType, Value memref,
4365  AffineMap map, ValueRange mapOperands) {
4366  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4367  result.addOperands(memref);
4368  result.addOperands(mapOperands);
4369  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4370  result.types.push_back(resultType);
4371 }
4372 
4373 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4374  VectorType resultType, Value memref,
4375  ValueRange indices) {
4376  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4377  int64_t rank = memrefType.getRank();
4378  // Create identity map for memrefs with at least one dimension or () -> ()
4379  // for zero-dimensional memrefs.
4380  auto map =
4381  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4382  build(builder, result, resultType, memref, map, indices);
4383 }
4384 
4385 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4386  MLIRContext *context) {
4387  results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4388 }
4389 
4390 ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4391  OperationState &result) {
4392  auto &builder = parser.getBuilder();
4393  auto indexTy = builder.getIndexType();
4394 
4395  MemRefType memrefType;
4396  VectorType resultType;
4397  OpAsmParser::UnresolvedOperand memrefInfo;
4398  AffineMapAttr mapAttr;
4400  return failure(
4401  parser.parseOperand(memrefInfo) ||
4402  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4403  AffineVectorLoadOp::getMapAttrStrName(),
4404  result.attributes) ||
4405  parser.parseOptionalAttrDict(result.attributes) ||
4406  parser.parseColonType(memrefType) || parser.parseComma() ||
4407  parser.parseType(resultType) ||
4408  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4409  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
4410  parser.addTypeToList(resultType, result.types));
4411 }
4412 
4414  p << " " << getMemRef() << '[';
4415  if (AffineMapAttr mapAttr =
4416  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4417  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4418  p << ']';
4419  p.printOptionalAttrDict((*this)->getAttrs(),
4420  /*elidedAttrs=*/{getMapAttrStrName()});
4421  p << " : " << getMemRefType() << ", " << getType();
4422 }
4423 
4424 /// Verify common invariants of affine.vector_load and affine.vector_store.
4425 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4426  VectorType vectorType) {
4427  // Check that memref and vector element types match.
4428  if (memrefType.getElementType() != vectorType.getElementType())
4429  return op->emitOpError(
4430  "requires memref and vector types of the same elemental type");
4431  return success();
4432 }
4433 
4434 LogicalResult AffineVectorLoadOp::verify() {
4435  MemRefType memrefType = getMemRefType();
4436  if (failed(verifyMemoryOpIndexing(
4437  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4438  getMapOperands(), memrefType,
4439  /*numIndexOperands=*/getNumOperands() - 1)))
4440  return failure();
4441 
4442  if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
4443  return failure();
4444 
4445  return success();
4446 }
4447 
4448 //===----------------------------------------------------------------------===//
4449 // AffineVectorStoreOp
4450 //===----------------------------------------------------------------------===//
4451 
4452 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4453  Value valueToStore, Value memref, AffineMap map,
4454  ValueRange mapOperands) {
4455  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4456  result.addOperands(valueToStore);
4457  result.addOperands(memref);
4458  result.addOperands(mapOperands);
4459  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4460 }
4461 
4462 // Use identity map.
4463 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4464  Value valueToStore, Value memref,
4465  ValueRange indices) {
4466  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4467  int64_t rank = memrefType.getRank();
4468  // Create identity map for memrefs with at least one dimension or () -> ()
4469  // for zero-dimensional memrefs.
4470  auto map =
4471  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4472  build(builder, result, valueToStore, memref, map, indices);
4473 }
4474 void AffineVectorStoreOp::getCanonicalizationPatterns(
4475  RewritePatternSet &results, MLIRContext *context) {
4476  results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4477 }
4478 
4479 ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4480  OperationState &result) {
4481  auto indexTy = parser.getBuilder().getIndexType();
4482 
4483  MemRefType memrefType;
4484  VectorType resultType;
4485  OpAsmParser::UnresolvedOperand storeValueInfo;
4486  OpAsmParser::UnresolvedOperand memrefInfo;
4487  AffineMapAttr mapAttr;
4489  return failure(
4490  parser.parseOperand(storeValueInfo) || parser.parseComma() ||
4491  parser.parseOperand(memrefInfo) ||
4492  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4493  AffineVectorStoreOp::getMapAttrStrName(),
4494  result.attributes) ||
4495  parser.parseOptionalAttrDict(result.attributes) ||
4496  parser.parseColonType(memrefType) || parser.parseComma() ||
4497  parser.parseType(resultType) ||
4498  parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
4499  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4500  parser.resolveOperands(mapOperands, indexTy, result.operands));
4501 }
4502 
4504  p << " " << getValueToStore();
4505  p << ", " << getMemRef() << '[';
4506  if (AffineMapAttr mapAttr =
4507  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4508  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4509  p << ']';
4510  p.printOptionalAttrDict((*this)->getAttrs(),
4511  /*elidedAttrs=*/{getMapAttrStrName()});
4512  p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4513 }
4514 
4515 LogicalResult AffineVectorStoreOp::verify() {
4516  MemRefType memrefType = getMemRefType();
4517  if (failed(verifyMemoryOpIndexing(
4518  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4519  getMapOperands(), memrefType,
4520  /*numIndexOperands=*/getNumOperands() - 2)))
4521  return failure();
4522 
4523  if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4524  return failure();
4525 
4526  return success();
4527 }
4528 
4529 //===----------------------------------------------------------------------===//
4530 // DelinearizeIndexOp
4531 //===----------------------------------------------------------------------===//
4532 
4533 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4534  OperationState &odsState,
4535  Value linearIndex, ValueRange dynamicBasis,
4536  ArrayRef<int64_t> staticBasis,
4537  bool hasOuterBound) {
4538  SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
4539  : staticBasis.size() + 1,
4540  linearIndex.getType());
4541  build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4542  staticBasis);
4543 }
4544 
4545 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4546  OperationState &odsState,
4547  Value linearIndex, ValueRange basis,
4548  bool hasOuterBound) {
4549  if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
4550  hasOuterBound = false;
4551  basis = basis.drop_front();
4552  }
4553  SmallVector<Value> dynamicBasis;
4554  SmallVector<int64_t> staticBasis;
4555  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4556  staticBasis);
4557  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4558  hasOuterBound);
4559 }
4560 
4561 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4562  OperationState &odsState,
4563  Value linearIndex,
4564  ArrayRef<OpFoldResult> basis,
4565  bool hasOuterBound) {
4566  if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
4567  hasOuterBound = false;
4568  basis = basis.drop_front();
4569  }
4570  SmallVector<Value> dynamicBasis;
4571  SmallVector<int64_t> staticBasis;
4572  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
4573  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4574  hasOuterBound);
4575 }
4576 
4577 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4578  OperationState &odsState,
4579  Value linearIndex, ArrayRef<int64_t> basis,
4580  bool hasOuterBound) {
4581  build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound);
4582 }
4583 
4584 LogicalResult AffineDelinearizeIndexOp::verify() {
4585  ArrayRef<int64_t> staticBasis = getStaticBasis();
4586  if (getNumResults() != staticBasis.size() &&
4587  getNumResults() != staticBasis.size() + 1)
4588  return emitOpError("should return an index for each basis element and up "
4589  "to one extra index");
4590 
4591  auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4592  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4593  return emitOpError(
4594  "mismatch between dynamic and static basis (kDynamic marker but no "
4595  "corresponding dynamic basis entry) -- this can only happen due to an "
4596  "incorrect fold/rewrite");
4597 
4598  if (!llvm::all_of(staticBasis, [](int64_t v) {
4599  return v > 0 || ShapedType::isDynamic(v);
4600  }))
4601  return emitOpError("no basis element may be statically non-positive");
4602 
4603  return success();
4604 }
4605 
4606 /// Given mixed basis of affine.delinearize_index/linearize_index replace
4607 /// constant SSA values with the constant integer value and return the new
4608 /// static basis. In case no such candidate for replacement exists, this utility
4609 /// returns std::nullopt.
4610 static std::optional<SmallVector<int64_t>>
4612  MutableOperandRange mutableDynamicBasis,
4613  ArrayRef<Attribute> dynamicBasis) {
4614  uint64_t dynamicBasisIndex = 0;
4615  for (OpFoldResult basis : dynamicBasis) {
4616  if (basis) {
4617  mutableDynamicBasis.erase(dynamicBasisIndex);
4618  } else {
4619  ++dynamicBasisIndex;
4620  }
4621  }
4622 
4623  // No constant SSA value exists.
4624  if (dynamicBasisIndex == dynamicBasis.size())
4625  return std::nullopt;
4626 
4627  SmallVector<int64_t> staticBasis;
4628  for (OpFoldResult basis : mixedBasis) {
4629  std::optional<int64_t> basisVal = getConstantIntValue(basis);
4630  if (!basisVal)
4631  staticBasis.push_back(ShapedType::kDynamic);
4632  else
4633  staticBasis.push_back(*basisVal);
4634  }
4635 
4636  return staticBasis;
4637 }
4638 
4639 LogicalResult
4640 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4642  std::optional<SmallVector<int64_t>> maybeStaticBasis =
4643  foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
4644  adaptor.getDynamicBasis());
4645  if (maybeStaticBasis) {
4646  setStaticBasis(*maybeStaticBasis);
4647  return success();
4648  }
4649  // If we won't be doing any division or modulo (no basis or the one basis
4650  // element is purely advisory), simply return the input value.
4651  if (getNumResults() == 1) {
4652  result.push_back(getLinearIndex());
4653  return success();
4654  }
4655 
4656  if (adaptor.getLinearIndex() == nullptr)
4657  return failure();
4658 
4659  if (!adaptor.getDynamicBasis().empty())
4660  return failure();
4661 
4662  int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
4663  Type attrType = getLinearIndex().getType();
4664 
4665  ArrayRef<int64_t> staticBasis = getStaticBasis();
4666  if (hasOuterBound())
4667  staticBasis = staticBasis.drop_front();
4668  for (int64_t modulus : llvm::reverse(staticBasis)) {
4669  result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
4670  highPart = llvm::divideFloorSigned(highPart, modulus);
4671  }
4672  result.push_back(IntegerAttr::get(attrType, highPart));
4673  std::reverse(result.begin(), result.end());
4674  return success();
4675 }
4676 
4677 SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
4678  OpBuilder builder(getContext());
4679  if (hasOuterBound()) {
4680  if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4681  return getMixedValues(getStaticBasis().drop_front(),
4682  getDynamicBasis().drop_front(), builder);
4683 
4684  return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4685  builder);
4686  }
4687 
4688  return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4689 }
4690 
4691 SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
4692  SmallVector<OpFoldResult> ret = getMixedBasis();
4693  if (!hasOuterBound())
4694  ret.insert(ret.begin(), OpFoldResult());
4695  return ret;
4696 }
4697 
4698 namespace {
4699 
4700 // Drops delinearization indices that correspond to unit-extent basis
4701 struct DropUnitExtentBasis
4702  : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4704 
4705  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4706  PatternRewriter &rewriter) const override {
4707  SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
4708  std::optional<Value> zero = std::nullopt;
4709  Location loc = delinearizeOp->getLoc();
4710  auto getZero = [&]() -> Value {
4711  if (!zero)
4712  zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
4713  return zero.value();
4714  };
4715 
4716  // Replace all indices corresponding to unit-extent basis with 0.
4717  // Remaining basis can be used to get a new `affine.delinearize_index` op.
4718  SmallVector<OpFoldResult> newBasis;
4719  for (auto [index, basis] :
4720  llvm::enumerate(delinearizeOp.getPaddedBasis())) {
4721  std::optional<int64_t> basisVal =
4722  basis ? getConstantIntValue(basis) : std::nullopt;
4723  if (basisVal && *basisVal == 1)
4724  replacements[index] = getZero();
4725  else
4726  newBasis.push_back(basis);
4727  }
4728 
4729  if (newBasis.size() == delinearizeOp.getNumResults())
4730  return rewriter.notifyMatchFailure(delinearizeOp,
4731  "no unit basis elements");
4732 
4733  if (!newBasis.empty()) {
4734  // Will drop the leading nullptr from `basis` if there was no outer bound.
4735  auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
4736  loc, delinearizeOp.getLinearIndex(), newBasis);
4737  int newIndex = 0;
4738  // Map back the new delinearized indices to the values they replace.
4739  for (auto &replacement : replacements) {
4740  if (replacement)
4741  continue;
4742  replacement = newDelinearizeOp->getResult(newIndex++);
4743  }
4744  }
4745 
4746  rewriter.replaceOp(delinearizeOp, replacements);
4747  return success();
4748  }
4749 };
4750 
4751 /// If a `affine.delinearize_index`'s input is a `affine.linearize_index
4752 /// disjoint` and the two operations end with the same basis elements,
4753 /// cancel those parts of the operations out because they are inverses
4754 /// of each other.
4755 ///
4756 /// If the operations have the same basis, cancel them entirely.
4757 ///
4758 /// The `disjoint` flag is needed on the `affine.linearize_index` because
4759 /// otherwise, there is no guarantee that the inputs to the linearization are
4760 /// in-bounds the way the outputs of the delinearization would be.
4761 struct CancelDelinearizeOfLinearizeDisjointExactTail
4762  : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4764 
4765  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4766  PatternRewriter &rewriter) const override {
4767  auto linearizeOp = delinearizeOp.getLinearIndex()
4768  .getDefiningOp<affine::AffineLinearizeIndexOp>();
4769  if (!linearizeOp)
4770  return rewriter.notifyMatchFailure(delinearizeOp,
4771  "index doesn't come from linearize");
4772 
4773  if (!linearizeOp.getDisjoint())
4774  return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
4775 
4776  ValueRange linearizeIns = linearizeOp.getMultiIndex();
4777  // Note: we use the full basis so we don't lose outer bounds later.
4778  SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
4779  SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
4780  size_t numMatches = 0;
4781  for (auto [linSize, delinSize] : llvm::zip(
4782  llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
4783  if (linSize != delinSize)
4784  break;
4785  ++numMatches;
4786  }
4787 
4788  if (numMatches == 0)
4789  return rewriter.notifyMatchFailure(
4790  delinearizeOp, "final basis element doesn't match linearize");
4791 
4792  // The easy case: everything lines up and the basis match sup completely.
4793  if (numMatches == linearizeBasis.size() &&
4794  numMatches == delinearizeBasis.size() &&
4795  linearizeIns.size() == delinearizeOp.getNumResults()) {
4796  rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
4797  return success();
4798  }
4799 
4800  Value newLinearize = rewriter.create<affine::AffineLinearizeIndexOp>(
4801  linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
4802  ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
4803  linearizeOp.getDisjoint());
4804  auto newDelinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
4805  delinearizeOp.getLoc(), newLinearize,
4806  ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
4807  delinearizeOp.hasOuterBound());
4808  SmallVector<Value> mergedResults(newDelinearize.getResults());
4809  mergedResults.append(linearizeIns.take_back(numMatches).begin(),
4810  linearizeIns.take_back(numMatches).end());
4811  rewriter.replaceOp(delinearizeOp, mergedResults);
4812  return success();
4813  }
4814 };
4815 
4816 /// If the input to a delinearization is a disjoint linearization, and the
4817 /// last k > 1 components of the delinearization basis multiply to the
4818 /// last component of the linearization basis, break the linearization and
4819 /// delinearization into two parts, peeling off the last input to linearization.
4820 ///
4821 /// For example:
4822 /// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
4823 /// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
4824 /// becomes
4825 /// %0 = affine.linearize_index [%z, %y] by (3, 2) : index
4826 /// %1:2 = affine.delinearize_index %0 by (2, 3) : index
4827 /// %2:2 = affine.delinearize_index %x by (8, 4) : index
4828 /// where the original %1:4 is replaced by %1:2 ++ %2:2
4829 struct SplitDelinearizeSpanningLastLinearizeArg final
4830  : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4832 
4833  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4834  PatternRewriter &rewriter) const override {
4835  auto linearizeOp = delinearizeOp.getLinearIndex()
4836  .getDefiningOp<affine::AffineLinearizeIndexOp>();
4837  if (!linearizeOp)
4838  return rewriter.notifyMatchFailure(delinearizeOp,
4839  "index doesn't come from linearize");
4840 
4841  if (!linearizeOp.getDisjoint())
4842  return rewriter.notifyMatchFailure(linearizeOp,
4843  "linearize isn't disjoint");
4844 
4845  int64_t target = linearizeOp.getStaticBasis().back();
4846  if (ShapedType::isDynamic(target))
4847  return rewriter.notifyMatchFailure(
4848  linearizeOp, "linearize ends with dynamic basis value");
4849 
4850  int64_t sizeToSplit = 1;
4851  size_t elemsToSplit = 0;
4852  ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
4853  for (int64_t basisElem : llvm::reverse(basis)) {
4854  if (ShapedType::isDynamic(basisElem))
4855  return rewriter.notifyMatchFailure(
4856  delinearizeOp, "dynamic basis element while scanning for split");
4857  sizeToSplit *= basisElem;
4858  elemsToSplit += 1;
4859 
4860  if (sizeToSplit > target)
4861  return rewriter.notifyMatchFailure(delinearizeOp,
4862  "overshot last argument size");
4863  if (sizeToSplit == target)
4864  break;
4865  }
4866 
4867  if (sizeToSplit < target)
4868  return rewriter.notifyMatchFailure(
4869  delinearizeOp, "product of known basis elements doesn't exceed last "
4870  "linearize argument");
4871 
4872  if (elemsToSplit < 2)
4873  return rewriter.notifyMatchFailure(
4874  delinearizeOp,
4875  "need at least two elements to form the basis product");
4876 
4877  Value linearizeWithoutBack =
4878  rewriter.create<affine::AffineLinearizeIndexOp>(
4879  linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
4880  linearizeOp.getDynamicBasis(),
4881  linearizeOp.getStaticBasis().drop_back(),
4882  linearizeOp.getDisjoint());
4883  auto delinearizeWithoutSplitPart =
4884  rewriter.create<affine::AffineDelinearizeIndexOp>(
4885  delinearizeOp.getLoc(), linearizeWithoutBack,
4886  delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
4887  delinearizeOp.hasOuterBound());
4888  auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
4889  delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
4890  basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
4891  SmallVector<Value> results = llvm::to_vector(
4892  llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
4893  delinearizeBack.getResults()));
4894  rewriter.replaceOp(delinearizeOp, results);
4895 
4896  return success();
4897  }
4898 };
4899 } // namespace
4900 
4901 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
4902  RewritePatternSet &patterns, MLIRContext *context) {
4903  patterns
4904  .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
4905  DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
4906  context);
4907 }
4908 
4909 //===----------------------------------------------------------------------===//
4910 // LinearizeIndexOp
4911 //===----------------------------------------------------------------------===//
4912 
4913 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4914  OperationState &odsState,
4915  ValueRange multiIndex, ValueRange basis,
4916  bool disjoint) {
4917  if (!basis.empty() && basis.front() == Value())
4918  basis = basis.drop_front();
4919  SmallVector<Value> dynamicBasis;
4920  SmallVector<int64_t> staticBasis;
4921  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4922  staticBasis);
4923  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4924 }
4925 
4926 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4927  OperationState &odsState,
4928  ValueRange multiIndex,
4929  ArrayRef<OpFoldResult> basis,
4930  bool disjoint) {
4931  if (!basis.empty() && basis.front() == OpFoldResult())
4932  basis = basis.drop_front();
4933  SmallVector<Value> dynamicBasis;
4934  SmallVector<int64_t> staticBasis;
4935  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
4936  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
4937 }
4938 
4939 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
4940  OperationState &odsState,
4941  ValueRange multiIndex,
4942  ArrayRef<int64_t> basis, bool disjoint) {
4943  build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
4944 }
4945 
4946 LogicalResult AffineLinearizeIndexOp::verify() {
4947  size_t numIndexes = getMultiIndex().size();
4948  size_t numBasisElems = getStaticBasis().size();
4949  if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
4950  return emitOpError("should be passed a basis element for each index except "
4951  "possibly the first");
4952 
4953  auto dynamicMarkersCount =
4954  llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
4955  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4956  return emitOpError(
4957  "mismatch between dynamic and static basis (kDynamic marker but no "
4958  "corresponding dynamic basis entry) -- this can only happen due to an "
4959  "incorrect fold/rewrite");
4960 
4961  return success();
4962 }
4963 
4964 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
4965  std::optional<SmallVector<int64_t>> maybeStaticBasis =
4966  foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
4967  adaptor.getDynamicBasis());
4968  if (maybeStaticBasis) {
4969  setStaticBasis(*maybeStaticBasis);
4970  return getResult();
4971  }
4972  // No indices linearizes to zero.
4973  if (getMultiIndex().empty())
4974  return IntegerAttr::get(getResult().getType(), 0);
4975 
4976  // One single index linearizes to itself.
4977  if (getMultiIndex().size() == 1)
4978  return getMultiIndex().front();
4979 
4980  if (llvm::any_of(adaptor.getMultiIndex(),
4981  [](Attribute a) { return a == nullptr; }))
4982  return nullptr;
4983 
4984  if (!adaptor.getDynamicBasis().empty())
4985  return nullptr;
4986 
4987  int64_t result = 0;
4988  int64_t stride = 1;
4989  for (auto [length, indexAttr] :
4990  llvm::zip_first(llvm::reverse(getStaticBasis()),
4991  llvm::reverse(adaptor.getMultiIndex()))) {
4992  result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
4993  stride = stride * length;
4994  }
4995  // Handle the index element with no basis element.
4996  if (!hasOuterBound())
4997  result =
4998  result +
4999  cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
5000 
5001  return IntegerAttr::get(getResult().getType(), result);
5002 }
5003 
5004 SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
5005  OpBuilder builder(getContext());
5006  if (hasOuterBound()) {
5007  if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5008  return getMixedValues(getStaticBasis().drop_front(),
5009  getDynamicBasis().drop_front(), builder);
5010 
5011  return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5012  builder);
5013  }
5014 
5015  return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5016 }
5017 
5018 SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
5019  SmallVector<OpFoldResult> ret = getMixedBasis();
5020  if (!hasOuterBound())
5021  ret.insert(ret.begin(), OpFoldResult());
5022  return ret;
5023 }
5024 
5025 namespace {
5026 /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
5027 /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
5028 /// %...d)`.
5029 
5030 /// Note that `disjoint` is required here, because, without it, we could have
5031 /// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
5032 /// is a valid operation where the `%c64` cannot be trivially dropped.
5033 ///
5034 /// Alternatively, if `%x` in the above is a known constant 0, remove it even if
5035 /// the operation isn't asserted to be `disjoint`.
5036 struct DropLinearizeUnitComponentsIfDisjointOrZero final
5037  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5039 
5040  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5041  PatternRewriter &rewriter) const override {
5042  ValueRange multiIndex = op.getMultiIndex();
5043  size_t numIndices = multiIndex.size();
5044  SmallVector<Value> newIndices;
5045  newIndices.reserve(numIndices);
5046  SmallVector<OpFoldResult> newBasis;
5047  newBasis.reserve(numIndices);
5048 
5049  if (!op.hasOuterBound()) {
5050  newIndices.push_back(multiIndex.front());
5051  multiIndex = multiIndex.drop_front();
5052  }
5053 
5054  SmallVector<OpFoldResult> basis = op.getMixedBasis();
5055  for (auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5056  std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
5057  if (!basisEntry || *basisEntry != 1) {
5058  newIndices.push_back(index);
5059  newBasis.push_back(basisElem);
5060  continue;
5061  }
5062 
5063  std::optional<int64_t> indexValue = getConstantIntValue(index);
5064  if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5065  newIndices.push_back(index);
5066  newBasis.push_back(basisElem);
5067  continue;
5068  }
5069  }
5070  if (newIndices.size() == numIndices)
5071  return rewriter.notifyMatchFailure(op,
5072  "no unit basis entries to replace");
5073 
5074  if (newIndices.size() == 0) {
5075  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
5076  return success();
5077  }
5078  rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5079  op, newIndices, newBasis, op.getDisjoint());
5080  return success();
5081  }
5082 };
5083 
5084 /// Return the product of `terms`, creating an `affine.apply` if any of them are
5085 /// non-constant values. If any of `terms` is `nullptr`, return `nullptr`.
5086 static OpFoldResult computeProduct(Location loc, OpBuilder &builder,
5087  ArrayRef<OpFoldResult> terms) {
5088  int64_t nDynamic = 0;
5089  SmallVector<Value> dynamicPart;
5090  AffineExpr result = builder.getAffineConstantExpr(1);
5091  for (OpFoldResult term : terms) {
5092  if (!term)
5093  return term;
5094  std::optional<int64_t> maybeConst = getConstantIntValue(term);
5095  if (maybeConst) {
5096  result = result * builder.getAffineConstantExpr(*maybeConst);
5097  } else {
5098  dynamicPart.push_back(cast<Value>(term));
5099  result = result * builder.getAffineSymbolExpr(nDynamic++);
5100  }
5101  }
5102  if (auto constant = dyn_cast<AffineConstantExpr>(result))
5103  return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
5104  return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
5105 }
5106 
5107 /// If conseceutive outputs of a delinearize_index are linearized with the same
5108 /// bounds, canonicalize away the redundant arithmetic.
5109 ///
5110 /// That is, if we have
5111 /// ```
5112 /// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
5113 /// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
5114 /// by (...e, B1, B2, ..., BK, ...f)
5115 /// ```
5116 ///
5117 /// We can rewrite this to
5118 /// ```
5119 /// B = B1 * B2 ... BK
5120 /// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
5121 /// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
5122 /// ```
5123 /// where we replace all results of %s unaffected by the change with results
5124 /// from %sMerged.
5125 ///
5126 /// As a special case, if all results of the delinearize are merged in this way
5127 /// we can replace those usages with %x, thus cancelling the delinearization
5128 /// entirely, as in
5129 /// ```
5130 /// %s:3 = affine.delinearize_index %x into (2, 4, 8)
5131 /// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
5132 /// ```
5133 /// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
5134 struct CancelLinearizeOfDelinearizePortion final
5135  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5137 
5138 private:
5139  // Struct representing a case where the cancellation pattern
5140  // applies. A `Match` means that `length` inputs to the linearize operation
5141  // starting at `linStart` can be cancelled with `length` outputs of
5142  // `delinearize`, starting from `delinStart`.
5143  struct Match {
5144  AffineDelinearizeIndexOp delinearize;
5145  unsigned linStart = 0;
5146  unsigned delinStart = 0;
5147  unsigned length = 0;
5148  };
5149 
5150 public:
5151  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5152  PatternRewriter &rewriter) const override {
5153  SmallVector<Match> matches;
5154 
5155  const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
5156  ArrayRef<OpFoldResult> linBasisRef = linBasis;
5157 
5158  ValueRange multiIndex = linearizeOp.getMultiIndex();
5159  unsigned numLinArgs = multiIndex.size();
5160  unsigned linArgIdx = 0;
5161  // We only want to replace one run from the same delinearize op per
5162  // pattern invocation lest we run into invalidation issues.
5163  llvm::SmallPtrSet<Operation *, 2> alreadyMatchedDelinearize;
5164  while (linArgIdx < numLinArgs) {
5165  auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5166  if (!asResult) {
5167  linArgIdx++;
5168  continue;
5169  }
5170 
5171  auto delinearizeOp =
5172  dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5173  if (!delinearizeOp) {
5174  linArgIdx++;
5175  continue;
5176  }
5177 
5178  /// Result 0 of the delinearize and argument 0 of the linearize can
5179  /// leave their maximum value unspecified. However, even if this happens
5180  /// we can still sometimes start the match process. Specifically, if
5181  /// - The argument we're matching is result 0 and argument 0 (so the
5182  /// bounds don't matter). For example,
5183  ///
5184  /// %0:2 = affine.delinearize_index %x into (8) : index, index
5185  /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
5186  /// allows cancellation
5187  /// - The delinearization doesn't specify a bound, but the linearization
5188  /// is `disjoint`, which asserts that the bound on the linearization is
5189  /// correct.
5190  unsigned delinArgIdx = asResult.getResultNumber();
5191  SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
5192  OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5193  OpFoldResult firstLinBound = linBasis[linArgIdx];
5194  bool boundsMatch = firstDelinBound == firstLinBound;
5195  bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5196  bool knownByDisjoint =
5197  linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5198  if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5199  linArgIdx++;
5200  continue;
5201  }
5202 
5203  unsigned j = 1;
5204  unsigned numDelinOuts = delinearizeOp.getNumResults();
5205  for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
5206  ++j) {
5207  if (multiIndex[linArgIdx + j] !=
5208  delinearizeOp.getResult(delinArgIdx + j))
5209  break;
5210  if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
5211  break;
5212  }
5213  // If there're multiple matches against the same delinearize_index,
5214  // only rewrite the first one we find to prevent invalidations. The next
5215  // ones will be taken care of by subsequent pattern invocations.
5216  if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5217  linArgIdx++;
5218  continue;
5219  }
5220  matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx, j});
5221  linArgIdx += j;
5222  }
5223 
5224  if (matches.empty())
5225  return rewriter.notifyMatchFailure(
5226  linearizeOp, "no run of delinearize outputs to deal with");
5227 
5228  // Record all the delinearize replacements so we can do them after creating
5229  // the new linearization operation, since the new operation might use
5230  // outputs of something we're replacing.
5231  SmallVector<SmallVector<Value>> delinearizeReplacements;
5232 
5233  SmallVector<Value> newIndex;
5234  newIndex.reserve(numLinArgs);
5235  SmallVector<OpFoldResult> newBasis;
5236  newBasis.reserve(numLinArgs);
5237  unsigned prevMatchEnd = 0;
5238  for (Match m : matches) {
5239  unsigned gap = m.linStart - prevMatchEnd;
5240  llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5241  llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5242  // Update here so we don't forget this during early continues
5243  prevMatchEnd = m.linStart + m.length;
5244 
5245  PatternRewriter::InsertionGuard g(rewriter);
5246  rewriter.setInsertionPoint(m.delinearize);
5247 
5248  ArrayRef<OpFoldResult> basisToMerge =
5249  linBasisRef.slice(m.linStart, m.length);
5250  // We use the slice from the linearize's basis above because of the
5251  // "bounds inferred from `disjoint`" case above.
5252  OpFoldResult newSize =
5253  computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
5254 
5255  // Trivial case where we can just skip past the delinearize all together
5256  if (m.length == m.delinearize.getNumResults()) {
5257  newIndex.push_back(m.delinearize.getLinearIndex());
5258  newBasis.push_back(newSize);
5259  // Pad out set of replacements so we don't do anything with this one.
5260  delinearizeReplacements.push_back(SmallVector<Value>());
5261  continue;
5262  }
5263 
5264  SmallVector<Value> newDelinResults;
5265  SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
5266  newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5267  newDelinBasis.begin() + m.delinStart + m.length);
5268  newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5269  auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
5270  m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5271  newDelinBasis);
5272 
5273  // Since there may be other uses of the indices we just merged together,
5274  // create a residual affine.delinearize_index that delinearizes the
5275  // merged output into its component parts.
5276  Value combinedElem = newDelinearize.getResult(m.delinStart);
5277  auto residualDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
5278  m.delinearize.getLoc(), combinedElem, basisToMerge);
5279 
5280  // Swap all the uses of the unaffected delinearize outputs to the new
5281  // delinearization so that the old code can be removed if this
5282  // linearize_index is the only user of the merged results.
5283  llvm::append_range(newDelinResults,
5284  newDelinearize.getResults().take_front(m.delinStart));
5285  llvm::append_range(newDelinResults, residualDelinearize.getResults());
5286  llvm::append_range(
5287  newDelinResults,
5288  newDelinearize.getResults().drop_front(m.delinStart + 1));
5289 
5290  delinearizeReplacements.push_back(newDelinResults);
5291  newIndex.push_back(combinedElem);
5292  newBasis.push_back(newSize);
5293  }
5294  llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5295  llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5296  rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
5297  linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5298 
5299  for (auto [m, newResults] :
5300  llvm::zip_equal(matches, delinearizeReplacements)) {
5301  if (newResults.empty())
5302  continue;
5303  rewriter.replaceOp(m.delinearize, newResults);
5304  }
5305 
5306  return success();
5307  }
5308 };
5309 
5310 /// Strip leading zero from affine.linearize_index.
5311 ///
5312 /// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
5313 /// to `affine.linearize_index [...a] by (...b)` in all cases.
5314 struct DropLinearizeLeadingZero final
5315  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5317 
5318  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5319  PatternRewriter &rewriter) const override {
5320  Value leadingIdx = op.getMultiIndex().front();
5321  if (!matchPattern(leadingIdx, m_Zero()))
5322  return failure();
5323 
5324  if (op.getMultiIndex().size() == 1) {
5325  rewriter.replaceOp(op, leadingIdx);
5326  return success();
5327  }
5328 
5329  SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
5330  ArrayRef<OpFoldResult> newMixedBasis = mixedBasis;
5331  if (op.hasOuterBound())
5332  newMixedBasis = newMixedBasis.drop_front();
5333 
5334  rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5335  op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5336  return success();
5337  }
5338 };
5339 } // namespace
5340 
5341 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5342  RewritePatternSet &patterns, MLIRContext *context) {
5343  patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5344  DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5345 }
5346 
5347 //===----------------------------------------------------------------------===//
5348 // TableGen'd op method definitions
5349 //===----------------------------------------------------------------------===//
5350 
5351 #define GET_OP_CLASSES
5352 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
Definition: AffineOps.cpp:2673
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
Definition: AffineOps.cpp:2393
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
Definition: AffineOps.cpp:1179
static LogicalResult verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
Definition: AffineOps.cpp:3065
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
Definition: AffineOps.cpp:3243
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
Definition: AffineOps.cpp:3876
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
Definition: AffineOps.cpp:60
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
Definition: AffineOps.cpp:4023
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
Definition: AffineOps.cpp:1015
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
Definition: AffineOps.cpp:3279
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
Definition: AffineOps.cpp:2247
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
Definition: AffineOps.cpp:801
static bool isValidAffineIndexOperand(Value value, Region *region)
Definition: AffineOps.cpp:480
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1371
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
Definition: AffineOps.cpp:1086
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:736
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
Definition: AffineOps.cpp:1936
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:728
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
Definition: AffineOps.cpp:2958
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
Definition: AffineOps.cpp:1038
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1328
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
Definition: AffineOps.cpp:4425
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
Definition: AffineOps.cpp:904
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
Definition: AffineOps.cpp:3256
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
Definition: AffineOps.cpp:339
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
Definition: AffineOps.cpp:676
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
Definition: AffineOps.cpp:4138
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
Definition: AffineOps.cpp:2682
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
Definition: AffineOps.cpp:485
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
Definition: AffineOps.cpp:638
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1265
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
Definition: AffineOps.cpp:2632
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
Definition: AffineOps.cpp:2201
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
Definition: AffineOps.cpp:517
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1280
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
Definition: AffineOps.cpp:358
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr &quotientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
Definition: AffineOps.cpp:704
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
Definition: AffineOps.cpp:4091
static LogicalResult verifyAffineMinMaxOp(T op)
Definition: AffineOps.cpp:3230
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
Definition: AffineOps.cpp:2111
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
Definition: AffineOps.cpp:4611
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
Definition: AffineOps.cpp:3442
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
union mlir::linalg::@1183::ArityGroupAndKind::Kind kind
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:76
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:81
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:921
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:35
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
Definition: AffineExpr.cpp:243
MLIRContext * getContext() const
Definition: AffineExpr.cpp:33
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:662
MLIRContext * getContext() const
Definition: AffineMap.cpp:343
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
Definition: AffineMap.h:221
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:267
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
Definition: AffineMap.h:228
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineMap.cpp:500
unsigned getNumInputs() const
Definition: AffineMap.cpp:403
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
Definition: AffineMap.h:280
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:515
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
Definition: AffineMap.cpp:128
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:654
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
Definition: AffineMap.cpp:434
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:312
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
AffineMap getDimIdentityMap()
Definition: Builders.cpp:379
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:383
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:364
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:368
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
Definition: Builders.cpp:175
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
NoneType getNoneType()
Definition: Builders.cpp:84
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
Definition: Builders.cpp:372
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:374
MLIRContext * getContext() const
Definition: Builders.h:56
AffineMap getSymbolIdentityMap()
Definition: Builders.cpp:392
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:277
IndexType getIndexType()
Definition: Builders.cpp:51
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition: IntegerSet.h:44
unsigned getNumDims() const
Definition: IntegerSet.cpp:15
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
Definition: IntegerSet.cpp:57
unsigned getNumInputs() const
Definition: IntegerSet.cpp:17
ArrayRef< AffineExpr > getConstraints() const
Definition: IntegerSet.cpp:41
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
Definition: IntegerSet.cpp:51
unsigned getNumSymbols() const
Definition: IntegerSet.cpp:16
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:115
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:267
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:435
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
operand_range::iterator operand_iterator
Definition: Operation.h:372
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:865
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:412
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:648
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:632
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
AffineBound represents a lower or upper bound in the for operation.
Definition: AffineOps.h:518
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
Definition: AffineOps.h:101
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
Definition: AffineOps.h:310
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
Definition: AffineOps.cpp:3983
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
unsigned getNumResults() const
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
Definition: AffineOps.cpp:2695
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
Definition: AffineOps.cpp:1148
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
Definition: AffineOps.cpp:2609
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
Definition: AffineOps.cpp:288
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1254
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
Definition: AffineOps.cpp:2581
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: AffineOps.cpp:2585
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1158
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1447
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1319
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
Definition: AffineOps.cpp:2573
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1312
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Definition: AffineOps.cpp:248
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
Definition: AffineOps.cpp:273
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
Definition: AffineOps.cpp:1452
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
Definition: AffineOps.cpp:2616
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
Definition: AffineOps.cpp:402
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1208
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Definition: AffineOps.cpp:2596
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
Definition: AffineOps.cpp:263
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
Definition: AffineOps.cpp:495
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
Definition: AffineOps.cpp:2577
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
Definition: AffineOps.cpp:1274
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the type of the given value can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:773
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
Definition: AffineMap.cpp:783
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
AffineExprKind
Definition: AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:70
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:645
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:419
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
Definition: AffineMap.cpp:745
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:631
Canonicalize the affine map result expression order of an affine min/max operation.
Definition: AffineOps.cpp:3496
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3499
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3513
Remove duplicated expressions in affine min/max ops.
Definition: AffineOps.cpp:3312
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3315
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
Definition: AffineOps.cpp:3355
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3358
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:283
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:368
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.