MLIR  22.0.0git
AffineOps.cpp
Go to the documentation of this file.
1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/IRMapping.h"
17 #include "mlir/IR/IntegerSet.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/OpDefinition.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/Value.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallBitVector.h"
27 #include "llvm/ADT/SmallVectorExtras.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/DebugLog.h"
30 #include "llvm/Support/LogicalResult.h"
31 #include "llvm/Support/MathExtras.h"
32 #include <numeric>
33 #include <optional>
34 
35 using namespace mlir;
36 using namespace mlir::affine;
37 
38 using llvm::divideCeilSigned;
39 using llvm::divideFloorSigned;
40 using llvm::mod;
41 
42 #define DEBUG_TYPE "affine-ops"
43 
44 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
45 
46 /// A utility function to check if a value is defined at the top level of
47 /// `region` or is an argument of `region`. A value of index type defined at the
48 /// top level of a `AffineScope` region is always a valid symbol for all
49 /// uses in that region.
51  if (auto arg = dyn_cast<BlockArgument>(value))
52  return arg.getParentRegion() == region;
53  return value.getDefiningOp()->getParentRegion() == region;
54 }
55 
56 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
57 /// region remains legal if the operation that uses it is inlined into `dest`
58 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
59 /// `isValidSymbol`, depending on the value being required to remain a valid
60 /// dimension or symbol.
61 static bool
63  const IRMapping &mapping,
64  function_ref<bool(Value, Region *)> legalityCheck) {
65  // If the value is a valid dimension for any other reason than being
66  // a top-level value, it will remain valid: constants get inlined
67  // with the function, transitive affine applies also get inlined and
68  // will be checked themselves, etc.
69  if (!isTopLevelValue(value, src))
70  return true;
71 
72  // If it's a top-level value because it's a block operand, i.e. a
73  // function argument, check whether the value replacing it after
74  // inlining is a valid dimension in the new region.
75  if (llvm::isa<BlockArgument>(value))
76  return legalityCheck(mapping.lookup(value), dest);
77 
78  // If it's a top-level value because it's defined in the region,
79  // it can only be inlined if the defining op is a constant or a
80  // `dim`, which can appear anywhere and be valid, since the defining
81  // op won't be top-level anymore after inlining.
82  Attribute operandCst;
83  bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp());
84  return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
85  isDimLikeOp;
86 }
87 
88 /// Checks if all values known to be legal affine dimensions or symbols in `src`
89 /// remain so if their respective users are inlined into `dest`.
90 static bool
92  const IRMapping &mapping,
93  function_ref<bool(Value, Region *)> legalityCheck) {
94  return llvm::all_of(values, [&](Value v) {
95  return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
96  });
97 }
98 
99 /// Checks if an affine read or write operation remains legal after inlining
100 /// from `src` to `dest`.
101 template <typename OpTy>
102 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
103  const IRMapping &mapping) {
104  static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
105  AffineWriteOpInterface>::value,
106  "only ops with affine read/write interface are supported");
107 
108  AffineMap map = op.getAffineMap();
109  ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
110  ValueRange symbolOperands =
111  op.getMapOperands().take_back(map.getNumSymbols());
113  dimOperands, src, dest, mapping,
114  static_cast<bool (*)(Value, Region *)>(isValidDim)))
115  return false;
117  symbolOperands, src, dest, mapping,
118  static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
119  return false;
120  return true;
121 }
122 
123 /// Checks if an affine apply operation remains legal after inlining from `src`
124 /// to `dest`.
125 // Use "unused attribute" marker to silence clang-tidy warning stemming from
126 // the inability to see through "llvm::TypeSwitch".
127 template <>
128 [[maybe_unused]] bool remainsLegalAfterInline(AffineApplyOp op, Region *src,
129  Region *dest,
130  const IRMapping &mapping) {
131  // If it's a valid dimension, we need to check that it remains so.
132  if (isValidDim(op.getResult(), src))
134  op.getMapOperands(), src, dest, mapping,
135  static_cast<bool (*)(Value, Region *)>(isValidDim));
136 
137  // Otherwise it must be a valid symbol, check that it remains so.
139  op.getMapOperands(), src, dest, mapping,
140  static_cast<bool (*)(Value, Region *)>(isValidSymbol));
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // AffineDialect Interfaces
145 //===----------------------------------------------------------------------===//
146 
147 namespace {
148 /// This class defines the interface for handling inlining with affine
149 /// operations.
150 struct AffineInlinerInterface : public DialectInlinerInterface {
152 
153  //===--------------------------------------------------------------------===//
154  // Analysis Hooks
155  //===--------------------------------------------------------------------===//
156 
157  /// Returns true if the given region 'src' can be inlined into the region
158  /// 'dest' that is attached to an operation registered to the current dialect.
159  /// 'wouldBeCloned' is set if the region is cloned into its new location
160  /// rather than moved, indicating there may be other users.
161  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
162  IRMapping &valueMapping) const final {
163  // We can inline into affine loops and conditionals if this doesn't break
164  // affine value categorization rules.
165  Operation *destOp = dest->getParentOp();
166  if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
167  return false;
168 
169  // Multi-block regions cannot be inlined into affine constructs, all of
170  // which require single-block regions.
171  if (!src->hasOneBlock())
172  return false;
173 
174  // Side-effecting operations that the affine dialect cannot understand
175  // should not be inlined.
176  Block &srcBlock = src->front();
177  for (Operation &op : srcBlock) {
178  // Ops with no side effects are fine,
179  if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
180  if (iface.hasNoEffect())
181  continue;
182  }
183 
184  // Assuming the inlined region is valid, we only need to check if the
185  // inlining would change it.
186  bool remainsValid =
188  .Case<AffineApplyOp, AffineReadOpInterface,
189  AffineWriteOpInterface>([&](auto op) {
190  return remainsLegalAfterInline(op, src, dest, valueMapping);
191  })
192  .Default([](Operation *) {
193  // Conservatively disallow inlining ops we cannot reason about.
194  return false;
195  });
196 
197  if (!remainsValid)
198  return false;
199  }
200 
201  return true;
202  }
203 
204  /// Returns true if the given operation 'op', that is registered to this
205  /// dialect, can be inlined into the given region, false otherwise.
206  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
207  IRMapping &valueMapping) const final {
208  // Always allow inlining affine operations into a region that is marked as
209  // affine scope, or into affine loops and conditionals. There are some edge
210  // cases when inlining *into* affine structures, but that is handled in the
211  // other 'isLegalToInline' hook above.
212  Operation *parentOp = region->getParentOp();
213  return parentOp->hasTrait<OpTrait::AffineScope>() ||
214  isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
215  }
216 
217  /// Affine regions should be analyzed recursively.
218  bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
219 };
220 } // namespace
221 
222 //===----------------------------------------------------------------------===//
223 // AffineDialect
224 //===----------------------------------------------------------------------===//
225 
226 void AffineDialect::initialize() {
227  addOperations<AffineDmaStartOp, AffineDmaWaitOp,
228 #define GET_OP_LIST
229 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
230  >();
231  addInterfaces<AffineInlinerInterface>();
232  declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
233  AffineMinOp>();
234 }
235 
236 /// Materialize a single constant operation from a given attribute value with
237 /// the desired resultant type.
239  Attribute value, Type type,
240  Location loc) {
241  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
242  return ub::PoisonOp::create(builder, loc, type, poison);
243  return arith::ConstantOp::materialize(builder, value, type, loc);
244 }
245 
246 /// A utility function to check if a value is defined at the top level of an
247 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
248 /// conservatively assume it is not top-level. A value of index type defined at
249 /// the top level is always a valid symbol.
251  if (auto arg = dyn_cast<BlockArgument>(value)) {
252  // The block owning the argument may be unlinked, e.g. when the surrounding
253  // region has not yet been attached to an Op, at which point the parent Op
254  // is null.
255  Operation *parentOp = arg.getOwner()->getParentOp();
256  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
257  }
258  // The defining Op may live in an unlinked block so its parent Op may be null.
259  Operation *parentOp = value.getDefiningOp()->getParentOp();
260  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
261 }
262 
263 /// Returns the closest region enclosing `op` that is held by an operation with
264 /// trait `AffineScope`; `nullptr` if there is no such region.
266  auto *curOp = op;
267  while (auto *parentOp = curOp->getParentOp()) {
268  if (parentOp->hasTrait<OpTrait::AffineScope>())
269  return curOp->getParentRegion();
270  curOp = parentOp;
271  }
272  return nullptr;
273 }
274 
276  Operation *curOp = op;
277  while (auto *parentOp = curOp->getParentOp()) {
278  if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))
279  return curOp->getParentRegion();
280  curOp = parentOp;
281  }
282  return nullptr;
283 }
284 
285 // A Value can be used as a dimension id iff it meets one of the following
286 // conditions:
287 // *) It is valid as a symbol.
288 // *) It is an induction variable.
289 // *) It is the result of affine apply operation with dimension id arguments.
291  // The value must be an index type.
292  if (!value.getType().isIndex())
293  return false;
294 
295  if (auto *defOp = value.getDefiningOp())
296  return isValidDim(value, getAffineScope(defOp));
297 
298  // This value has to be a block argument for an op that has the
299  // `AffineScope` trait or an induction var of an affine.for or
300  // affine.parallel.
301  if (isAffineInductionVar(value))
302  return true;
303  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
304  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
305 }
306 
307 // Value can be used as a dimension id iff it meets one of the following
308 // conditions:
309 // *) It is valid as a symbol.
310 // *) It is an induction variable.
311 // *) It is the result of an affine apply operation with dimension id operands.
312 // *) It is the result of a more specialized index transformation (ex.
313 // delinearize_index or linearize_index) with dimension id operands.
314 bool mlir::affine::isValidDim(Value value, Region *region) {
315  // The value must be an index type.
316  if (!value.getType().isIndex())
317  return false;
318 
319  // All valid symbols are okay.
320  if (isValidSymbol(value, region))
321  return true;
322 
323  auto *op = value.getDefiningOp();
324  if (!op) {
325  // This value has to be an induction var for an affine.for or an
326  // affine.parallel.
327  return isAffineInductionVar(value);
328  }
329 
330  // Affine apply operation is ok if all of its operands are ok.
331  if (auto applyOp = dyn_cast<AffineApplyOp>(op))
332  return applyOp.isValidDim(region);
333  // delinearize_index and linearize_index are special forms of apply
334  // and so are valid dimensions if all their arguments are valid dimensions.
335  if (isa<AffineDelinearizeIndexOp, AffineLinearizeIndexOp>(op))
336  return llvm::all_of(op->getOperands(),
337  [&](Value arg) { return ::isValidDim(arg, region); });
338  // The dim op is okay if its operand memref/tensor is defined at the top
339  // level.
340  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
341  return isTopLevelValue(dimOp.getShapedValue());
342  return false;
343 }
344 
345 /// Returns true if the 'index' dimension of the `memref` defined by
346 /// `memrefDefOp` is a statically shaped one or defined using a valid symbol
347 /// for `region`.
348 template <typename AnyMemRefDefOp>
349 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
350  Region *region) {
351  MemRefType memRefType = memrefDefOp.getType();
352 
353  // Dimension index is out of bounds.
354  if (index >= memRefType.getRank()) {
355  return false;
356  }
357 
358  // Statically shaped.
359  if (!memRefType.isDynamicDim(index))
360  return true;
361  // Get the position of the dimension among dynamic dimensions;
362  unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
363  return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
364  region);
365 }
366 
367 /// Returns true if the result of the dim op is a valid symbol for `region`.
368 static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
369  // The dim op is okay if its source is defined at the top level.
370  if (isTopLevelValue(dimOp.getShapedValue()))
371  return true;
372 
373  // Conservatively handle remaining BlockArguments as non-valid symbols.
374  // E.g. scf.for iterArgs.
375  if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
376  return false;
377 
378  // The dim op is also okay if its operand memref is a view/subview whose
379  // corresponding size is a valid symbol.
380  std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
381 
382  // Be conservative if we can't understand the dimension.
383  if (!index.has_value())
384  return false;
385 
386  // Skip over all memref.cast ops (if any).
387  Operation *op = dimOp.getShapedValue().getDefiningOp();
388  while (auto castOp = dyn_cast<memref::CastOp>(op)) {
389  // Bail on unranked memrefs.
390  if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
391  return false;
392  op = castOp.getSource().getDefiningOp();
393  if (!op)
394  return false;
395  }
396 
397  int64_t i = index.value();
399  .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
400  [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
401  .Default([](Operation *) { return false; });
402 }
403 
404 // A value can be used as a symbol (at all its use sites) iff it meets one of
405 // the following conditions:
406 // *) It is a constant.
407 // *) Its defining op or block arg appearance is immediately enclosed by an op
408 // with `AffineScope` trait.
409 // *) It is the result of an affine.apply operation with symbol operands.
410 // *) It is a result of the dim op on a memref whose corresponding size is a
411 // valid symbol.
413  if (!value)
414  return false;
415 
416  // The value must be an index type.
417  if (!value.getType().isIndex())
418  return false;
419 
420  // Check that the value is a top level value.
421  if (isTopLevelValue(value))
422  return true;
423 
424  if (auto *defOp = value.getDefiningOp())
425  return isValidSymbol(value, getAffineScope(defOp));
426 
427  return false;
428 }
429 
430 /// A utility function to check if a value is defined at the top level of
431 /// `region` or is an argument of `region` or is defined above the region.
432 static bool isTopLevelValueOrAbove(Value value, Region *region) {
433  Region *parentRegion = value.getParentRegion();
434  do {
435  if (parentRegion == region)
436  return true;
437  Operation *regionOp = region->getParentOp();
438  if (regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
439  break;
440  region = region->getParentOp()->getParentRegion();
441  } while (region);
442  return false;
443 }
444 
445 /// A value can be used as a symbol for `region` iff it meets one of the
446 /// following conditions:
447 /// *) It is a constant.
448 /// *) It is a result of a `Pure` operation whose operands are valid symbolic
449 /// *) identifiers.
450 /// *) It is a result of the dim op on a memref whose corresponding size is
451 /// a valid symbol.
452 /// *) It is defined at the top level of 'region' or is its argument.
453 /// *) It dominates `region`'s parent op.
454 /// If `region` is null, conservatively assume the symbol definition scope does
455 /// not exist and only accept the values that would be symbols regardless of
456 /// the surrounding region structure, i.e. the first three cases above.
458  // The value must be an index type.
459  if (!value.getType().isIndex())
460  return false;
461 
462  // A top-level value is a valid symbol.
463  if (region && isTopLevelValueOrAbove(value, region))
464  return true;
465 
466  auto *defOp = value.getDefiningOp();
467  if (!defOp)
468  return false;
469 
470  // Constant operation is ok.
471  Attribute operandCst;
472  if (matchPattern(defOp, m_Constant(&operandCst)))
473  return true;
474 
475  // `Pure` operation that whose operands are valid symbolic identifiers.
476  if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
477  return affine::isValidSymbol(operand, region);
478  })) {
479  return true;
480  }
481 
482  // Dim op results could be valid symbols at any level.
483  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
484  return isDimOpValidSymbol(dimOp, region);
485 
486  return false;
487 }
488 
489 // Returns true if 'value' is a valid index to an affine operation (e.g.
490 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
491 // `region` provides the polyhedral symbol scope. Returns false otherwise.
492 static bool isValidAffineIndexOperand(Value value, Region *region) {
493  return isValidDim(value, region) || isValidSymbol(value, region);
494 }
495 
496 /// Prints dimension and symbol list.
499  unsigned numDims, OpAsmPrinter &printer) {
500  OperandRange operands(begin, end);
501  printer << '(' << operands.take_front(numDims) << ')';
502  if (operands.size() > numDims)
503  printer << '[' << operands.drop_front(numDims) << ']';
504 }
505 
506 /// Parses dimension and symbol list and returns true if parsing failed.
508  OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
510  if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
511  return failure();
512  // Store number of dimensions for validation by caller.
513  numDims = opInfos.size();
514 
515  // Parse the optional symbol operands.
516  auto indexTy = parser.getBuilder().getIndexType();
517  return failure(parser.parseOperandList(
519  parser.resolveOperands(opInfos, indexTy, operands));
520 }
521 
522 /// Utility function to verify that a set of operands are valid dimension and
523 /// symbol identifiers. The operands should be laid out such that the dimension
524 /// operands are before the symbol operands. This function returns failure if
525 /// there was an invalid operand. An operation is provided to emit any necessary
526 /// errors.
527 template <typename OpTy>
528 static LogicalResult
530  unsigned numDims) {
531  unsigned opIt = 0;
532  for (auto operand : operands) {
533  if (opIt++ < numDims) {
534  if (!isValidDim(operand, getAffineScope(op)))
535  return op.emitOpError("operand cannot be used as a dimension id");
536  } else if (!isValidSymbol(operand, getAffineScope(op))) {
537  return op.emitOpError("operand cannot be used as a symbol");
538  }
539  }
540  return success();
541 }
542 
543 //===----------------------------------------------------------------------===//
544 // AffineApplyOp
545 //===----------------------------------------------------------------------===//
546 
547 AffineValueMap AffineApplyOp::getAffineValueMap() {
548  return AffineValueMap(getAffineMap(), getOperands(), getResult());
549 }
550 
551 ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
552  auto &builder = parser.getBuilder();
553  auto indexTy = builder.getIndexType();
554 
555  AffineMapAttr mapAttr;
556  unsigned numDims;
557  if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
558  parseDimAndSymbolList(parser, result.operands, numDims) ||
559  parser.parseOptionalAttrDict(result.attributes))
560  return failure();
561  auto map = mapAttr.getValue();
562 
563  if (map.getNumDims() != numDims ||
564  numDims + map.getNumSymbols() != result.operands.size()) {
565  return parser.emitError(parser.getNameLoc(),
566  "dimension or symbol index mismatch");
567  }
568 
569  result.types.append(map.getNumResults(), indexTy);
570  return success();
571 }
572 
574  p << " " << getMapAttr();
575  printDimAndSymbolList(operand_begin(), operand_end(),
576  getAffineMap().getNumDims(), p);
577  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
578 }
579 
580 LogicalResult AffineApplyOp::verify() {
581  // Check input and output dimensions match.
582  AffineMap affineMap = getMap();
583 
584  // Verify that operand count matches affine map dimension and symbol count.
585  if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
586  return emitOpError(
587  "operand count and affine map dimension and symbol count must match");
588 
589  // Verify that the map only produces one result.
590  if (affineMap.getNumResults() != 1)
591  return emitOpError("mapping must produce one value");
592 
593  // Do not allow valid dims to be used in symbol positions. We do allow
594  // affine.apply to use operands for values that may neither qualify as affine
595  // dims or affine symbols due to usage outside of affine ops, analyses, etc.
596  Region *region = getAffineScope(*this);
597  for (Value operand : getMapOperands().drop_front(affineMap.getNumDims())) {
598  if (::isValidDim(operand, region) && !::isValidSymbol(operand, region))
599  return emitError("dimensional operand cannot be used as a symbol");
600  }
601 
602  return success();
603 }
604 
605 // The result of the affine apply operation can be used as a dimension id if all
606 // its operands are valid dimension ids.
608  return llvm::all_of(getOperands(),
609  [](Value op) { return affine::isValidDim(op); });
610 }
611 
612 // The result of the affine apply operation can be used as a dimension id if all
613 // its operands are valid dimension ids with the parent operation of `region`
614 // defining the polyhedral scope for symbols.
615 bool AffineApplyOp::isValidDim(Region *region) {
616  return llvm::all_of(getOperands(),
617  [&](Value op) { return ::isValidDim(op, region); });
618 }
619 
620 // The result of the affine apply operation can be used as a symbol if all its
621 // operands are symbols.
623  return llvm::all_of(getOperands(),
624  [](Value op) { return affine::isValidSymbol(op); });
625 }
626 
627 // The result of the affine apply operation can be used as a symbol in `region`
628 // if all its operands are symbols in `region`.
629 bool AffineApplyOp::isValidSymbol(Region *region) {
630  return llvm::all_of(getOperands(), [&](Value operand) {
631  return affine::isValidSymbol(operand, region);
632  });
633 }
634 
635 OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
636  auto map = getAffineMap();
637 
638  // Fold dims and symbols to existing values.
639  auto expr = map.getResult(0);
640  if (auto dim = dyn_cast<AffineDimExpr>(expr))
641  return getOperand(dim.getPosition());
642  if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
643  return getOperand(map.getNumDims() + sym.getPosition());
644 
645  // Otherwise, default to folding the map.
647  bool hasPoison = false;
648  auto foldResult =
649  map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
650  if (hasPoison)
652  if (failed(foldResult))
653  return {};
654  return result[0];
655 }
656 
657 /// Returns the largest known divisor of `e`. Exploits information from the
658 /// values in `operands`.
659 static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
660  // This method isn't aware of `operands`.
661  int64_t div = e.getLargestKnownDivisor();
662 
663  // We now make use of operands for the case `e` is a dim expression.
664  // TODO: More powerful simplification would have to modify
665  // getLargestKnownDivisor to take `operands` and exploit that information as
666  // well for dim/sym expressions, but in that case, getLargestKnownDivisor
667  // can't be part of the IR library but of the `Analysis` library. The IR
668  // library can only really depend on simple O(1) checks.
669  auto dimExpr = dyn_cast<AffineDimExpr>(e);
670  // If it's not a dim expr, `div` is the best we have.
671  if (!dimExpr)
672  return div;
673 
674  // We simply exploit information from loop IVs.
675  // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
676  // desired simplifications are expected to be part of other
677  // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
678  // LoopAnalysis library.
679  Value operand = operands[dimExpr.getPosition()];
680  int64_t operandDivisor = 1;
681  // TODO: With the right accessors, this can be extended to
682  // LoopLikeOpInterface.
683  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
684  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
685  operandDivisor = forOp.getStepAsInt();
686  } else {
687  uint64_t lbLargestKnownDivisor =
688  forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
689  operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
690  }
691  }
692  return operandDivisor;
693 }
694 
695 /// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
696 /// being an affine dim expression or a constant.
698  int64_t k) {
699  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
700  int64_t constVal = constExpr.getValue();
701  return constVal >= 0 && constVal < k;
702  }
703  auto dimExpr = dyn_cast<AffineDimExpr>(e);
704  if (!dimExpr)
705  return false;
706  Value operand = operands[dimExpr.getPosition()];
707  // TODO: With the right accessors, this can be extended to
708  // LoopLikeOpInterface.
709  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
710  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
711  forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
712  return true;
713  }
714  }
715 
716  // We don't consider other cases like `operand` being defined by a constant or
717  // an affine.apply op since such cases will already be handled by other
718  // patterns and propagation of loop IVs or constant would happen.
719  return false;
720 }
721 
722 /// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
723 /// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
724 /// expression is in that form.
725 static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
726  AffineExpr &quotientTimesDiv, AffineExpr &rem) {
727  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
728  if (!bin || bin.getKind() != AffineExprKind::Add)
729  return false;
730 
731  AffineExpr llhs = bin.getLHS();
732  AffineExpr rlhs = bin.getRHS();
733  div = getLargestKnownDivisor(llhs, operands);
734  if (isNonNegativeBoundedBy(rlhs, operands, div)) {
735  quotientTimesDiv = llhs;
736  rem = rlhs;
737  return true;
738  }
739  div = getLargestKnownDivisor(rlhs, operands);
740  if (isNonNegativeBoundedBy(llhs, operands, div)) {
741  quotientTimesDiv = rlhs;
742  rem = llhs;
743  return true;
744  }
745  return false;
746 }
747 
748 /// Gets the constant lower bound on an `iv`.
749 static std::optional<int64_t> getLowerBound(Value iv) {
750  AffineForOp forOp = getForInductionVarOwner(iv);
751  if (forOp && forOp.hasConstantLowerBound())
752  return forOp.getConstantLowerBound();
753  return std::nullopt;
754 }
755 
756 /// Gets the constant upper bound on an affine.for `iv`.
757 static std::optional<int64_t> getUpperBound(Value iv) {
758  AffineForOp forOp = getForInductionVarOwner(iv);
759  if (!forOp || !forOp.hasConstantUpperBound())
760  return std::nullopt;
761 
762  // If its lower bound is also known, we can get a more precise bound
763  // whenever the step is not one.
764  if (forOp.hasConstantLowerBound()) {
765  return forOp.getConstantUpperBound() - 1 -
766  (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
767  forOp.getStepAsInt();
768  }
769  return forOp.getConstantUpperBound() - 1;
770 }
771 
772 /// Determine a constant upper bound for `expr` if one exists while exploiting
773 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
774 /// is guaranteed to be less than or equal to it.
775 static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
776  unsigned numSymbols,
777  ArrayRef<Value> operands) {
778  // Get the constant lower or upper bounds on the operands.
779  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
780  constLowerBounds.reserve(operands.size());
781  constUpperBounds.reserve(operands.size());
782  for (Value operand : operands) {
783  constLowerBounds.push_back(getLowerBound(operand));
784  constUpperBounds.push_back(getUpperBound(operand));
785  }
786 
787  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
788  return constExpr.getValue();
789 
790  return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
791  constUpperBounds,
792  /*isUpper=*/true);
793 }
794 
795 /// Determine a constant lower bound for `expr` if one exists while exploiting
796 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
797 /// is guaranteed to be less than or equal to it.
798 static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
799  unsigned numSymbols,
800  ArrayRef<Value> operands) {
801  // Get the constant lower or upper bounds on the operands.
802  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
803  constLowerBounds.reserve(operands.size());
804  constUpperBounds.reserve(operands.size());
805  for (Value operand : operands) {
806  constLowerBounds.push_back(getLowerBound(operand));
807  constUpperBounds.push_back(getUpperBound(operand));
808  }
809 
810  std::optional<int64_t> lowerBound;
811  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
812  lowerBound = constExpr.getValue();
813  } else {
814  lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
815  constLowerBounds, constUpperBounds,
816  /*isUpper=*/false);
817  }
818  return lowerBound;
819 }
820 
821 /// Simplify `expr` while exploiting information from the values in `operands`.
822 static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
823  unsigned numSymbols,
824  ArrayRef<Value> operands) {
825  // We do this only for certain floordiv/mod expressions.
826  auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
827  if (!binExpr)
828  return;
829 
830  // Simplify the child expressions first.
831  AffineExpr lhs = binExpr.getLHS();
832  AffineExpr rhs = binExpr.getRHS();
833  simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
834  simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
835  expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
836 
837  binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
838  if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
839  expr.getKind() != AffineExprKind::CeilDiv &&
840  expr.getKind() != AffineExprKind::Mod)) {
841  return;
842  }
843 
844  // The `lhs` and `rhs` may be different post construction of simplified expr.
845  lhs = binExpr.getLHS();
846  rhs = binExpr.getRHS();
847  auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
848  if (!rhsConst)
849  return;
850 
851  int64_t rhsConstVal = rhsConst.getValue();
852  // Undefined exprsessions aren't touched; IR can still be valid with them.
853  if (rhsConstVal <= 0)
854  return;
855 
856  // Exploit constant lower/upper bounds to simplify a floordiv or mod.
857  MLIRContext *context = expr.getContext();
858  std::optional<int64_t> lhsLbConst =
859  getLowerBound(lhs, numDims, numSymbols, operands);
860  std::optional<int64_t> lhsUbConst =
861  getUpperBound(lhs, numDims, numSymbols, operands);
862  if (lhsLbConst && lhsUbConst) {
863  int64_t lhsLbConstVal = *lhsLbConst;
864  int64_t lhsUbConstVal = *lhsUbConst;
865  // lhs floordiv c is a single value lhs is bounded in a range `c` that has
866  // the same quotient.
867  if (binExpr.getKind() == AffineExprKind::FloorDiv &&
868  divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
869  divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
870  expr = getAffineConstantExpr(
871  divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
872  return;
873  }
874  // lhs ceildiv c is a single value if the entire range has the same ceil
875  // quotient.
876  if (binExpr.getKind() == AffineExprKind::CeilDiv &&
877  divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
878  divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
879  expr = getAffineConstantExpr(divideCeilSigned(lhsLbConstVal, rhsConstVal),
880  context);
881  return;
882  }
883  // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
884  if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
885  lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
886  expr = lhs;
887  return;
888  }
889  }
890 
891  // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
892  // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
893  // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
894  // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
895  AffineExpr quotientTimesDiv, rem;
896  int64_t divisor;
897  if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
898  if (rhsConstVal % divisor == 0 &&
899  binExpr.getKind() == AffineExprKind::FloorDiv) {
900  expr = quotientTimesDiv.floorDiv(rhsConst);
901  } else if (divisor % rhsConstVal == 0 &&
902  binExpr.getKind() == AffineExprKind::Mod) {
903  expr = rem % rhsConst;
904  }
905  return;
906  }
907 
908  // Handle the simple case when the LHS expression can be either upper
909  // bounded or is a known multiple of RHS constant.
910  // lhs floordiv c -> 0 if 0 <= lhs < c,
911  // lhs mod c -> 0 if lhs % c = 0.
912  if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) &&
913  binExpr.getKind() == AffineExprKind::FloorDiv) ||
914  (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 &&
915  binExpr.getKind() == AffineExprKind::Mod)) {
916  expr = getAffineConstantExpr(0, expr.getContext());
917  }
918 }
919 
920 /// Simplify the expressions in `map` while making use of lower or upper bounds
921 /// of its operands. If `isMax` is true, the map is to be treated as a max of
922 /// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
923 /// d1) can be simplified to (8) if the operands are respectively lower bounded
924 /// by 2 and 0 (the second expression can't be lower than 8).
926  ArrayRef<Value> operands,
927  bool isMax) {
928  // Can't simplify.
929  if (operands.empty())
930  return;
931 
932  // Get the upper or lower bound on an affine.for op IV using its range.
933  // Get the constant lower or upper bounds on the operands.
934  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
935  constLowerBounds.reserve(operands.size());
936  constUpperBounds.reserve(operands.size());
937  for (Value operand : operands) {
938  constLowerBounds.push_back(getLowerBound(operand));
939  constUpperBounds.push_back(getUpperBound(operand));
940  }
941 
942  // We will compute the lower and upper bounds on each of the expressions
943  // Then, we will check (depending on max or min) as to whether a specific
944  // bound is redundant by checking if its highest (in case of max) and its
945  // lowest (in the case of min) value is already lower than (or higher than)
946  // the lower bound (or upper bound in the case of min) of another bound.
947  SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
948  lowerBounds.reserve(map.getNumResults());
949  upperBounds.reserve(map.getNumResults());
950  for (AffineExpr e : map.getResults()) {
951  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
952  lowerBounds.push_back(constExpr.getValue());
953  upperBounds.push_back(constExpr.getValue());
954  } else {
955  lowerBounds.push_back(
957  constLowerBounds, constUpperBounds,
958  /*isUpper=*/false));
959  upperBounds.push_back(
961  constLowerBounds, constUpperBounds,
962  /*isUpper=*/true));
963  }
964  }
965 
966  // Collect expressions that are not redundant.
967  SmallVector<AffineExpr, 4> irredundantExprs;
968  for (auto exprEn : llvm::enumerate(map.getResults())) {
969  AffineExpr e = exprEn.value();
970  unsigned i = exprEn.index();
971  // Some expressions can be turned into constants.
972  if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
973  e = getAffineConstantExpr(*lowerBounds[i], e.getContext());
974 
975  // Check if the expression is redundant.
976  if (isMax) {
977  if (!upperBounds[i]) {
978  irredundantExprs.push_back(e);
979  continue;
980  }
981  // If there exists another expression such that its lower bound is greater
982  // than this expression's upper bound, it's redundant.
983  if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
984  auto otherLowerBound = en.value();
985  unsigned pos = en.index();
986  if (pos == i || !otherLowerBound)
987  return false;
988  if (*otherLowerBound > *upperBounds[i])
989  return true;
990  if (*otherLowerBound < *upperBounds[i])
991  return false;
992  // Equality case. When both expressions are considered redundant, we
993  // don't want to get both of them. We keep the one that appears
994  // first.
995  if (upperBounds[pos] && lowerBounds[i] &&
996  lowerBounds[i] == upperBounds[i] &&
997  otherLowerBound == *upperBounds[pos] && i < pos)
998  return false;
999  return true;
1000  }))
1001  irredundantExprs.push_back(e);
1002  } else {
1003  if (!lowerBounds[i]) {
1004  irredundantExprs.push_back(e);
1005  continue;
1006  }
1007  // Likewise for the `min` case. Use the complement of the condition above.
1008  if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
1009  auto otherUpperBound = en.value();
1010  unsigned pos = en.index();
1011  if (pos == i || !otherUpperBound)
1012  return false;
1013  if (*otherUpperBound < *lowerBounds[i])
1014  return true;
1015  if (*otherUpperBound > *lowerBounds[i])
1016  return false;
1017  if (lowerBounds[pos] && upperBounds[i] &&
1018  lowerBounds[i] == upperBounds[i] &&
1019  otherUpperBound == lowerBounds[pos] && i < pos)
1020  return false;
1021  return true;
1022  }))
1023  irredundantExprs.push_back(e);
1024  }
1025  }
1026 
1027  // Create the map without the redundant expressions.
1028  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs,
1029  map.getContext());
1030 }
1031 
1032 /// Simplify the map while exploiting information on the values in `operands`.
1033 // Use "unused attribute" marker to silence warning stemming from the inability
1034 // to see through the template expansion.
1035 [[maybe_unused]] static void simplifyMapWithOperands(AffineMap &map,
1036  ArrayRef<Value> operands) {
1037  assert(map.getNumInputs() == operands.size() && "invalid operands for map");
1038  SmallVector<AffineExpr> newResults;
1039  newResults.reserve(map.getNumResults());
1040  for (AffineExpr expr : map.getResults()) {
1042  operands);
1043  newResults.push_back(expr);
1044  }
1045  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
1046  map.getContext());
1047 }
1048 
1049 /// Assuming `dimOrSym` is a quantity in the apply op map `map` and defined by
1050 /// `minOp = affine_min(x_1, ..., x_n)`. This function checks that:
1051 /// `0 < affine_min(x_1, ..., x_n)` and proceeds with replacing the patterns:
1052 /// ```
1053 /// dimOrSym.ceildiv(x_k)
1054 /// (dimOrSym + x_k - 1).floordiv(x_k)
1055 /// ```
1056 /// by `1` for all `k` in `1, ..., n`. This is possible because `x / x_k <= 1`.
1057 ///
1058 ///
1059 /// Warning: ValueBoundsConstraintSet::computeConstantBound is needed to check
1060 /// `minOp` is positive.
1061 static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
1062  AffineExpr dimOrSym,
1063  AffineMap *map,
1064  ValueRange dims,
1065  ValueRange syms) {
1066  LDBG() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`";
1067  AffineMap affineMinMap = minOp.getAffineMap();
1068 
1069  // Check the value is positive.
1070  for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
1071  // Compare each expression in the minimum against 0.
1073  getAsIndexOpFoldResult(minOp.getContext(), 0),
1074  ValueBoundsConstraintSet::ComparisonOperator::LT,
1076  minOp.getOperands())))
1077  return failure();
1078  }
1079 
1080  /// Convert affine symbols and dimensions in minOp to symbols or dimensions in
1081  /// the apply op affine map.
1082  DenseMap<AffineExpr, AffineExpr> dimSymConversionTable;
1083  SmallVector<unsigned> unmappedDims, unmappedSyms;
1084  for (auto [i, dim] : llvm::enumerate(minOp.getDimOperands())) {
1085  auto it = llvm::find(dims, dim);
1086  if (it == dims.end()) {
1087  unmappedDims.push_back(i);
1088  continue;
1089  }
1090  dimSymConversionTable[getAffineDimExpr(i, minOp.getContext())] =
1091  getAffineDimExpr(it.getIndex(), minOp.getContext());
1092  }
1093  for (auto [i, sym] : llvm::enumerate(minOp.getSymbolOperands())) {
1094  auto it = llvm::find(syms, sym);
1095  if (it == syms.end()) {
1096  unmappedSyms.push_back(i);
1097  continue;
1098  }
1099  dimSymConversionTable[getAffineSymbolExpr(i, minOp.getContext())] =
1100  getAffineSymbolExpr(it.getIndex(), minOp.getContext());
1101  }
1102 
1103  // Create the replacement map.
1105  AffineExpr c1 = getAffineConstantExpr(1, minOp.getContext());
1106  for (AffineExpr expr : affineMinMap.getResults()) {
1107  // If we cannot express the result in terms of the apply map symbols and
1108  // sims then continue.
1109  if (llvm::any_of(unmappedDims,
1110  [&](unsigned i) { return expr.isFunctionOfDim(i); }) ||
1111  llvm::any_of(unmappedSyms,
1112  [&](unsigned i) { return expr.isFunctionOfSymbol(i); }))
1113  continue;
1114 
1115  AffineExpr convertedExpr = expr.replace(dimSymConversionTable);
1116 
1117  // dimOrSym.ceilDiv(expr) -> 1
1118  repl[dimOrSym.ceilDiv(convertedExpr)] = c1;
1119  // (dimOrSym + expr - 1).floorDiv(expr) -> 1
1120  repl[(dimOrSym + convertedExpr - 1).floorDiv(convertedExpr)] = c1;
1121  }
1122  AffineMap initialMap = *map;
1123  *map = initialMap.replace(repl, initialMap.getNumDims(),
1124  initialMap.getNumSymbols());
1125  return success(*map != initialMap);
1126 }
1127 
1128 /// Recursively traverse `e`. If `e` or one of its sub-expressions has the form
1129 /// e1 + e2 + ... + eK, where the e_i are a super(multi)set of `exprsToRemove`,
1130 /// place a map between e and `newVal` + sum({e1, e2, .. eK} - exprsToRemove)
1131 /// into `replacementsMap`. If no entries were added to `replacementsMap`,
1132 /// nothing was found.
1134  AffineExpr e, const llvm::SmallDenseSet<AffineExpr, 4> &exprsToRemove,
1135  AffineExpr newVal, DenseMap<AffineExpr, AffineExpr> &replacementsMap) {
1136  auto binOp = dyn_cast<AffineBinaryOpExpr>(e);
1137  if (!binOp)
1138  return;
1139  AffineExpr lhs = binOp.getLHS();
1140  AffineExpr rhs = binOp.getRHS();
1141  if (binOp.getKind() != AffineExprKind::Add) {
1142  shortenAddChainsContainingAll(lhs, exprsToRemove, newVal, replacementsMap);
1143  shortenAddChainsContainingAll(rhs, exprsToRemove, newVal, replacementsMap);
1144  return;
1145  }
1146  SmallVector<AffineExpr> toPreserve;
1147  llvm::SmallDenseSet<AffineExpr, 4> ourTracker(exprsToRemove);
1148  AffineExpr thisTerm = rhs;
1149  AffineExpr nextTerm = lhs;
1150 
1151  while (thisTerm) {
1152  if (!ourTracker.erase(thisTerm)) {
1153  toPreserve.push_back(thisTerm);
1154  shortenAddChainsContainingAll(thisTerm, exprsToRemove, newVal,
1155  replacementsMap);
1156  }
1157  auto nextBinOp = dyn_cast_if_present<AffineBinaryOpExpr>(nextTerm);
1158  if (!nextBinOp || nextBinOp.getKind() != AffineExprKind::Add) {
1159  thisTerm = nextTerm;
1160  nextTerm = AffineExpr();
1161  } else {
1162  thisTerm = nextBinOp.getRHS();
1163  nextTerm = nextBinOp.getLHS();
1164  }
1165  }
1166  if (!ourTracker.empty())
1167  return;
1168  // We reverse the terms to be preserved here in order to preserve
1169  // associativity between them.
1170  AffineExpr newExpr = newVal;
1171  for (AffineExpr preserved : llvm::reverse(toPreserve))
1172  newExpr = newExpr + preserved;
1173  replacementsMap.insert({e, newExpr});
1174 }
1175 
1176 /// If this map contains of the expression `x_1 + x_1 * C_1 + ... x_n * C_N +
1177 /// ...` (not necessarily in order) where the set of the `x_i` is the set of
1178 /// outputs of an `affine.delinearize_index` whos inverse is that expression,
1179 /// replace that expression with the input of that delinearize_index op.
1180 ///
1181 /// `unitDimInput` is the input that was detected as the potential start to this
1182 /// replacement chain - if it isn't the rightmost result of the delinearization,
1183 /// this method fails. (This is intended to ensure we don't have redundant scans
1184 /// over the same expression).
1185 ///
1186 /// While this currently only handles delinearizations with a constant basis,
1187 /// that isn't a fundamental limitation.
1188 ///
1189 /// This is a utility function for `replaceDimOrSym` below.
1191  AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map,
1193  if (!delinOp.getDynamicBasis().empty())
1194  return failure();
1195  if (resultToReplace != delinOp.getMultiIndex().back())
1196  return failure();
1197 
1198  MLIRContext *ctx = delinOp.getContext();
1199  SmallVector<AffineExpr> resToExpr(delinOp.getNumResults(), AffineExpr());
1200  for (auto [pos, dim] : llvm::enumerate(dims)) {
1201  auto asResult = dyn_cast_if_present<OpResult>(dim);
1202  if (!asResult)
1203  continue;
1204  if (asResult.getOwner() == delinOp.getOperation())
1205  resToExpr[asResult.getResultNumber()] = getAffineDimExpr(pos, ctx);
1206  }
1207  for (auto [pos, sym] : llvm::enumerate(syms)) {
1208  auto asResult = dyn_cast_if_present<OpResult>(sym);
1209  if (!asResult)
1210  continue;
1211  if (asResult.getOwner() == delinOp.getOperation())
1212  resToExpr[asResult.getResultNumber()] = getAffineSymbolExpr(pos, ctx);
1213  }
1214  if (llvm::is_contained(resToExpr, AffineExpr()))
1215  return failure();
1216 
1217  bool isDimReplacement = llvm::all_of(resToExpr, llvm::IsaPred<AffineDimExpr>);
1218  int64_t stride = 1;
1219  llvm::SmallDenseSet<AffineExpr, 4> expectedExprs;
1220  // This isn't zip_equal since sometimes the delinearize basis is missing a
1221  // size for the first result.
1222  for (auto [binding, size] : llvm::zip(
1223  llvm::reverse(resToExpr), llvm::reverse(delinOp.getStaticBasis()))) {
1224  expectedExprs.insert(binding * getAffineConstantExpr(stride, ctx));
1225  stride *= size;
1226  }
1227  if (resToExpr.size() != delinOp.getStaticBasis().size())
1228  expectedExprs.insert(resToExpr[0] * stride);
1229 
1230  DenseMap<AffineExpr, AffineExpr> replacements;
1231  AffineExpr delinInExpr = isDimReplacement
1232  ? getAffineDimExpr(dims.size(), ctx)
1233  : getAffineSymbolExpr(syms.size(), ctx);
1234 
1235  for (AffineExpr e : map->getResults())
1236  shortenAddChainsContainingAll(e, expectedExprs, delinInExpr, replacements);
1237  if (replacements.empty())
1238  return failure();
1239 
1240  AffineMap origMap = *map;
1241  if (isDimReplacement)
1242  dims.push_back(delinOp.getLinearIndex());
1243  else
1244  syms.push_back(delinOp.getLinearIndex());
1245  *map = origMap.replace(replacements, dims.size(), syms.size());
1246 
1247  // Blank out dead dimensions and symbols
1248  for (AffineExpr e : resToExpr) {
1249  if (auto d = dyn_cast<AffineDimExpr>(e)) {
1250  unsigned pos = d.getPosition();
1251  if (!map->isFunctionOfDim(pos))
1252  dims[pos] = nullptr;
1253  }
1254  if (auto s = dyn_cast<AffineSymbolExpr>(e)) {
1255  unsigned pos = s.getPosition();
1256  if (!map->isFunctionOfSymbol(pos))
1257  syms[pos] = nullptr;
1258  }
1259  }
1260  return success();
1261 }
1262 
1263 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
1264 /// defining AffineApplyOp expression and operands.
1265 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
1266 /// When `dimOrSymbolPosition >= dims.size()`,
1267 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
1268 /// Mutate `map`,`dims` and `syms` in place as follows:
1269 /// 1. `dims` and `syms` are only appended to.
1270 /// 2. `map` dim and symbols are gradually shifted to higher positions.
1271 /// 3. Old `dim` and `sym` entries are replaced by nullptr
1272 /// This avoids the need for any bookkeeping.
1273 /// If `replaceAffineMin` is set to true, additionally triggers more expensive
1274 /// replacements involving affine_min operations.
1275 static LogicalResult replaceDimOrSym(AffineMap *map,
1276  unsigned dimOrSymbolPosition,
1277  SmallVectorImpl<Value> &dims,
1278  SmallVectorImpl<Value> &syms,
1279  bool replaceAffineMin) {
1280  MLIRContext *ctx = map->getContext();
1281  bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1282  unsigned pos = isDimReplacement ? dimOrSymbolPosition
1283  : dimOrSymbolPosition - dims.size();
1284  Value &v = isDimReplacement ? dims[pos] : syms[pos];
1285  if (!v)
1286  return failure();
1287 
1288  if (auto minOp = v.getDefiningOp<AffineMinOp>(); minOp && replaceAffineMin) {
1289  AffineExpr dimOrSym = isDimReplacement ? getAffineDimExpr(pos, ctx)
1290  : getAffineSymbolExpr(pos, ctx);
1291  return replaceAffineMinBoundingBoxExpression(minOp, dimOrSym, map, dims,
1292  syms);
1293  }
1294 
1295  if (auto delinOp = v.getDefiningOp<affine::AffineDelinearizeIndexOp>()) {
1296  return replaceAffineDelinearizeIndexInverseExpression(delinOp, v, map, dims,
1297  syms);
1298  }
1299 
1300  auto affineApply = v.getDefiningOp<AffineApplyOp>();
1301  if (!affineApply)
1302  return failure();
1303 
1304  // At this point we will perform a replacement of `v`, set the entry in `dim`
1305  // or `sym` to nullptr immediately.
1306  v = nullptr;
1307 
1308  // Compute the map, dims and symbols coming from the AffineApplyOp.
1309  AffineMap composeMap = affineApply.getAffineMap();
1310  assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1311  SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1312  affineApply.getMapOperands().end());
1313  // Canonicalize the map to promote dims to symbols when possible. This is to
1314  // avoid generating invalid maps.
1315  canonicalizeMapAndOperands(&composeMap, &composeOperands);
1316  AffineExpr replacementExpr =
1317  composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
1318  ValueRange composeDims =
1319  ArrayRef<Value>(composeOperands).take_front(composeMap.getNumDims());
1320  ValueRange composeSyms =
1321  ArrayRef<Value>(composeOperands).take_back(composeMap.getNumSymbols());
1322  AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
1323  : getAffineSymbolExpr(pos, ctx);
1324 
1325  // Append the dims and symbols where relevant and perform the replacement.
1326  dims.append(composeDims.begin(), composeDims.end());
1327  syms.append(composeSyms.begin(), composeSyms.end());
1328  *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());
1329 
1330  return success();
1331 }
1332 
1333 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1334 /// iteratively. Perform canonicalization of map and operands as well as
1335 /// AffineMap simplification. `map` and `operands` are mutated in place.
1337  SmallVectorImpl<Value> *operands,
1338  bool composeAffineMin = false) {
1339  if (map->getNumResults() == 0) {
1340  canonicalizeMapAndOperands(map, operands);
1341  *map = simplifyAffineMap(*map);
1342  return;
1343  }
1344 
1345  MLIRContext *ctx = map->getContext();
1346  SmallVector<Value, 4> dims(operands->begin(),
1347  operands->begin() + map->getNumDims());
1348  SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1349  operands->end());
1350 
1351  // Iterate over dims and symbols coming from AffineApplyOp and replace until
1352  // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1353  // and `syms` can only increase by construction.
1354  // The implementation uses a `while` loop to support the case of symbols
1355  // that may be constructed from dims ;this may be overkill.
1356  while (true) {
1357  bool changed = false;
1358  for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1359  if ((changed |=
1360  succeeded(replaceDimOrSym(map, pos, dims, syms, composeAffineMin))))
1361  break;
1362  if (!changed)
1363  break;
1364  }
1365 
1366  // Clear operands so we can fill them anew.
1367  operands->clear();
1368 
1369  // At this point we may have introduced null operands, prune them out before
1370  // canonicalizing map and operands.
1371  unsigned nDims = 0, nSyms = 0;
1372  SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1373  dimReplacements.reserve(dims.size());
1374  symReplacements.reserve(syms.size());
1375  for (auto *container : {&dims, &syms}) {
1376  bool isDim = (container == &dims);
1377  auto &repls = isDim ? dimReplacements : symReplacements;
1378  for (const auto &en : llvm::enumerate(*container)) {
1379  Value v = en.value();
1380  if (!v) {
1381  assert(isDim ? !map->isFunctionOfDim(en.index())
1382  : !map->isFunctionOfSymbol(en.index()) &&
1383  "map is function of unexpected expr@pos");
1384  repls.push_back(getAffineConstantExpr(0, ctx));
1385  continue;
1386  }
1387  repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
1388  : getAffineSymbolExpr(nSyms++, ctx));
1389  operands->push_back(v);
1390  }
1391  }
1392  *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
1393  nSyms);
1394 
1395  // Canonicalize and simplify before returning.
1396  canonicalizeMapAndOperands(map, operands);
1397  *map = simplifyAffineMap(*map);
1398 }
1399 
1401  AffineMap *map, SmallVectorImpl<Value> *operands, bool composeAffineMin) {
1402  while (llvm::any_of(*operands, [](Value v) {
1403  return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
1404  })) {
1405  composeAffineMapAndOperands(map, operands, composeAffineMin);
1406  }
1407  // Additional trailing step for AffineMinOps in case no chains of AffineApply.
1408  if (composeAffineMin && llvm::any_of(*operands, [](Value v) {
1409  return isa_and_nonnull<AffineMinOp>(v.getDefiningOp());
1410  })) {
1411  composeAffineMapAndOperands(map, operands, composeAffineMin);
1412  }
1413 }
1414 
1415 AffineApplyOp
1417  ArrayRef<OpFoldResult> operands,
1418  bool composeAffineMin) {
1419  SmallVector<Value> valueOperands;
1420  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1421  composeAffineMapAndOperands(&map, &valueOperands, composeAffineMin);
1422  assert(map);
1423  return AffineApplyOp::create(b, loc, map, valueOperands);
1424 }
1425 
1426 AffineApplyOp
1428  ArrayRef<OpFoldResult> operands,
1429  bool composeAffineMin) {
1430  return makeComposedAffineApply(
1431  b, loc,
1433  .front(),
1434  operands, composeAffineMin);
1435 }
1436 
1437 /// Composes the given affine map with the given list of operands, pulling in
1438 /// the maps from any affine.apply operations that supply the operands.
1440  SmallVectorImpl<Value> &operands,
1441  bool composeAffineMin = false) {
1442  // Compose and canonicalize each expression in the map individually because
1443  // composition only applies to single-result maps, collecting potentially
1444  // duplicate operands in a single list with shifted dimensions and symbols.
1445  SmallVector<Value> dims, symbols;
1447  for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
1448  SmallVector<Value> submapOperands(operands.begin(), operands.end());
1449  AffineMap submap = map.getSubMap({i});
1450  fullyComposeAffineMapAndOperands(&submap, &submapOperands,
1451  composeAffineMin);
1452  canonicalizeMapAndOperands(&submap, &submapOperands);
1453  unsigned numNewDims = submap.getNumDims();
1454  submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
1455  llvm::append_range(dims,
1456  ArrayRef<Value>(submapOperands).take_front(numNewDims));
1457  llvm::append_range(symbols,
1458  ArrayRef<Value>(submapOperands).drop_front(numNewDims));
1459  exprs.push_back(submap.getResult(0));
1460  }
1461 
1462  // Canonicalize the map created from composed expressions to deduplicate the
1463  // dimension and symbol operands.
1464  operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1465  map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
1466  canonicalizeMapAndOperands(&map, &operands);
1467 }
1468 
1470  OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands,
1471  bool composeAffineMin) {
1472  assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1473 
1474  // Create new builder without a listener, so that no notification is
1475  // triggered if the op is folded.
1476  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1477  // workaround is no longer needed.
1478  OpBuilder newBuilder(b.getContext());
1480 
1481  // Create op.
1482  AffineApplyOp applyOp =
1483  makeComposedAffineApply(newBuilder, loc, map, operands, composeAffineMin);
1484 
1485  // Get constant operands.
1486  SmallVector<Attribute> constOperands(applyOp->getNumOperands());
1487  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1488  matchPattern(applyOp->getOperand(i), m_Constant(&constOperands[i]));
1489 
1490  // Try to fold the operation.
1491  SmallVector<OpFoldResult> foldResults;
1492  if (failed(applyOp->fold(constOperands, foldResults)) ||
1493  foldResults.empty()) {
1494  if (OpBuilder::Listener *listener = b.getListener())
1495  listener->notifyOperationInserted(applyOp, /*previous=*/{});
1496  return applyOp.getResult();
1497  }
1498 
1499  applyOp->erase();
1500  return llvm::getSingleElement(foldResults);
1501 }
1502 
1504  OpBuilder &b, Location loc, AffineExpr expr,
1505  ArrayRef<OpFoldResult> operands, bool composeAffineMin) {
1507  b, loc,
1509  .front(),
1510  operands, composeAffineMin);
1511 }
1512 
1515  OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands,
1516  bool composeAffineMin) {
1517  return llvm::map_to_vector(
1518  llvm::seq<unsigned>(0, map.getNumResults()), [&](unsigned i) {
1519  return makeComposedFoldedAffineApply(b, loc, map.getSubMap({i}),
1520  operands, composeAffineMin);
1521  });
1522 }
1523 
1524 template <typename OpTy>
1526  ArrayRef<OpFoldResult> operands) {
1527  SmallVector<Value> valueOperands;
1528  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1529  composeMultiResultAffineMap(map, valueOperands);
1530  return OpTy::create(b, loc, b.getIndexType(), map, valueOperands);
1531 }
1532 
1533 AffineMinOp
1535  ArrayRef<OpFoldResult> operands) {
1536  return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1537 }
1538 
1539 template <typename OpTy>
1541  AffineMap map,
1542  ArrayRef<OpFoldResult> operands) {
1543  // Create new builder without a listener, so that no notification is
1544  // triggered if the op is folded.
1545  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1546  // workaround is no longer needed.
1547  OpBuilder newBuilder(b.getContext());
1549 
1550  // Create op.
1551  auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1552 
1553  // Get constant operands.
1554  SmallVector<Attribute> constOperands(minMaxOp->getNumOperands());
1555  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1556  matchPattern(minMaxOp->getOperand(i), m_Constant(&constOperands[i]));
1557 
1558  // Try to fold the operation.
1559  SmallVector<OpFoldResult> foldResults;
1560  if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1561  foldResults.empty()) {
1562  if (OpBuilder::Listener *listener = b.getListener())
1563  listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
1564  return minMaxOp.getResult();
1565  }
1566 
1567  minMaxOp->erase();
1568  return llvm::getSingleElement(foldResults);
1569 }
1570 
1573  AffineMap map,
1574  ArrayRef<OpFoldResult> operands) {
1575  return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1576 }
1577 
1580  AffineMap map,
1581  ArrayRef<OpFoldResult> operands) {
1582  return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1583 }
1584 
1585 // A symbol may appear as a dim in affine.apply operations. This function
1586 // canonicalizes dims that are valid symbols into actual symbols.
1587 template <class MapOrSet>
1588 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1589  SmallVectorImpl<Value> *operands) {
1590  if (!mapOrSet || operands->empty())
1591  return;
1592 
1593  assert(mapOrSet->getNumInputs() == operands->size() &&
1594  "map/set inputs must match number of operands");
1595 
1596  auto *context = mapOrSet->getContext();
1597  SmallVector<Value, 8> resultOperands;
1598  resultOperands.reserve(operands->size());
1599  SmallVector<Value, 8> remappedSymbols;
1600  remappedSymbols.reserve(operands->size());
1601  unsigned nextDim = 0;
1602  unsigned nextSym = 0;
1603  unsigned oldNumSyms = mapOrSet->getNumSymbols();
1604  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1605  for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1606  if (i < mapOrSet->getNumDims()) {
1607  if (isValidSymbol((*operands)[i])) {
1608  // This is a valid symbol that appears as a dim, canonicalize it.
1609  dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1610  remappedSymbols.push_back((*operands)[i]);
1611  } else {
1612  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1613  resultOperands.push_back((*operands)[i]);
1614  }
1615  } else {
1616  resultOperands.push_back((*operands)[i]);
1617  }
1618  }
1619 
1620  resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1621  *operands = resultOperands;
1622  *mapOrSet = mapOrSet->replaceDimsAndSymbols(
1623  dimRemapping, /*symReplacements=*/{}, nextDim, oldNumSyms + nextSym);
1624 
1625  assert(mapOrSet->getNumInputs() == operands->size() &&
1626  "map/set inputs must match number of operands");
1627 }
1628 
1629 /// A valid affine dimension may appear as a symbol in affine.apply operations.
1630 /// Given an application of `operands` to an affine map or integer set
1631 /// `mapOrSet`, this function canonicalizes symbols of `mapOrSet` that are valid
1632 /// dims, but not valid symbols into actual dims. Without such a legalization,
1633 /// the affine.apply will be invalid. This method is the exact inverse of
1634 /// canonicalizePromotedSymbols.
1635 template <class MapOrSet>
1636 static void legalizeDemotedDims(MapOrSet &mapOrSet,
1637  SmallVectorImpl<Value> &operands) {
1638  if (!mapOrSet || operands.empty())
1639  return;
1640 
1641  unsigned numOperands = operands.size();
1642 
1643  assert(mapOrSet.getNumInputs() == numOperands &&
1644  "map/set inputs must match number of operands");
1645 
1646  auto *context = mapOrSet.getContext();
1647  SmallVector<Value, 8> resultOperands;
1648  resultOperands.reserve(numOperands);
1649  SmallVector<Value, 8> remappedDims;
1650  remappedDims.reserve(numOperands);
1651  SmallVector<Value, 8> symOperands;
1652  symOperands.reserve(mapOrSet.getNumSymbols());
1653  unsigned nextSym = 0;
1654  unsigned nextDim = 0;
1655  unsigned oldNumDims = mapOrSet.getNumDims();
1656  SmallVector<AffineExpr, 8> symRemapping(mapOrSet.getNumSymbols());
1657  resultOperands.assign(operands.begin(), operands.begin() + oldNumDims);
1658  for (unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
1659  if (operands[i] && isValidDim(operands[i]) && !isValidSymbol(operands[i])) {
1660  // This is a valid dim that appears as a symbol, legalize it.
1661  symRemapping[i - oldNumDims] =
1662  getAffineDimExpr(oldNumDims + nextDim++, context);
1663  remappedDims.push_back(operands[i]);
1664  } else {
1665  symRemapping[i - oldNumDims] = getAffineSymbolExpr(nextSym++, context);
1666  symOperands.push_back(operands[i]);
1667  }
1668  }
1669 
1670  append_range(resultOperands, remappedDims);
1671  append_range(resultOperands, symOperands);
1672  operands = resultOperands;
1673  mapOrSet = mapOrSet.replaceDimsAndSymbols(
1674  /*dimReplacements=*/{}, symRemapping, oldNumDims + nextDim, nextSym);
1675 
1676  assert(mapOrSet.getNumInputs() == operands.size() &&
1677  "map/set inputs must match number of operands");
1678 }
1679 
1680 // Works for either an affine map or an integer set.
1681 template <class MapOrSet>
1682 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1683  SmallVectorImpl<Value> *operands) {
1684  static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1685  "Argument must be either of AffineMap or IntegerSet type");
1686 
1687  if (!mapOrSet || operands->empty())
1688  return;
1689 
1690  assert(mapOrSet->getNumInputs() == operands->size() &&
1691  "map/set inputs must match number of operands");
1692 
1693  canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1694  legalizeDemotedDims<MapOrSet>(*mapOrSet, *operands);
1695 
1696  // Check to see what dims are used.
1697  llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1698  llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1699  mapOrSet->walkExprs([&](AffineExpr expr) {
1700  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1701  usedDims[dimExpr.getPosition()] = true;
1702  else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1703  usedSyms[symExpr.getPosition()] = true;
1704  });
1705 
1706  auto *context = mapOrSet->getContext();
1707 
1708  SmallVector<Value, 8> resultOperands;
1709  resultOperands.reserve(operands->size());
1710 
1711  llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1712  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1713  unsigned nextDim = 0;
1714  for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1715  if (usedDims[i]) {
1716  // Remap dim positions for duplicate operands.
1717  auto it = seenDims.find((*operands)[i]);
1718  if (it == seenDims.end()) {
1719  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1720  resultOperands.push_back((*operands)[i]);
1721  seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1722  } else {
1723  dimRemapping[i] = it->second;
1724  }
1725  }
1726  }
1727  llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1728  SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1729  unsigned nextSym = 0;
1730  for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1731  if (!usedSyms[i])
1732  continue;
1733  // Handle constant operands (only needed for symbolic operands since
1734  // constant operands in dimensional positions would have already been
1735  // promoted to symbolic positions above).
1736  IntegerAttr operandCst;
1737  if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1738  m_Constant(&operandCst))) {
1739  symRemapping[i] =
1740  getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1741  continue;
1742  }
1743  // Remap symbol positions for duplicate operands.
1744  auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1745  if (it == seenSymbols.end()) {
1746  symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1747  resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1748  seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1749  symRemapping[i]));
1750  } else {
1751  symRemapping[i] = it->second;
1752  }
1753  }
1754  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1755  nextDim, nextSym);
1756  *operands = resultOperands;
1757 }
1758 
1760  AffineMap *map, SmallVectorImpl<Value> *operands) {
1761  canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1762 }
1763 
1765  IntegerSet *set, SmallVectorImpl<Value> *operands) {
1766  canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1767 }
1768 
1769 namespace {
1770 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1771 /// maps that supply results into them.
1772 ///
1773 template <typename AffineOpTy>
1774 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1776 
1777  /// Replace the affine op with another instance of it with the supplied
1778  /// map and mapOperands.
1779  void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1780  AffineMap map, ArrayRef<Value> mapOperands) const;
1781 
1782  LogicalResult matchAndRewrite(AffineOpTy affineOp,
1783  PatternRewriter &rewriter) const override {
1784  static_assert(
1785  llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1786  AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1787  AffineVectorStoreOp, AffineVectorLoadOp>::value,
1788  "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1789  "expected");
1790  auto map = affineOp.getAffineMap();
1791  AffineMap oldMap = map;
1792  auto oldOperands = affineOp.getMapOperands();
1793  SmallVector<Value, 8> resultOperands(oldOperands);
1794  composeAffineMapAndOperands(&map, &resultOperands);
1795  canonicalizeMapAndOperands(&map, &resultOperands);
1796  simplifyMapWithOperands(map, resultOperands);
1797  if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1798  resultOperands.begin()))
1799  return failure();
1800 
1801  replaceAffineOp(rewriter, affineOp, map, resultOperands);
1802  return success();
1803  }
1804 };
1805 
1806 // Specialize the template to account for the different build signatures for
1807 // affine load, store, and apply ops.
1808 template <>
1809 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1810  PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1811  ArrayRef<Value> mapOperands) const {
1812  rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1813  mapOperands);
1814 }
1815 template <>
1816 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1817  PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1818  ArrayRef<Value> mapOperands) const {
1819  rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1820  prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1821  prefetch.getLocalityHint(), prefetch.getIsDataCache());
1822 }
1823 template <>
1824 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1825  PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1826  ArrayRef<Value> mapOperands) const {
1827  rewriter.replaceOpWithNewOp<AffineStoreOp>(
1828  store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1829 }
1830 template <>
1831 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1832  PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1833  ArrayRef<Value> mapOperands) const {
1834  rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1835  vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1836  mapOperands);
1837 }
1838 template <>
1839 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1840  PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1841  ArrayRef<Value> mapOperands) const {
1842  rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1843  vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1844  mapOperands);
1845 }
1846 
1847 // Generic version for ops that don't have extra operands.
1848 template <typename AffineOpTy>
1849 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1850  PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1851  ArrayRef<Value> mapOperands) const {
1852  rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1853 }
1854 } // namespace
1855 
1856 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1857  MLIRContext *context) {
1858  results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1859 }
1860 
1861 //===----------------------------------------------------------------------===//
1862 // AffineDmaStartOp
1863 //===----------------------------------------------------------------------===//
1864 
1865 // TODO: Check that map operands are loop IVs or symbols.
1866 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1867  Value srcMemRef, AffineMap srcMap,
1868  ValueRange srcIndices, Value destMemRef,
1869  AffineMap dstMap, ValueRange destIndices,
1870  Value tagMemRef, AffineMap tagMap,
1871  ValueRange tagIndices, Value numElements,
1872  Value stride, Value elementsPerStride) {
1873  result.addOperands(srcMemRef);
1874  result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1875  result.addOperands(srcIndices);
1876  result.addOperands(destMemRef);
1877  result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1878  result.addOperands(destIndices);
1879  result.addOperands(tagMemRef);
1880  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1881  result.addOperands(tagIndices);
1882  result.addOperands(numElements);
1883  if (stride) {
1884  result.addOperands({stride, elementsPerStride});
1885  }
1886 }
1887 
1888 AffineDmaStartOp AffineDmaStartOp::create(
1889  OpBuilder &builder, Location location, Value srcMemRef, AffineMap srcMap,
1890  ValueRange srcIndices, Value destMemRef, AffineMap dstMap,
1891  ValueRange destIndices, Value tagMemRef, AffineMap tagMap,
1892  ValueRange tagIndices, Value numElements, Value stride,
1893  Value elementsPerStride) {
1894  mlir::OperationState state(location, getOperationName());
1895  build(builder, state, srcMemRef, srcMap, srcIndices, destMemRef, dstMap,
1896  destIndices, tagMemRef, tagMap, tagIndices, numElements, stride,
1897  elementsPerStride);
1898  auto result = dyn_cast<AffineDmaStartOp>(builder.create(state));
1899  assert(result && "builder didn't return the right type");
1900  return result;
1901 }
1902 
1903 AffineDmaStartOp AffineDmaStartOp::create(
1904  ImplicitLocOpBuilder &builder, Value srcMemRef, AffineMap srcMap,
1905  ValueRange srcIndices, Value destMemRef, AffineMap dstMap,
1906  ValueRange destIndices, Value tagMemRef, AffineMap tagMap,
1907  ValueRange tagIndices, Value numElements, Value stride,
1908  Value elementsPerStride) {
1909  return create(builder, builder.getLoc(), srcMemRef, srcMap, srcIndices,
1910  destMemRef, dstMap, destIndices, tagMemRef, tagMap, tagIndices,
1911  numElements, stride, elementsPerStride);
1912 }
1913 
1915  p << " " << getSrcMemRef() << '[';
1916  p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1917  p << "], " << getDstMemRef() << '[';
1918  p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1919  p << "], " << getTagMemRef() << '[';
1920  p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1921  p << "], " << getNumElements();
1922  if (isStrided()) {
1923  p << ", " << getStride();
1924  p << ", " << getNumElementsPerStride();
1925  }
1926  p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1927  << getTagMemRefType();
1928 }
1929 
1930 // Parse AffineDmaStartOp.
1931 // Ex:
1932 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1933 // %stride, %num_elt_per_stride
1934 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1935 //
1937  OperationState &result) {
1938  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1939  AffineMapAttr srcMapAttr;
1941  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1942  AffineMapAttr dstMapAttr;
1944  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1945  AffineMapAttr tagMapAttr;
1947  OpAsmParser::UnresolvedOperand numElementsInfo;
1949 
1950  SmallVector<Type, 3> types;
1951  auto indexType = parser.getBuilder().getIndexType();
1952 
1953  // Parse and resolve the following list of operands:
1954  // *) dst memref followed by its affine maps operands (in square brackets).
1955  // *) src memref followed by its affine map operands (in square brackets).
1956  // *) tag memref followed by its affine map operands (in square brackets).
1957  // *) number of elements transferred by DMA operation.
1958  if (parser.parseOperand(srcMemRefInfo) ||
1959  parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1960  getSrcMapAttrStrName(),
1961  result.attributes) ||
1962  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1963  parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1964  getDstMapAttrStrName(),
1965  result.attributes) ||
1966  parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1967  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1968  getTagMapAttrStrName(),
1969  result.attributes) ||
1970  parser.parseComma() || parser.parseOperand(numElementsInfo))
1971  return failure();
1972 
1973  // Parse optional stride and elements per stride.
1974  if (parser.parseTrailingOperandList(strideInfo))
1975  return failure();
1976 
1977  if (!strideInfo.empty() && strideInfo.size() != 2) {
1978  return parser.emitError(parser.getNameLoc(),
1979  "expected two stride related operands");
1980  }
1981  bool isStrided = strideInfo.size() == 2;
1982 
1983  if (parser.parseColonTypeList(types))
1984  return failure();
1985 
1986  if (types.size() != 3)
1987  return parser.emitError(parser.getNameLoc(), "expected three types");
1988 
1989  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1990  parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1991  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1992  parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1993  parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1994  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1995  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1996  return failure();
1997 
1998  if (isStrided) {
1999  if (parser.resolveOperands(strideInfo, indexType, result.operands))
2000  return failure();
2001  }
2002 
2003  // Check that src/dst/tag operand counts match their map.numInputs.
2004  if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
2005  dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
2006  tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
2007  return parser.emitError(parser.getNameLoc(),
2008  "memref operand count not equal to map.numInputs");
2009  return success();
2010 }
2011 
2012 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
2013  if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
2014  return emitOpError("expected DMA source to be of memref type");
2015  if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
2016  return emitOpError("expected DMA destination to be of memref type");
2017  if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
2018  return emitOpError("expected DMA tag to be of memref type");
2019 
2020  unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
2021  getDstMap().getNumInputs() +
2022  getTagMap().getNumInputs();
2023  if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
2024  getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
2025  return emitOpError("incorrect number of operands");
2026  }
2027 
2028  Region *scope = getAffineScope(*this);
2029  for (auto idx : getSrcIndices()) {
2030  if (!idx.getType().isIndex())
2031  return emitOpError("src index to dma_start must have 'index' type");
2032  if (!isValidAffineIndexOperand(idx, scope))
2033  return emitOpError(
2034  "src index must be a valid dimension or symbol identifier");
2035  }
2036  for (auto idx : getDstIndices()) {
2037  if (!idx.getType().isIndex())
2038  return emitOpError("dst index to dma_start must have 'index' type");
2039  if (!isValidAffineIndexOperand(idx, scope))
2040  return emitOpError(
2041  "dst index must be a valid dimension or symbol identifier");
2042  }
2043  for (auto idx : getTagIndices()) {
2044  if (!idx.getType().isIndex())
2045  return emitOpError("tag index to dma_start must have 'index' type");
2046  if (!isValidAffineIndexOperand(idx, scope))
2047  return emitOpError(
2048  "tag index must be a valid dimension or symbol identifier");
2049  }
2050  return success();
2051 }
2052 
2053 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
2054  SmallVectorImpl<OpFoldResult> &results) {
2055  /// dma_start(memrefcast) -> dma_start
2056  return memref::foldMemRefCast(*this);
2057 }
2058 
2059 void AffineDmaStartOp::getEffects(
2061  &effects) {
2062  effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
2064  effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
2066  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
2068 }
2069 
2070 //===----------------------------------------------------------------------===//
2071 // AffineDmaWaitOp
2072 //===----------------------------------------------------------------------===//
2073 
2074 // TODO: Check that map operands are loop IVs or symbols.
2075 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
2076  Value tagMemRef, AffineMap tagMap,
2077  ValueRange tagIndices, Value numElements) {
2078  result.addOperands(tagMemRef);
2079  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
2080  result.addOperands(tagIndices);
2081  result.addOperands(numElements);
2082 }
2083 
2084 AffineDmaWaitOp AffineDmaWaitOp::create(OpBuilder &builder, Location location,
2085  Value tagMemRef, AffineMap tagMap,
2086  ValueRange tagIndices,
2087  Value numElements) {
2088  mlir::OperationState state(location, getOperationName());
2089  build(builder, state, tagMemRef, tagMap, tagIndices, numElements);
2090  auto result = dyn_cast<AffineDmaWaitOp>(builder.create(state));
2091  assert(result && "builder didn't return the right type");
2092  return result;
2093 }
2094 
2095 AffineDmaWaitOp AffineDmaWaitOp::create(ImplicitLocOpBuilder &builder,
2096  Value tagMemRef, AffineMap tagMap,
2097  ValueRange tagIndices,
2098  Value numElements) {
2099  return create(builder, builder.getLoc(), tagMemRef, tagMap, tagIndices,
2100  numElements);
2101 }
2102 
2104  p << " " << getTagMemRef() << '[';
2105  SmallVector<Value, 2> operands(getTagIndices());
2106  p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
2107  p << "], ";
2109  p << " : " << getTagMemRef().getType();
2110 }
2111 
2112 // Parse AffineDmaWaitOp.
2113 // Eg:
2114 // affine.dma_wait %tag[%index], %num_elements
2115 // : memref<1 x i32, (d0) -> (d0), 4>
2116 //
2118  OperationState &result) {
2119  OpAsmParser::UnresolvedOperand tagMemRefInfo;
2120  AffineMapAttr tagMapAttr;
2122  Type type;
2123  auto indexType = parser.getBuilder().getIndexType();
2124  OpAsmParser::UnresolvedOperand numElementsInfo;
2125 
2126  // Parse tag memref, its map operands, and dma size.
2127  if (parser.parseOperand(tagMemRefInfo) ||
2128  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
2129  getTagMapAttrStrName(),
2130  result.attributes) ||
2131  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
2132  parser.parseColonType(type) ||
2133  parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
2134  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
2135  parser.resolveOperand(numElementsInfo, indexType, result.operands))
2136  return failure();
2137 
2138  if (!llvm::isa<MemRefType>(type))
2139  return parser.emitError(parser.getNameLoc(),
2140  "expected tag to be of memref type");
2141 
2142  if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
2143  return parser.emitError(parser.getNameLoc(),
2144  "tag memref operand count != to map.numInputs");
2145  return success();
2146 }
2147 
2148 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
2149  if (!llvm::isa<MemRefType>(getOperand(0).getType()))
2150  return emitOpError("expected DMA tag to be of memref type");
2151  Region *scope = getAffineScope(*this);
2152  for (auto idx : getTagIndices()) {
2153  if (!idx.getType().isIndex())
2154  return emitOpError("index to dma_wait must have 'index' type");
2155  if (!isValidAffineIndexOperand(idx, scope))
2156  return emitOpError(
2157  "index must be a valid dimension or symbol identifier");
2158  }
2159  return success();
2160 }
2161 
2162 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
2163  SmallVectorImpl<OpFoldResult> &results) {
2164  /// dma_wait(memrefcast) -> dma_wait
2165  return memref::foldMemRefCast(*this);
2166 }
2167 
2168 void AffineDmaWaitOp::getEffects(
2170  &effects) {
2171  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
2173 }
2174 
2175 //===----------------------------------------------------------------------===//
2176 // AffineForOp
2177 //===----------------------------------------------------------------------===//
2178 
2179 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
2180 /// bodyBuilder are empty/null, we include default terminator op.
2181 void AffineForOp::build(OpBuilder &builder, OperationState &result,
2182  ValueRange lbOperands, AffineMap lbMap,
2183  ValueRange ubOperands, AffineMap ubMap, int64_t step,
2184  ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
2185  assert(((!lbMap && lbOperands.empty()) ||
2186  lbOperands.size() == lbMap.getNumInputs()) &&
2187  "lower bound operand count does not match the affine map");
2188  assert(((!ubMap && ubOperands.empty()) ||
2189  ubOperands.size() == ubMap.getNumInputs()) &&
2190  "upper bound operand count does not match the affine map");
2191  assert(step > 0 && "step has to be a positive integer constant");
2192 
2193  OpBuilder::InsertionGuard guard(builder);
2194 
2195  // Set variadic segment sizes.
2196  result.addAttribute(
2197  getOperandSegmentSizeAttr(),
2198  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()),
2199  static_cast<int32_t>(ubOperands.size()),
2200  static_cast<int32_t>(iterArgs.size())}));
2201 
2202  for (Value val : iterArgs)
2203  result.addTypes(val.getType());
2204 
2205  // Add an attribute for the step.
2206  result.addAttribute(getStepAttrName(result.name),
2207  builder.getIntegerAttr(builder.getIndexType(), step));
2208 
2209  // Add the lower bound.
2210  result.addAttribute(getLowerBoundMapAttrName(result.name),
2211  AffineMapAttr::get(lbMap));
2212  result.addOperands(lbOperands);
2213 
2214  // Add the upper bound.
2215  result.addAttribute(getUpperBoundMapAttrName(result.name),
2216  AffineMapAttr::get(ubMap));
2217  result.addOperands(ubOperands);
2218 
2219  result.addOperands(iterArgs);
2220  // Create a region and a block for the body. The argument of the region is
2221  // the loop induction variable.
2222  Region *bodyRegion = result.addRegion();
2223  Block *bodyBlock = builder.createBlock(bodyRegion);
2224  Value inductionVar =
2225  bodyBlock->addArgument(builder.getIndexType(), result.location);
2226  for (Value val : iterArgs)
2227  bodyBlock->addArgument(val.getType(), val.getLoc());
2228 
2229  // Create the default terminator if the builder is not provided and if the
2230  // iteration arguments are not provided. Otherwise, leave this to the caller
2231  // because we don't know which values to return from the loop.
2232  if (iterArgs.empty() && !bodyBuilder) {
2233  ensureTerminator(*bodyRegion, builder, result.location);
2234  } else if (bodyBuilder) {
2235  OpBuilder::InsertionGuard guard(builder);
2236  builder.setInsertionPointToStart(bodyBlock);
2237  bodyBuilder(builder, result.location, inductionVar,
2238  bodyBlock->getArguments().drop_front());
2239  }
2240 }
2241 
2242 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
2243  int64_t ub, int64_t step, ValueRange iterArgs,
2244  BodyBuilderFn bodyBuilder) {
2245  auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
2246  auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
2247  return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
2248  bodyBuilder);
2249 }
2250 
2251 LogicalResult AffineForOp::verifyRegions() {
2252  // Check that the body defines as single block argument for the induction
2253  // variable.
2254  auto *body = getBody();
2255  if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
2256  return emitOpError("expected body to have a single index argument for the "
2257  "induction variable");
2258 
2259  // Verify that the bound operands are valid dimension/symbols.
2260  /// Lower bound.
2261  if (getLowerBoundMap().getNumInputs() > 0)
2263  getLowerBoundMap().getNumDims())))
2264  return failure();
2265  /// Upper bound.
2266  if (getUpperBoundMap().getNumInputs() > 0)
2268  getUpperBoundMap().getNumDims())))
2269  return failure();
2270  if (getLowerBoundMap().getNumResults() < 1)
2271  return emitOpError("expected lower bound map to have at least one result");
2272  if (getUpperBoundMap().getNumResults() < 1)
2273  return emitOpError("expected upper bound map to have at least one result");
2274 
2275  unsigned opNumResults = getNumResults();
2276  if (opNumResults == 0)
2277  return success();
2278 
2279  // If ForOp defines values, check that the number and types of the defined
2280  // values match ForOp initial iter operands and backedge basic block
2281  // arguments.
2282  if (getNumIterOperands() != opNumResults)
2283  return emitOpError(
2284  "mismatch between the number of loop-carried values and results");
2285  if (getNumRegionIterArgs() != opNumResults)
2286  return emitOpError(
2287  "mismatch between the number of basic block args and results");
2288 
2289  return success();
2290 }
2291 
2292 /// Parse a for operation loop bounds.
2293 static ParseResult parseBound(bool isLower, OperationState &result,
2294  OpAsmParser &p) {
2295  // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
2296  // the map has multiple results.
2297  bool failedToParsedMinMax =
2298  failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
2299 
2300  auto &builder = p.getBuilder();
2301  auto boundAttrStrName =
2302  isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)
2303  : AffineForOp::getUpperBoundMapAttrName(result.name);
2304 
2305  // Parse ssa-id as identity map.
2307  if (p.parseOperandList(boundOpInfos))
2308  return failure();
2309 
2310  if (!boundOpInfos.empty()) {
2311  // Check that only one operand was parsed.
2312  if (boundOpInfos.size() > 1)
2313  return p.emitError(p.getNameLoc(),
2314  "expected only one loop bound operand");
2315 
2316  // TODO: improve error message when SSA value is not of index type.
2317  // Currently it is 'use of value ... expects different type than prior uses'
2318  if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
2319  result.operands))
2320  return failure();
2321 
2322  // Create an identity map using symbol id. This representation is optimized
2323  // for storage. Analysis passes may expand it into a multi-dimensional map
2324  // if desired.
2325  AffineMap map = builder.getSymbolIdentityMap();
2326  result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
2327  return success();
2328  }
2329 
2330  // Get the attribute location.
2331  SMLoc attrLoc = p.getCurrentLocation();
2332 
2333  Attribute boundAttr;
2334  if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
2335  result.attributes))
2336  return failure();
2337 
2338  // Parse full form - affine map followed by dim and symbol list.
2339  if (auto affineMapAttr = dyn_cast<AffineMapAttr>(boundAttr)) {
2340  unsigned currentNumOperands = result.operands.size();
2341  unsigned numDims;
2342  if (parseDimAndSymbolList(p, result.operands, numDims))
2343  return failure();
2344 
2345  auto map = affineMapAttr.getValue();
2346  if (map.getNumDims() != numDims)
2347  return p.emitError(
2348  p.getNameLoc(),
2349  "dim operand count and affine map dim count must match");
2350 
2351  unsigned numDimAndSymbolOperands =
2352  result.operands.size() - currentNumOperands;
2353  if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
2354  return p.emitError(
2355  p.getNameLoc(),
2356  "symbol operand count and affine map symbol count must match");
2357 
2358  // If the map has multiple results, make sure that we parsed the min/max
2359  // prefix.
2360  if (map.getNumResults() > 1 && failedToParsedMinMax) {
2361  if (isLower) {
2362  return p.emitError(attrLoc, "lower loop bound affine map with "
2363  "multiple results requires 'max' prefix");
2364  }
2365  return p.emitError(attrLoc, "upper loop bound affine map with multiple "
2366  "results requires 'min' prefix");
2367  }
2368  return success();
2369  }
2370 
2371  // Parse custom assembly form.
2372  if (auto integerAttr = dyn_cast<IntegerAttr>(boundAttr)) {
2373  result.attributes.pop_back();
2374  result.addAttribute(
2375  boundAttrStrName,
2376  AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
2377  return success();
2378  }
2379 
2380  return p.emitError(
2381  p.getNameLoc(),
2382  "expected valid affine map representation for loop bounds");
2383 }
2384 
2385 ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
2386  auto &builder = parser.getBuilder();
2387  OpAsmParser::Argument inductionVariable;
2388  inductionVariable.type = builder.getIndexType();
2389  // Parse the induction variable followed by '='.
2390  if (parser.parseArgument(inductionVariable) || parser.parseEqual())
2391  return failure();
2392 
2393  // Parse loop bounds.
2394  int64_t numOperands = result.operands.size();
2395  if (parseBound(/*isLower=*/true, result, parser))
2396  return failure();
2397  int64_t numLbOperands = result.operands.size() - numOperands;
2398  if (parser.parseKeyword("to", " between bounds"))
2399  return failure();
2400  numOperands = result.operands.size();
2401  if (parseBound(/*isLower=*/false, result, parser))
2402  return failure();
2403  int64_t numUbOperands = result.operands.size() - numOperands;
2404 
2405  // Parse the optional loop step, we default to 1 if one is not present.
2406  if (parser.parseOptionalKeyword("step")) {
2407  result.addAttribute(
2408  getStepAttrName(result.name),
2409  builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
2410  } else {
2411  SMLoc stepLoc = parser.getCurrentLocation();
2412  IntegerAttr stepAttr;
2413  if (parser.parseAttribute(stepAttr, builder.getIndexType(),
2414  getStepAttrName(result.name).data(),
2415  result.attributes))
2416  return failure();
2417 
2418  if (stepAttr.getValue().isNegative())
2419  return parser.emitError(
2420  stepLoc,
2421  "expected step to be representable as a positive signed integer");
2422  }
2423 
2424  // Parse the optional initial iteration arguments.
2427 
2428  // Induction variable.
2429  regionArgs.push_back(inductionVariable);
2430 
2431  if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
2432  // Parse assignment list and results type list.
2433  if (parser.parseAssignmentList(regionArgs, operands) ||
2434  parser.parseArrowTypeList(result.types))
2435  return failure();
2436  // Resolve input operands.
2437  for (auto argOperandType :
2438  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2439  Type type = std::get<2>(argOperandType);
2440  std::get<0>(argOperandType).type = type;
2441  if (parser.resolveOperand(std::get<1>(argOperandType), type,
2442  result.operands))
2443  return failure();
2444  }
2445  }
2446 
2447  result.addAttribute(
2448  getOperandSegmentSizeAttr(),
2449  builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands),
2450  static_cast<int32_t>(numUbOperands),
2451  static_cast<int32_t>(operands.size())}));
2452 
2453  // Parse the body region.
2454  Region *body = result.addRegion();
2455  if (regionArgs.size() != result.types.size() + 1)
2456  return parser.emitError(
2457  parser.getNameLoc(),
2458  "mismatch between the number of loop-carried values and results");
2459  if (parser.parseRegion(*body, regionArgs))
2460  return failure();
2461 
2462  AffineForOp::ensureTerminator(*body, builder, result.location);
2463 
2464  // Parse the optional attribute list.
2465  return parser.parseOptionalAttrDict(result.attributes);
2466 }
2467 
2468 static void printBound(AffineMapAttr boundMap,
2469  Operation::operand_range boundOperands,
2470  const char *prefix, OpAsmPrinter &p) {
2471  AffineMap map = boundMap.getValue();
2472 
2473  // Check if this bound should be printed using custom assembly form.
2474  // The decision to restrict printing custom assembly form to trivial cases
2475  // comes from the will to roundtrip MLIR binary -> text -> binary in a
2476  // lossless way.
2477  // Therefore, custom assembly form parsing and printing is only supported for
2478  // zero-operand constant maps and single symbol operand identity maps.
2479  if (map.getNumResults() == 1) {
2480  AffineExpr expr = map.getResult(0);
2481 
2482  // Print constant bound.
2483  if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2484  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2485  p << constExpr.getValue();
2486  return;
2487  }
2488  }
2489 
2490  // Print bound that consists of a single SSA symbol if the map is over a
2491  // single symbol.
2492  if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2493  if (isa<AffineSymbolExpr>(expr)) {
2494  p.printOperand(*boundOperands.begin());
2495  return;
2496  }
2497  }
2498  } else {
2499  // Map has multiple results. Print 'min' or 'max' prefix.
2500  p << prefix << ' ';
2501  }
2502 
2503  // Print the map and its operands.
2504  p << boundMap;
2505  printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
2506  map.getNumDims(), p);
2507 }
2508 
2509 unsigned AffineForOp::getNumIterOperands() {
2510  AffineMap lbMap = getLowerBoundMapAttr().getValue();
2511  AffineMap ubMap = getUpperBoundMapAttr().getValue();
2512 
2513  return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2514 }
2515 
2516 std::optional<MutableArrayRef<OpOperand>>
2517 AffineForOp::getYieldedValuesMutable() {
2518  return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2519 }
2520 
2522  p << ' ';
2523  p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
2524  /*omitType=*/true);
2525  p << " = ";
2526  printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
2527  p << " to ";
2528  printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
2529 
2530  if (getStepAsInt() != 1)
2531  p << " step " << getStepAsInt();
2532 
2533  bool printBlockTerminators = false;
2534  if (getNumIterOperands() > 0) {
2535  p << " iter_args(";
2536  auto regionArgs = getRegionIterArgs();
2537  auto operands = getInits();
2538 
2539  llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2540  p << std::get<0>(it) << " = " << std::get<1>(it);
2541  });
2542  p << ") -> (" << getResultTypes() << ")";
2543  printBlockTerminators = true;
2544  }
2545 
2546  p << ' ';
2547  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2548  printBlockTerminators);
2550  (*this)->getAttrs(),
2551  /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()),
2552  getUpperBoundMapAttrName(getOperation()->getName()),
2553  getStepAttrName(getOperation()->getName()),
2554  getOperandSegmentSizeAttr()});
2555 }
2556 
2557 /// Fold the constant bounds of a loop.
2558 static LogicalResult foldLoopBounds(AffineForOp forOp) {
2559  auto foldLowerOrUpperBound = [&forOp](bool lower) {
2560  // Check to see if each of the operands is the result of a constant. If
2561  // so, get the value. If not, ignore it.
2562  SmallVector<Attribute, 8> operandConstants;
2563  auto boundOperands =
2564  lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2565  for (auto operand : boundOperands) {
2566  Attribute operandCst;
2567  matchPattern(operand, m_Constant(&operandCst));
2568  operandConstants.push_back(operandCst);
2569  }
2570 
2571  AffineMap boundMap =
2572  lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2573  assert(boundMap.getNumResults() >= 1 &&
2574  "bound maps should have at least one result");
2575  SmallVector<Attribute, 4> foldedResults;
2576  if (failed(boundMap.constantFold(operandConstants, foldedResults)))
2577  return failure();
2578 
2579  // Compute the max or min as applicable over the results.
2580  assert(!foldedResults.empty() && "bounds should have at least one result");
2581  auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2582  for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2583  auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2584  maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2585  : llvm::APIntOps::smin(maxOrMin, foldedResult);
2586  }
2587  lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2588  : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2589  return success();
2590  };
2591 
2592  // Try to fold the lower bound.
2593  bool folded = false;
2594  if (!forOp.hasConstantLowerBound())
2595  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
2596 
2597  // Try to fold the upper bound.
2598  if (!forOp.hasConstantUpperBound())
2599  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
2600  return success(folded);
2601 }
2602 
2603 /// Returns constant trip count in trivial cases.
2604 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2605  int64_t step = forOp.getStepAsInt();
2606  if (!forOp.hasConstantBounds() || step <= 0)
2607  return std::nullopt;
2608  int64_t lb = forOp.getConstantLowerBound();
2609  int64_t ub = forOp.getConstantUpperBound();
2610  return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2611 }
2612 
2613 /// Fold the empty loop.
2615  if (!llvm::hasSingleElement(*forOp.getBody()))
2616  return {};
2617  if (forOp.getNumResults() == 0)
2618  return {};
2619  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2620  if (tripCount == 0) {
2621  // The initial values of the iteration arguments would be the op's
2622  // results.
2623  return forOp.getInits();
2624  }
2625  SmallVector<Value, 4> replacements;
2626  auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2627  auto iterArgs = forOp.getRegionIterArgs();
2628  bool hasValDefinedOutsideLoop = false;
2629  bool iterArgsNotInOrder = false;
2630  for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2631  Value val = yieldOp.getOperand(i);
2632  BlockArgument *iterArgIt = llvm::find(iterArgs, val);
2633  // TODO: It should be possible to perform a replacement by computing the
2634  // last value of the IV based on the bounds and the step.
2635  if (val == forOp.getInductionVar())
2636  return {};
2637  if (iterArgIt == iterArgs.end()) {
2638  // `val` is defined outside of the loop.
2639  assert(forOp.isDefinedOutsideOfLoop(val) &&
2640  "must be defined outside of the loop");
2641  hasValDefinedOutsideLoop = true;
2642  replacements.push_back(val);
2643  } else {
2644  unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2645  if (pos != i)
2646  iterArgsNotInOrder = true;
2647  replacements.push_back(forOp.getInits()[pos]);
2648  }
2649  }
2650  // Bail out when the trip count is unknown and the loop returns any value
2651  // defined outside of the loop or any iterArg out of order.
2652  if (!tripCount.has_value() &&
2653  (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2654  return {};
2655  // Bail out when the loop iterates more than once and it returns any iterArg
2656  // out of order.
2657  if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2658  return {};
2659  return llvm::to_vector_of<OpFoldResult>(replacements);
2660 }
2661 
2662 /// Canonicalize the bounds of the given loop.
2663 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2664  SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2665  SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2666 
2667  auto lbMap = forOp.getLowerBoundMap();
2668  auto ubMap = forOp.getUpperBoundMap();
2669  auto prevLbMap = lbMap;
2670  auto prevUbMap = ubMap;
2671 
2672  composeAffineMapAndOperands(&lbMap, &lbOperands);
2673  canonicalizeMapAndOperands(&lbMap, &lbOperands);
2674  simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
2675  simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
2676  lbMap = removeDuplicateExprs(lbMap);
2677 
2678  composeAffineMapAndOperands(&ubMap, &ubOperands);
2679  canonicalizeMapAndOperands(&ubMap, &ubOperands);
2680  ubMap = removeDuplicateExprs(ubMap);
2681 
2682  // Any canonicalization change always leads to updated map(s).
2683  if (lbMap == prevLbMap && ubMap == prevUbMap)
2684  return failure();
2685 
2686  if (lbMap != prevLbMap)
2687  forOp.setLowerBound(lbOperands, lbMap);
2688  if (ubMap != prevUbMap)
2689  forOp.setUpperBound(ubOperands, ubMap);
2690  return success();
2691 }
2692 
2693 /// Returns true if the affine.for has zero iterations in trivial cases.
2694 static bool hasTrivialZeroTripCount(AffineForOp op) {
2695  return getTrivialConstantTripCount(op) == 0;
2696 }
2697 
2698 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2699  SmallVectorImpl<OpFoldResult> &results) {
2700  bool folded = succeeded(foldLoopBounds(*this));
2701  folded |= succeeded(canonicalizeLoopBounds(*this));
2702  if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
2703  // The initial values of the loop-carried variables (iter_args) are the
2704  // results of the op. But this must be avoided for an affine.for op that
2705  // does not return any results. Since ops that do not return results cannot
2706  // be folded away, we would enter an infinite loop of folds on the same
2707  // affine.for op.
2708  results.assign(getInits().begin(), getInits().end());
2709  folded = true;
2710  }
2712  if (!foldResults.empty()) {
2713  results.assign(foldResults);
2714  folded = true;
2715  }
2716  return success(folded);
2717 }
2718 
2719 OperandRange AffineForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2720  assert((successor.isParent() || successor.getSuccessor() == &getRegion()) &&
2721  "invalid region point");
2722 
2723  // The initial operands map to the loop arguments after the induction
2724  // variable or are forwarded to the results when the trip count is zero.
2725  return getInits();
2726 }
2727 
2728 void AffineForOp::getSuccessorRegions(
2730  assert((point.isParent() ||
2732  &getRegion()) &&
2733  "expected loop region");
2734  // The loop may typically branch back to its body or to the parent operation.
2735  // If the predecessor is the parent op and the trip count is known to be at
2736  // least one, branch into the body using the iterator arguments. And in cases
2737  // we know the trip count is zero, it can only branch back to its parent.
2738  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2739  if (tripCount.has_value()) {
2740  if (!point.isParent()) {
2741  // From the loop body, if the trip count is one, we can only branch back
2742  // to the parent.
2743  if (tripCount == 1) {
2744  regions.push_back(RegionSuccessor(getOperation(), getResults()));
2745  return;
2746  }
2747  if (tripCount == 0)
2748  return;
2749  } else {
2750  if (tripCount.value() > 0) {
2751  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2752  return;
2753  }
2754  if (tripCount.value() == 0) {
2755  regions.push_back(RegionSuccessor(getOperation(), getResults()));
2756  return;
2757  }
2758  }
2759  }
2760 
2761  // In all other cases, the loop may branch back to itself or the parent
2762  // operation.
2763  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2764  regions.push_back(RegionSuccessor(getOperation(), getResults()));
2765 }
2766 
2768  return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
2769 }
2770 
2772  return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap());
2773 }
2774 
2775 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2776  assert(lbOperands.size() == map.getNumInputs());
2777  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2778  getLowerBoundOperandsMutable().assign(lbOperands);
2779  setLowerBoundMap(map);
2780 }
2781 
2782 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2783  assert(ubOperands.size() == map.getNumInputs());
2784  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2785  getUpperBoundOperandsMutable().assign(ubOperands);
2786  setUpperBoundMap(map);
2787 }
2788 
2789 bool AffineForOp::hasConstantLowerBound() {
2790  return getLowerBoundMap().isSingleConstant();
2791 }
2792 
2793 bool AffineForOp::hasConstantUpperBound() {
2794  return getUpperBoundMap().isSingleConstant();
2795 }
2796 
2797 int64_t AffineForOp::getConstantLowerBound() {
2798  return getLowerBoundMap().getSingleConstantResult();
2799 }
2800 
2801 int64_t AffineForOp::getConstantUpperBound() {
2802  return getUpperBoundMap().getSingleConstantResult();
2803 }
2804 
2805 void AffineForOp::setConstantLowerBound(int64_t value) {
2806  setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2807 }
2808 
2809 void AffineForOp::setConstantUpperBound(int64_t value) {
2810  setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2811 }
2812 
2813 AffineForOp::operand_range AffineForOp::getControlOperands() {
2814  return {operand_begin(), operand_begin() + getLowerBoundOperands().size() +
2815  getUpperBoundOperands().size()};
2816 }
2817 
2818 bool AffineForOp::matchingBoundOperandList() {
2819  auto lbMap = getLowerBoundMap();
2820  auto ubMap = getUpperBoundMap();
2821  if (lbMap.getNumDims() != ubMap.getNumDims() ||
2822  lbMap.getNumSymbols() != ubMap.getNumSymbols())
2823  return false;
2824 
2825  unsigned numOperands = lbMap.getNumInputs();
2826  for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2827  // Compare Value 's.
2828  if (getOperand(i) != getOperand(numOperands + i))
2829  return false;
2830  }
2831  return true;
2832 }
2833 
2834 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
2835 
2836 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2837  return SmallVector<Value>{getInductionVar()};
2838 }
2839 
2840 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2841  if (!hasConstantLowerBound())
2842  return std::nullopt;
2843  OpBuilder b(getContext());
2845  OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2846 }
2847 
2848 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2849  OpBuilder b(getContext());
2851  OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
2852 }
2853 
2854 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2855  if (!hasConstantUpperBound())
2856  return {};
2857  OpBuilder b(getContext());
2859  OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2860 }
2861 
2862 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2863  RewriterBase &rewriter, ValueRange newInitOperands,
2864  bool replaceInitOperandUsesInLoop,
2865  const NewYieldValuesFn &newYieldValuesFn) {
2866  // Create a new loop before the existing one, with the extra operands.
2867  OpBuilder::InsertionGuard g(rewriter);
2868  rewriter.setInsertionPoint(getOperation());
2869  auto inits = llvm::to_vector(getInits());
2870  inits.append(newInitOperands.begin(), newInitOperands.end());
2871  AffineForOp newLoop = AffineForOp::create(
2872  rewriter, getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2873  getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
2874 
2875  // Generate the new yield values and append them to the scf.yield operation.
2876  auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2877  ArrayRef<BlockArgument> newIterArgs =
2878  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2879  {
2880  OpBuilder::InsertionGuard g(rewriter);
2881  rewriter.setInsertionPoint(yieldOp);
2882  SmallVector<Value> newYieldedValues =
2883  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2884  assert(newInitOperands.size() == newYieldedValues.size() &&
2885  "expected as many new yield values as new iter operands");
2886  rewriter.modifyOpInPlace(yieldOp, [&]() {
2887  yieldOp.getOperandsMutable().append(newYieldedValues);
2888  });
2889  }
2890 
2891  // Move the loop body to the new op.
2892  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2893  newLoop.getBody()->getArguments().take_front(
2894  getBody()->getNumArguments()));
2895 
2896  if (replaceInitOperandUsesInLoop) {
2897  // Replace all uses of `newInitOperands` with the corresponding basic block
2898  // arguments.
2899  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2900  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
2901  [&](OpOperand &use) {
2902  Operation *user = use.getOwner();
2903  return newLoop->isProperAncestor(user);
2904  });
2905  }
2906  }
2907 
2908  // Replace the old loop.
2909  rewriter.replaceOp(getOperation(),
2910  newLoop->getResults().take_front(getNumResults()));
2911  return cast<LoopLikeOpInterface>(newLoop.getOperation());
2912 }
2913 
2914 Speculation::Speculatability AffineForOp::getSpeculatability() {
2915  // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2916  // Start and End.
2917  //
2918  // For Step != 1, the loop may not terminate. We can add more smarts here if
2919  // needed.
2920  return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable
2922 }
2923 
2924 /// Returns true if the provided value is the induction variable of a
2925 /// AffineForOp.
2927  return getForInductionVarOwner(val) != AffineForOp();
2928 }
2929 
2931  return getAffineParallelInductionVarOwner(val) != nullptr;
2932 }
2933 
2936 }
2937 
2939  auto ivArg = dyn_cast<BlockArgument>(val);
2940  if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2941  return AffineForOp();
2942  if (auto forOp =
2943  ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2944  // Check to make sure `val` is the induction variable, not an iter_arg.
2945  return forOp.getInductionVar() == val ? forOp : AffineForOp();
2946  return AffineForOp();
2947 }
2948 
2950  auto ivArg = dyn_cast<BlockArgument>(val);
2951  if (!ivArg || !ivArg.getOwner())
2952  return nullptr;
2953  Operation *containingOp = ivArg.getOwner()->getParentOp();
2954  auto parallelOp = dyn_cast_if_present<AffineParallelOp>(containingOp);
2955  if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2956  return parallelOp;
2957  return nullptr;
2958 }
2959 
2960 /// Extracts the induction variables from a list of AffineForOps and returns
2961 /// them.
2963  SmallVectorImpl<Value> *ivs) {
2964  ivs->reserve(forInsts.size());
2965  for (auto forInst : forInsts)
2966  ivs->push_back(forInst.getInductionVar());
2967 }
2968 
2971  ivs.reserve(affineOps.size());
2972  for (Operation *op : affineOps) {
2973  // Add constraints from forOp's bounds.
2974  if (auto forOp = dyn_cast<AffineForOp>(op))
2975  ivs.push_back(forOp.getInductionVar());
2976  else if (auto parallelOp = dyn_cast<AffineParallelOp>(op))
2977  for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2978  ivs.push_back(parallelOp.getBody()->getArgument(i));
2979  }
2980 }
2981 
2982 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2983 /// operations.
2984 template <typename BoundListTy, typename LoopCreatorTy>
2986  OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2987  ArrayRef<int64_t> steps,
2988  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2989  LoopCreatorTy &&loopCreatorFn) {
2990  assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2991  assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2992 
2993  // If there are no loops to be constructed, construct the body anyway.
2994  OpBuilder::InsertionGuard guard(builder);
2995  if (lbs.empty()) {
2996  if (bodyBuilderFn)
2997  bodyBuilderFn(builder, loc, ValueRange());
2998  return;
2999  }
3000 
3001  // Create the loops iteratively and store the induction variables.
3003  ivs.reserve(lbs.size());
3004  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
3005  // Callback for creating the loop body, always creates the terminator.
3006  auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
3007  ValueRange iterArgs) {
3008  ivs.push_back(iv);
3009  // In the innermost loop, call the body builder.
3010  if (i == e - 1 && bodyBuilderFn) {
3011  OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
3012  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
3013  }
3014  AffineYieldOp::create(nestedBuilder, nestedLoc);
3015  };
3016 
3017  // Delegate actual loop creation to the callback in order to dispatch
3018  // between constant- and variable-bound loops.
3019  auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
3020  builder.setInsertionPointToStart(loop.getBody());
3021  }
3022 }
3023 
3024 /// Creates an affine loop from the bounds known to be constants.
3025 static AffineForOp
3027  int64_t ub, int64_t step,
3028  AffineForOp::BodyBuilderFn bodyBuilderFn) {
3029  return AffineForOp::create(builder, loc, lb, ub, step,
3030  /*iterArgs=*/ValueRange(), bodyBuilderFn);
3031 }
3032 
3033 /// Creates an affine loop from the bounds that may or may not be constants.
3034 static AffineForOp
3036  int64_t step,
3037  AffineForOp::BodyBuilderFn bodyBuilderFn) {
3038  std::optional<int64_t> lbConst = getConstantIntValue(lb);
3039  std::optional<int64_t> ubConst = getConstantIntValue(ub);
3040  if (lbConst && ubConst)
3041  return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
3042  ubConst.value(), step, bodyBuilderFn);
3043  return AffineForOp::create(builder, loc, lb, builder.getDimIdentityMap(), ub,
3044  builder.getDimIdentityMap(), step,
3045  /*iterArgs=*/ValueRange(), bodyBuilderFn);
3046 }
3047 
3049  OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
3051  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
3052  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
3054 }
3055 
3057  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
3058  ArrayRef<int64_t> steps,
3059  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
3060  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
3062 }
3063 
3064 //===----------------------------------------------------------------------===//
3065 // AffineIfOp
3066 //===----------------------------------------------------------------------===//
3067 
3068 namespace {
3069 /// Remove else blocks that have nothing other than a zero value yield.
3070 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
3072 
3073  LogicalResult matchAndRewrite(AffineIfOp ifOp,
3074  PatternRewriter &rewriter) const override {
3075  if (ifOp.getElseRegion().empty() ||
3076  !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
3077  return failure();
3078 
3079  rewriter.startOpModification(ifOp);
3080  rewriter.eraseBlock(ifOp.getElseBlock());
3081  rewriter.finalizeOpModification(ifOp);
3082  return success();
3083  }
3084 };
3085 
3086 /// Removes affine.if cond if the condition is always true or false in certain
3087 /// trivial cases. Promotes the then/else block in the parent operation block.
3088 struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
3090 
3091  LogicalResult matchAndRewrite(AffineIfOp op,
3092  PatternRewriter &rewriter) const override {
3093 
3094  auto isTriviallyFalse = [](IntegerSet iSet) {
3095  return iSet.isEmptyIntegerSet();
3096  };
3097 
3098  auto isTriviallyTrue = [](IntegerSet iSet) {
3099  return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
3100  iSet.getConstraint(0) == 0);
3101  };
3102 
3103  IntegerSet affineIfConditions = op.getIntegerSet();
3104  Block *blockToMove;
3105  if (isTriviallyFalse(affineIfConditions)) {
3106  // The absence, or equivalently, the emptiness of the else region need not
3107  // be checked when affine.if is returning results because if an affine.if
3108  // operation is returning results, it always has a non-empty else region.
3109  if (op.getNumResults() == 0 && !op.hasElse()) {
3110  // If the else region is absent, or equivalently, empty, remove the
3111  // affine.if operation (which is not returning any results).
3112  rewriter.eraseOp(op);
3113  return success();
3114  }
3115  blockToMove = op.getElseBlock();
3116  } else if (isTriviallyTrue(affineIfConditions)) {
3117  blockToMove = op.getThenBlock();
3118  } else {
3119  return failure();
3120  }
3121  Operation *blockToMoveTerminator = blockToMove->getTerminator();
3122  // Promote the "blockToMove" block to the parent operation block between the
3123  // prologue and epilogue of "op".
3124  rewriter.inlineBlockBefore(blockToMove, op);
3125  // Replace the "op" operation with the operands of the
3126  // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
3127  // the affine.yield operation present in the "blockToMove" block. It has no
3128  // operands when affine.if is not returning results and therefore, in that
3129  // case, replaceOp just erases "op". When affine.if is not returning
3130  // results, the affine.yield operation can be omitted. It gets inserted
3131  // implicitly.
3132  rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
3133  // Erase the "blockToMoveTerminator" operation since it is now in the parent
3134  // operation block, which already has its own terminator.
3135  rewriter.eraseOp(blockToMoveTerminator);
3136  return success();
3137  }
3138 };
3139 } // namespace
3140 
3141 /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
3142 /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
3143 void AffineIfOp::getSuccessorRegions(
3145  // If the predecessor is an AffineIfOp, then branching into both `then` and
3146  // `else` region is valid.
3147  if (point.isParent()) {
3148  regions.reserve(2);
3149  regions.push_back(
3150  RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
3151  // If the "else" region is empty, branch bach into parent.
3152  if (getElseRegion().empty()) {
3153  regions.push_back(RegionSuccessor(getOperation(), getResults()));
3154  } else {
3155  regions.push_back(
3156  RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
3157  }
3158  return;
3159  }
3160 
3161  // If the predecessor is the `else`/`then` region, then branching into parent
3162  // op is valid.
3163  regions.push_back(RegionSuccessor(getOperation(), getResults()));
3164 }
3165 
3166 LogicalResult AffineIfOp::verify() {
3167  // Verify that we have a condition attribute.
3168  // FIXME: This should be specified in the arguments list in ODS.
3169  auto conditionAttr =
3170  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3171  if (!conditionAttr)
3172  return emitOpError("requires an integer set attribute named 'condition'");
3173 
3174  // Verify that there are enough operands for the condition.
3175  IntegerSet condition = conditionAttr.getValue();
3176  if (getNumOperands() != condition.getNumInputs())
3177  return emitOpError("operand count and condition integer set dimension and "
3178  "symbol count must match");
3179 
3180  // Verify that the operands are valid dimension/symbols.
3181  if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
3182  condition.getNumDims())))
3183  return failure();
3184 
3185  return success();
3186 }
3187 
3188 ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
3189  // Parse the condition attribute set.
3190  IntegerSetAttr conditionAttr;
3191  unsigned numDims;
3192  if (parser.parseAttribute(conditionAttr,
3193  AffineIfOp::getConditionAttrStrName(),
3194  result.attributes) ||
3195  parseDimAndSymbolList(parser, result.operands, numDims))
3196  return failure();
3197 
3198  // Verify the condition operands.
3199  auto set = conditionAttr.getValue();
3200  if (set.getNumDims() != numDims)
3201  return parser.emitError(
3202  parser.getNameLoc(),
3203  "dim operand count and integer set dim count must match");
3204  if (numDims + set.getNumSymbols() != result.operands.size())
3205  return parser.emitError(
3206  parser.getNameLoc(),
3207  "symbol operand count and integer set symbol count must match");
3208 
3209  if (parser.parseOptionalArrowTypeList(result.types))
3210  return failure();
3211 
3212  // Create the regions for 'then' and 'else'. The latter must be created even
3213  // if it remains empty for the validity of the operation.
3214  result.regions.reserve(2);
3215  Region *thenRegion = result.addRegion();
3216  Region *elseRegion = result.addRegion();
3217 
3218  // Parse the 'then' region.
3219  if (parser.parseRegion(*thenRegion, {}, {}))
3220  return failure();
3221  AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
3222  result.location);
3223 
3224  // If we find an 'else' keyword then parse the 'else' region.
3225  if (!parser.parseOptionalKeyword("else")) {
3226  if (parser.parseRegion(*elseRegion, {}, {}))
3227  return failure();
3228  AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
3229  result.location);
3230  }
3231 
3232  // Parse the optional attribute list.
3233  if (parser.parseOptionalAttrDict(result.attributes))
3234  return failure();
3235 
3236  return success();
3237 }
3238 
3240  auto conditionAttr =
3241  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3242  p << " " << conditionAttr;
3243  printDimAndSymbolList(operand_begin(), operand_end(),
3244  conditionAttr.getValue().getNumDims(), p);
3245  p.printOptionalArrowTypeList(getResultTypes());
3246  p << ' ';
3247  p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
3248  /*printBlockTerminators=*/getNumResults());
3249 
3250  // Print the 'else' regions if it has any blocks.
3251  auto &elseRegion = this->getElseRegion();
3252  if (!elseRegion.empty()) {
3253  p << " else ";
3254  p.printRegion(elseRegion,
3255  /*printEntryBlockArgs=*/false,
3256  /*printBlockTerminators=*/getNumResults());
3257  }
3258 
3259  // Print the attribute list.
3260  p.printOptionalAttrDict((*this)->getAttrs(),
3261  /*elidedAttrs=*/getConditionAttrStrName());
3262 }
3263 
3264 IntegerSet AffineIfOp::getIntegerSet() {
3265  return (*this)
3266  ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
3267  .getValue();
3268 }
3269 
3270 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
3271  (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
3272 }
3273 
3274 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
3275  setIntegerSet(set);
3276  (*this)->setOperands(operands);
3277 }
3278 
3279 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
3280  TypeRange resultTypes, IntegerSet set, ValueRange args,
3281  bool withElseRegion) {
3282  assert(resultTypes.empty() || withElseRegion);
3283  OpBuilder::InsertionGuard guard(builder);
3284 
3285  result.addTypes(resultTypes);
3286  result.addOperands(args);
3287  result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
3288 
3289  Region *thenRegion = result.addRegion();
3290  builder.createBlock(thenRegion);
3291  if (resultTypes.empty())
3292  AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
3293 
3294  Region *elseRegion = result.addRegion();
3295  if (withElseRegion) {
3296  builder.createBlock(elseRegion);
3297  if (resultTypes.empty())
3298  AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
3299  }
3300 }
3301 
3302 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
3303  IntegerSet set, ValueRange args, bool withElseRegion) {
3304  AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
3305  withElseRegion);
3306 }
3307 
3308 /// Compose any affine.apply ops feeding into `operands` of the integer set
3309 /// `set` by composing the maps of such affine.apply ops with the integer
3310 /// set constraints.
3312  SmallVectorImpl<Value> &operands,
3313  bool composeAffineMin = false) {
3314  // We will simply reuse the API of the map composition by viewing the LHSs of
3315  // the equalities and inequalities of `set` as the affine exprs of an affine
3316  // map. Convert to equivalent map, compose, and convert back to set.
3317  auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(),
3318  set.getConstraints(), set.getContext());
3319  // Check if any composition is possible.
3320  if (llvm::none_of(operands,
3321  [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
3322  return;
3323 
3324  composeAffineMapAndOperands(&map, &operands, composeAffineMin);
3325  set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
3326  set.getEqFlags());
3327 }
3328 
3329 /// Canonicalize an affine if op's conditional (integer set + operands).
3330 LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3331  auto set = getIntegerSet();
3332  SmallVector<Value, 4> operands(getOperands());
3333  composeSetAndOperands(set, operands);
3334  canonicalizeSetAndOperands(&set, &operands);
3335 
3336  // Check if the canonicalization or composition led to any change.
3337  if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
3338  return failure();
3339 
3340  setConditional(set, operands);
3341  return success();
3342 }
3343 
3344 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
3345  MLIRContext *context) {
3346  results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
3347 }
3348 
3349 //===----------------------------------------------------------------------===//
3350 // AffineLoadOp
3351 //===----------------------------------------------------------------------===//
3352 
3353 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3354  AffineMap map, ValueRange operands) {
3355  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
3356  result.addOperands(operands);
3357  if (map)
3358  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3359  auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
3360  result.types.push_back(memrefType.getElementType());
3361 }
3362 
3363 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3364  Value memref, AffineMap map, ValueRange mapOperands) {
3365  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3366  result.addOperands(memref);
3367  result.addOperands(mapOperands);
3368  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3369  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3370  result.types.push_back(memrefType.getElementType());
3371 }
3372 
3373 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3374  Value memref, ValueRange indices) {
3375  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3376  int64_t rank = memrefType.getRank();
3377  // Create identity map for memrefs with at least one dimension or () -> ()
3378  // for zero-dimensional memrefs.
3379  auto map =
3380  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3381  build(builder, result, memref, map, indices);
3382 }
3383 
3384 ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
3385  auto &builder = parser.getBuilder();
3386  auto indexTy = builder.getIndexType();
3387 
3388  MemRefType type;
3389  OpAsmParser::UnresolvedOperand memrefInfo;
3390  AffineMapAttr mapAttr;
3392  return failure(
3393  parser.parseOperand(memrefInfo) ||
3394  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3395  AffineLoadOp::getMapAttrStrName(),
3396  result.attributes) ||
3397  parser.parseOptionalAttrDict(result.attributes) ||
3398  parser.parseColonType(type) ||
3399  parser.resolveOperand(memrefInfo, type, result.operands) ||
3400  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3401  parser.addTypeToList(type.getElementType(), result.types));
3402 }
3403 
3405  p << " " << getMemRef() << '[';
3406  if (AffineMapAttr mapAttr =
3407  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3408  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3409  p << ']';
3410  p.printOptionalAttrDict((*this)->getAttrs(),
3411  /*elidedAttrs=*/{getMapAttrStrName()});
3412  p << " : " << getMemRefType();
3413 }
3414 
3415 /// Verify common indexing invariants of affine.load, affine.store,
3416 /// affine.vector_load and affine.vector_store.
3417 template <typename AffineMemOpTy>
3418 static LogicalResult
3419 verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr,
3420  Operation::operand_range mapOperands,
3421  MemRefType memrefType, unsigned numIndexOperands) {
3422  AffineMap map = mapAttr.getValue();
3423  if (map.getNumResults() != memrefType.getRank())
3424  return op->emitOpError("affine map num results must equal memref rank");
3425  if (map.getNumInputs() != numIndexOperands)
3426  return op->emitOpError("expects as many subscripts as affine map inputs");
3427 
3428  for (auto idx : mapOperands) {
3429  if (!idx.getType().isIndex())
3430  return op->emitOpError("index to load must have 'index' type");
3431  }
3432  if (failed(verifyDimAndSymbolIdentifiers(op, mapOperands, map.getNumDims())))
3433  return failure();
3434 
3435  return success();
3436 }
3437 
3438 LogicalResult AffineLoadOp::verify() {
3439  auto memrefType = getMemRefType();
3440  if (getType() != memrefType.getElementType())
3441  return emitOpError("result type must match element type of memref");
3442 
3444  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3445  getMapOperands(), memrefType,
3446  /*numIndexOperands=*/getNumOperands() - 1)))
3447  return failure();
3448 
3449  return success();
3450 }
3451 
3452 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3453  MLIRContext *context) {
3454  results.add<SimplifyAffineOp<AffineLoadOp>>(context);
3455 }
3456 
3457 OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3458  /// load(memrefcast) -> load
3459  if (succeeded(memref::foldMemRefCast(*this)))
3460  return getResult();
3461 
3462  // Fold load from a global constant memref.
3463  auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3464  if (!getGlobalOp)
3465  return {};
3466  // Get to the memref.global defining the symbol.
3467  auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3468  if (!symbolTableOp)
3469  return {};
3470  auto global = dyn_cast_or_null<memref::GlobalOp>(
3471  SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
3472  if (!global)
3473  return {};
3474 
3475  // Check if the global memref is a constant.
3476  auto cstAttr =
3477  dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3478  if (!cstAttr)
3479  return {};
3480  // If it's a splat constant, we can fold irrespective of indices.
3481  if (auto splatAttr = dyn_cast<SplatElementsAttr>(cstAttr))
3482  return splatAttr.getSplatValue<Attribute>();
3483  // Otherwise, we can fold only if we know the indices.
3484  if (!getAffineMap().isConstant())
3485  return {};
3486  auto indices = llvm::to_vector<4>(
3487  llvm::map_range(getAffineMap().getConstantResults(),
3488  [](int64_t v) -> uint64_t { return v; }));
3489  return cstAttr.getValues<Attribute>()[indices];
3490 }
3491 
3492 //===----------------------------------------------------------------------===//
3493 // AffineStoreOp
3494 //===----------------------------------------------------------------------===//
3495 
3496 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3497  Value valueToStore, Value memref, AffineMap map,
3498  ValueRange mapOperands) {
3499  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3500  result.addOperands(valueToStore);
3501  result.addOperands(memref);
3502  result.addOperands(mapOperands);
3503  result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3504 }
3505 
3506 // Use identity map.
3507 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3508  Value valueToStore, Value memref,
3509  ValueRange indices) {
3510  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3511  int64_t rank = memrefType.getRank();
3512  // Create identity map for memrefs with at least one dimension or () -> ()
3513  // for zero-dimensional memrefs.
3514  auto map =
3515  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3516  build(builder, result, valueToStore, memref, map, indices);
3517 }
3518 
3519 ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
3520  auto indexTy = parser.getBuilder().getIndexType();
3521 
3522  MemRefType type;
3523  OpAsmParser::UnresolvedOperand storeValueInfo;
3524  OpAsmParser::UnresolvedOperand memrefInfo;
3525  AffineMapAttr mapAttr;
3527  return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3528  parser.parseOperand(memrefInfo) ||
3529  parser.parseAffineMapOfSSAIds(
3530  mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3531  result.attributes) ||
3532  parser.parseOptionalAttrDict(result.attributes) ||
3533  parser.parseColonType(type) ||
3534  parser.resolveOperand(storeValueInfo, type.getElementType(),
3535  result.operands) ||
3536  parser.resolveOperand(memrefInfo, type, result.operands) ||
3537  parser.resolveOperands(mapOperands, indexTy, result.operands));
3538 }
3539 
3541  p << " " << getValueToStore();
3542  p << ", " << getMemRef() << '[';
3543  if (AffineMapAttr mapAttr =
3544  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3545  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3546  p << ']';
3547  p.printOptionalAttrDict((*this)->getAttrs(),
3548  /*elidedAttrs=*/{getMapAttrStrName()});
3549  p << " : " << getMemRefType();
3550 }
3551 
3552 LogicalResult AffineStoreOp::verify() {
3553  // The value to store must have the same type as memref element type.
3554  auto memrefType = getMemRefType();
3555  if (getValueToStore().getType() != memrefType.getElementType())
3556  return emitOpError(
3557  "value to store must have the same type as memref element type");
3558 
3560  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3561  getMapOperands(), memrefType,
3562  /*numIndexOperands=*/getNumOperands() - 2)))
3563  return failure();
3564 
3565  return success();
3566 }
3567 
3568 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3569  MLIRContext *context) {
3570  results.add<SimplifyAffineOp<AffineStoreOp>>(context);
3571 }
3572 
3573 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3574  SmallVectorImpl<OpFoldResult> &results) {
3575  /// store(memrefcast) -> store
3576  return memref::foldMemRefCast(*this, getValueToStore());
3577 }
3578 
3579 //===----------------------------------------------------------------------===//
3580 // AffineMinMaxOpBase
3581 //===----------------------------------------------------------------------===//
3582 
3583 template <typename T>
3584 static LogicalResult verifyAffineMinMaxOp(T op) {
3585  // Verify that operand count matches affine map dimension and symbol count.
3586  if (op.getNumOperands() !=
3587  op.getMap().getNumDims() + op.getMap().getNumSymbols())
3588  return op.emitOpError(
3589  "operand count and affine map dimension and symbol count must match");
3590 
3591  if (op.getMap().getNumResults() == 0)
3592  return op.emitOpError("affine map expect at least one result");
3593  return success();
3594 }
3595 
3596 template <typename T>
3597 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3598  p << ' ' << op->getAttr(T::getMapAttrStrName());
3599  auto operands = op.getOperands();
3600  unsigned numDims = op.getMap().getNumDims();
3601  p << '(' << operands.take_front(numDims) << ')';
3602 
3603  if (operands.size() != numDims)
3604  p << '[' << operands.drop_front(numDims) << ']';
3605  p.printOptionalAttrDict(op->getAttrs(),
3606  /*elidedAttrs=*/{T::getMapAttrStrName()});
3607 }
3608 
3609 template <typename T>
3610 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
3611  OperationState &result) {
3612  auto &builder = parser.getBuilder();
3613  auto indexType = builder.getIndexType();
3616  AffineMapAttr mapAttr;
3617  return failure(
3618  parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3619  result.attributes) ||
3620  parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) ||
3621  parser.parseOperandList(symInfos,
3623  parser.parseOptionalAttrDict(result.attributes) ||
3624  parser.resolveOperands(dimInfos, indexType, result.operands) ||
3625  parser.resolveOperands(symInfos, indexType, result.operands) ||
3626  parser.addTypeToList(indexType, result.types));
3627 }
3628 
3629 /// Fold an affine min or max operation with the given operands. The operand
3630 /// list may contain nulls, which are interpreted as the operand not being a
3631 /// constant.
3632 template <typename T>
3634  static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3635  "expected affine min or max op");
3636 
3637  // Fold the affine map.
3638  // TODO: Fold more cases:
3639  // min(some_affine, some_affine + constant, ...), etc.
3640  SmallVector<int64_t, 2> results;
3641  auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3642 
3643  if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3644  return op.getOperand(0);
3645 
3646  // If some of the map results are not constant, try changing the map in-place.
3647  if (results.empty()) {
3648  // If the map is the same, report that folding did not happen.
3649  if (foldedMap == op.getMap())
3650  return {};
3651  op->setAttr("map", AffineMapAttr::get(foldedMap));
3652  return op.getResult();
3653  }
3654 
3655  // Otherwise, completely fold the op into a constant.
3656  auto resultIt = std::is_same<T, AffineMinOp>::value
3657  ? llvm::min_element(results)
3658  : llvm::max_element(results);
3659  if (resultIt == results.end())
3660  return {};
3661  return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3662 }
3663 
3664 /// Remove duplicated expressions in affine min/max ops.
3665 template <typename T>
3668 
3669  LogicalResult matchAndRewrite(T affineOp,
3670  PatternRewriter &rewriter) const override {
3671  AffineMap oldMap = affineOp.getAffineMap();
3672 
3673  SmallVector<AffineExpr, 4> newExprs;
3674  for (AffineExpr expr : oldMap.getResults()) {
3675  // This is a linear scan over newExprs, but it should be fine given that
3676  // we typically just have a few expressions per op.
3677  if (!llvm::is_contained(newExprs, expr))
3678  newExprs.push_back(expr);
3679  }
3680 
3681  if (newExprs.size() == oldMap.getNumResults())
3682  return failure();
3683 
3684  auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
3685  newExprs, rewriter.getContext());
3686  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3687 
3688  return success();
3689  }
3690 };
3691 
3692 /// Merge an affine min/max op to its consumers if its consumer is also an
3693 /// affine min/max op.
3694 ///
3695 /// This pattern requires the producer affine min/max op is bound to a
3696 /// dimension/symbol that is used as a standalone expression in the consumer
3697 /// affine op's map.
3698 ///
3699 /// For example, a pattern like the following:
3700 ///
3701 /// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3702 /// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3703 ///
3704 /// Can be turned into:
3705 ///
3706 /// %1 = affine.min affine_map<
3707 /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3708 template <typename T>
3711 
3712  LogicalResult matchAndRewrite(T affineOp,
3713  PatternRewriter &rewriter) const override {
3714  AffineMap oldMap = affineOp.getAffineMap();
3715  ValueRange dimOperands =
3716  affineOp.getMapOperands().take_front(oldMap.getNumDims());
3717  ValueRange symOperands =
3718  affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3719 
3720  auto newDimOperands = llvm::to_vector<8>(dimOperands);
3721  auto newSymOperands = llvm::to_vector<8>(symOperands);
3722  SmallVector<AffineExpr, 4> newExprs;
3723  SmallVector<T, 4> producerOps;
3724 
3725  // Go over each expression to see whether it's a single dimension/symbol
3726  // with the corresponding operand which is the result of another affine
3727  // min/max op. If So it can be merged into this affine op.
3728  for (AffineExpr expr : oldMap.getResults()) {
3729  if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3730  Value symValue = symOperands[symExpr.getPosition()];
3731  if (auto producerOp = symValue.getDefiningOp<T>()) {
3732  producerOps.push_back(producerOp);
3733  continue;
3734  }
3735  } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3736  Value dimValue = dimOperands[dimExpr.getPosition()];
3737  if (auto producerOp = dimValue.getDefiningOp<T>()) {
3738  producerOps.push_back(producerOp);
3739  continue;
3740  }
3741  }
3742  // For the above cases we will remove the expression by merging the
3743  // producer affine min/max's affine expressions. Otherwise we need to
3744  // keep the existing expression.
3745  newExprs.push_back(expr);
3746  }
3747 
3748  if (producerOps.empty())
3749  return failure();
3750 
3751  unsigned numUsedDims = oldMap.getNumDims();
3752  unsigned numUsedSyms = oldMap.getNumSymbols();
3753 
3754  // Now go over all producer affine ops and merge their expressions.
3755  for (T producerOp : producerOps) {
3756  AffineMap producerMap = producerOp.getAffineMap();
3757  unsigned numProducerDims = producerMap.getNumDims();
3758  unsigned numProducerSyms = producerMap.getNumSymbols();
3759 
3760  // Collect all dimension/symbol values.
3761  ValueRange dimValues =
3762  producerOp.getMapOperands().take_front(numProducerDims);
3763  ValueRange symValues =
3764  producerOp.getMapOperands().take_back(numProducerSyms);
3765  newDimOperands.append(dimValues.begin(), dimValues.end());
3766  newSymOperands.append(symValues.begin(), symValues.end());
3767 
3768  // For expressions we need to shift to avoid overlap.
3769  for (AffineExpr expr : producerMap.getResults()) {
3770  newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3771  .shiftSymbols(numProducerSyms, numUsedSyms));
3772  }
3773 
3774  numUsedDims += numProducerDims;
3775  numUsedSyms += numProducerSyms;
3776  }
3777 
3778  auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
3779  rewriter.getContext());
3780  auto newOperands =
3781  llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3782  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3783 
3784  return success();
3785  }
3786 };
3787 
3788 /// Canonicalize the result expression order of an affine map and return success
3789 /// if the order changed.
3790 ///
3791 /// The function flattens the map's affine expressions to coefficient arrays and
3792 /// sorts them in lexicographic order. A coefficient array contains a multiplier
3793 /// for every dimension/symbol and a constant term. The canonicalization fails
3794 /// if a result expression is not pure or if the flattening requires local
3795 /// variables that, unlike dimensions and symbols, have no global order.
3796 static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
3797  SmallVector<SmallVector<int64_t>> flattenedExprs;
3798  for (const AffineExpr &resultExpr : map.getResults()) {
3799  // Fail if the expression is not pure.
3800  if (!resultExpr.isPureAffine())
3801  return failure();
3802 
3803  SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3804  auto flattenResult = flattener.walkPostOrder(resultExpr);
3805  if (failed(flattenResult))
3806  return failure();
3807 
3808  // Fail if the flattened expression has local variables.
3809  if (flattener.operandExprStack.back().size() !=
3810  map.getNumDims() + map.getNumSymbols() + 1)
3811  return failure();
3812 
3813  flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3814  flattener.operandExprStack.back().end());
3815  }
3816 
3817  // Fail if sorting is not necessary.
3818  if (llvm::is_sorted(flattenedExprs))
3819  return failure();
3820 
3821  // Reorder the result expressions according to their flattened form.
3822  SmallVector<unsigned> resultPermutation =
3823  llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults()));
3824  llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3825  return flattenedExprs[lhs] < flattenedExprs[rhs];
3826  });
3827  SmallVector<AffineExpr> newExprs;
3828  for (unsigned idx : resultPermutation)
3829  newExprs.push_back(map.getResult(idx));
3830 
3831  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs,
3832  map.getContext());
3833  return success();
3834 }
3835 
3836 /// Canonicalize the affine map result expression order of an affine min/max
3837 /// operation.
3838 ///
3839 /// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3840 /// expressions and replaces the operation if the order changed.
3841 ///
3842 /// For example, the following operation:
3843 ///
3844 /// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3845 ///
3846 /// Turns into:
3847 ///
3848 /// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3849 template <typename T>
3852 
3853  LogicalResult matchAndRewrite(T affineOp,
3854  PatternRewriter &rewriter) const override {
3855  AffineMap map = affineOp.getAffineMap();
3857  return failure();
3858  rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3859  return success();
3860  }
3861 };
3862 
3863 template <typename T>
3866 
3867  LogicalResult matchAndRewrite(T affineOp,
3868  PatternRewriter &rewriter) const override {
3869  if (affineOp.getMap().getNumResults() != 1)
3870  return failure();
3871  rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3872  affineOp.getOperands());
3873  return success();
3874  }
3875 };
3876 
3877 //===----------------------------------------------------------------------===//
3878 // AffineMinOp
3879 //===----------------------------------------------------------------------===//
3880 //
3881 // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3882 //
3883 
3884 OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3885  return foldMinMaxOp(*this, adaptor.getOperands());
3886 }
3887 
3888 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3889  MLIRContext *context) {
3892  MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3894  context);
3895 }
3896 
3897 LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); }
3898 
3899 ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3900  return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3901 }
3902 
3903 void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3904 
3905 //===----------------------------------------------------------------------===//
3906 // AffineMaxOp
3907 //===----------------------------------------------------------------------===//
3908 //
3909 // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3910 //
3911 
3912 OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3913  return foldMinMaxOp(*this, adaptor.getOperands());
3914 }
3915 
3916 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3917  MLIRContext *context) {
3920  MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3922  context);
3923 }
3924 
3925 LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); }
3926 
3927 ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3928  return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3929 }
3930 
3931 void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3932 
3933 //===----------------------------------------------------------------------===//
3934 // AffinePrefetchOp
3935 //===----------------------------------------------------------------------===//
3936 
3937 //
3938 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3939 //
3940 ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3941  OperationState &result) {
3942  auto &builder = parser.getBuilder();
3943  auto indexTy = builder.getIndexType();
3944 
3945  MemRefType type;
3946  OpAsmParser::UnresolvedOperand memrefInfo;
3947  IntegerAttr hintInfo;
3948  auto i32Type = parser.getBuilder().getIntegerType(32);
3949  StringRef readOrWrite, cacheType;
3950 
3951  AffineMapAttr mapAttr;
3953  if (parser.parseOperand(memrefInfo) ||
3954  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3955  AffinePrefetchOp::getMapAttrStrName(),
3956  result.attributes) ||
3957  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3958  parser.parseComma() || parser.parseKeyword("locality") ||
3959  parser.parseLess() ||
3960  parser.parseAttribute(hintInfo, i32Type,
3961  AffinePrefetchOp::getLocalityHintAttrStrName(),
3962  result.attributes) ||
3963  parser.parseGreater() || parser.parseComma() ||
3964  parser.parseKeyword(&cacheType) ||
3965  parser.parseOptionalAttrDict(result.attributes) ||
3966  parser.parseColonType(type) ||
3967  parser.resolveOperand(memrefInfo, type, result.operands) ||
3968  parser.resolveOperands(mapOperands, indexTy, result.operands))
3969  return failure();
3970 
3971  if (readOrWrite != "read" && readOrWrite != "write")
3972  return parser.emitError(parser.getNameLoc(),
3973  "rw specifier has to be 'read' or 'write'");
3974  result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3975  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
3976 
3977  if (cacheType != "data" && cacheType != "instr")
3978  return parser.emitError(parser.getNameLoc(),
3979  "cache type has to be 'data' or 'instr'");
3980 
3981  result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3982  parser.getBuilder().getBoolAttr(cacheType == "data"));
3983 
3984  return success();
3985 }
3986 
3988  p << " " << getMemref() << '[';
3989  AffineMapAttr mapAttr =
3990  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3991  if (mapAttr)
3992  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3993  p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3994  << "locality<" << getLocalityHint() << ">, "
3995  << (getIsDataCache() ? "data" : "instr");
3997  (*this)->getAttrs(),
3998  /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3999  getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
4000  p << " : " << getMemRefType();
4001 }
4002 
4003 LogicalResult AffinePrefetchOp::verify() {
4004  auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
4005  if (mapAttr) {
4006  AffineMap map = mapAttr.getValue();
4007  if (map.getNumResults() != getMemRefType().getRank())
4008  return emitOpError("affine.prefetch affine map num results must equal"
4009  " memref rank");
4010  if (map.getNumInputs() + 1 != getNumOperands())
4011  return emitOpError("too few operands");
4012  } else {
4013  if (getNumOperands() != 1)
4014  return emitOpError("too few operands");
4015  }
4016 
4017  Region *scope = getAffineScope(*this);
4018  for (auto idx : getMapOperands()) {
4019  if (!isValidAffineIndexOperand(idx, scope))
4020  return emitOpError(
4021  "index must be a valid dimension or symbol identifier");
4022  }
4023  return success();
4024 }
4025 
4026 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4027  MLIRContext *context) {
4028  // prefetch(memrefcast) -> prefetch
4029  results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
4030 }
4031 
4032 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
4033  SmallVectorImpl<OpFoldResult> &results) {
4034  /// prefetch(memrefcast) -> prefetch
4035  return memref::foldMemRefCast(*this);
4036 }
4037 
4038 //===----------------------------------------------------------------------===//
4039 // AffineParallelOp
4040 //===----------------------------------------------------------------------===//
4041 
4042 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
4043  TypeRange resultTypes,
4044  ArrayRef<arith::AtomicRMWKind> reductions,
4045  ArrayRef<int64_t> ranges) {
4046  SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
4047  auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
4048  return builder.getConstantAffineMap(value);
4049  }));
4050  SmallVector<int64_t> steps(ranges.size(), 1);
4051  build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
4052  /*ubArgs=*/{}, steps);
4053 }
4054 
4055 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
4056  TypeRange resultTypes,
4057  ArrayRef<arith::AtomicRMWKind> reductions,
4058  ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
4059  ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
4060  ArrayRef<int64_t> steps) {
4061  assert(llvm::all_of(lbMaps,
4062  [lbMaps](AffineMap m) {
4063  return m.getNumDims() == lbMaps[0].getNumDims() &&
4064  m.getNumSymbols() == lbMaps[0].getNumSymbols();
4065  }) &&
4066  "expected all lower bounds maps to have the same number of dimensions "
4067  "and symbols");
4068  assert(llvm::all_of(ubMaps,
4069  [ubMaps](AffineMap m) {
4070  return m.getNumDims() == ubMaps[0].getNumDims() &&
4071  m.getNumSymbols() == ubMaps[0].getNumSymbols();
4072  }) &&
4073  "expected all upper bounds maps to have the same number of dimensions "
4074  "and symbols");
4075  assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
4076  "expected lower bound maps to have as many inputs as lower bound "
4077  "operands");
4078  assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
4079  "expected upper bound maps to have as many inputs as upper bound "
4080  "operands");
4081 
4082  OpBuilder::InsertionGuard guard(builder);
4083  result.addTypes(resultTypes);
4084 
4085  // Convert the reductions to integer attributes.
4086  SmallVector<Attribute, 4> reductionAttrs;
4087  for (arith::AtomicRMWKind reduction : reductions)
4088  reductionAttrs.push_back(
4089  builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
4090  result.addAttribute(getReductionsAttrStrName(),
4091  builder.getArrayAttr(reductionAttrs));
4092 
4093  // Concatenates maps defined in the same input space (same dimensions and
4094  // symbols), assumes there is at least one map.
4095  auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
4096  SmallVectorImpl<int32_t> &groups) {
4097  if (maps.empty())
4098  return AffineMap::get(builder.getContext());
4100  groups.reserve(groups.size() + maps.size());
4101  exprs.reserve(maps.size());
4102  for (AffineMap m : maps) {
4103  llvm::append_range(exprs, m.getResults());
4104  groups.push_back(m.getNumResults());
4105  }
4106  return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
4107  maps[0].getContext());
4108  };
4109 
4110  // Set up the bounds.
4111  SmallVector<int32_t> lbGroups, ubGroups;
4112  AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
4113  AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
4114  result.addAttribute(getLowerBoundsMapAttrStrName(),
4115  AffineMapAttr::get(lbMap));
4116  result.addAttribute(getLowerBoundsGroupsAttrStrName(),
4117  builder.getI32TensorAttr(lbGroups));
4118  result.addAttribute(getUpperBoundsMapAttrStrName(),
4119  AffineMapAttr::get(ubMap));
4120  result.addAttribute(getUpperBoundsGroupsAttrStrName(),
4121  builder.getI32TensorAttr(ubGroups));
4122  result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
4123  result.addOperands(lbArgs);
4124  result.addOperands(ubArgs);
4125 
4126  // Create a region and a block for the body.
4127  auto *bodyRegion = result.addRegion();
4128  Block *body = builder.createBlock(bodyRegion);
4129 
4130  // Add all the block arguments.
4131  for (unsigned i = 0, e = steps.size(); i < e; ++i)
4132  body->addArgument(IndexType::get(builder.getContext()), result.location);
4133  if (resultTypes.empty())
4134  ensureTerminator(*bodyRegion, builder, result.location);
4135 }
4136 
4137 SmallVector<Region *> AffineParallelOp::getLoopRegions() {
4138  return {&getRegion()};
4139 }
4140 
4141 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
4142 
4143 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
4144  return getOperands().take_front(getLowerBoundsMap().getNumInputs());
4145 }
4146 
4147 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
4148  return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
4149 }
4150 
4151 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
4152  auto values = getLowerBoundsGroups().getValues<int32_t>();
4153  unsigned start = 0;
4154  for (unsigned i = 0; i < pos; ++i)
4155  start += values[i];
4156  return getLowerBoundsMap().getSliceMap(start, values[pos]);
4157 }
4158 
4159 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
4160  auto values = getUpperBoundsGroups().getValues<int32_t>();
4161  unsigned start = 0;
4162  for (unsigned i = 0; i < pos; ++i)
4163  start += values[i];
4164  return getUpperBoundsMap().getSliceMap(start, values[pos]);
4165 }
4166 
4167 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
4168  return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
4169 }
4170 
4171 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
4172  return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
4173 }
4174 
4175 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
4176  if (hasMinMaxBounds())
4177  return std::nullopt;
4178 
4179  // Try to convert all the ranges to constant expressions.
4181  AffineValueMap rangesValueMap;
4182  AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
4183  &rangesValueMap);
4184  out.reserve(rangesValueMap.getNumResults());
4185  for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
4186  auto expr = rangesValueMap.getResult(i);
4187  auto cst = dyn_cast<AffineConstantExpr>(expr);
4188  if (!cst)
4189  return std::nullopt;
4190  out.push_back(cst.getValue());
4191  }
4192  return out;
4193 }
4194 
4195 Block *AffineParallelOp::getBody() { return &getRegion().front(); }
4196 
4197 OpBuilder AffineParallelOp::getBodyBuilder() {
4198  return OpBuilder(getBody(), std::prev(getBody()->end()));
4199 }
4200 
4201 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
4202  assert(lbOperands.size() == map.getNumInputs() &&
4203  "operands to map must match number of inputs");
4204 
4205  auto ubOperands = getUpperBoundsOperands();
4206 
4207  SmallVector<Value, 4> newOperands(lbOperands);
4208  newOperands.append(ubOperands.begin(), ubOperands.end());
4209  (*this)->setOperands(newOperands);
4210 
4211  setLowerBoundsMapAttr(AffineMapAttr::get(map));
4212 }
4213 
4214 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
4215  assert(ubOperands.size() == map.getNumInputs() &&
4216  "operands to map must match number of inputs");
4217 
4218  SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
4219  newOperands.append(ubOperands.begin(), ubOperands.end());
4220  (*this)->setOperands(newOperands);
4221 
4222  setUpperBoundsMapAttr(AffineMapAttr::get(map));
4223 }
4224 
4225 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
4226  setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
4227 }
4228 
4229 // check whether resultType match op or not in affine.parallel
4230 static bool isResultTypeMatchAtomicRMWKind(Type resultType,
4231  arith::AtomicRMWKind op) {
4232  switch (op) {
4233  case arith::AtomicRMWKind::addf:
4234  return isa<FloatType>(resultType);
4235  case arith::AtomicRMWKind::addi:
4236  return isa<IntegerType>(resultType);
4237  case arith::AtomicRMWKind::assign:
4238  return true;
4239  case arith::AtomicRMWKind::mulf:
4240  return isa<FloatType>(resultType);
4241  case arith::AtomicRMWKind::muli:
4242  return isa<IntegerType>(resultType);
4243  case arith::AtomicRMWKind::maximumf:
4244  return isa<FloatType>(resultType);
4245  case arith::AtomicRMWKind::minimumf:
4246  return isa<FloatType>(resultType);
4247  case arith::AtomicRMWKind::maxs: {
4248  auto intType = dyn_cast<IntegerType>(resultType);
4249  return intType && intType.isSigned();
4250  }
4251  case arith::AtomicRMWKind::mins: {
4252  auto intType = dyn_cast<IntegerType>(resultType);
4253  return intType && intType.isSigned();
4254  }
4255  case arith::AtomicRMWKind::maxu: {
4256  auto intType = dyn_cast<IntegerType>(resultType);
4257  return intType && intType.isUnsigned();
4258  }
4259  case arith::AtomicRMWKind::minu: {
4260  auto intType = dyn_cast<IntegerType>(resultType);
4261  return intType && intType.isUnsigned();
4262  }
4263  case arith::AtomicRMWKind::ori:
4264  return isa<IntegerType>(resultType);
4265  case arith::AtomicRMWKind::andi:
4266  return isa<IntegerType>(resultType);
4267  default:
4268  return false;
4269  }
4270 }
4271 
4272 LogicalResult AffineParallelOp::verify() {
4273  auto numDims = getNumDims();
4274  if (getLowerBoundsGroups().getNumElements() != numDims ||
4275  getUpperBoundsGroups().getNumElements() != numDims ||
4276  getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
4277  return emitOpError() << "the number of region arguments ("
4278  << getBody()->getNumArguments()
4279  << ") and the number of map groups for lower ("
4280  << getLowerBoundsGroups().getNumElements()
4281  << ") and upper bound ("
4282  << getUpperBoundsGroups().getNumElements()
4283  << "), and the number of steps (" << getSteps().size()
4284  << ") must all match";
4285  }
4286 
4287  unsigned expectedNumLBResults = 0;
4288  for (APInt v : getLowerBoundsGroups()) {
4289  unsigned results = v.getZExtValue();
4290  if (results == 0)
4291  return emitOpError()
4292  << "expected lower bound map to have at least one result";
4293  expectedNumLBResults += results;
4294  }
4295  if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
4296  return emitOpError() << "expected lower bounds map to have "
4297  << expectedNumLBResults << " results";
4298  unsigned expectedNumUBResults = 0;
4299  for (APInt v : getUpperBoundsGroups()) {
4300  unsigned results = v.getZExtValue();
4301  if (results == 0)
4302  return emitOpError()
4303  << "expected upper bound map to have at least one result";
4304  expectedNumUBResults += results;
4305  }
4306  if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
4307  return emitOpError() << "expected upper bounds map to have "
4308  << expectedNumUBResults << " results";
4309 
4310  if (getReductions().size() != getNumResults())
4311  return emitOpError("a reduction must be specified for each output");
4312 
4313  // Verify reduction ops are all valid and each result type matches reduction
4314  // ops
4315  for (auto it : llvm::enumerate((getReductions()))) {
4316  Attribute attr = it.value();
4317  auto intAttr = dyn_cast<IntegerAttr>(attr);
4318  if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
4319  return emitOpError("invalid reduction attribute");
4320  auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
4321  if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
4322  return emitOpError("result type cannot match reduction attribute");
4323  }
4324 
4325  // Verify that the bound operands are valid dimension/symbols.
4326  /// Lower bounds.
4327  if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
4328  getLowerBoundsMap().getNumDims())))
4329  return failure();
4330  /// Upper bounds.
4331  if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
4332  getUpperBoundsMap().getNumDims())))
4333  return failure();
4334  return success();
4335 }
4336 
4337 LogicalResult AffineValueMap::canonicalize() {
4338  SmallVector<Value, 4> newOperands{operands};
4339  auto newMap = getAffineMap();
4340  composeAffineMapAndOperands(&newMap, &newOperands);
4341  if (newMap == getAffineMap() && newOperands == operands)
4342  return failure();
4343  reset(newMap, newOperands);
4344  return success();
4345 }
4346 
4347 /// Canonicalize the bounds of the given loop.
4348 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
4349  AffineValueMap lb = op.getLowerBoundsValueMap();
4350  bool lbCanonicalized = succeeded(lb.canonicalize());
4351 
4352  AffineValueMap ub = op.getUpperBoundsValueMap();
4353  bool ubCanonicalized = succeeded(ub.canonicalize());
4354 
4355  // Any canonicalization change always leads to updated map(s).
4356  if (!lbCanonicalized && !ubCanonicalized)
4357  return failure();
4358 
4359  if (lbCanonicalized)
4360  op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
4361  if (ubCanonicalized)
4362  op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
4363 
4364  return success();
4365 }
4366 
4367 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4368  SmallVectorImpl<OpFoldResult> &results) {
4369  return canonicalizeLoopBounds(*this);
4370 }
4371 
4372 /// Prints a lower(upper) bound of an affine parallel loop with max(min)
4373 /// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
4374 /// identifies which of the those expressions form max/min groups. `operands`
4375 /// are the SSA values of dimensions and symbols and `keyword` is either "min"
4376 /// or "max".
4377 static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
4378  DenseIntElementsAttr group, ValueRange operands,
4379  StringRef keyword) {
4380  AffineMap map = mapAttr.getValue();
4381  unsigned numDims = map.getNumDims();
4382  ValueRange dimOperands = operands.take_front(numDims);
4383  ValueRange symOperands = operands.drop_front(numDims);
4384  unsigned start = 0;
4385  for (llvm::APInt groupSize : group) {
4386  if (start != 0)
4387  p << ", ";
4388 
4389  unsigned size = groupSize.getZExtValue();
4390  if (size == 1) {
4391  p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
4392  ++start;
4393  } else {
4394  p << keyword << '(';
4395  AffineMap submap = map.getSliceMap(start, size);
4396  p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
4397  p << ')';
4398  start += size;
4399  }
4400  }
4401 }
4402 
4404  p << " (" << getBody()->getArguments() << ") = (";
4405  printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
4406  getLowerBoundsOperands(), "max");
4407  p << ") to (";
4408  printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
4409  getUpperBoundsOperands(), "min");
4410  p << ')';
4411  SmallVector<int64_t, 8> steps = getSteps();
4412  bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4413  if (!elideSteps) {
4414  p << " step (";
4415  llvm::interleaveComma(steps, p);
4416  p << ')';
4417  }
4418  if (getNumResults()) {
4419  p << " reduce (";
4420  llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4421  arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4422  llvm::cast<IntegerAttr>(attr).getInt());
4423  p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4424  });
4425  p << ") -> (" << getResultTypes() << ")";
4426  }
4427 
4428  p << ' ';
4429  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4430  /*printBlockTerminators=*/getNumResults());
4432  (*this)->getAttrs(),
4433  /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4434  AffineParallelOp::getLowerBoundsMapAttrStrName(),
4435  AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4436  AffineParallelOp::getUpperBoundsMapAttrStrName(),
4437  AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4438  AffineParallelOp::getStepsAttrStrName()});
4439 }
4440 
4441 /// Given a list of lists of parsed operands, populates `uniqueOperands` with
4442 /// unique operands. Also populates `replacements with affine expressions of
4443 /// `kind` that can be used to update affine maps previously accepting a
4444 /// `operands` to accept `uniqueOperands` instead.
4446  OpAsmParser &parser,
4448  SmallVectorImpl<Value> &uniqueOperands,
4451  "expected operands to be dim or symbol expression");
4452 
4453  Type indexType = parser.getBuilder().getIndexType();
4454  for (const auto &list : operands) {
4455  SmallVector<Value> valueOperands;
4456  if (parser.resolveOperands(list, indexType, valueOperands))
4457  return failure();
4458  for (Value operand : valueOperands) {
4459  unsigned pos = std::distance(uniqueOperands.begin(),
4460  llvm::find(uniqueOperands, operand));
4461  if (pos == uniqueOperands.size())
4462  uniqueOperands.push_back(operand);
4463  replacements.push_back(
4465  ? getAffineDimExpr(pos, parser.getContext())
4466  : getAffineSymbolExpr(pos, parser.getContext()));
4467  }
4468  }
4469  return success();
4470 }
4471 
4472 namespace {
4473 enum class MinMaxKind { Min, Max };
4474 } // namespace
4475 
4476 /// Parses an affine map that can contain a min/max for groups of its results,
4477 /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4478 /// `result` attributes with the map (flat list of expressions) and the grouping
4479 /// (list of integers that specify how many expressions to put into each
4480 /// min/max) attributes. Deduplicates repeated operands.
4481 ///
4482 /// parallel-bound ::= `(` parallel-group-list `)`
4483 /// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
4484 /// parallel-group ::= simple-group | min-max-group
4485 /// simple-group ::= expr-of-ssa-ids
4486 /// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4487 /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4488 ///
4489 /// Examples:
4490 /// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4491 /// (%0, max(%1 - 2 * %2))
4492 static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
4493  OperationState &result,
4494  MinMaxKind kind) {
4495  // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4496  // see: https://reviews.llvm.org/D134227#3821753
4497  const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4498 
4499  StringRef mapName = kind == MinMaxKind::Min
4500  ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4501  : AffineParallelOp::getLowerBoundsMapAttrStrName();
4502  StringRef groupsName =
4503  kind == MinMaxKind::Min
4504  ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4505  : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4506 
4507  if (failed(parser.parseLParen()))
4508  return failure();
4509 
4510  if (succeeded(parser.parseOptionalRParen())) {
4511  result.addAttribute(
4512  mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4513  result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
4514  return success();
4515  }
4516 
4517  SmallVector<AffineExpr> flatExprs;
4520  SmallVector<int32_t> numMapsPerGroup;
4522  auto parseOperands = [&]() {
4523  if (succeeded(parser.parseOptionalKeyword(
4524  kind == MinMaxKind::Min ? "min" : "max"))) {
4525  mapOperands.clear();
4526  AffineMapAttr map;
4527  if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName,
4528  result.attributes,
4530  return failure();
4531  result.attributes.erase(tmpAttrStrName);
4532  llvm::append_range(flatExprs, map.getValue().getResults());
4533  auto operandsRef = llvm::ArrayRef(mapOperands);
4534  auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4536  auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4538  flatDimOperands.append(map.getValue().getNumResults(), dims);
4539  flatSymOperands.append(map.getValue().getNumResults(), syms);
4540  numMapsPerGroup.push_back(map.getValue().getNumResults());
4541  } else {
4542  if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
4543  flatSymOperands.emplace_back(),
4544  flatExprs.emplace_back())))
4545  return failure();
4546  numMapsPerGroup.push_back(1);
4547  }
4548  return success();
4549  };
4550  if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
4551  return failure();
4552 
4553  unsigned totalNumDims = 0;
4554  unsigned totalNumSyms = 0;
4555  for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4556  unsigned numDims = flatDimOperands[i].size();
4557  unsigned numSyms = flatSymOperands[i].size();
4558  flatExprs[i] = flatExprs[i]
4559  .shiftDims(numDims, totalNumDims)
4560  .shiftSymbols(numSyms, totalNumSyms);
4561  totalNumDims += numDims;
4562  totalNumSyms += numSyms;
4563  }
4564 
4565  // Deduplicate map operands.
4566  SmallVector<Value> dimOperands, symOperands;
4567  SmallVector<AffineExpr> dimRplacements, symRepacements;
4568  if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
4569  dimRplacements, AffineExprKind::DimId) ||
4570  deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
4571  symRepacements, AffineExprKind::SymbolId))
4572  return failure();
4573 
4574  result.operands.append(dimOperands.begin(), dimOperands.end());
4575  result.operands.append(symOperands.begin(), symOperands.end());
4576 
4577  Builder &builder = parser.getBuilder();
4578  auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4579  parser.getContext());
4580  flatMap = flatMap.replaceDimsAndSymbols(
4581  dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4582 
4583  result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4584  result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
4585  return success();
4586 }
4587 
4588 //
4589 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4590 // `to` parallel-bound steps? region attr-dict?
4591 // steps ::= `steps` `(` integer-literals `)`
4592 //
4593 ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4594  OperationState &result) {
4595  auto &builder = parser.getBuilder();
4596  auto indexType = builder.getIndexType();
4599  parser.parseEqual() ||
4600  parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
4601  parser.parseKeyword("to") ||
4602  parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
4603  return failure();
4604 
4605  AffineMapAttr stepsMapAttr;
4606  NamedAttrList stepsAttrs;
4608  if (failed(parser.parseOptionalKeyword("step"))) {
4609  SmallVector<int64_t, 4> steps(ivs.size(), 1);
4610  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4611  builder.getI64ArrayAttr(steps));
4612  } else {
4613  if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
4614  AffineParallelOp::getStepsAttrStrName(),
4615  stepsAttrs,
4617  return failure();
4618 
4619  // Convert steps from an AffineMap into an I64ArrayAttr.
4621  auto stepsMap = stepsMapAttr.getValue();
4622  for (const auto &result : stepsMap.getResults()) {
4623  auto constExpr = dyn_cast<AffineConstantExpr>(result);
4624  if (!constExpr)
4625  return parser.emitError(parser.getNameLoc(),
4626  "steps must be constant integers");
4627  steps.push_back(constExpr.getValue());
4628  }
4629  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4630  builder.getI64ArrayAttr(steps));
4631  }
4632 
4633  // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4634  // quoted strings are a member of the enum AtomicRMWKind.
4635  SmallVector<Attribute, 4> reductions;
4636  if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4637  if (parser.parseLParen())
4638  return failure();
4639  auto parseAttributes = [&]() -> ParseResult {
4640  // Parse a single quoted string via the attribute parsing, and then
4641  // verify it is a member of the enum and convert to it's integer
4642  // representation.
4643  StringAttr attrVal;
4644  NamedAttrList attrStorage;
4645  auto loc = parser.getCurrentLocation();
4646  if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
4647  attrStorage))
4648  return failure();
4649  std::optional<arith::AtomicRMWKind> reduction =
4650  arith::symbolizeAtomicRMWKind(attrVal.getValue());
4651  if (!reduction)
4652  return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4653  reductions.push_back(
4654  builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4655  // While we keep getting commas, keep parsing.
4656  return success();
4657  };
4658  if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
4659  return failure();
4660  }
4661  result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4662  builder.getArrayAttr(reductions));
4663 
4664  // Parse return types of reductions (if any)
4665  if (parser.parseOptionalArrowTypeList(result.types))
4666  return failure();
4667 
4668  // Now parse the body.
4669  Region *body = result.addRegion();
4670  for (auto &iv : ivs)
4671  iv.type = indexType;
4672  if (parser.parseRegion(*body, ivs) ||
4673  parser.parseOptionalAttrDict(result.attributes))
4674  return failure();
4675 
4676  // Add a terminator if none was parsed.
4677  AffineParallelOp::ensureTerminator(*body, builder, result.location);
4678  return success();
4679 }
4680 
4681 //===----------------------------------------------------------------------===//
4682 // AffineYieldOp
4683 //===----------------------------------------------------------------------===//
4684 
4685 LogicalResult AffineYieldOp::verify() {
4686  auto *parentOp = (*this)->getParentOp();
4687  auto results = parentOp->getResults();
4688  auto operands = getOperands();
4689 
4690  if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4691  return emitOpError() << "only terminates affine.if/for/parallel regions";
4692  if (parentOp->getNumResults() != getNumOperands())
4693  return emitOpError() << "parent of yield must have same number of "
4694  "results as the yield operands";
4695  for (auto it : llvm::zip(results, operands)) {
4696  if (std::get<0>(it).getType() != std::get<1>(it).getType())
4697  return emitOpError() << "types mismatch between yield op and its parent";
4698  }
4699 
4700  return success();
4701 }
4702 
4703 //===----------------------------------------------------------------------===//
4704 // AffineVectorLoadOp
4705 //===----------------------------------------------------------------------===//
4706 
4707 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4708  VectorType resultType, AffineMap map,
4709  ValueRange operands) {
4710  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4711  result.addOperands(operands);
4712  if (map)
4713  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4714  result.types.push_back(resultType);
4715 }
4716 
4717 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4718  VectorType resultType, Value memref,
4719  AffineMap map, ValueRange mapOperands) {
4720  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4721  result.addOperands(memref);
4722  result.addOperands(mapOperands);
4723  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4724  result.types.push_back(resultType);
4725 }
4726 
4727 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4728  VectorType resultType, Value memref,
4729  ValueRange indices) {
4730  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4731  int64_t rank = memrefType.getRank();
4732  // Create identity map for memrefs with at least one dimension or () -> ()
4733  // for zero-dimensional memrefs.
4734  auto map =
4735  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4736  build(builder, result, resultType, memref, map, indices);
4737 }
4738 
4739 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4740  MLIRContext *context) {
4741  results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4742 }
4743 
4744 ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4745  OperationState &result) {
4746  auto &builder = parser.getBuilder();
4747  auto indexTy = builder.getIndexType();
4748 
4749  MemRefType memrefType;
4750  VectorType resultType;
4751  OpAsmParser::UnresolvedOperand memrefInfo;
4752  AffineMapAttr mapAttr;
4754  return failure(
4755  parser.parseOperand(memrefInfo) ||
4756  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4757  AffineVectorLoadOp::getMapAttrStrName(),
4758  result.attributes) ||
4759  parser.parseOptionalAttrDict(result.attributes) ||
4760  parser.parseColonType(memrefType) || parser.parseComma() ||
4761  parser.parseType(resultType) ||
4762  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4763  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
4764  parser.addTypeToList(resultType, result.types));
4765 }
4766 
4768  p << " " << getMemRef() << '[';
4769  if (AffineMapAttr mapAttr =
4770  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4771  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4772  p << ']';
4773  p.printOptionalAttrDict((*this)->getAttrs(),
4774  /*elidedAttrs=*/{getMapAttrStrName()});
4775  p << " : " << getMemRefType() << ", " << getType();
4776 }
4777 
4778 /// Verify common invariants of affine.vector_load and affine.vector_store.
4779 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4780  VectorType vectorType) {
4781  // Check that memref and vector element types match.
4782  if (memrefType.getElementType() != vectorType.getElementType())
4783  return op->emitOpError(
4784  "requires memref and vector types of the same elemental type");
4785  return success();
4786 }
4787 
4788 LogicalResult AffineVectorLoadOp::verify() {
4789  MemRefType memrefType = getMemRefType();
4791  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4792  getMapOperands(), memrefType,
4793  /*numIndexOperands=*/getNumOperands() - 1)))
4794  return failure();
4795 
4796  if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
4797  return failure();
4798 
4799  return success();
4800 }
4801 
4802 //===----------------------------------------------------------------------===//
4803 // AffineVectorStoreOp
4804 //===----------------------------------------------------------------------===//
4805 
4806 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4807  Value valueToStore, Value memref, AffineMap map,
4808  ValueRange mapOperands) {
4809  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4810  result.addOperands(valueToStore);
4811  result.addOperands(memref);
4812  result.addOperands(mapOperands);
4813  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4814 }
4815 
4816 // Use identity map.
4817 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4818  Value valueToStore, Value memref,
4819  ValueRange indices) {
4820  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4821  int64_t rank = memrefType.getRank();
4822  // Create identity map for memrefs with at least one dimension or () -> ()
4823  // for zero-dimensional memrefs.
4824  auto map =
4825  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4826  build(builder, result, valueToStore, memref, map, indices);
4827 }
4828 void AffineVectorStoreOp::getCanonicalizationPatterns(
4829  RewritePatternSet &results, MLIRContext *context) {
4830  results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4831 }
4832 
4833 ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4834  OperationState &result) {
4835  auto indexTy = parser.getBuilder().getIndexType();
4836 
4837  MemRefType memrefType;
4838  VectorType resultType;
4839  OpAsmParser::UnresolvedOperand storeValueInfo;
4840  OpAsmParser::UnresolvedOperand memrefInfo;
4841  AffineMapAttr mapAttr;
4843  return failure(
4844  parser.parseOperand(storeValueInfo) || parser.parseComma() ||
4845  parser.parseOperand(memrefInfo) ||
4846  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4847  AffineVectorStoreOp::getMapAttrStrName(),
4848  result.attributes) ||
4849  parser.parseOptionalAttrDict(result.attributes) ||
4850  parser.parseColonType(memrefType) || parser.parseComma() ||
4851  parser.parseType(resultType) ||
4852  parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
4853  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4854  parser.resolveOperands(mapOperands, indexTy, result.operands));
4855 }
4856 
4858  p << " " << getValueToStore();
4859  p << ", " << getMemRef() << '[';
4860  if (AffineMapAttr mapAttr =
4861  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4862  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4863  p << ']';
4864  p.printOptionalAttrDict((*this)->getAttrs(),
4865  /*elidedAttrs=*/{getMapAttrStrName()});
4866  p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4867 }
4868 
4869 LogicalResult AffineVectorStoreOp::verify() {
4870  MemRefType memrefType = getMemRefType();
4872  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4873  getMapOperands(), memrefType,
4874  /*numIndexOperands=*/getNumOperands() - 2)))
4875  return failure();
4876 
4877  if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4878  return failure();
4879 
4880  return success();
4881 }
4882 
4883 //===----------------------------------------------------------------------===//
4884 // DelinearizeIndexOp
4885 //===----------------------------------------------------------------------===//
4886 
4887 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4888  OperationState &odsState,
4889  Value linearIndex, ValueRange dynamicBasis,
4890  ArrayRef<int64_t> staticBasis,
4891  bool hasOuterBound) {
4892  SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
4893  : staticBasis.size() + 1,
4894  linearIndex.getType());
4895  build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4896  staticBasis);
4897 }
4898 
4899 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4900  OperationState &odsState,
4901  Value linearIndex, ValueRange basis,
4902  bool hasOuterBound) {
4903  if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
4904  hasOuterBound = false;
4905  basis = basis.drop_front();
4906  }
4907  SmallVector<Value> dynamicBasis;
4908  SmallVector<int64_t> staticBasis;
4909  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4910  staticBasis);
4911  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4912  hasOuterBound);
4913 }
4914 
4915 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4916  OperationState &odsState,
4917  Value linearIndex,
4918  ArrayRef<OpFoldResult> basis,
4919  bool hasOuterBound) {
4920  if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
4921  hasOuterBound = false;
4922  basis = basis.drop_front();
4923  }
4924  SmallVector<Value> dynamicBasis;
4925  SmallVector<int64_t> staticBasis;
4926  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
4927  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4928  hasOuterBound);
4929 }
4930 
4931 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4932  OperationState &odsState,
4933  Value linearIndex, ArrayRef<int64_t> basis,
4934  bool hasOuterBound) {
4935  build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound);
4936 }
4937 
4938 LogicalResult AffineDelinearizeIndexOp::verify() {
4939  ArrayRef<int64_t> staticBasis = getStaticBasis();
4940  if (getNumResults() != staticBasis.size() &&
4941  getNumResults() != staticBasis.size() + 1)
4942  return emitOpError("should return an index for each basis element and up "
4943  "to one extra index");
4944 
4945  auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4946  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4947  return emitOpError(
4948  "mismatch between dynamic and static basis (kDynamic marker but no "
4949  "corresponding dynamic basis entry) -- this can only happen due to an "
4950  "incorrect fold/rewrite");
4951 
4952  if (!llvm::all_of(staticBasis, [](int64_t v) {
4953  return v > 0 || ShapedType::isDynamic(v);
4954  }))
4955  return emitOpError("no basis element may be statically non-positive");
4956 
4957  return success();
4958 }
4959 
4960 /// Given mixed basis of affine.delinearize_index/linearize_index replace
4961 /// constant SSA values with the constant integer value and return the new
4962 /// static basis. In case no such candidate for replacement exists, this utility
4963 /// returns std::nullopt.
4964 static std::optional<SmallVector<int64_t>>
4966  MutableOperandRange mutableDynamicBasis,
4967  ArrayRef<Attribute> dynamicBasis) {
4968  uint64_t dynamicBasisIndex = 0;
4969  for (OpFoldResult basis : dynamicBasis) {
4970  if (basis) {
4971  mutableDynamicBasis.erase(dynamicBasisIndex);
4972  } else {
4973  ++dynamicBasisIndex;
4974  }
4975  }
4976 
4977  // No constant SSA value exists.
4978  if (dynamicBasisIndex == dynamicBasis.size())
4979  return std::nullopt;
4980 
4981  SmallVector<int64_t> staticBasis;
4982  for (OpFoldResult basis : mixedBasis) {
4983  std::optional<int64_t> basisVal = getConstantIntValue(basis);
4984  if (!basisVal)
4985  staticBasis.push_back(ShapedType::kDynamic);
4986  else
4987  staticBasis.push_back(*basisVal);
4988  }
4989 
4990  return staticBasis;
4991 }
4992 
4993 LogicalResult
4994 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4996  std::optional<SmallVector<int64_t>> maybeStaticBasis =
4997  foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
4998  adaptor.getDynamicBasis());
4999  if (maybeStaticBasis) {
5000  setStaticBasis(*maybeStaticBasis);
5001  return success();
5002  }
5003  // If we won't be doing any division or modulo (no basis or the one basis
5004  // element is purely advisory), simply return the input value.
5005  if (getNumResults() == 1) {
5006  result.push_back(getLinearIndex());
5007  return success();
5008  }
5009 
5010  if (adaptor.getLinearIndex() == nullptr)
5011  return failure();
5012 
5013  if (!adaptor.getDynamicBasis().empty())
5014  return failure();
5015 
5016  int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
5017  Type attrType = getLinearIndex().getType();
5018 
5019  ArrayRef<int64_t> staticBasis = getStaticBasis();
5020  if (hasOuterBound())
5021  staticBasis = staticBasis.drop_front();
5022  for (int64_t modulus : llvm::reverse(staticBasis)) {
5023  result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
5024  highPart = llvm::divideFloorSigned(highPart, modulus);
5025  }
5026  result.push_back(IntegerAttr::get(attrType, highPart));
5027  std::reverse(result.begin(), result.end());
5028  return success();
5029 }
5030 
5031 SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
5032  OpBuilder builder(getContext());
5033  if (hasOuterBound()) {
5034  if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5035  return getMixedValues(getStaticBasis().drop_front(),
5036  getDynamicBasis().drop_front(), builder);
5037 
5038  return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5039  builder);
5040  }
5041 
5042  return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5043 }
5044 
5045 SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
5046  SmallVector<OpFoldResult> ret = getMixedBasis();
5047  if (!hasOuterBound())
5048  ret.insert(ret.begin(), OpFoldResult());
5049  return ret;
5050 }
5051 
5052 namespace {
5053 
5054 // Drops delinearization indices that correspond to unit-extent basis
5055 struct DropUnitExtentBasis
5056  : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5058 
5059  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5060  PatternRewriter &rewriter) const override {
5061  SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
5062  std::optional<Value> zero = std::nullopt;
5063  Location loc = delinearizeOp->getLoc();
5064  auto getZero = [&]() -> Value {
5065  if (!zero)
5066  zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
5067  return zero.value();
5068  };
5069 
5070  // Replace all indices corresponding to unit-extent basis with 0.
5071  // Remaining basis can be used to get a new `affine.delinearize_index` op.
5072  SmallVector<OpFoldResult> newBasis;
5073  for (auto [index, basis] :
5074  llvm::enumerate(delinearizeOp.getPaddedBasis())) {
5075  std::optional<int64_t> basisVal =
5076  basis ? getConstantIntValue(basis) : std::nullopt;
5077  if (basisVal == 1)
5078  replacements[index] = getZero();
5079  else
5080  newBasis.push_back(basis);
5081  }
5082 
5083  if (newBasis.size() == delinearizeOp.getNumResults())
5084  return rewriter.notifyMatchFailure(delinearizeOp,
5085  "no unit basis elements");
5086 
5087  if (!newBasis.empty()) {
5088  // Will drop the leading nullptr from `basis` if there was no outer bound.
5089  auto newDelinearizeOp = affine::AffineDelinearizeIndexOp::create(
5090  rewriter, loc, delinearizeOp.getLinearIndex(), newBasis);
5091  int newIndex = 0;
5092  // Map back the new delinearized indices to the values they replace.
5093  for (auto &replacement : replacements) {
5094  if (replacement)
5095  continue;
5096  replacement = newDelinearizeOp->getResult(newIndex++);
5097  }
5098  }
5099 
5100  rewriter.replaceOp(delinearizeOp, replacements);
5101  return success();
5102  }
5103 };
5104 
5105 /// If a `affine.delinearize_index`'s input is a `affine.linearize_index
5106 /// disjoint` and the two operations end with the same basis elements,
5107 /// cancel those parts of the operations out because they are inverses
5108 /// of each other.
5109 ///
5110 /// If the operations have the same basis, cancel them entirely.
5111 ///
5112 /// The `disjoint` flag is needed on the `affine.linearize_index` because
5113 /// otherwise, there is no guarantee that the inputs to the linearization are
5114 /// in-bounds the way the outputs of the delinearization would be.
5115 struct CancelDelinearizeOfLinearizeDisjointExactTail
5116  : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5118 
5119  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5120  PatternRewriter &rewriter) const override {
5121  auto linearizeOp = delinearizeOp.getLinearIndex()
5122  .getDefiningOp<affine::AffineLinearizeIndexOp>();
5123  if (!linearizeOp)
5124  return rewriter.notifyMatchFailure(delinearizeOp,
5125  "index doesn't come from linearize");
5126 
5127  if (!linearizeOp.getDisjoint())
5128  return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
5129 
5130  ValueRange linearizeIns = linearizeOp.getMultiIndex();
5131  // Note: we use the full basis so we don't lose outer bounds later.
5132  SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
5133  SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
5134  size_t numMatches = 0;
5135  for (auto [linSize, delinSize] : llvm::zip(
5136  llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
5137  if (linSize != delinSize)
5138  break;
5139  ++numMatches;
5140  }
5141 
5142  if (numMatches == 0)
5143  return rewriter.notifyMatchFailure(
5144  delinearizeOp, "final basis element doesn't match linearize");
5145 
5146  // The easy case: everything lines up and the basis match sup completely.
5147  if (numMatches == linearizeBasis.size() &&
5148  numMatches == delinearizeBasis.size() &&
5149  linearizeIns.size() == delinearizeOp.getNumResults()) {
5150  rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
5151  return success();
5152  }
5153 
5154  Value newLinearize = affine::AffineLinearizeIndexOp::create(
5155  rewriter, linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
5156  ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
5157  linearizeOp.getDisjoint());
5158  auto newDelinearize = affine::AffineDelinearizeIndexOp::create(
5159  rewriter, delinearizeOp.getLoc(), newLinearize,
5160  ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
5161  delinearizeOp.hasOuterBound());
5162  SmallVector<Value> mergedResults(newDelinearize.getResults());
5163  mergedResults.append(linearizeIns.take_back(numMatches).begin(),
5164  linearizeIns.take_back(numMatches).end());
5165  rewriter.replaceOp(delinearizeOp, mergedResults);
5166  return success();
5167  }
5168 };
5169 
5170 /// If the input to a delinearization is a disjoint linearization, and the
5171 /// last k > 1 components of the delinearization basis multiply to the
5172 /// last component of the linearization basis, break the linearization and
5173 /// delinearization into two parts, peeling off the last input to linearization.
5174 ///
5175 /// For example:
5176 /// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
5177 /// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
5178 /// becomes
5179 /// %0 = affine.linearize_index [%z, %y] by (3, 2) : index
5180 /// %1:2 = affine.delinearize_index %0 by (2, 3) : index
5181 /// %2:2 = affine.delinearize_index %x by (8, 4) : index
5182 /// where the original %1:4 is replaced by %1:2 ++ %2:2
5183 struct SplitDelinearizeSpanningLastLinearizeArg final
5184  : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5186 
5187  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5188  PatternRewriter &rewriter) const override {
5189  auto linearizeOp = delinearizeOp.getLinearIndex()
5190  .getDefiningOp<affine::AffineLinearizeIndexOp>();
5191  if (!linearizeOp)
5192  return rewriter.notifyMatchFailure(delinearizeOp,
5193  "index doesn't come from linearize");
5194 
5195  if (!linearizeOp.getDisjoint())
5196  return rewriter.notifyMatchFailure(linearizeOp,
5197  "linearize isn't disjoint");
5198 
5199  int64_t target = linearizeOp.getStaticBasis().back();
5200  if (ShapedType::isDynamic(target))
5201  return rewriter.notifyMatchFailure(
5202  linearizeOp, "linearize ends with dynamic basis value");
5203 
5204  int64_t sizeToSplit = 1;
5205  size_t elemsToSplit = 0;
5206  ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
5207  for (int64_t basisElem : llvm::reverse(basis)) {
5208  if (ShapedType::isDynamic(basisElem))
5209  return rewriter.notifyMatchFailure(
5210  delinearizeOp, "dynamic basis element while scanning for split");
5211  sizeToSplit *= basisElem;
5212  elemsToSplit += 1;
5213 
5214  if (sizeToSplit > target)
5215  return rewriter.notifyMatchFailure(delinearizeOp,
5216  "overshot last argument size");
5217  if (sizeToSplit == target)
5218  break;
5219  }
5220 
5221  if (sizeToSplit < target)
5222  return rewriter.notifyMatchFailure(
5223  delinearizeOp, "product of known basis elements doesn't exceed last "
5224  "linearize argument");
5225 
5226  if (elemsToSplit < 2)
5227  return rewriter.notifyMatchFailure(
5228  delinearizeOp,
5229  "need at least two elements to form the basis product");
5230 
5231  Value linearizeWithoutBack = affine::AffineLinearizeIndexOp::create(
5232  rewriter, linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
5233  linearizeOp.getDynamicBasis(), linearizeOp.getStaticBasis().drop_back(),
5234  linearizeOp.getDisjoint());
5235  auto delinearizeWithoutSplitPart = affine::AffineDelinearizeIndexOp::create(
5236  rewriter, delinearizeOp.getLoc(), linearizeWithoutBack,
5237  delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
5238  delinearizeOp.hasOuterBound());
5239  auto delinearizeBack = affine::AffineDelinearizeIndexOp::create(
5240  rewriter, delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
5241  basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
5242  SmallVector<Value> results = llvm::to_vector(
5243  llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
5244  delinearizeBack.getResults()));
5245  rewriter.replaceOp(delinearizeOp, results);
5246 
5247  return success();
5248  }
5249 };
5250 } // namespace
5251 
5252 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
5253  RewritePatternSet &patterns, MLIRContext *context) {
5254  patterns
5255  .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
5256  DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
5257  context);
5258 }
5259 
5260 //===----------------------------------------------------------------------===//
5261 // LinearizeIndexOp
5262 //===----------------------------------------------------------------------===//
5263 
5264 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5265  OperationState &odsState,
5266  ValueRange multiIndex, ValueRange basis,
5267  bool disjoint) {
5268  if (!basis.empty() && basis.front() == Value())
5269  basis = basis.drop_front();
5270  SmallVector<Value> dynamicBasis;
5271  SmallVector<int64_t> staticBasis;
5272  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
5273  staticBasis);
5274  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5275 }
5276 
5277 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5278  OperationState &odsState,
5279  ValueRange multiIndex,
5280  ArrayRef<OpFoldResult> basis,
5281  bool disjoint) {
5282  if (!basis.empty() && basis.front() == OpFoldResult())
5283  basis = basis.drop_front();
5284  SmallVector<Value> dynamicBasis;
5285  SmallVector<int64_t> staticBasis;
5286  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
5287  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5288 }
5289 
5290 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5291  OperationState &odsState,
5292  ValueRange multiIndex,
5293  ArrayRef<int64_t> basis, bool disjoint) {
5294  build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
5295 }
5296 
5297 LogicalResult AffineLinearizeIndexOp::verify() {
5298  size_t numIndexes = getMultiIndex().size();
5299  size_t numBasisElems = getStaticBasis().size();
5300  if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
5301  return emitOpError("should be passed a basis element for each index except "
5302  "possibly the first");
5303 
5304  auto dynamicMarkersCount =
5305  llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
5306  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
5307  return emitOpError(
5308  "mismatch between dynamic and static basis (kDynamic marker but no "
5309  "corresponding dynamic basis entry) -- this can only happen due to an "
5310  "incorrect fold/rewrite");
5311 
5312  return success();
5313 }
5314 
5315 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
5316  std::optional<SmallVector<int64_t>> maybeStaticBasis =
5317  foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
5318  adaptor.getDynamicBasis());
5319  if (maybeStaticBasis) {
5320  setStaticBasis(*maybeStaticBasis);
5321  return getResult();
5322  }
5323  // No indices linearizes to zero.
5324  if (getMultiIndex().empty())
5325  return IntegerAttr::get(getResult().getType(), 0);
5326 
5327  // One single index linearizes to itself.
5328  if (getMultiIndex().size() == 1)
5329  return getMultiIndex().front();
5330 
5331  if (llvm::is_contained(adaptor.getMultiIndex(), nullptr))
5332  return nullptr;
5333 
5334  if (!adaptor.getDynamicBasis().empty())
5335  return nullptr;
5336 
5337  int64_t result = 0;
5338  int64_t stride = 1;
5339  for (auto [length, indexAttr] :
5340  llvm::zip_first(llvm::reverse(getStaticBasis()),
5341  llvm::reverse(adaptor.getMultiIndex()))) {
5342  result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
5343  stride = stride * length;
5344  }
5345  // Handle the index element with no basis element.
5346  if (!hasOuterBound())
5347  result =
5348  result +
5349  cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
5350 
5351  return IntegerAttr::get(getResult().getType(), result);
5352 }
5353 
5354 SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
5355  OpBuilder builder(getContext());
5356  if (hasOuterBound()) {
5357  if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5358  return getMixedValues(getStaticBasis().drop_front(),
5359  getDynamicBasis().drop_front(), builder);
5360 
5361  return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5362  builder);
5363  }
5364 
5365  return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5366 }
5367 
5368 SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
5369  SmallVector<OpFoldResult> ret = getMixedBasis();
5370  if (!hasOuterBound())
5371  ret.insert(ret.begin(), OpFoldResult());
5372  return ret;
5373 }
5374 
5375 namespace {
5376 /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
5377 /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
5378 /// %...d)`.
5379 
5380 /// Note that `disjoint` is required here, because, without it, we could have
5381 /// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
5382 /// is a valid operation where the `%c64` cannot be trivially dropped.
5383 ///
5384 /// Alternatively, if `%x` in the above is a known constant 0, remove it even if
5385 /// the operation isn't asserted to be `disjoint`.
5386 struct DropLinearizeUnitComponentsIfDisjointOrZero final
5387  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5389 
5390  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5391  PatternRewriter &rewriter) const override {
5392  ValueRange multiIndex = op.getMultiIndex();
5393  size_t numIndices = multiIndex.size();
5394  SmallVector<Value> newIndices;
5395  newIndices.reserve(numIndices);
5396  SmallVector<OpFoldResult> newBasis;
5397  newBasis.reserve(numIndices);
5398 
5399  if (!op.hasOuterBound()) {
5400  newIndices.push_back(multiIndex.front());
5401  multiIndex = multiIndex.drop_front();
5402  }
5403 
5404  SmallVector<OpFoldResult> basis = op.getMixedBasis();
5405  for (auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5406  std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
5407  if (!basisEntry || *basisEntry != 1) {
5408  newIndices.push_back(index);
5409  newBasis.push_back(basisElem);
5410  continue;
5411  }
5412 
5413  std::optional<int64_t> indexValue = getConstantIntValue(index);
5414  if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5415  newIndices.push_back(index);
5416  newBasis.push_back(basisElem);
5417  continue;
5418  }
5419  }
5420  if (newIndices.size() == numIndices)
5421  return rewriter.notifyMatchFailure(op,
5422  "no unit basis entries to replace");
5423 
5424  if (newIndices.size() == 0) {
5425  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
5426  return success();
5427  }
5428  rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5429  op, newIndices, newBasis, op.getDisjoint());
5430  return success();
5431  }
5432 };
5433 
5435  ArrayRef<OpFoldResult> terms) {
5436  int64_t nDynamic = 0;
5437  SmallVector<Value> dynamicPart;
5438  AffineExpr result = builder.getAffineConstantExpr(1);
5439  for (OpFoldResult term : terms) {
5440  if (!term)
5441  return term;
5442  std::optional<int64_t> maybeConst = getConstantIntValue(term);
5443  if (maybeConst) {
5444  result = result * builder.getAffineConstantExpr(*maybeConst);
5445  } else {
5446  dynamicPart.push_back(cast<Value>(term));
5447  result = result * builder.getAffineSymbolExpr(nDynamic++);
5448  }
5449  }
5450  if (auto constant = dyn_cast<AffineConstantExpr>(result))
5451  return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
5452  return AffineApplyOp::create(builder, loc, result, dynamicPart).getResult();
5453 }
5454 
5455 /// If conseceutive outputs of a delinearize_index are linearized with the same
5456 /// bounds, canonicalize away the redundant arithmetic.
5457 ///
5458 /// That is, if we have
5459 /// ```
5460 /// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
5461 /// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
5462 /// by (...e, B1, B2, ..., BK, ...f)
5463 /// ```
5464 ///
5465 /// We can rewrite this to
5466 /// ```
5467 /// B = B1 * B2 ... BK
5468 /// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
5469 /// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
5470 /// ```
5471 /// where we replace all results of %s unaffected by the change with results
5472 /// from %sMerged.
5473 ///
5474 /// As a special case, if all results of the delinearize are merged in this way
5475 /// we can replace those usages with %x, thus cancelling the delinearization
5476 /// entirely, as in
5477 /// ```
5478 /// %s:3 = affine.delinearize_index %x into (2, 4, 8)
5479 /// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
5480 /// ```
5481 /// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
5482 struct CancelLinearizeOfDelinearizePortion final
5483  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5485 
5486 private:
5487  // Struct representing a case where the cancellation pattern
5488  // applies. A `Match` means that `length` inputs to the linearize operation
5489  // starting at `linStart` can be cancelled with `length` outputs of
5490  // `delinearize`, starting from `delinStart`.
5491  struct Match {
5492  AffineDelinearizeIndexOp delinearize;
5493  unsigned linStart = 0;
5494  unsigned delinStart = 0;
5495  unsigned length = 0;
5496  };
5497 
5498 public:
5499  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5500  PatternRewriter &rewriter) const override {
5501  SmallVector<Match> matches;
5502 
5503  const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
5504  ArrayRef<OpFoldResult> linBasisRef = linBasis;
5505 
5506  ValueRange multiIndex = linearizeOp.getMultiIndex();
5507  unsigned numLinArgs = multiIndex.size();
5508  unsigned linArgIdx = 0;
5509  // We only want to replace one run from the same delinearize op per
5510  // pattern invocation lest we run into invalidation issues.
5511  llvm::SmallPtrSet<Operation *, 2> alreadyMatchedDelinearize;
5512  while (linArgIdx < numLinArgs) {
5513  auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5514  if (!asResult) {
5515  linArgIdx++;
5516  continue;
5517  }
5518 
5519  auto delinearizeOp =
5520  dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5521  if (!delinearizeOp) {
5522  linArgIdx++;
5523  continue;
5524  }
5525 
5526  /// Result 0 of the delinearize and argument 0 of the linearize can
5527  /// leave their maximum value unspecified. However, even if this happens
5528  /// we can still sometimes start the match process. Specifically, if
5529  /// - The argument we're matching is result 0 and argument 0 (so the
5530  /// bounds don't matter). For example,
5531  ///
5532  /// %0:2 = affine.delinearize_index %x into (8) : index, index
5533  /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
5534  /// allows cancellation
5535  /// - The delinearization doesn't specify a bound, but the linearization
5536  /// is `disjoint`, which asserts that the bound on the linearization is
5537  /// correct.
5538  unsigned delinArgIdx = asResult.getResultNumber();
5539  SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
5540  OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5541  OpFoldResult firstLinBound = linBasis[linArgIdx];
5542  bool boundsMatch = firstDelinBound == firstLinBound;
5543  bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5544  bool knownByDisjoint =
5545  linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5546  if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5547  linArgIdx++;
5548  continue;
5549  }
5550 
5551  unsigned j = 1;
5552  unsigned numDelinOuts = delinearizeOp.getNumResults();
5553  for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
5554  ++j) {
5555  if (multiIndex[linArgIdx + j] !=
5556  delinearizeOp.getResult(delinArgIdx + j))
5557  break;
5558  if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
5559  break;
5560  }
5561  // If there're multiple matches against the same delinearize_index,
5562  // only rewrite the first one we find to prevent invalidations. The next
5563  // ones will be taken care of by subsequent pattern invocations.
5564  if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5565  linArgIdx++;
5566  continue;
5567  }
5568  matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx, j});
5569  linArgIdx += j;
5570  }
5571 
5572  if (matches.empty())
5573  return rewriter.notifyMatchFailure(
5574  linearizeOp, "no run of delinearize outputs to deal with");
5575 
5576  // Record all the delinearize replacements so we can do them after creating
5577  // the new linearization operation, since the new operation might use
5578  // outputs of something we're replacing.
5579  SmallVector<SmallVector<Value>> delinearizeReplacements;
5580 
5581  SmallVector<Value> newIndex;
5582  newIndex.reserve(numLinArgs);
5583  SmallVector<OpFoldResult> newBasis;
5584  newBasis.reserve(numLinArgs);
5585  unsigned prevMatchEnd = 0;
5586  for (Match m : matches) {
5587  unsigned gap = m.linStart - prevMatchEnd;
5588  llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5589  llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5590  // Update here so we don't forget this during early continues
5591  prevMatchEnd = m.linStart + m.length;
5592 
5593  PatternRewriter::InsertionGuard g(rewriter);
5594  rewriter.setInsertionPoint(m.delinearize);
5595 
5596  ArrayRef<OpFoldResult> basisToMerge =
5597  linBasisRef.slice(m.linStart, m.length);
5598  // We use the slice from the linearize's basis above because of the
5599  // "bounds inferred from `disjoint`" case above.
5600  OpFoldResult newSize =
5601  computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
5602 
5603  // Trivial case where we can just skip past the delinearize all together
5604  if (m.length == m.delinearize.getNumResults()) {
5605  newIndex.push_back(m.delinearize.getLinearIndex());
5606  newBasis.push_back(newSize);
5607  // Pad out set of replacements so we don't do anything with this one.
5608  delinearizeReplacements.push_back(SmallVector<Value>());
5609  continue;
5610  }
5611 
5612  SmallVector<Value> newDelinResults;
5613  SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
5614  newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5615  newDelinBasis.begin() + m.delinStart + m.length);
5616  newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5617  auto newDelinearize = AffineDelinearizeIndexOp::create(
5618  rewriter, m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5619  newDelinBasis);
5620 
5621  // Since there may be other uses of the indices we just merged together,
5622  // create a residual affine.delinearize_index that delinearizes the
5623  // merged output into its component parts.
5624  Value combinedElem = newDelinearize.getResult(m.delinStart);
5625  auto residualDelinearize = AffineDelinearizeIndexOp::create(
5626  rewriter, m.delinearize.getLoc(), combinedElem, basisToMerge);
5627 
5628  // Swap all the uses of the unaffected delinearize outputs to the new
5629  // delinearization so that the old code can be removed if this
5630  // linearize_index is the only user of the merged results.
5631  llvm::append_range(newDelinResults,
5632  newDelinearize.getResults().take_front(m.delinStart));
5633  llvm::append_range(newDelinResults, residualDelinearize.getResults());
5634  llvm::append_range(
5635  newDelinResults,
5636  newDelinearize.getResults().drop_front(m.delinStart + 1));
5637 
5638  delinearizeReplacements.push_back(newDelinResults);
5639  newIndex.push_back(combinedElem);
5640  newBasis.push_back(newSize);
5641  }
5642  llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5643  llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5644  rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
5645  linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5646 
5647  for (auto [m, newResults] :
5648  llvm::zip_equal(matches, delinearizeReplacements)) {
5649  if (newResults.empty())
5650  continue;
5651  rewriter.replaceOp(m.delinearize, newResults);
5652  }
5653 
5654  return success();
5655  }
5656 };
5657 
5658 /// Strip leading zero from affine.linearize_index.
5659 ///
5660 /// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
5661 /// to `affine.linearize_index [...a] by (...b)` in all cases.
5662 struct DropLinearizeLeadingZero final
5663  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5665 
5666  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5667  PatternRewriter &rewriter) const override {
5668  Value leadingIdx = op.getMultiIndex().front();
5669  if (!matchPattern(leadingIdx, m_Zero()))
5670  return failure();
5671 
5672  if (op.getMultiIndex().size() == 1) {
5673  rewriter.replaceOp(op, leadingIdx);
5674  return success();
5675  }
5676 
5677  SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
5678  ArrayRef<OpFoldResult> newMixedBasis = mixedBasis;
5679  if (op.hasOuterBound())
5680  newMixedBasis = newMixedBasis.drop_front();
5681 
5682  rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5683  op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5684  return success();
5685  }
5686 };
5687 } // namespace
5688 
5689 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5690  RewritePatternSet &patterns, MLIRContext *context) {
5691  patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5692  DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5693 }
5694 
5695 //===----------------------------------------------------------------------===//
5696 // TableGen'd op method definitions
5697 //===----------------------------------------------------------------------===//
5698 
5699 #define GET_OP_CLASSES
5700 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
Definition: AffineOps.cpp:3026
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
Definition: AffineOps.cpp:2694
static LogicalResult verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
Definition: AffineOps.cpp:3419
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
Definition: AffineOps.cpp:3597
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
Definition: AffineOps.cpp:4230
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
Definition: AffineOps.cpp:62
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
Definition: AffineOps.cpp:4377
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
Definition: AffineOps.cpp:3633
static bool isTopLevelValueOrAbove(Value value, Region *region)
A utility function to check if a value is defined at the top level of region or is an argument of reg...
Definition: AffineOps.cpp:432
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
Definition: AffineOps.cpp:2663
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
Definition: AffineOps.cpp:822
static bool isValidAffineIndexOperand(Value value, Region *region)
Definition: AffineOps.cpp:492
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1682
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:757
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
Definition: AffineOps.cpp:2293
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:749
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1588
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
Definition: AffineOps.cpp:4779
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
Definition: AffineOps.cpp:925
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
Definition: AffineOps.cpp:3610
static std::optional< uint64_t > getTrivialConstantTripCount(AffineForOp forOp)
Returns constant trip count in trivial cases.
Definition: AffineOps.cpp:2604
static LogicalResult replaceAffineDelinearizeIndexInverseExpression(AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
If this map contains of the expression x_1 + x_1 * C_1 + ...
Definition: AffineOps.cpp:1190
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
Definition: AffineOps.cpp:3311
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
Definition: AffineOps.cpp:349
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
Definition: AffineOps.cpp:697
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
Definition: AffineOps.cpp:4492
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
Definition: AffineOps.cpp:3035
static void simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
Definition: AffineOps.cpp:1035
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
Definition: AffineOps.cpp:497
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
Definition: AffineOps.cpp:659
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
Definition: AffineOps.cpp:1336
static void legalizeDemotedDims(MapOrSet &mapOrSet, SmallVectorImpl< Value > &operands)
A valid affine dimension may appear as a symbol in affine.apply operations.
Definition: AffineOps.cpp:1636
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1525
static SmallVector< OpFoldResult > AffineForEmptyLoopFolder(AffineForOp forOp)
Fold the empty loop.
Definition: AffineOps.cpp:2614
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
Definition: AffineOps.cpp:2985
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
Definition: AffineOps.cpp:2558
static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, AffineExpr dimOrSym, AffineMap *map, ValueRange dims, ValueRange syms)
Assuming dimOrSym is a quantity in the apply op map map and defined by minOp = affine_min(x_1,...
Definition: AffineOps.cpp:1061
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
Definition: AffineOps.cpp:529
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1540
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
Definition: AffineOps.cpp:368
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr &quotientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
Definition: AffineOps.cpp:725
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
Definition: AffineOps.cpp:4445
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms, bool replaceAffineMin)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
Definition: AffineOps.cpp:1275
static LogicalResult verifyAffineMinMaxOp(T op)
Definition: AffineOps.cpp:3584
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
Definition: AffineOps.cpp:2468
static void shortenAddChainsContainingAll(AffineExpr e, const llvm::SmallDenseSet< AffineExpr, 4 > &exprsToRemove, AffineExpr newVal, DenseMap< AffineExpr, AffineExpr > &replacementsMap)
Recursively traverse e.
Definition: AffineOps.cpp:1133
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
Definition: AffineOps.cpp:1439
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
Definition: AffineOps.cpp:4965
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
Definition: AffineOps.cpp:3796
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value getMemRef(Operation *memOp)
Returns the memref being read/written by a memref/affine load/store op.
Definition: Utils.cpp:246
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:51
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
union mlir::linalg::@1257::ArityGroupAndKind::Kind kind
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:75
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:80
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
#define div(a, b)
#define rem(a, b)
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:959
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:33
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
Definition: AffineExpr.cpp:241
MLIRContext * getContext() const
Definition: AffineExpr.cpp:31
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
Definition: AffineExpr.cpp:179
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:655
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
Definition: AffineMap.h:221
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:267
unsigned getNumSymbols() const
Definition: AffineMap.cpp:394
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
Definition: AffineMap.h:228
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineMap.cpp:496
unsigned getNumInputs() const
Definition: AffineMap.cpp:399
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
Definition: AffineMap.h:280
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:511
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
Definition: AffineMap.cpp:124
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:647
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
Definition: AffineMap.cpp:430
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:228
AffineMap getDimIdentityMap()
Definition: Builders.cpp:383
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:387
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:368
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:372
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
Definition: Builders.cpp:179
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:112
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
NoneType getNoneType()
Definition: Builders.cpp:88
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:100
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
Definition: Builders.cpp:376
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:378
MLIRContext * getContext() const
Definition: Builders.h:56
AffineMap getSymbolIdentityMap()
Definition: Builders.cpp:396
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:266
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:281
IndexType getIndexType()
Definition: Builders.cpp:51
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:630
Location getLoc() const
Accessors for the implied location.
Definition: Builders.h:663
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition: IntegerSet.h:44
unsigned getNumDims() const
Definition: IntegerSet.cpp:15
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
Definition: IntegerSet.cpp:57
unsigned getNumInputs() const
Definition: IntegerSet.cpp:17
ArrayRef< AffineExpr > getConstraints() const
Definition: IntegerSet.cpp:41
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
Definition: IntegerSet.cpp:51
unsigned getNumSymbols() const
Definition: IntegerSet.cpp:16
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:445
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:320
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:457
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:442
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:219
operand_range::iterator operand_iterator
Definition: Operation.h:372
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:855
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:622
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
A variable that can be added to the constraint set as a "column".
static bool compare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Return "true" if "lhs cmp rhs" was proven to hold.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition: Value.cpp:39
AffineBound represents a lower or upper bound in the for operation.
Definition: AffineOps.h:550
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
Definition: AffineOps.h:106
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
Definition: AffineOps.h:330
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
Definition: AffineOps.cpp:4337
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
unsigned getNumResults() const
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
Definition: AffineOps.cpp:3048
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1416
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
Definition: AffineOps.cpp:2962
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
Definition: AffineOps.cpp:290
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
Definition: AffineOps.cpp:2934
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1514
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: AffineOps.cpp:2938
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1759
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1579
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
Definition: AffineOps.cpp:2926
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1469
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1572
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Definition: AffineOps.cpp:250
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
Definition: AffineOps.cpp:275
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
Definition: AffineOps.cpp:1400
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
Definition: AffineOps.cpp:1764
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
Definition: AffineOps.cpp:2969
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
Definition: AffineOps.cpp:412
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Definition: AffineOps.cpp:2949
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
Definition: AffineOps.cpp:265
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
Definition: AffineOps.cpp:507
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
Definition: AffineOps.cpp:2930
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
Definition: AffineOps.cpp:1534
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:766
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
Definition: AffineMap.cpp:776
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
AffineExprKind
Definition: AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:68
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
Definition: AffineMap.cpp:738
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:629
Canonicalize the affine map result expression order of an affine min/max operation.
Definition: AffineOps.cpp:3850
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3853
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3867
Remove duplicated expressions in affine min/max ops.
Definition: AffineOps.cpp:3666
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3669
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
Definition: AffineOps.cpp:3709
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3712
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:285
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:322
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.