MLIR  22.0.0git
AffineOps.cpp
Go to the documentation of this file.
1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/AffineExpr.h"
16 #include "mlir/IR/IRMapping.h"
17 #include "mlir/IR/IntegerSet.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/OpDefinition.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/Value.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallBitVector.h"
27 #include "llvm/ADT/SmallVectorExtras.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/DebugLog.h"
30 #include "llvm/Support/LogicalResult.h"
31 #include "llvm/Support/MathExtras.h"
32 #include <numeric>
33 #include <optional>
34 
35 using namespace mlir;
36 using namespace mlir::affine;
37 
38 using llvm::divideCeilSigned;
39 using llvm::divideFloorSigned;
40 using llvm::mod;
41 
42 #define DEBUG_TYPE "affine-ops"
43 
44 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
45 
46 /// A utility function to check if a value is defined at the top level of
47 /// `region` or is an argument of `region`. A value of index type defined at the
48 /// top level of a `AffineScope` region is always a valid symbol for all
49 /// uses in that region.
51  if (auto arg = dyn_cast<BlockArgument>(value))
52  return arg.getParentRegion() == region;
53  return value.getDefiningOp()->getParentRegion() == region;
54 }
55 
56 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
57 /// region remains legal if the operation that uses it is inlined into `dest`
58 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
59 /// `isValidSymbol`, depending on the value being required to remain a valid
60 /// dimension or symbol.
61 static bool
63  const IRMapping &mapping,
64  function_ref<bool(Value, Region *)> legalityCheck) {
65  // If the value is a valid dimension for any other reason than being
66  // a top-level value, it will remain valid: constants get inlined
67  // with the function, transitive affine applies also get inlined and
68  // will be checked themselves, etc.
69  if (!isTopLevelValue(value, src))
70  return true;
71 
72  // If it's a top-level value because it's a block operand, i.e. a
73  // function argument, check whether the value replacing it after
74  // inlining is a valid dimension in the new region.
75  if (llvm::isa<BlockArgument>(value))
76  return legalityCheck(mapping.lookup(value), dest);
77 
78  // If it's a top-level value because it's defined in the region,
79  // it can only be inlined if the defining op is a constant or a
80  // `dim`, which can appear anywhere and be valid, since the defining
81  // op won't be top-level anymore after inlining.
82  Attribute operandCst;
83  bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp());
84  return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
85  isDimLikeOp;
86 }
87 
88 /// Checks if all values known to be legal affine dimensions or symbols in `src`
89 /// remain so if their respective users are inlined into `dest`.
90 static bool
92  const IRMapping &mapping,
93  function_ref<bool(Value, Region *)> legalityCheck) {
94  return llvm::all_of(values, [&](Value v) {
95  return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
96  });
97 }
98 
99 /// Checks if an affine read or write operation remains legal after inlining
100 /// from `src` to `dest`.
101 template <typename OpTy>
102 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
103  const IRMapping &mapping) {
104  static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
105  AffineWriteOpInterface>::value,
106  "only ops with affine read/write interface are supported");
107 
108  AffineMap map = op.getAffineMap();
109  ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
110  ValueRange symbolOperands =
111  op.getMapOperands().take_back(map.getNumSymbols());
113  dimOperands, src, dest, mapping,
114  static_cast<bool (*)(Value, Region *)>(isValidDim)))
115  return false;
117  symbolOperands, src, dest, mapping,
118  static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
119  return false;
120  return true;
121 }
122 
123 /// Checks if an affine apply operation remains legal after inlining from `src`
124 /// to `dest`.
125 // Use "unused attribute" marker to silence clang-tidy warning stemming from
126 // the inability to see through "llvm::TypeSwitch".
127 template <>
128 bool LLVM_ATTRIBUTE_UNUSED remainsLegalAfterInline(AffineApplyOp op,
129  Region *src, Region *dest,
130  const IRMapping &mapping) {
131  // If it's a valid dimension, we need to check that it remains so.
132  if (isValidDim(op.getResult(), src))
134  op.getMapOperands(), src, dest, mapping,
135  static_cast<bool (*)(Value, Region *)>(isValidDim));
136 
137  // Otherwise it must be a valid symbol, check that it remains so.
139  op.getMapOperands(), src, dest, mapping,
140  static_cast<bool (*)(Value, Region *)>(isValidSymbol));
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // AffineDialect Interfaces
145 //===----------------------------------------------------------------------===//
146 
147 namespace {
148 /// This class defines the interface for handling inlining with affine
149 /// operations.
150 struct AffineInlinerInterface : public DialectInlinerInterface {
152 
153  //===--------------------------------------------------------------------===//
154  // Analysis Hooks
155  //===--------------------------------------------------------------------===//
156 
157  /// Returns true if the given region 'src' can be inlined into the region
158  /// 'dest' that is attached to an operation registered to the current dialect.
159  /// 'wouldBeCloned' is set if the region is cloned into its new location
160  /// rather than moved, indicating there may be other users.
161  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
162  IRMapping &valueMapping) const final {
163  // We can inline into affine loops and conditionals if this doesn't break
164  // affine value categorization rules.
165  Operation *destOp = dest->getParentOp();
166  if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
167  return false;
168 
169  // Multi-block regions cannot be inlined into affine constructs, all of
170  // which require single-block regions.
171  if (!src->hasOneBlock())
172  return false;
173 
174  // Side-effecting operations that the affine dialect cannot understand
175  // should not be inlined.
176  Block &srcBlock = src->front();
177  for (Operation &op : srcBlock) {
178  // Ops with no side effects are fine,
179  if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
180  if (iface.hasNoEffect())
181  continue;
182  }
183 
184  // Assuming the inlined region is valid, we only need to check if the
185  // inlining would change it.
186  bool remainsValid =
188  .Case<AffineApplyOp, AffineReadOpInterface,
189  AffineWriteOpInterface>([&](auto op) {
190  return remainsLegalAfterInline(op, src, dest, valueMapping);
191  })
192  .Default([](Operation *) {
193  // Conservatively disallow inlining ops we cannot reason about.
194  return false;
195  });
196 
197  if (!remainsValid)
198  return false;
199  }
200 
201  return true;
202  }
203 
204  /// Returns true if the given operation 'op', that is registered to this
205  /// dialect, can be inlined into the given region, false otherwise.
206  bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
207  IRMapping &valueMapping) const final {
208  // Always allow inlining affine operations into a region that is marked as
209  // affine scope, or into affine loops and conditionals. There are some edge
210  // cases when inlining *into* affine structures, but that is handled in the
211  // other 'isLegalToInline' hook above.
212  Operation *parentOp = region->getParentOp();
213  return parentOp->hasTrait<OpTrait::AffineScope>() ||
214  isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
215  }
216 
217  /// Affine regions should be analyzed recursively.
218  bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
219 };
220 } // namespace
221 
222 //===----------------------------------------------------------------------===//
223 // AffineDialect
224 //===----------------------------------------------------------------------===//
225 
226 void AffineDialect::initialize() {
227  addOperations<AffineDmaStartOp, AffineDmaWaitOp,
228 #define GET_OP_LIST
229 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
230  >();
231  addInterfaces<AffineInlinerInterface>();
232  declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
233  AffineMinOp>();
234 }
235 
236 /// Materialize a single constant operation from a given attribute value with
237 /// the desired resultant type.
239  Attribute value, Type type,
240  Location loc) {
241  if (auto poison = dyn_cast<ub::PoisonAttr>(value))
242  return ub::PoisonOp::create(builder, loc, type, poison);
243  return arith::ConstantOp::materialize(builder, value, type, loc);
244 }
245 
246 /// A utility function to check if a value is defined at the top level of an
247 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
248 /// conservatively assume it is not top-level. A value of index type defined at
249 /// the top level is always a valid symbol.
251  if (auto arg = dyn_cast<BlockArgument>(value)) {
252  // The block owning the argument may be unlinked, e.g. when the surrounding
253  // region has not yet been attached to an Op, at which point the parent Op
254  // is null.
255  Operation *parentOp = arg.getOwner()->getParentOp();
256  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
257  }
258  // The defining Op may live in an unlinked block so its parent Op may be null.
259  Operation *parentOp = value.getDefiningOp()->getParentOp();
260  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
261 }
262 
263 /// Returns the closest region enclosing `op` that is held by an operation with
264 /// trait `AffineScope`; `nullptr` if there is no such region.
266  auto *curOp = op;
267  while (auto *parentOp = curOp->getParentOp()) {
268  if (parentOp->hasTrait<OpTrait::AffineScope>())
269  return curOp->getParentRegion();
270  curOp = parentOp;
271  }
272  return nullptr;
273 }
274 
276  Operation *curOp = op;
277  while (auto *parentOp = curOp->getParentOp()) {
278  if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))
279  return curOp->getParentRegion();
280  curOp = parentOp;
281  }
282  return nullptr;
283 }
284 
285 // A Value can be used as a dimension id iff it meets one of the following
286 // conditions:
287 // *) It is valid as a symbol.
288 // *) It is an induction variable.
289 // *) It is the result of affine apply operation with dimension id arguments.
291  // The value must be an index type.
292  if (!value.getType().isIndex())
293  return false;
294 
295  if (auto *defOp = value.getDefiningOp())
296  return isValidDim(value, getAffineScope(defOp));
297 
298  // This value has to be a block argument for an op that has the
299  // `AffineScope` trait or an induction var of an affine.for or
300  // affine.parallel.
301  if (isAffineInductionVar(value))
302  return true;
303  auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
304  return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
305 }
306 
307 // Value can be used as a dimension id iff it meets one of the following
308 // conditions:
309 // *) It is valid as a symbol.
310 // *) It is an induction variable.
311 // *) It is the result of an affine apply operation with dimension id operands.
312 // *) It is the result of a more specialized index transformation (ex.
313 // delinearize_index or linearize_index) with dimension id operands.
314 bool mlir::affine::isValidDim(Value value, Region *region) {
315  // The value must be an index type.
316  if (!value.getType().isIndex())
317  return false;
318 
319  // All valid symbols are okay.
320  if (isValidSymbol(value, region))
321  return true;
322 
323  auto *op = value.getDefiningOp();
324  if (!op) {
325  // This value has to be an induction var for an affine.for or an
326  // affine.parallel.
327  return isAffineInductionVar(value);
328  }
329 
330  // Affine apply operation is ok if all of its operands are ok.
331  if (auto applyOp = dyn_cast<AffineApplyOp>(op))
332  return applyOp.isValidDim(region);
333  // delinearize_index and linearize_index are special forms of apply
334  // and so are valid dimensions if all their arguments are valid dimensions.
335  if (isa<AffineDelinearizeIndexOp, AffineLinearizeIndexOp>(op))
336  return llvm::all_of(op->getOperands(),
337  [&](Value arg) { return ::isValidDim(arg, region); });
338  // The dim op is okay if its operand memref/tensor is defined at the top
339  // level.
340  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
341  return isTopLevelValue(dimOp.getShapedValue());
342  return false;
343 }
344 
345 /// Returns true if the 'index' dimension of the `memref` defined by
346 /// `memrefDefOp` is a statically shaped one or defined using a valid symbol
347 /// for `region`.
348 template <typename AnyMemRefDefOp>
349 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
350  Region *region) {
351  MemRefType memRefType = memrefDefOp.getType();
352 
353  // Dimension index is out of bounds.
354  if (index >= memRefType.getRank()) {
355  return false;
356  }
357 
358  // Statically shaped.
359  if (!memRefType.isDynamicDim(index))
360  return true;
361  // Get the position of the dimension among dynamic dimensions;
362  unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
363  return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
364  region);
365 }
366 
367 /// Returns true if the result of the dim op is a valid symbol for `region`.
368 static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
369  // The dim op is okay if its source is defined at the top level.
370  if (isTopLevelValue(dimOp.getShapedValue()))
371  return true;
372 
373  // Conservatively handle remaining BlockArguments as non-valid symbols.
374  // E.g. scf.for iterArgs.
375  if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
376  return false;
377 
378  // The dim op is also okay if its operand memref is a view/subview whose
379  // corresponding size is a valid symbol.
380  std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
381 
382  // Be conservative if we can't understand the dimension.
383  if (!index.has_value())
384  return false;
385 
386  // Skip over all memref.cast ops (if any).
387  Operation *op = dimOp.getShapedValue().getDefiningOp();
388  while (auto castOp = dyn_cast<memref::CastOp>(op)) {
389  // Bail on unranked memrefs.
390  if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
391  return false;
392  op = castOp.getSource().getDefiningOp();
393  if (!op)
394  return false;
395  }
396 
397  int64_t i = index.value();
399  .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
400  [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
401  .Default([](Operation *) { return false; });
402 }
403 
404 // A value can be used as a symbol (at all its use sites) iff it meets one of
405 // the following conditions:
406 // *) It is a constant.
407 // *) Its defining op or block arg appearance is immediately enclosed by an op
408 // with `AffineScope` trait.
409 // *) It is the result of an affine.apply operation with symbol operands.
410 // *) It is a result of the dim op on a memref whose corresponding size is a
411 // valid symbol.
413  if (!value)
414  return false;
415 
416  // The value must be an index type.
417  if (!value.getType().isIndex())
418  return false;
419 
420  // Check that the value is a top level value.
421  if (isTopLevelValue(value))
422  return true;
423 
424  if (auto *defOp = value.getDefiningOp())
425  return isValidSymbol(value, getAffineScope(defOp));
426 
427  return false;
428 }
429 
430 /// A value can be used as a symbol for `region` iff it meets one of the
431 /// following conditions:
432 /// *) It is a constant.
433 /// *) It is a result of a `Pure` operation whose operands are valid symbolic
434 /// *) identifiers.
435 /// *) It is a result of the dim op on a memref whose corresponding size is
436 /// a valid symbol.
437 /// *) It is defined at the top level of 'region' or is its argument.
438 /// *) It dominates `region`'s parent op.
439 /// If `region` is null, conservatively assume the symbol definition scope does
440 /// not exist and only accept the values that would be symbols regardless of
441 /// the surrounding region structure, i.e. the first three cases above.
443  // The value must be an index type.
444  if (!value.getType().isIndex())
445  return false;
446 
447  // A top-level value is a valid symbol.
448  if (region && ::isTopLevelValue(value, region))
449  return true;
450 
451  auto *defOp = value.getDefiningOp();
452  if (!defOp) {
453  // A block argument that is not a top-level value is a valid symbol if it
454  // dominates region's parent op.
455  Operation *regionOp = region ? region->getParentOp() : nullptr;
456  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
457  if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
458  return isValidSymbol(value, parentOpRegion);
459  return false;
460  }
461 
462  // Constant operation is ok.
463  Attribute operandCst;
464  if (matchPattern(defOp, m_Constant(&operandCst)))
465  return true;
466 
467  // `Pure` operation that whose operands are valid symbolic identifiers.
468  if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
469  return affine::isValidSymbol(operand, region);
470  })) {
471  return true;
472  }
473 
474  // Dim op results could be valid symbols at any level.
475  if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
476  return isDimOpValidSymbol(dimOp, region);
477 
478  // Check for values dominating `region`'s parent op.
479  Operation *regionOp = region ? region->getParentOp() : nullptr;
480  if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
481  if (auto *parentRegion = region->getParentOp()->getParentRegion())
482  return isValidSymbol(value, parentRegion);
483 
484  return false;
485 }
486 
487 // Returns true if 'value' is a valid index to an affine operation (e.g.
488 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
489 // `region` provides the polyhedral symbol scope. Returns false otherwise.
490 static bool isValidAffineIndexOperand(Value value, Region *region) {
491  return isValidDim(value, region) || isValidSymbol(value, region);
492 }
493 
494 /// Prints dimension and symbol list.
497  unsigned numDims, OpAsmPrinter &printer) {
498  OperandRange operands(begin, end);
499  printer << '(' << operands.take_front(numDims) << ')';
500  if (operands.size() > numDims)
501  printer << '[' << operands.drop_front(numDims) << ']';
502 }
503 
504 /// Parses dimension and symbol list and returns true if parsing failed.
506  OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
508  if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
509  return failure();
510  // Store number of dimensions for validation by caller.
511  numDims = opInfos.size();
512 
513  // Parse the optional symbol operands.
514  auto indexTy = parser.getBuilder().getIndexType();
515  return failure(parser.parseOperandList(
517  parser.resolveOperands(opInfos, indexTy, operands));
518 }
519 
520 /// Utility function to verify that a set of operands are valid dimension and
521 /// symbol identifiers. The operands should be laid out such that the dimension
522 /// operands are before the symbol operands. This function returns failure if
523 /// there was an invalid operand. An operation is provided to emit any necessary
524 /// errors.
525 template <typename OpTy>
526 static LogicalResult
528  unsigned numDims) {
529  unsigned opIt = 0;
530  for (auto operand : operands) {
531  if (opIt++ < numDims) {
532  if (!isValidDim(operand, getAffineScope(op)))
533  return op.emitOpError("operand cannot be used as a dimension id");
534  } else if (!isValidSymbol(operand, getAffineScope(op))) {
535  return op.emitOpError("operand cannot be used as a symbol");
536  }
537  }
538  return success();
539 }
540 
541 //===----------------------------------------------------------------------===//
542 // AffineApplyOp
543 //===----------------------------------------------------------------------===//
544 
545 AffineValueMap AffineApplyOp::getAffineValueMap() {
546  return AffineValueMap(getAffineMap(), getOperands(), getResult());
547 }
548 
549 ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
550  auto &builder = parser.getBuilder();
551  auto indexTy = builder.getIndexType();
552 
553  AffineMapAttr mapAttr;
554  unsigned numDims;
555  if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
556  parseDimAndSymbolList(parser, result.operands, numDims) ||
557  parser.parseOptionalAttrDict(result.attributes))
558  return failure();
559  auto map = mapAttr.getValue();
560 
561  if (map.getNumDims() != numDims ||
562  numDims + map.getNumSymbols() != result.operands.size()) {
563  return parser.emitError(parser.getNameLoc(),
564  "dimension or symbol index mismatch");
565  }
566 
567  result.types.append(map.getNumResults(), indexTy);
568  return success();
569 }
570 
572  p << " " << getMapAttr();
573  printDimAndSymbolList(operand_begin(), operand_end(),
574  getAffineMap().getNumDims(), p);
575  p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
576 }
577 
578 LogicalResult AffineApplyOp::verify() {
579  // Check input and output dimensions match.
580  AffineMap affineMap = getMap();
581 
582  // Verify that operand count matches affine map dimension and symbol count.
583  if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
584  return emitOpError(
585  "operand count and affine map dimension and symbol count must match");
586 
587  // Verify that the map only produces one result.
588  if (affineMap.getNumResults() != 1)
589  return emitOpError("mapping must produce one value");
590 
591  // Do not allow valid dims to be used in symbol positions. We do allow
592  // affine.apply to use operands for values that may neither qualify as affine
593  // dims or affine symbols due to usage outside of affine ops, analyses, etc.
594  Region *region = getAffineScope(*this);
595  for (Value operand : getMapOperands().drop_front(affineMap.getNumDims())) {
596  if (::isValidDim(operand, region) && !::isValidSymbol(operand, region))
597  return emitError("dimensional operand cannot be used as a symbol");
598  }
599 
600  return success();
601 }
602 
603 // The result of the affine apply operation can be used as a dimension id if all
604 // its operands are valid dimension ids.
606  return llvm::all_of(getOperands(),
607  [](Value op) { return affine::isValidDim(op); });
608 }
609 
610 // The result of the affine apply operation can be used as a dimension id if all
611 // its operands are valid dimension ids with the parent operation of `region`
612 // defining the polyhedral scope for symbols.
613 bool AffineApplyOp::isValidDim(Region *region) {
614  return llvm::all_of(getOperands(),
615  [&](Value op) { return ::isValidDim(op, region); });
616 }
617 
618 // The result of the affine apply operation can be used as a symbol if all its
619 // operands are symbols.
621  return llvm::all_of(getOperands(),
622  [](Value op) { return affine::isValidSymbol(op); });
623 }
624 
625 // The result of the affine apply operation can be used as a symbol in `region`
626 // if all its operands are symbols in `region`.
627 bool AffineApplyOp::isValidSymbol(Region *region) {
628  return llvm::all_of(getOperands(), [&](Value operand) {
629  return affine::isValidSymbol(operand, region);
630  });
631 }
632 
633 OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
634  auto map = getAffineMap();
635 
636  // Fold dims and symbols to existing values.
637  auto expr = map.getResult(0);
638  if (auto dim = dyn_cast<AffineDimExpr>(expr))
639  return getOperand(dim.getPosition());
640  if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
641  return getOperand(map.getNumDims() + sym.getPosition());
642 
643  // Otherwise, default to folding the map.
645  bool hasPoison = false;
646  auto foldResult =
647  map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
648  if (hasPoison)
650  if (failed(foldResult))
651  return {};
652  return result[0];
653 }
654 
655 /// Returns the largest known divisor of `e`. Exploits information from the
656 /// values in `operands`.
657 static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef<Value> operands) {
658  // This method isn't aware of `operands`.
659  int64_t div = e.getLargestKnownDivisor();
660 
661  // We now make use of operands for the case `e` is a dim expression.
662  // TODO: More powerful simplification would have to modify
663  // getLargestKnownDivisor to take `operands` and exploit that information as
664  // well for dim/sym expressions, but in that case, getLargestKnownDivisor
665  // can't be part of the IR library but of the `Analysis` library. The IR
666  // library can only really depend on simple O(1) checks.
667  auto dimExpr = dyn_cast<AffineDimExpr>(e);
668  // If it's not a dim expr, `div` is the best we have.
669  if (!dimExpr)
670  return div;
671 
672  // We simply exploit information from loop IVs.
673  // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
674  // desired simplifications are expected to be part of other
675  // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
676  // LoopAnalysis library.
677  Value operand = operands[dimExpr.getPosition()];
678  int64_t operandDivisor = 1;
679  // TODO: With the right accessors, this can be extended to
680  // LoopLikeOpInterface.
681  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
682  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
683  operandDivisor = forOp.getStepAsInt();
684  } else {
685  uint64_t lbLargestKnownDivisor =
686  forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
687  operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
688  }
689  }
690  return operandDivisor;
691 }
692 
693 /// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
694 /// being an affine dim expression or a constant.
696  int64_t k) {
697  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
698  int64_t constVal = constExpr.getValue();
699  return constVal >= 0 && constVal < k;
700  }
701  auto dimExpr = dyn_cast<AffineDimExpr>(e);
702  if (!dimExpr)
703  return false;
704  Value operand = operands[dimExpr.getPosition()];
705  // TODO: With the right accessors, this can be extended to
706  // LoopLikeOpInterface.
707  if (AffineForOp forOp = getForInductionVarOwner(operand)) {
708  if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
709  forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
710  return true;
711  }
712  }
713 
714  // We don't consider other cases like `operand` being defined by a constant or
715  // an affine.apply op since such cases will already be handled by other
716  // patterns and propagation of loop IVs or constant would happen.
717  return false;
718 }
719 
720 /// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
721 /// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
722 /// expression is in that form.
723 static bool isQTimesDPlusR(AffineExpr e, ArrayRef<Value> operands, int64_t &div,
724  AffineExpr &quotientTimesDiv, AffineExpr &rem) {
725  auto bin = dyn_cast<AffineBinaryOpExpr>(e);
726  if (!bin || bin.getKind() != AffineExprKind::Add)
727  return false;
728 
729  AffineExpr llhs = bin.getLHS();
730  AffineExpr rlhs = bin.getRHS();
731  div = getLargestKnownDivisor(llhs, operands);
732  if (isNonNegativeBoundedBy(rlhs, operands, div)) {
733  quotientTimesDiv = llhs;
734  rem = rlhs;
735  return true;
736  }
737  div = getLargestKnownDivisor(rlhs, operands);
738  if (isNonNegativeBoundedBy(llhs, operands, div)) {
739  quotientTimesDiv = rlhs;
740  rem = llhs;
741  return true;
742  }
743  return false;
744 }
745 
746 /// Gets the constant lower bound on an `iv`.
747 static std::optional<int64_t> getLowerBound(Value iv) {
748  AffineForOp forOp = getForInductionVarOwner(iv);
749  if (forOp && forOp.hasConstantLowerBound())
750  return forOp.getConstantLowerBound();
751  return std::nullopt;
752 }
753 
754 /// Gets the constant upper bound on an affine.for `iv`.
755 static std::optional<int64_t> getUpperBound(Value iv) {
756  AffineForOp forOp = getForInductionVarOwner(iv);
757  if (!forOp || !forOp.hasConstantUpperBound())
758  return std::nullopt;
759 
760  // If its lower bound is also known, we can get a more precise bound
761  // whenever the step is not one.
762  if (forOp.hasConstantLowerBound()) {
763  return forOp.getConstantUpperBound() - 1 -
764  (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
765  forOp.getStepAsInt();
766  }
767  return forOp.getConstantUpperBound() - 1;
768 }
769 
770 /// Determine a constant upper bound for `expr` if one exists while exploiting
771 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
772 /// is guaranteed to be less than or equal to it.
773 static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
774  unsigned numSymbols,
775  ArrayRef<Value> operands) {
776  // Get the constant lower or upper bounds on the operands.
777  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
778  constLowerBounds.reserve(operands.size());
779  constUpperBounds.reserve(operands.size());
780  for (Value operand : operands) {
781  constLowerBounds.push_back(getLowerBound(operand));
782  constUpperBounds.push_back(getUpperBound(operand));
783  }
784 
785  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
786  return constExpr.getValue();
787 
788  return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
789  constUpperBounds,
790  /*isUpper=*/true);
791 }
792 
793 /// Determine a constant lower bound for `expr` if one exists while exploiting
794 /// values in `operands`. Note that the upper bound is an inclusive one. `expr`
795 /// is guaranteed to be less than or equal to it.
796 static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
797  unsigned numSymbols,
798  ArrayRef<Value> operands) {
799  // Get the constant lower or upper bounds on the operands.
800  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
801  constLowerBounds.reserve(operands.size());
802  constUpperBounds.reserve(operands.size());
803  for (Value operand : operands) {
804  constLowerBounds.push_back(getLowerBound(operand));
805  constUpperBounds.push_back(getUpperBound(operand));
806  }
807 
808  std::optional<int64_t> lowerBound;
809  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
810  lowerBound = constExpr.getValue();
811  } else {
812  lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
813  constLowerBounds, constUpperBounds,
814  /*isUpper=*/false);
815  }
816  return lowerBound;
817 }
818 
819 /// Simplify `expr` while exploiting information from the values in `operands`.
820 static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
821  unsigned numSymbols,
822  ArrayRef<Value> operands) {
823  // We do this only for certain floordiv/mod expressions.
824  auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
825  if (!binExpr)
826  return;
827 
828  // Simplify the child expressions first.
829  AffineExpr lhs = binExpr.getLHS();
830  AffineExpr rhs = binExpr.getRHS();
831  simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
832  simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
833  expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
834 
835  binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
836  if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
837  expr.getKind() != AffineExprKind::CeilDiv &&
838  expr.getKind() != AffineExprKind::Mod)) {
839  return;
840  }
841 
842  // The `lhs` and `rhs` may be different post construction of simplified expr.
843  lhs = binExpr.getLHS();
844  rhs = binExpr.getRHS();
845  auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
846  if (!rhsConst)
847  return;
848 
849  int64_t rhsConstVal = rhsConst.getValue();
850  // Undefined exprsessions aren't touched; IR can still be valid with them.
851  if (rhsConstVal <= 0)
852  return;
853 
854  // Exploit constant lower/upper bounds to simplify a floordiv or mod.
855  MLIRContext *context = expr.getContext();
856  std::optional<int64_t> lhsLbConst =
857  getLowerBound(lhs, numDims, numSymbols, operands);
858  std::optional<int64_t> lhsUbConst =
859  getUpperBound(lhs, numDims, numSymbols, operands);
860  if (lhsLbConst && lhsUbConst) {
861  int64_t lhsLbConstVal = *lhsLbConst;
862  int64_t lhsUbConstVal = *lhsUbConst;
863  // lhs floordiv c is a single value lhs is bounded in a range `c` that has
864  // the same quotient.
865  if (binExpr.getKind() == AffineExprKind::FloorDiv &&
866  divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
867  divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
868  expr = getAffineConstantExpr(
869  divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
870  return;
871  }
872  // lhs ceildiv c is a single value if the entire range has the same ceil
873  // quotient.
874  if (binExpr.getKind() == AffineExprKind::CeilDiv &&
875  divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
876  divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
877  expr = getAffineConstantExpr(divideCeilSigned(lhsLbConstVal, rhsConstVal),
878  context);
879  return;
880  }
881  // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
882  if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
883  lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
884  expr = lhs;
885  return;
886  }
887  }
888 
889  // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
890  // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
891  // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
892  // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
893  AffineExpr quotientTimesDiv, rem;
894  int64_t divisor;
895  if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
896  if (rhsConstVal % divisor == 0 &&
897  binExpr.getKind() == AffineExprKind::FloorDiv) {
898  expr = quotientTimesDiv.floorDiv(rhsConst);
899  } else if (divisor % rhsConstVal == 0 &&
900  binExpr.getKind() == AffineExprKind::Mod) {
901  expr = rem % rhsConst;
902  }
903  return;
904  }
905 
906  // Handle the simple case when the LHS expression can be either upper
907  // bounded or is a known multiple of RHS constant.
908  // lhs floordiv c -> 0 if 0 <= lhs < c,
909  // lhs mod c -> 0 if lhs % c = 0.
910  if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) &&
911  binExpr.getKind() == AffineExprKind::FloorDiv) ||
912  (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 &&
913  binExpr.getKind() == AffineExprKind::Mod)) {
914  expr = getAffineConstantExpr(0, expr.getContext());
915  }
916 }
917 
918 /// Simplify the expressions in `map` while making use of lower or upper bounds
919 /// of its operands. If `isMax` is true, the map is to be treated as a max of
920 /// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
921 /// d1) can be simplified to (8) if the operands are respectively lower bounded
922 /// by 2 and 0 (the second expression can't be lower than 8).
924  ArrayRef<Value> operands,
925  bool isMax) {
926  // Can't simplify.
927  if (operands.empty())
928  return;
929 
930  // Get the upper or lower bound on an affine.for op IV using its range.
931  // Get the constant lower or upper bounds on the operands.
932  SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
933  constLowerBounds.reserve(operands.size());
934  constUpperBounds.reserve(operands.size());
935  for (Value operand : operands) {
936  constLowerBounds.push_back(getLowerBound(operand));
937  constUpperBounds.push_back(getUpperBound(operand));
938  }
939 
940  // We will compute the lower and upper bounds on each of the expressions
941  // Then, we will check (depending on max or min) as to whether a specific
942  // bound is redundant by checking if its highest (in case of max) and its
943  // lowest (in the case of min) value is already lower than (or higher than)
944  // the lower bound (or upper bound in the case of min) of another bound.
945  SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
946  lowerBounds.reserve(map.getNumResults());
947  upperBounds.reserve(map.getNumResults());
948  for (AffineExpr e : map.getResults()) {
949  if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
950  lowerBounds.push_back(constExpr.getValue());
951  upperBounds.push_back(constExpr.getValue());
952  } else {
953  lowerBounds.push_back(
955  constLowerBounds, constUpperBounds,
956  /*isUpper=*/false));
957  upperBounds.push_back(
959  constLowerBounds, constUpperBounds,
960  /*isUpper=*/true));
961  }
962  }
963 
964  // Collect expressions that are not redundant.
965  SmallVector<AffineExpr, 4> irredundantExprs;
966  for (auto exprEn : llvm::enumerate(map.getResults())) {
967  AffineExpr e = exprEn.value();
968  unsigned i = exprEn.index();
969  // Some expressions can be turned into constants.
970  if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
971  e = getAffineConstantExpr(*lowerBounds[i], e.getContext());
972 
973  // Check if the expression is redundant.
974  if (isMax) {
975  if (!upperBounds[i]) {
976  irredundantExprs.push_back(e);
977  continue;
978  }
979  // If there exists another expression such that its lower bound is greater
980  // than this expression's upper bound, it's redundant.
981  if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
982  auto otherLowerBound = en.value();
983  unsigned pos = en.index();
984  if (pos == i || !otherLowerBound)
985  return false;
986  if (*otherLowerBound > *upperBounds[i])
987  return true;
988  if (*otherLowerBound < *upperBounds[i])
989  return false;
990  // Equality case. When both expressions are considered redundant, we
991  // don't want to get both of them. We keep the one that appears
992  // first.
993  if (upperBounds[pos] && lowerBounds[i] &&
994  lowerBounds[i] == upperBounds[i] &&
995  otherLowerBound == *upperBounds[pos] && i < pos)
996  return false;
997  return true;
998  }))
999  irredundantExprs.push_back(e);
1000  } else {
1001  if (!lowerBounds[i]) {
1002  irredundantExprs.push_back(e);
1003  continue;
1004  }
1005  // Likewise for the `min` case. Use the complement of the condition above.
1006  if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
1007  auto otherUpperBound = en.value();
1008  unsigned pos = en.index();
1009  if (pos == i || !otherUpperBound)
1010  return false;
1011  if (*otherUpperBound < *lowerBounds[i])
1012  return true;
1013  if (*otherUpperBound > *lowerBounds[i])
1014  return false;
1015  if (lowerBounds[pos] && upperBounds[i] &&
1016  lowerBounds[i] == upperBounds[i] &&
1017  otherUpperBound == lowerBounds[pos] && i < pos)
1018  return false;
1019  return true;
1020  }))
1021  irredundantExprs.push_back(e);
1022  }
1023  }
1024 
1025  // Create the map without the redundant expressions.
1026  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs,
1027  map.getContext());
1028 }
1029 
1030 /// Simplify the map while exploiting information on the values in `operands`.
1031 // Use "unused attribute" marker to silence warning stemming from the inability
1032 // to see through the template expansion.
1033 static void LLVM_ATTRIBUTE_UNUSED
1035  assert(map.getNumInputs() == operands.size() && "invalid operands for map");
1036  SmallVector<AffineExpr> newResults;
1037  newResults.reserve(map.getNumResults());
1038  for (AffineExpr expr : map.getResults()) {
1040  operands);
1041  newResults.push_back(expr);
1042  }
1043  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
1044  map.getContext());
1045 }
1046 
1047 /// Assuming `dimOrSym` is a quantity in the apply op map `map` and defined by
1048 /// `minOp = affine_min(x_1, ..., x_n)`. This function checks that:
1049 /// `0 < affine_min(x_1, ..., x_n)` and proceeds with replacing the patterns:
1050 /// ```
1051 /// dimOrSym.ceildiv(x_k)
1052 /// (dimOrSym + x_k - 1).floordiv(x_k)
1053 /// ```
1054 /// by `1` for all `k` in `1, ..., n`. This is possible because `x / x_k <= 1`.
1055 ///
1056 ///
1057 /// Warning: ValueBoundsConstraintSet::computeConstantBound is needed to check
1058 /// `minOp` is positive.
1059 static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
1060  AffineExpr dimOrSym,
1061  AffineMap *map,
1062  ValueRange dims,
1063  ValueRange syms) {
1064  LDBG() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`";
1065  AffineMap affineMinMap = minOp.getAffineMap();
1066 
1067  // Check the value is positive.
1068  for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
1069  // Compare each expression in the minimum against 0.
1071  getAsIndexOpFoldResult(minOp.getContext(), 0),
1072  ValueBoundsConstraintSet::ComparisonOperator::LT,
1074  minOp.getOperands())))
1075  return failure();
1076  }
1077 
1078  /// Convert affine symbols and dimensions in minOp to symbols or dimensions in
1079  /// the apply op affine map.
1080  DenseMap<AffineExpr, AffineExpr> dimSymConversionTable;
1081  SmallVector<unsigned> unmappedDims, unmappedSyms;
1082  for (auto [i, dim] : llvm::enumerate(minOp.getDimOperands())) {
1083  auto it = llvm::find(dims, dim);
1084  if (it == dims.end()) {
1085  unmappedDims.push_back(i);
1086  continue;
1087  }
1088  dimSymConversionTable[getAffineDimExpr(i, minOp.getContext())] =
1089  getAffineDimExpr(it.getIndex(), minOp.getContext());
1090  }
1091  for (auto [i, sym] : llvm::enumerate(minOp.getSymbolOperands())) {
1092  auto it = llvm::find(syms, sym);
1093  if (it == syms.end()) {
1094  unmappedSyms.push_back(i);
1095  continue;
1096  }
1097  dimSymConversionTable[getAffineSymbolExpr(i, minOp.getContext())] =
1098  getAffineSymbolExpr(it.getIndex(), minOp.getContext());
1099  }
1100 
1101  // Create the replacement map.
1103  AffineExpr c1 = getAffineConstantExpr(1, minOp.getContext());
1104  for (AffineExpr expr : affineMinMap.getResults()) {
1105  // If we cannot express the result in terms of the apply map symbols and
1106  // sims then continue.
1107  if (llvm::any_of(unmappedDims,
1108  [&](unsigned i) { return expr.isFunctionOfDim(i); }) ||
1109  llvm::any_of(unmappedSyms,
1110  [&](unsigned i) { return expr.isFunctionOfSymbol(i); }))
1111  continue;
1112 
1113  AffineExpr convertedExpr = expr.replace(dimSymConversionTable);
1114 
1115  // dimOrSym.ceilDiv(expr) -> 1
1116  repl[dimOrSym.ceilDiv(convertedExpr)] = c1;
1117  // (dimOrSym + expr - 1).floorDiv(expr) -> 1
1118  repl[(dimOrSym + convertedExpr - 1).floorDiv(convertedExpr)] = c1;
1119  }
1120  AffineMap initialMap = *map;
1121  *map = initialMap.replace(repl, initialMap.getNumDims(),
1122  initialMap.getNumSymbols());
1123  return success(*map != initialMap);
1124 }
1125 
1126 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
1127 /// defining AffineApplyOp expression and operands.
1128 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
1129 /// When `dimOrSymbolPosition >= dims.size()`,
1130 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
1131 /// Mutate `map`,`dims` and `syms` in place as follows:
1132 /// 1. `dims` and `syms` are only appended to.
1133 /// 2. `map` dim and symbols are gradually shifted to higher positions.
1134 /// 3. Old `dim` and `sym` entries are replaced by nullptr
1135 /// This avoids the need for any bookkeeping.
1136 /// If `replaceAffineMin` is set to true, additionally triggers more expensive
1137 /// replacements involving affine_min operations.
1138 static LogicalResult replaceDimOrSym(AffineMap *map,
1139  unsigned dimOrSymbolPosition,
1140  SmallVectorImpl<Value> &dims,
1141  SmallVectorImpl<Value> &syms,
1142  bool replaceAffineMin) {
1143  MLIRContext *ctx = map->getContext();
1144  bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1145  unsigned pos = isDimReplacement ? dimOrSymbolPosition
1146  : dimOrSymbolPosition - dims.size();
1147  Value &v = isDimReplacement ? dims[pos] : syms[pos];
1148  if (!v)
1149  return failure();
1150 
1151  if (auto minOp = v.getDefiningOp<AffineMinOp>(); minOp && replaceAffineMin) {
1152  AffineExpr dimOrSym = isDimReplacement ? getAffineDimExpr(pos, ctx)
1153  : getAffineSymbolExpr(pos, ctx);
1154  return replaceAffineMinBoundingBoxExpression(minOp, dimOrSym, map, dims,
1155  syms);
1156  }
1157 
1158  auto affineApply = v.getDefiningOp<AffineApplyOp>();
1159  if (!affineApply)
1160  return failure();
1161 
1162  // At this point we will perform a replacement of `v`, set the entry in `dim`
1163  // or `sym` to nullptr immediately.
1164  v = nullptr;
1165 
1166  // Compute the map, dims and symbols coming from the AffineApplyOp.
1167  AffineMap composeMap = affineApply.getAffineMap();
1168  assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1169  SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1170  affineApply.getMapOperands().end());
1171  // Canonicalize the map to promote dims to symbols when possible. This is to
1172  // avoid generating invalid maps.
1173  canonicalizeMapAndOperands(&composeMap, &composeOperands);
1174  AffineExpr replacementExpr =
1175  composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
1176  ValueRange composeDims =
1177  ArrayRef<Value>(composeOperands).take_front(composeMap.getNumDims());
1178  ValueRange composeSyms =
1179  ArrayRef<Value>(composeOperands).take_back(composeMap.getNumSymbols());
1180  AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
1181  : getAffineSymbolExpr(pos, ctx);
1182 
1183  // Append the dims and symbols where relevant and perform the replacement.
1184  dims.append(composeDims.begin(), composeDims.end());
1185  syms.append(composeSyms.begin(), composeSyms.end());
1186  *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());
1187 
1188  return success();
1189 }
1190 
1191 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1192 /// iteratively. Perform canonicalization of map and operands as well as
1193 /// AffineMap simplification. `map` and `operands` are mutated in place.
1195  SmallVectorImpl<Value> *operands,
1196  bool composeAffineMin = false) {
1197  if (map->getNumResults() == 0) {
1198  canonicalizeMapAndOperands(map, operands);
1199  *map = simplifyAffineMap(*map);
1200  return;
1201  }
1202 
1203  MLIRContext *ctx = map->getContext();
1204  SmallVector<Value, 4> dims(operands->begin(),
1205  operands->begin() + map->getNumDims());
1206  SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1207  operands->end());
1208 
1209  // Iterate over dims and symbols coming from AffineApplyOp and replace until
1210  // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1211  // and `syms` can only increase by construction.
1212  // The implementation uses a `while` loop to support the case of symbols
1213  // that may be constructed from dims ;this may be overkill.
1214  while (true) {
1215  bool changed = false;
1216  for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1217  if ((changed |=
1218  succeeded(replaceDimOrSym(map, pos, dims, syms, composeAffineMin))))
1219  break;
1220  if (!changed)
1221  break;
1222  }
1223 
1224  // Clear operands so we can fill them anew.
1225  operands->clear();
1226 
1227  // At this point we may have introduced null operands, prune them out before
1228  // canonicalizing map and operands.
1229  unsigned nDims = 0, nSyms = 0;
1230  SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1231  dimReplacements.reserve(dims.size());
1232  symReplacements.reserve(syms.size());
1233  for (auto *container : {&dims, &syms}) {
1234  bool isDim = (container == &dims);
1235  auto &repls = isDim ? dimReplacements : symReplacements;
1236  for (const auto &en : llvm::enumerate(*container)) {
1237  Value v = en.value();
1238  if (!v) {
1239  assert(isDim ? !map->isFunctionOfDim(en.index())
1240  : !map->isFunctionOfSymbol(en.index()) &&
1241  "map is function of unexpected expr@pos");
1242  repls.push_back(getAffineConstantExpr(0, ctx));
1243  continue;
1244  }
1245  repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
1246  : getAffineSymbolExpr(nSyms++, ctx));
1247  operands->push_back(v);
1248  }
1249  }
1250  *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
1251  nSyms);
1252 
1253  // Canonicalize and simplify before returning.
1254  canonicalizeMapAndOperands(map, operands);
1255  *map = simplifyAffineMap(*map);
1256 }
1257 
1259  AffineMap *map, SmallVectorImpl<Value> *operands, bool composeAffineMin) {
1260  while (llvm::any_of(*operands, [](Value v) {
1261  return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
1262  })) {
1263  composeAffineMapAndOperands(map, operands, composeAffineMin);
1264  }
1265  // Additional trailing step for AffineMinOps in case no chains of AffineApply.
1266  if (composeAffineMin && llvm::any_of(*operands, [](Value v) {
1267  return isa_and_nonnull<AffineMinOp>(v.getDefiningOp());
1268  })) {
1269  composeAffineMapAndOperands(map, operands, composeAffineMin);
1270  }
1271 }
1272 
1273 AffineApplyOp
1275  ArrayRef<OpFoldResult> operands,
1276  bool composeAffineMin) {
1277  SmallVector<Value> valueOperands;
1278  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1279  composeAffineMapAndOperands(&map, &valueOperands, composeAffineMin);
1280  assert(map);
1281  return AffineApplyOp::create(b, loc, map, valueOperands);
1282 }
1283 
1284 AffineApplyOp
1286  ArrayRef<OpFoldResult> operands,
1287  bool composeAffineMin) {
1288  return makeComposedAffineApply(
1289  b, loc,
1291  .front(),
1292  operands, composeAffineMin);
1293 }
1294 
1295 /// Composes the given affine map with the given list of operands, pulling in
1296 /// the maps from any affine.apply operations that supply the operands.
1298  SmallVectorImpl<Value> &operands,
1299  bool composeAffineMin = false) {
1300  // Compose and canonicalize each expression in the map individually because
1301  // composition only applies to single-result maps, collecting potentially
1302  // duplicate operands in a single list with shifted dimensions and symbols.
1303  SmallVector<Value> dims, symbols;
1305  for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
1306  SmallVector<Value> submapOperands(operands.begin(), operands.end());
1307  AffineMap submap = map.getSubMap({i});
1308  fullyComposeAffineMapAndOperands(&submap, &submapOperands,
1309  composeAffineMin);
1310  canonicalizeMapAndOperands(&submap, &submapOperands);
1311  unsigned numNewDims = submap.getNumDims();
1312  submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
1313  llvm::append_range(dims,
1314  ArrayRef<Value>(submapOperands).take_front(numNewDims));
1315  llvm::append_range(symbols,
1316  ArrayRef<Value>(submapOperands).drop_front(numNewDims));
1317  exprs.push_back(submap.getResult(0));
1318  }
1319 
1320  // Canonicalize the map created from composed expressions to deduplicate the
1321  // dimension and symbol operands.
1322  operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1323  map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
1324  canonicalizeMapAndOperands(&map, &operands);
1325 }
1326 
1328  OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands,
1329  bool composeAffineMin) {
1330  assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1331 
1332  // Create new builder without a listener, so that no notification is
1333  // triggered if the op is folded.
1334  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1335  // workaround is no longer needed.
1336  OpBuilder newBuilder(b.getContext());
1338 
1339  // Create op.
1340  AffineApplyOp applyOp =
1341  makeComposedAffineApply(newBuilder, loc, map, operands, composeAffineMin);
1342 
1343  // Get constant operands.
1344  SmallVector<Attribute> constOperands(applyOp->getNumOperands());
1345  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1346  matchPattern(applyOp->getOperand(i), m_Constant(&constOperands[i]));
1347 
1348  // Try to fold the operation.
1349  SmallVector<OpFoldResult> foldResults;
1350  if (failed(applyOp->fold(constOperands, foldResults)) ||
1351  foldResults.empty()) {
1352  if (OpBuilder::Listener *listener = b.getListener())
1353  listener->notifyOperationInserted(applyOp, /*previous=*/{});
1354  return applyOp.getResult();
1355  }
1356 
1357  applyOp->erase();
1358  return llvm::getSingleElement(foldResults);
1359 }
1360 
1362  OpBuilder &b, Location loc, AffineExpr expr,
1363  ArrayRef<OpFoldResult> operands, bool composeAffineMin) {
1365  b, loc,
1367  .front(),
1368  operands, composeAffineMin);
1369 }
1370 
1373  OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands,
1374  bool composeAffineMin) {
1375  return llvm::map_to_vector(
1376  llvm::seq<unsigned>(0, map.getNumResults()), [&](unsigned i) {
1377  return makeComposedFoldedAffineApply(b, loc, map.getSubMap({i}),
1378  operands, composeAffineMin);
1379  });
1380 }
1381 
1382 template <typename OpTy>
1384  ArrayRef<OpFoldResult> operands) {
1385  SmallVector<Value> valueOperands;
1386  map = foldAttributesIntoMap(b, map, operands, valueOperands);
1387  composeMultiResultAffineMap(map, valueOperands);
1388  return OpTy::create(b, loc, b.getIndexType(), map, valueOperands);
1389 }
1390 
1391 AffineMinOp
1393  ArrayRef<OpFoldResult> operands) {
1394  return makeComposedMinMax<AffineMinOp>(b, loc, map, operands);
1395 }
1396 
1397 template <typename OpTy>
1399  AffineMap map,
1400  ArrayRef<OpFoldResult> operands) {
1401  // Create new builder without a listener, so that no notification is
1402  // triggered if the op is folded.
1403  // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1404  // workaround is no longer needed.
1405  OpBuilder newBuilder(b.getContext());
1407 
1408  // Create op.
1409  auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1410 
1411  // Get constant operands.
1412  SmallVector<Attribute> constOperands(minMaxOp->getNumOperands());
1413  for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1414  matchPattern(minMaxOp->getOperand(i), m_Constant(&constOperands[i]));
1415 
1416  // Try to fold the operation.
1417  SmallVector<OpFoldResult> foldResults;
1418  if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1419  foldResults.empty()) {
1420  if (OpBuilder::Listener *listener = b.getListener())
1421  listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
1422  return minMaxOp.getResult();
1423  }
1424 
1425  minMaxOp->erase();
1426  return llvm::getSingleElement(foldResults);
1427 }
1428 
1431  AffineMap map,
1432  ArrayRef<OpFoldResult> operands) {
1433  return makeComposedFoldedMinMax<AffineMinOp>(b, loc, map, operands);
1434 }
1435 
1438  AffineMap map,
1439  ArrayRef<OpFoldResult> operands) {
1440  return makeComposedFoldedMinMax<AffineMaxOp>(b, loc, map, operands);
1441 }
1442 
1443 // A symbol may appear as a dim in affine.apply operations. This function
1444 // canonicalizes dims that are valid symbols into actual symbols.
1445 template <class MapOrSet>
1446 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1447  SmallVectorImpl<Value> *operands) {
1448  if (!mapOrSet || operands->empty())
1449  return;
1450 
1451  assert(mapOrSet->getNumInputs() == operands->size() &&
1452  "map/set inputs must match number of operands");
1453 
1454  auto *context = mapOrSet->getContext();
1455  SmallVector<Value, 8> resultOperands;
1456  resultOperands.reserve(operands->size());
1457  SmallVector<Value, 8> remappedSymbols;
1458  remappedSymbols.reserve(operands->size());
1459  unsigned nextDim = 0;
1460  unsigned nextSym = 0;
1461  unsigned oldNumSyms = mapOrSet->getNumSymbols();
1462  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1463  for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1464  if (i < mapOrSet->getNumDims()) {
1465  if (isValidSymbol((*operands)[i])) {
1466  // This is a valid symbol that appears as a dim, canonicalize it.
1467  dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1468  remappedSymbols.push_back((*operands)[i]);
1469  } else {
1470  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1471  resultOperands.push_back((*operands)[i]);
1472  }
1473  } else {
1474  resultOperands.push_back((*operands)[i]);
1475  }
1476  }
1477 
1478  resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1479  *operands = resultOperands;
1480  *mapOrSet = mapOrSet->replaceDimsAndSymbols(
1481  dimRemapping, /*symReplacements=*/{}, nextDim, oldNumSyms + nextSym);
1482 
1483  assert(mapOrSet->getNumInputs() == operands->size() &&
1484  "map/set inputs must match number of operands");
1485 }
1486 
1487 /// A valid affine dimension may appear as a symbol in affine.apply operations.
1488 /// Given an application of `operands` to an affine map or integer set
1489 /// `mapOrSet`, this function canonicalizes symbols of `mapOrSet` that are valid
1490 /// dims, but not valid symbols into actual dims. Without such a legalization,
1491 /// the affine.apply will be invalid. This method is the exact inverse of
1492 /// canonicalizePromotedSymbols.
1493 template <class MapOrSet>
1494 static void legalizeDemotedDims(MapOrSet &mapOrSet,
1495  SmallVectorImpl<Value> &operands) {
1496  if (!mapOrSet || operands.empty())
1497  return;
1498 
1499  unsigned numOperands = operands.size();
1500 
1501  assert(mapOrSet.getNumInputs() == numOperands &&
1502  "map/set inputs must match number of operands");
1503 
1504  auto *context = mapOrSet.getContext();
1505  SmallVector<Value, 8> resultOperands;
1506  resultOperands.reserve(numOperands);
1507  SmallVector<Value, 8> remappedDims;
1508  remappedDims.reserve(numOperands);
1509  SmallVector<Value, 8> symOperands;
1510  symOperands.reserve(mapOrSet.getNumSymbols());
1511  unsigned nextSym = 0;
1512  unsigned nextDim = 0;
1513  unsigned oldNumDims = mapOrSet.getNumDims();
1514  SmallVector<AffineExpr, 8> symRemapping(mapOrSet.getNumSymbols());
1515  resultOperands.assign(operands.begin(), operands.begin() + oldNumDims);
1516  for (unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
1517  if (operands[i] && isValidDim(operands[i]) && !isValidSymbol(operands[i])) {
1518  // This is a valid dim that appears as a symbol, legalize it.
1519  symRemapping[i - oldNumDims] =
1520  getAffineDimExpr(oldNumDims + nextDim++, context);
1521  remappedDims.push_back(operands[i]);
1522  } else {
1523  symRemapping[i - oldNumDims] = getAffineSymbolExpr(nextSym++, context);
1524  symOperands.push_back(operands[i]);
1525  }
1526  }
1527 
1528  append_range(resultOperands, remappedDims);
1529  append_range(resultOperands, symOperands);
1530  operands = resultOperands;
1531  mapOrSet = mapOrSet.replaceDimsAndSymbols(
1532  /*dimReplacements=*/{}, symRemapping, oldNumDims + nextDim, nextSym);
1533 
1534  assert(mapOrSet.getNumInputs() == operands.size() &&
1535  "map/set inputs must match number of operands");
1536 }
1537 
1538 // Works for either an affine map or an integer set.
1539 template <class MapOrSet>
1540 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1541  SmallVectorImpl<Value> *operands) {
1542  static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1543  "Argument must be either of AffineMap or IntegerSet type");
1544 
1545  if (!mapOrSet || operands->empty())
1546  return;
1547 
1548  assert(mapOrSet->getNumInputs() == operands->size() &&
1549  "map/set inputs must match number of operands");
1550 
1551  canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1552  legalizeDemotedDims<MapOrSet>(*mapOrSet, *operands);
1553 
1554  // Check to see what dims are used.
1555  llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1556  llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1557  mapOrSet->walkExprs([&](AffineExpr expr) {
1558  if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1559  usedDims[dimExpr.getPosition()] = true;
1560  else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1561  usedSyms[symExpr.getPosition()] = true;
1562  });
1563 
1564  auto *context = mapOrSet->getContext();
1565 
1566  SmallVector<Value, 8> resultOperands;
1567  resultOperands.reserve(operands->size());
1568 
1569  llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1570  SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1571  unsigned nextDim = 0;
1572  for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1573  if (usedDims[i]) {
1574  // Remap dim positions for duplicate operands.
1575  auto it = seenDims.find((*operands)[i]);
1576  if (it == seenDims.end()) {
1577  dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1578  resultOperands.push_back((*operands)[i]);
1579  seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1580  } else {
1581  dimRemapping[i] = it->second;
1582  }
1583  }
1584  }
1585  llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1586  SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1587  unsigned nextSym = 0;
1588  for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1589  if (!usedSyms[i])
1590  continue;
1591  // Handle constant operands (only needed for symbolic operands since
1592  // constant operands in dimensional positions would have already been
1593  // promoted to symbolic positions above).
1594  IntegerAttr operandCst;
1595  if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1596  m_Constant(&operandCst))) {
1597  symRemapping[i] =
1598  getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1599  continue;
1600  }
1601  // Remap symbol positions for duplicate operands.
1602  auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1603  if (it == seenSymbols.end()) {
1604  symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1605  resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1606  seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1607  symRemapping[i]));
1608  } else {
1609  symRemapping[i] = it->second;
1610  }
1611  }
1612  *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1613  nextDim, nextSym);
1614  *operands = resultOperands;
1615 }
1616 
1618  AffineMap *map, SmallVectorImpl<Value> *operands) {
1619  canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
1620 }
1621 
1623  IntegerSet *set, SmallVectorImpl<Value> *operands) {
1624  canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
1625 }
1626 
1627 namespace {
1628 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1629 /// maps that supply results into them.
1630 ///
1631 template <typename AffineOpTy>
1632 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1634 
1635  /// Replace the affine op with another instance of it with the supplied
1636  /// map and mapOperands.
1637  void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1638  AffineMap map, ArrayRef<Value> mapOperands) const;
1639 
1640  LogicalResult matchAndRewrite(AffineOpTy affineOp,
1641  PatternRewriter &rewriter) const override {
1642  static_assert(
1643  llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1644  AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1645  AffineVectorStoreOp, AffineVectorLoadOp>::value,
1646  "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1647  "expected");
1648  auto map = affineOp.getAffineMap();
1649  AffineMap oldMap = map;
1650  auto oldOperands = affineOp.getMapOperands();
1651  SmallVector<Value, 8> resultOperands(oldOperands);
1652  composeAffineMapAndOperands(&map, &resultOperands);
1653  canonicalizeMapAndOperands(&map, &resultOperands);
1654  simplifyMapWithOperands(map, resultOperands);
1655  if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1656  resultOperands.begin()))
1657  return failure();
1658 
1659  replaceAffineOp(rewriter, affineOp, map, resultOperands);
1660  return success();
1661  }
1662 };
1663 
1664 // Specialize the template to account for the different build signatures for
1665 // affine load, store, and apply ops.
1666 template <>
1667 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1668  PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1669  ArrayRef<Value> mapOperands) const {
1670  rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1671  mapOperands);
1672 }
1673 template <>
1674 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1675  PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1676  ArrayRef<Value> mapOperands) const {
1677  rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1678  prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1679  prefetch.getLocalityHint(), prefetch.getIsDataCache());
1680 }
1681 template <>
1682 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1683  PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1684  ArrayRef<Value> mapOperands) const {
1685  rewriter.replaceOpWithNewOp<AffineStoreOp>(
1686  store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1687 }
1688 template <>
1689 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1690  PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1691  ArrayRef<Value> mapOperands) const {
1692  rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1693  vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1694  mapOperands);
1695 }
1696 template <>
1697 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1698  PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1699  ArrayRef<Value> mapOperands) const {
1700  rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1701  vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1702  mapOperands);
1703 }
1704 
1705 // Generic version for ops that don't have extra operands.
1706 template <typename AffineOpTy>
1707 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1708  PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1709  ArrayRef<Value> mapOperands) const {
1710  rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1711 }
1712 } // namespace
1713 
1714 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1715  MLIRContext *context) {
1716  results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1717 }
1718 
1719 //===----------------------------------------------------------------------===//
1720 // AffineDmaStartOp
1721 //===----------------------------------------------------------------------===//
1722 
1723 // TODO: Check that map operands are loop IVs or symbols.
1724 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1725  Value srcMemRef, AffineMap srcMap,
1726  ValueRange srcIndices, Value destMemRef,
1727  AffineMap dstMap, ValueRange destIndices,
1728  Value tagMemRef, AffineMap tagMap,
1729  ValueRange tagIndices, Value numElements,
1730  Value stride, Value elementsPerStride) {
1731  result.addOperands(srcMemRef);
1732  result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1733  result.addOperands(srcIndices);
1734  result.addOperands(destMemRef);
1735  result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1736  result.addOperands(destIndices);
1737  result.addOperands(tagMemRef);
1738  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1739  result.addOperands(tagIndices);
1740  result.addOperands(numElements);
1741  if (stride) {
1742  result.addOperands({stride, elementsPerStride});
1743  }
1744 }
1745 
1746 AffineDmaStartOp AffineDmaStartOp::create(
1747  OpBuilder &builder, Location location, Value srcMemRef, AffineMap srcMap,
1748  ValueRange srcIndices, Value destMemRef, AffineMap dstMap,
1749  ValueRange destIndices, Value tagMemRef, AffineMap tagMap,
1750  ValueRange tagIndices, Value numElements, Value stride,
1751  Value elementsPerStride) {
1752  mlir::OperationState state(location, getOperationName());
1753  build(builder, state, srcMemRef, srcMap, srcIndices, destMemRef, dstMap,
1754  destIndices, tagMemRef, tagMap, tagIndices, numElements, stride,
1755  elementsPerStride);
1756  auto result = dyn_cast<AffineDmaStartOp>(builder.create(state));
1757  assert(result && "builder didn't return the right type");
1758  return result;
1759 }
1760 
1761 AffineDmaStartOp AffineDmaStartOp::create(
1762  ImplicitLocOpBuilder &builder, Value srcMemRef, AffineMap srcMap,
1763  ValueRange srcIndices, Value destMemRef, AffineMap dstMap,
1764  ValueRange destIndices, Value tagMemRef, AffineMap tagMap,
1765  ValueRange tagIndices, Value numElements, Value stride,
1766  Value elementsPerStride) {
1767  return create(builder, builder.getLoc(), srcMemRef, srcMap, srcIndices,
1768  destMemRef, dstMap, destIndices, tagMemRef, tagMap, tagIndices,
1769  numElements, stride, elementsPerStride);
1770 }
1771 
1773  p << " " << getSrcMemRef() << '[';
1774  p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1775  p << "], " << getDstMemRef() << '[';
1776  p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1777  p << "], " << getTagMemRef() << '[';
1778  p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1779  p << "], " << getNumElements();
1780  if (isStrided()) {
1781  p << ", " << getStride();
1782  p << ", " << getNumElementsPerStride();
1783  }
1784  p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1785  << getTagMemRefType();
1786 }
1787 
1788 // Parse AffineDmaStartOp.
1789 // Ex:
1790 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1791 // %stride, %num_elt_per_stride
1792 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1793 //
1795  OperationState &result) {
1796  OpAsmParser::UnresolvedOperand srcMemRefInfo;
1797  AffineMapAttr srcMapAttr;
1799  OpAsmParser::UnresolvedOperand dstMemRefInfo;
1800  AffineMapAttr dstMapAttr;
1802  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1803  AffineMapAttr tagMapAttr;
1805  OpAsmParser::UnresolvedOperand numElementsInfo;
1807 
1808  SmallVector<Type, 3> types;
1809  auto indexType = parser.getBuilder().getIndexType();
1810 
1811  // Parse and resolve the following list of operands:
1812  // *) dst memref followed by its affine maps operands (in square brackets).
1813  // *) src memref followed by its affine map operands (in square brackets).
1814  // *) tag memref followed by its affine map operands (in square brackets).
1815  // *) number of elements transferred by DMA operation.
1816  if (parser.parseOperand(srcMemRefInfo) ||
1817  parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1818  getSrcMapAttrStrName(),
1819  result.attributes) ||
1820  parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1821  parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1822  getDstMapAttrStrName(),
1823  result.attributes) ||
1824  parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1825  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1826  getTagMapAttrStrName(),
1827  result.attributes) ||
1828  parser.parseComma() || parser.parseOperand(numElementsInfo))
1829  return failure();
1830 
1831  // Parse optional stride and elements per stride.
1832  if (parser.parseTrailingOperandList(strideInfo))
1833  return failure();
1834 
1835  if (!strideInfo.empty() && strideInfo.size() != 2) {
1836  return parser.emitError(parser.getNameLoc(),
1837  "expected two stride related operands");
1838  }
1839  bool isStrided = strideInfo.size() == 2;
1840 
1841  if (parser.parseColonTypeList(types))
1842  return failure();
1843 
1844  if (types.size() != 3)
1845  return parser.emitError(parser.getNameLoc(), "expected three types");
1846 
1847  if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1848  parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1849  parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1850  parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1851  parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1852  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1853  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1854  return failure();
1855 
1856  if (isStrided) {
1857  if (parser.resolveOperands(strideInfo, indexType, result.operands))
1858  return failure();
1859  }
1860 
1861  // Check that src/dst/tag operand counts match their map.numInputs.
1862  if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1863  dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1864  tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1865  return parser.emitError(parser.getNameLoc(),
1866  "memref operand count not equal to map.numInputs");
1867  return success();
1868 }
1869 
1870 LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
1871  if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
1872  return emitOpError("expected DMA source to be of memref type");
1873  if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
1874  return emitOpError("expected DMA destination to be of memref type");
1875  if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
1876  return emitOpError("expected DMA tag to be of memref type");
1877 
1878  unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1879  getDstMap().getNumInputs() +
1880  getTagMap().getNumInputs();
1881  if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1882  getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1883  return emitOpError("incorrect number of operands");
1884  }
1885 
1886  Region *scope = getAffineScope(*this);
1887  for (auto idx : getSrcIndices()) {
1888  if (!idx.getType().isIndex())
1889  return emitOpError("src index to dma_start must have 'index' type");
1890  if (!isValidAffineIndexOperand(idx, scope))
1891  return emitOpError(
1892  "src index must be a valid dimension or symbol identifier");
1893  }
1894  for (auto idx : getDstIndices()) {
1895  if (!idx.getType().isIndex())
1896  return emitOpError("dst index to dma_start must have 'index' type");
1897  if (!isValidAffineIndexOperand(idx, scope))
1898  return emitOpError(
1899  "dst index must be a valid dimension or symbol identifier");
1900  }
1901  for (auto idx : getTagIndices()) {
1902  if (!idx.getType().isIndex())
1903  return emitOpError("tag index to dma_start must have 'index' type");
1904  if (!isValidAffineIndexOperand(idx, scope))
1905  return emitOpError(
1906  "tag index must be a valid dimension or symbol identifier");
1907  }
1908  return success();
1909 }
1910 
1911 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1912  SmallVectorImpl<OpFoldResult> &results) {
1913  /// dma_start(memrefcast) -> dma_start
1914  return memref::foldMemRefCast(*this);
1915 }
1916 
1917 void AffineDmaStartOp::getEffects(
1919  &effects) {
1920  effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
1922  effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
1924  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
1926 }
1927 
1928 //===----------------------------------------------------------------------===//
1929 // AffineDmaWaitOp
1930 //===----------------------------------------------------------------------===//
1931 
1932 // TODO: Check that map operands are loop IVs or symbols.
1933 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1934  Value tagMemRef, AffineMap tagMap,
1935  ValueRange tagIndices, Value numElements) {
1936  result.addOperands(tagMemRef);
1937  result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1938  result.addOperands(tagIndices);
1939  result.addOperands(numElements);
1940 }
1941 
1942 AffineDmaWaitOp AffineDmaWaitOp::create(OpBuilder &builder, Location location,
1943  Value tagMemRef, AffineMap tagMap,
1944  ValueRange tagIndices,
1945  Value numElements) {
1946  mlir::OperationState state(location, getOperationName());
1947  build(builder, state, tagMemRef, tagMap, tagIndices, numElements);
1948  auto result = dyn_cast<AffineDmaWaitOp>(builder.create(state));
1949  assert(result && "builder didn't return the right type");
1950  return result;
1951 }
1952 
1953 AffineDmaWaitOp AffineDmaWaitOp::create(ImplicitLocOpBuilder &builder,
1954  Value tagMemRef, AffineMap tagMap,
1955  ValueRange tagIndices,
1956  Value numElements) {
1957  return create(builder, builder.getLoc(), tagMemRef, tagMap, tagIndices,
1958  numElements);
1959 }
1960 
1962  p << " " << getTagMemRef() << '[';
1963  SmallVector<Value, 2> operands(getTagIndices());
1964  p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1965  p << "], ";
1967  p << " : " << getTagMemRef().getType();
1968 }
1969 
1970 // Parse AffineDmaWaitOp.
1971 // Eg:
1972 // affine.dma_wait %tag[%index], %num_elements
1973 // : memref<1 x i32, (d0) -> (d0), 4>
1974 //
1976  OperationState &result) {
1977  OpAsmParser::UnresolvedOperand tagMemRefInfo;
1978  AffineMapAttr tagMapAttr;
1980  Type type;
1981  auto indexType = parser.getBuilder().getIndexType();
1982  OpAsmParser::UnresolvedOperand numElementsInfo;
1983 
1984  // Parse tag memref, its map operands, and dma size.
1985  if (parser.parseOperand(tagMemRefInfo) ||
1986  parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1987  getTagMapAttrStrName(),
1988  result.attributes) ||
1989  parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1990  parser.parseColonType(type) ||
1991  parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1992  parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1993  parser.resolveOperand(numElementsInfo, indexType, result.operands))
1994  return failure();
1995 
1996  if (!llvm::isa<MemRefType>(type))
1997  return parser.emitError(parser.getNameLoc(),
1998  "expected tag to be of memref type");
1999 
2000  if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
2001  return parser.emitError(parser.getNameLoc(),
2002  "tag memref operand count != to map.numInputs");
2003  return success();
2004 }
2005 
2006 LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
2007  if (!llvm::isa<MemRefType>(getOperand(0).getType()))
2008  return emitOpError("expected DMA tag to be of memref type");
2009  Region *scope = getAffineScope(*this);
2010  for (auto idx : getTagIndices()) {
2011  if (!idx.getType().isIndex())
2012  return emitOpError("index to dma_wait must have 'index' type");
2013  if (!isValidAffineIndexOperand(idx, scope))
2014  return emitOpError(
2015  "index must be a valid dimension or symbol identifier");
2016  }
2017  return success();
2018 }
2019 
2020 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
2021  SmallVectorImpl<OpFoldResult> &results) {
2022  /// dma_wait(memrefcast) -> dma_wait
2023  return memref::foldMemRefCast(*this);
2024 }
2025 
2026 void AffineDmaWaitOp::getEffects(
2028  &effects) {
2029  effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
2031 }
2032 
2033 //===----------------------------------------------------------------------===//
2034 // AffineForOp
2035 //===----------------------------------------------------------------------===//
2036 
2037 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
2038 /// bodyBuilder are empty/null, we include default terminator op.
2039 void AffineForOp::build(OpBuilder &builder, OperationState &result,
2040  ValueRange lbOperands, AffineMap lbMap,
2041  ValueRange ubOperands, AffineMap ubMap, int64_t step,
2042  ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
2043  assert(((!lbMap && lbOperands.empty()) ||
2044  lbOperands.size() == lbMap.getNumInputs()) &&
2045  "lower bound operand count does not match the affine map");
2046  assert(((!ubMap && ubOperands.empty()) ||
2047  ubOperands.size() == ubMap.getNumInputs()) &&
2048  "upper bound operand count does not match the affine map");
2049  assert(step > 0 && "step has to be a positive integer constant");
2050 
2051  OpBuilder::InsertionGuard guard(builder);
2052 
2053  // Set variadic segment sizes.
2054  result.addAttribute(
2055  getOperandSegmentSizeAttr(),
2056  builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()),
2057  static_cast<int32_t>(ubOperands.size()),
2058  static_cast<int32_t>(iterArgs.size())}));
2059 
2060  for (Value val : iterArgs)
2061  result.addTypes(val.getType());
2062 
2063  // Add an attribute for the step.
2064  result.addAttribute(getStepAttrName(result.name),
2065  builder.getIntegerAttr(builder.getIndexType(), step));
2066 
2067  // Add the lower bound.
2068  result.addAttribute(getLowerBoundMapAttrName(result.name),
2069  AffineMapAttr::get(lbMap));
2070  result.addOperands(lbOperands);
2071 
2072  // Add the upper bound.
2073  result.addAttribute(getUpperBoundMapAttrName(result.name),
2074  AffineMapAttr::get(ubMap));
2075  result.addOperands(ubOperands);
2076 
2077  result.addOperands(iterArgs);
2078  // Create a region and a block for the body. The argument of the region is
2079  // the loop induction variable.
2080  Region *bodyRegion = result.addRegion();
2081  Block *bodyBlock = builder.createBlock(bodyRegion);
2082  Value inductionVar =
2083  bodyBlock->addArgument(builder.getIndexType(), result.location);
2084  for (Value val : iterArgs)
2085  bodyBlock->addArgument(val.getType(), val.getLoc());
2086 
2087  // Create the default terminator if the builder is not provided and if the
2088  // iteration arguments are not provided. Otherwise, leave this to the caller
2089  // because we don't know which values to return from the loop.
2090  if (iterArgs.empty() && !bodyBuilder) {
2091  ensureTerminator(*bodyRegion, builder, result.location);
2092  } else if (bodyBuilder) {
2093  OpBuilder::InsertionGuard guard(builder);
2094  builder.setInsertionPointToStart(bodyBlock);
2095  bodyBuilder(builder, result.location, inductionVar,
2096  bodyBlock->getArguments().drop_front());
2097  }
2098 }
2099 
2100 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
2101  int64_t ub, int64_t step, ValueRange iterArgs,
2102  BodyBuilderFn bodyBuilder) {
2103  auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
2104  auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
2105  return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
2106  bodyBuilder);
2107 }
2108 
2109 LogicalResult AffineForOp::verifyRegions() {
2110  // Check that the body defines as single block argument for the induction
2111  // variable.
2112  auto *body = getBody();
2113  if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
2114  return emitOpError("expected body to have a single index argument for the "
2115  "induction variable");
2116 
2117  // Verify that the bound operands are valid dimension/symbols.
2118  /// Lower bound.
2119  if (getLowerBoundMap().getNumInputs() > 0)
2121  getLowerBoundMap().getNumDims())))
2122  return failure();
2123  /// Upper bound.
2124  if (getUpperBoundMap().getNumInputs() > 0)
2126  getUpperBoundMap().getNumDims())))
2127  return failure();
2128  if (getLowerBoundMap().getNumResults() < 1)
2129  return emitOpError("expected lower bound map to have at least one result");
2130  if (getUpperBoundMap().getNumResults() < 1)
2131  return emitOpError("expected upper bound map to have at least one result");
2132 
2133  unsigned opNumResults = getNumResults();
2134  if (opNumResults == 0)
2135  return success();
2136 
2137  // If ForOp defines values, check that the number and types of the defined
2138  // values match ForOp initial iter operands and backedge basic block
2139  // arguments.
2140  if (getNumIterOperands() != opNumResults)
2141  return emitOpError(
2142  "mismatch between the number of loop-carried values and results");
2143  if (getNumRegionIterArgs() != opNumResults)
2144  return emitOpError(
2145  "mismatch between the number of basic block args and results");
2146 
2147  return success();
2148 }
2149 
2150 /// Parse a for operation loop bounds.
2151 static ParseResult parseBound(bool isLower, OperationState &result,
2152  OpAsmParser &p) {
2153  // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
2154  // the map has multiple results.
2155  bool failedToParsedMinMax =
2156  failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
2157 
2158  auto &builder = p.getBuilder();
2159  auto boundAttrStrName =
2160  isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)
2161  : AffineForOp::getUpperBoundMapAttrName(result.name);
2162 
2163  // Parse ssa-id as identity map.
2165  if (p.parseOperandList(boundOpInfos))
2166  return failure();
2167 
2168  if (!boundOpInfos.empty()) {
2169  // Check that only one operand was parsed.
2170  if (boundOpInfos.size() > 1)
2171  return p.emitError(p.getNameLoc(),
2172  "expected only one loop bound operand");
2173 
2174  // TODO: improve error message when SSA value is not of index type.
2175  // Currently it is 'use of value ... expects different type than prior uses'
2176  if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
2177  result.operands))
2178  return failure();
2179 
2180  // Create an identity map using symbol id. This representation is optimized
2181  // for storage. Analysis passes may expand it into a multi-dimensional map
2182  // if desired.
2183  AffineMap map = builder.getSymbolIdentityMap();
2184  result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
2185  return success();
2186  }
2187 
2188  // Get the attribute location.
2189  SMLoc attrLoc = p.getCurrentLocation();
2190 
2191  Attribute boundAttr;
2192  if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
2193  result.attributes))
2194  return failure();
2195 
2196  // Parse full form - affine map followed by dim and symbol list.
2197  if (auto affineMapAttr = dyn_cast<AffineMapAttr>(boundAttr)) {
2198  unsigned currentNumOperands = result.operands.size();
2199  unsigned numDims;
2200  if (parseDimAndSymbolList(p, result.operands, numDims))
2201  return failure();
2202 
2203  auto map = affineMapAttr.getValue();
2204  if (map.getNumDims() != numDims)
2205  return p.emitError(
2206  p.getNameLoc(),
2207  "dim operand count and affine map dim count must match");
2208 
2209  unsigned numDimAndSymbolOperands =
2210  result.operands.size() - currentNumOperands;
2211  if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
2212  return p.emitError(
2213  p.getNameLoc(),
2214  "symbol operand count and affine map symbol count must match");
2215 
2216  // If the map has multiple results, make sure that we parsed the min/max
2217  // prefix.
2218  if (map.getNumResults() > 1 && failedToParsedMinMax) {
2219  if (isLower) {
2220  return p.emitError(attrLoc, "lower loop bound affine map with "
2221  "multiple results requires 'max' prefix");
2222  }
2223  return p.emitError(attrLoc, "upper loop bound affine map with multiple "
2224  "results requires 'min' prefix");
2225  }
2226  return success();
2227  }
2228 
2229  // Parse custom assembly form.
2230  if (auto integerAttr = dyn_cast<IntegerAttr>(boundAttr)) {
2231  result.attributes.pop_back();
2232  result.addAttribute(
2233  boundAttrStrName,
2234  AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
2235  return success();
2236  }
2237 
2238  return p.emitError(
2239  p.getNameLoc(),
2240  "expected valid affine map representation for loop bounds");
2241 }
2242 
2243 ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
2244  auto &builder = parser.getBuilder();
2245  OpAsmParser::Argument inductionVariable;
2246  inductionVariable.type = builder.getIndexType();
2247  // Parse the induction variable followed by '='.
2248  if (parser.parseArgument(inductionVariable) || parser.parseEqual())
2249  return failure();
2250 
2251  // Parse loop bounds.
2252  int64_t numOperands = result.operands.size();
2253  if (parseBound(/*isLower=*/true, result, parser))
2254  return failure();
2255  int64_t numLbOperands = result.operands.size() - numOperands;
2256  if (parser.parseKeyword("to", " between bounds"))
2257  return failure();
2258  numOperands = result.operands.size();
2259  if (parseBound(/*isLower=*/false, result, parser))
2260  return failure();
2261  int64_t numUbOperands = result.operands.size() - numOperands;
2262 
2263  // Parse the optional loop step, we default to 1 if one is not present.
2264  if (parser.parseOptionalKeyword("step")) {
2265  result.addAttribute(
2266  getStepAttrName(result.name),
2267  builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
2268  } else {
2269  SMLoc stepLoc = parser.getCurrentLocation();
2270  IntegerAttr stepAttr;
2271  if (parser.parseAttribute(stepAttr, builder.getIndexType(),
2272  getStepAttrName(result.name).data(),
2273  result.attributes))
2274  return failure();
2275 
2276  if (stepAttr.getValue().isNegative())
2277  return parser.emitError(
2278  stepLoc,
2279  "expected step to be representable as a positive signed integer");
2280  }
2281 
2282  // Parse the optional initial iteration arguments.
2285 
2286  // Induction variable.
2287  regionArgs.push_back(inductionVariable);
2288 
2289  if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
2290  // Parse assignment list and results type list.
2291  if (parser.parseAssignmentList(regionArgs, operands) ||
2292  parser.parseArrowTypeList(result.types))
2293  return failure();
2294  // Resolve input operands.
2295  for (auto argOperandType :
2296  llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2297  Type type = std::get<2>(argOperandType);
2298  std::get<0>(argOperandType).type = type;
2299  if (parser.resolveOperand(std::get<1>(argOperandType), type,
2300  result.operands))
2301  return failure();
2302  }
2303  }
2304 
2305  result.addAttribute(
2306  getOperandSegmentSizeAttr(),
2307  builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands),
2308  static_cast<int32_t>(numUbOperands),
2309  static_cast<int32_t>(operands.size())}));
2310 
2311  // Parse the body region.
2312  Region *body = result.addRegion();
2313  if (regionArgs.size() != result.types.size() + 1)
2314  return parser.emitError(
2315  parser.getNameLoc(),
2316  "mismatch between the number of loop-carried values and results");
2317  if (parser.parseRegion(*body, regionArgs))
2318  return failure();
2319 
2320  AffineForOp::ensureTerminator(*body, builder, result.location);
2321 
2322  // Parse the optional attribute list.
2323  return parser.parseOptionalAttrDict(result.attributes);
2324 }
2325 
2326 static void printBound(AffineMapAttr boundMap,
2327  Operation::operand_range boundOperands,
2328  const char *prefix, OpAsmPrinter &p) {
2329  AffineMap map = boundMap.getValue();
2330 
2331  // Check if this bound should be printed using custom assembly form.
2332  // The decision to restrict printing custom assembly form to trivial cases
2333  // comes from the will to roundtrip MLIR binary -> text -> binary in a
2334  // lossless way.
2335  // Therefore, custom assembly form parsing and printing is only supported for
2336  // zero-operand constant maps and single symbol operand identity maps.
2337  if (map.getNumResults() == 1) {
2338  AffineExpr expr = map.getResult(0);
2339 
2340  // Print constant bound.
2341  if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2342  if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2343  p << constExpr.getValue();
2344  return;
2345  }
2346  }
2347 
2348  // Print bound that consists of a single SSA symbol if the map is over a
2349  // single symbol.
2350  if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2351  if (isa<AffineSymbolExpr>(expr)) {
2352  p.printOperand(*boundOperands.begin());
2353  return;
2354  }
2355  }
2356  } else {
2357  // Map has multiple results. Print 'min' or 'max' prefix.
2358  p << prefix << ' ';
2359  }
2360 
2361  // Print the map and its operands.
2362  p << boundMap;
2363  printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
2364  map.getNumDims(), p);
2365 }
2366 
2367 unsigned AffineForOp::getNumIterOperands() {
2368  AffineMap lbMap = getLowerBoundMapAttr().getValue();
2369  AffineMap ubMap = getUpperBoundMapAttr().getValue();
2370 
2371  return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2372 }
2373 
2374 std::optional<MutableArrayRef<OpOperand>>
2375 AffineForOp::getYieldedValuesMutable() {
2376  return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2377 }
2378 
2380  p << ' ';
2381  p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
2382  /*omitType=*/true);
2383  p << " = ";
2384  printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
2385  p << " to ";
2386  printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
2387 
2388  if (getStepAsInt() != 1)
2389  p << " step " << getStepAsInt();
2390 
2391  bool printBlockTerminators = false;
2392  if (getNumIterOperands() > 0) {
2393  p << " iter_args(";
2394  auto regionArgs = getRegionIterArgs();
2395  auto operands = getInits();
2396 
2397  llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2398  p << std::get<0>(it) << " = " << std::get<1>(it);
2399  });
2400  p << ") -> (" << getResultTypes() << ")";
2401  printBlockTerminators = true;
2402  }
2403 
2404  p << ' ';
2405  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2406  printBlockTerminators);
2408  (*this)->getAttrs(),
2409  /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()),
2410  getUpperBoundMapAttrName(getOperation()->getName()),
2411  getStepAttrName(getOperation()->getName()),
2412  getOperandSegmentSizeAttr()});
2413 }
2414 
2415 /// Fold the constant bounds of a loop.
2416 static LogicalResult foldLoopBounds(AffineForOp forOp) {
2417  auto foldLowerOrUpperBound = [&forOp](bool lower) {
2418  // Check to see if each of the operands is the result of a constant. If
2419  // so, get the value. If not, ignore it.
2420  SmallVector<Attribute, 8> operandConstants;
2421  auto boundOperands =
2422  lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2423  for (auto operand : boundOperands) {
2424  Attribute operandCst;
2425  matchPattern(operand, m_Constant(&operandCst));
2426  operandConstants.push_back(operandCst);
2427  }
2428 
2429  AffineMap boundMap =
2430  lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2431  assert(boundMap.getNumResults() >= 1 &&
2432  "bound maps should have at least one result");
2433  SmallVector<Attribute, 4> foldedResults;
2434  if (failed(boundMap.constantFold(operandConstants, foldedResults)))
2435  return failure();
2436 
2437  // Compute the max or min as applicable over the results.
2438  assert(!foldedResults.empty() && "bounds should have at least one result");
2439  auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2440  for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2441  auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2442  maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2443  : llvm::APIntOps::smin(maxOrMin, foldedResult);
2444  }
2445  lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2446  : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2447  return success();
2448  };
2449 
2450  // Try to fold the lower bound.
2451  bool folded = false;
2452  if (!forOp.hasConstantLowerBound())
2453  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
2454 
2455  // Try to fold the upper bound.
2456  if (!forOp.hasConstantUpperBound())
2457  folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
2458  return success(folded);
2459 }
2460 
2461 /// Canonicalize the bounds of the given loop.
2462 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2463  SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2464  SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2465 
2466  auto lbMap = forOp.getLowerBoundMap();
2467  auto ubMap = forOp.getUpperBoundMap();
2468  auto prevLbMap = lbMap;
2469  auto prevUbMap = ubMap;
2470 
2471  composeAffineMapAndOperands(&lbMap, &lbOperands);
2472  canonicalizeMapAndOperands(&lbMap, &lbOperands);
2473  simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
2474  simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
2475  lbMap = removeDuplicateExprs(lbMap);
2476 
2477  composeAffineMapAndOperands(&ubMap, &ubOperands);
2478  canonicalizeMapAndOperands(&ubMap, &ubOperands);
2479  ubMap = removeDuplicateExprs(ubMap);
2480 
2481  // Any canonicalization change always leads to updated map(s).
2482  if (lbMap == prevLbMap && ubMap == prevUbMap)
2483  return failure();
2484 
2485  if (lbMap != prevLbMap)
2486  forOp.setLowerBound(lbOperands, lbMap);
2487  if (ubMap != prevUbMap)
2488  forOp.setUpperBound(ubOperands, ubMap);
2489  return success();
2490 }
2491 
2492 namespace {
2493 /// Returns constant trip count in trivial cases.
2494 static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2495  int64_t step = forOp.getStepAsInt();
2496  if (!forOp.hasConstantBounds() || step <= 0)
2497  return std::nullopt;
2498  int64_t lb = forOp.getConstantLowerBound();
2499  int64_t ub = forOp.getConstantUpperBound();
2500  return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2501 }
2502 
2503 /// This is a pattern to fold trivially empty loop bodies.
2504 /// TODO: This should be moved into the folding hook.
2505 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
2507 
2508  LogicalResult matchAndRewrite(AffineForOp forOp,
2509  PatternRewriter &rewriter) const override {
2510  // Check that the body only contains a yield.
2511  if (!llvm::hasSingleElement(*forOp.getBody()))
2512  return failure();
2513  if (forOp.getNumResults() == 0)
2514  return success();
2515  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2516  if (tripCount == 0) {
2517  // The initial values of the iteration arguments would be the op's
2518  // results.
2519  rewriter.replaceOp(forOp, forOp.getInits());
2520  return success();
2521  }
2522  SmallVector<Value, 4> replacements;
2523  auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2524  auto iterArgs = forOp.getRegionIterArgs();
2525  bool hasValDefinedOutsideLoop = false;
2526  bool iterArgsNotInOrder = false;
2527  for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2528  Value val = yieldOp.getOperand(i);
2529  auto *iterArgIt = llvm::find(iterArgs, val);
2530  // TODO: It should be possible to perform a replacement by computing the
2531  // last value of the IV based on the bounds and the step.
2532  if (val == forOp.getInductionVar())
2533  return failure();
2534  if (iterArgIt == iterArgs.end()) {
2535  // `val` is defined outside of the loop.
2536  assert(forOp.isDefinedOutsideOfLoop(val) &&
2537  "must be defined outside of the loop");
2538  hasValDefinedOutsideLoop = true;
2539  replacements.push_back(val);
2540  } else {
2541  unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2542  if (pos != i)
2543  iterArgsNotInOrder = true;
2544  replacements.push_back(forOp.getInits()[pos]);
2545  }
2546  }
2547  // Bail out when the trip count is unknown and the loop returns any value
2548  // defined outside of the loop or any iterArg out of order.
2549  if (!tripCount.has_value() &&
2550  (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2551  return failure();
2552  // Bail out when the loop iterates more than once and it returns any iterArg
2553  // out of order.
2554  if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2555  return failure();
2556  rewriter.replaceOp(forOp, replacements);
2557  return success();
2558  }
2559 };
2560 } // namespace
2561 
2562 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
2563  MLIRContext *context) {
2564  results.add<AffineForEmptyLoopFolder>(context);
2565 }
2566 
2567 OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
2568  assert((point.isParent() || point == getRegion()) && "invalid region point");
2569 
2570  // The initial operands map to the loop arguments after the induction
2571  // variable or are forwarded to the results when the trip count is zero.
2572  return getInits();
2573 }
2574 
2575 void AffineForOp::getSuccessorRegions(
2577  assert((point.isParent() || point == getRegion()) && "expected loop region");
2578  // The loop may typically branch back to its body or to the parent operation.
2579  // If the predecessor is the parent op and the trip count is known to be at
2580  // least one, branch into the body using the iterator arguments. And in cases
2581  // we know the trip count is zero, it can only branch back to its parent.
2582  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2583  if (point.isParent() && tripCount.has_value()) {
2584  if (tripCount.value() > 0) {
2585  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2586  return;
2587  }
2588  if (tripCount.value() == 0) {
2589  regions.push_back(RegionSuccessor(getResults()));
2590  return;
2591  }
2592  }
2593 
2594  // From the loop body, if the trip count is one, we can only branch back to
2595  // the parent.
2596  if (!point.isParent() && tripCount == 1) {
2597  regions.push_back(RegionSuccessor(getResults()));
2598  return;
2599  }
2600 
2601  // In all other cases, the loop may branch back to itself or the parent
2602  // operation.
2603  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
2604  regions.push_back(RegionSuccessor(getResults()));
2605 }
2606 
2607 /// Returns true if the affine.for has zero iterations in trivial cases.
2608 static bool hasTrivialZeroTripCount(AffineForOp op) {
2609  return getTrivialConstantTripCount(op) == 0;
2610 }
2611 
2612 LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2613  SmallVectorImpl<OpFoldResult> &results) {
2614  bool folded = succeeded(foldLoopBounds(*this));
2615  folded |= succeeded(canonicalizeLoopBounds(*this));
2616  if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
2617  // The initial values of the loop-carried variables (iter_args) are the
2618  // results of the op. But this must be avoided for an affine.for op that
2619  // does not return any results. Since ops that do not return results cannot
2620  // be folded away, we would enter an infinite loop of folds on the same
2621  // affine.for op.
2622  results.assign(getInits().begin(), getInits().end());
2623  folded = true;
2624  }
2625  return success(folded);
2626 }
2627 
2629  return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
2630 }
2631 
2633  return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap());
2634 }
2635 
2636 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2637  assert(lbOperands.size() == map.getNumInputs());
2638  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2639  getLowerBoundOperandsMutable().assign(lbOperands);
2640  setLowerBoundMap(map);
2641 }
2642 
2643 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2644  assert(ubOperands.size() == map.getNumInputs());
2645  assert(map.getNumResults() >= 1 && "bound map has at least one result");
2646  getUpperBoundOperandsMutable().assign(ubOperands);
2647  setUpperBoundMap(map);
2648 }
2649 
2650 bool AffineForOp::hasConstantLowerBound() {
2651  return getLowerBoundMap().isSingleConstant();
2652 }
2653 
2654 bool AffineForOp::hasConstantUpperBound() {
2655  return getUpperBoundMap().isSingleConstant();
2656 }
2657 
2658 int64_t AffineForOp::getConstantLowerBound() {
2659  return getLowerBoundMap().getSingleConstantResult();
2660 }
2661 
2662 int64_t AffineForOp::getConstantUpperBound() {
2663  return getUpperBoundMap().getSingleConstantResult();
2664 }
2665 
2666 void AffineForOp::setConstantLowerBound(int64_t value) {
2667  setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2668 }
2669 
2670 void AffineForOp::setConstantUpperBound(int64_t value) {
2671  setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2672 }
2673 
2674 AffineForOp::operand_range AffineForOp::getControlOperands() {
2675  return {operand_begin(), operand_begin() + getLowerBoundOperands().size() +
2676  getUpperBoundOperands().size()};
2677 }
2678 
2679 bool AffineForOp::matchingBoundOperandList() {
2680  auto lbMap = getLowerBoundMap();
2681  auto ubMap = getUpperBoundMap();
2682  if (lbMap.getNumDims() != ubMap.getNumDims() ||
2683  lbMap.getNumSymbols() != ubMap.getNumSymbols())
2684  return false;
2685 
2686  unsigned numOperands = lbMap.getNumInputs();
2687  for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2688  // Compare Value 's.
2689  if (getOperand(i) != getOperand(numOperands + i))
2690  return false;
2691  }
2692  return true;
2693 }
2694 
2695 SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
2696 
2697 std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2698  return SmallVector<Value>{getInductionVar()};
2699 }
2700 
2701 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2702  if (!hasConstantLowerBound())
2703  return std::nullopt;
2704  OpBuilder b(getContext());
2706  OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2707 }
2708 
2709 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2710  OpBuilder b(getContext());
2712  OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
2713 }
2714 
2715 std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2716  if (!hasConstantUpperBound())
2717  return {};
2718  OpBuilder b(getContext());
2720  OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2721 }
2722 
2723 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2724  RewriterBase &rewriter, ValueRange newInitOperands,
2725  bool replaceInitOperandUsesInLoop,
2726  const NewYieldValuesFn &newYieldValuesFn) {
2727  // Create a new loop before the existing one, with the extra operands.
2728  OpBuilder::InsertionGuard g(rewriter);
2729  rewriter.setInsertionPoint(getOperation());
2730  auto inits = llvm::to_vector(getInits());
2731  inits.append(newInitOperands.begin(), newInitOperands.end());
2732  AffineForOp newLoop = AffineForOp::create(
2733  rewriter, getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2734  getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
2735 
2736  // Generate the new yield values and append them to the scf.yield operation.
2737  auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2738  ArrayRef<BlockArgument> newIterArgs =
2739  newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2740  {
2741  OpBuilder::InsertionGuard g(rewriter);
2742  rewriter.setInsertionPoint(yieldOp);
2743  SmallVector<Value> newYieldedValues =
2744  newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2745  assert(newInitOperands.size() == newYieldedValues.size() &&
2746  "expected as many new yield values as new iter operands");
2747  rewriter.modifyOpInPlace(yieldOp, [&]() {
2748  yieldOp.getOperandsMutable().append(newYieldedValues);
2749  });
2750  }
2751 
2752  // Move the loop body to the new op.
2753  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2754  newLoop.getBody()->getArguments().take_front(
2755  getBody()->getNumArguments()));
2756 
2757  if (replaceInitOperandUsesInLoop) {
2758  // Replace all uses of `newInitOperands` with the corresponding basic block
2759  // arguments.
2760  for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2761  rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
2762  [&](OpOperand &use) {
2763  Operation *user = use.getOwner();
2764  return newLoop->isProperAncestor(user);
2765  });
2766  }
2767  }
2768 
2769  // Replace the old loop.
2770  rewriter.replaceOp(getOperation(),
2771  newLoop->getResults().take_front(getNumResults()));
2772  return cast<LoopLikeOpInterface>(newLoop.getOperation());
2773 }
2774 
2775 Speculation::Speculatability AffineForOp::getSpeculatability() {
2776  // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2777  // Start and End.
2778  //
2779  // For Step != 1, the loop may not terminate. We can add more smarts here if
2780  // needed.
2781  return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable
2783 }
2784 
2785 /// Returns true if the provided value is the induction variable of a
2786 /// AffineForOp.
2788  return getForInductionVarOwner(val) != AffineForOp();
2789 }
2790 
2792  return getAffineParallelInductionVarOwner(val) != nullptr;
2793 }
2794 
2797 }
2798 
2800  auto ivArg = dyn_cast<BlockArgument>(val);
2801  if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2802  return AffineForOp();
2803  if (auto forOp =
2804  ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2805  // Check to make sure `val` is the induction variable, not an iter_arg.
2806  return forOp.getInductionVar() == val ? forOp : AffineForOp();
2807  return AffineForOp();
2808 }
2809 
2811  auto ivArg = dyn_cast<BlockArgument>(val);
2812  if (!ivArg || !ivArg.getOwner())
2813  return nullptr;
2814  Operation *containingOp = ivArg.getOwner()->getParentOp();
2815  auto parallelOp = dyn_cast_if_present<AffineParallelOp>(containingOp);
2816  if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2817  return parallelOp;
2818  return nullptr;
2819 }
2820 
2821 /// Extracts the induction variables from a list of AffineForOps and returns
2822 /// them.
2824  SmallVectorImpl<Value> *ivs) {
2825  ivs->reserve(forInsts.size());
2826  for (auto forInst : forInsts)
2827  ivs->push_back(forInst.getInductionVar());
2828 }
2829 
2832  ivs.reserve(affineOps.size());
2833  for (Operation *op : affineOps) {
2834  // Add constraints from forOp's bounds.
2835  if (auto forOp = dyn_cast<AffineForOp>(op))
2836  ivs.push_back(forOp.getInductionVar());
2837  else if (auto parallelOp = dyn_cast<AffineParallelOp>(op))
2838  for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2839  ivs.push_back(parallelOp.getBody()->getArgument(i));
2840  }
2841 }
2842 
2843 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2844 /// operations.
2845 template <typename BoundListTy, typename LoopCreatorTy>
2847  OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2848  ArrayRef<int64_t> steps,
2849  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2850  LoopCreatorTy &&loopCreatorFn) {
2851  assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2852  assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2853 
2854  // If there are no loops to be constructed, construct the body anyway.
2855  OpBuilder::InsertionGuard guard(builder);
2856  if (lbs.empty()) {
2857  if (bodyBuilderFn)
2858  bodyBuilderFn(builder, loc, ValueRange());
2859  return;
2860  }
2861 
2862  // Create the loops iteratively and store the induction variables.
2864  ivs.reserve(lbs.size());
2865  for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2866  // Callback for creating the loop body, always creates the terminator.
2867  auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
2868  ValueRange iterArgs) {
2869  ivs.push_back(iv);
2870  // In the innermost loop, call the body builder.
2871  if (i == e - 1 && bodyBuilderFn) {
2872  OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
2873  bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2874  }
2875  AffineYieldOp::create(nestedBuilder, nestedLoc);
2876  };
2877 
2878  // Delegate actual loop creation to the callback in order to dispatch
2879  // between constant- and variable-bound loops.
2880  auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2881  builder.setInsertionPointToStart(loop.getBody());
2882  }
2883 }
2884 
2885 /// Creates an affine loop from the bounds known to be constants.
2886 static AffineForOp
2888  int64_t ub, int64_t step,
2889  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2890  return AffineForOp::create(builder, loc, lb, ub, step,
2891  /*iterArgs=*/ValueRange(), bodyBuilderFn);
2892 }
2893 
2894 /// Creates an affine loop from the bounds that may or may not be constants.
2895 static AffineForOp
2897  int64_t step,
2898  AffineForOp::BodyBuilderFn bodyBuilderFn) {
2899  std::optional<int64_t> lbConst = getConstantIntValue(lb);
2900  std::optional<int64_t> ubConst = getConstantIntValue(ub);
2901  if (lbConst && ubConst)
2902  return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
2903  ubConst.value(), step, bodyBuilderFn);
2904  return AffineForOp::create(builder, loc, lb, builder.getDimIdentityMap(), ub,
2905  builder.getDimIdentityMap(), step,
2906  /*iterArgs=*/ValueRange(), bodyBuilderFn);
2907 }
2908 
2910  OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
2912  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2913  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2915 }
2916 
2918  OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
2919  ArrayRef<int64_t> steps,
2920  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
2921  buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
2923 }
2924 
2925 //===----------------------------------------------------------------------===//
2926 // AffineIfOp
2927 //===----------------------------------------------------------------------===//
2928 
2929 namespace {
2930 /// Remove else blocks that have nothing other than a zero value yield.
2931 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
2933 
2934  LogicalResult matchAndRewrite(AffineIfOp ifOp,
2935  PatternRewriter &rewriter) const override {
2936  if (ifOp.getElseRegion().empty() ||
2937  !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
2938  return failure();
2939 
2940  rewriter.startOpModification(ifOp);
2941  rewriter.eraseBlock(ifOp.getElseBlock());
2942  rewriter.finalizeOpModification(ifOp);
2943  return success();
2944  }
2945 };
2946 
2947 /// Removes affine.if cond if the condition is always true or false in certain
2948 /// trivial cases. Promotes the then/else block in the parent operation block.
2949 struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
2951 
2952  LogicalResult matchAndRewrite(AffineIfOp op,
2953  PatternRewriter &rewriter) const override {
2954 
2955  auto isTriviallyFalse = [](IntegerSet iSet) {
2956  return iSet.isEmptyIntegerSet();
2957  };
2958 
2959  auto isTriviallyTrue = [](IntegerSet iSet) {
2960  return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
2961  iSet.getConstraint(0) == 0);
2962  };
2963 
2964  IntegerSet affineIfConditions = op.getIntegerSet();
2965  Block *blockToMove;
2966  if (isTriviallyFalse(affineIfConditions)) {
2967  // The absence, or equivalently, the emptiness of the else region need not
2968  // be checked when affine.if is returning results because if an affine.if
2969  // operation is returning results, it always has a non-empty else region.
2970  if (op.getNumResults() == 0 && !op.hasElse()) {
2971  // If the else region is absent, or equivalently, empty, remove the
2972  // affine.if operation (which is not returning any results).
2973  rewriter.eraseOp(op);
2974  return success();
2975  }
2976  blockToMove = op.getElseBlock();
2977  } else if (isTriviallyTrue(affineIfConditions)) {
2978  blockToMove = op.getThenBlock();
2979  } else {
2980  return failure();
2981  }
2982  Operation *blockToMoveTerminator = blockToMove->getTerminator();
2983  // Promote the "blockToMove" block to the parent operation block between the
2984  // prologue and epilogue of "op".
2985  rewriter.inlineBlockBefore(blockToMove, op);
2986  // Replace the "op" operation with the operands of the
2987  // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2988  // the affine.yield operation present in the "blockToMove" block. It has no
2989  // operands when affine.if is not returning results and therefore, in that
2990  // case, replaceOp just erases "op". When affine.if is not returning
2991  // results, the affine.yield operation can be omitted. It gets inserted
2992  // implicitly.
2993  rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
2994  // Erase the "blockToMoveTerminator" operation since it is now in the parent
2995  // operation block, which already has its own terminator.
2996  rewriter.eraseOp(blockToMoveTerminator);
2997  return success();
2998  }
2999 };
3000 } // namespace
3001 
3002 /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
3003 /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
3004 void AffineIfOp::getSuccessorRegions(
3006  // If the predecessor is an AffineIfOp, then branching into both `then` and
3007  // `else` region is valid.
3008  if (point.isParent()) {
3009  regions.reserve(2);
3010  regions.push_back(
3011  RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
3012  // If the "else" region is empty, branch bach into parent.
3013  if (getElseRegion().empty()) {
3014  regions.push_back(getResults());
3015  } else {
3016  regions.push_back(
3017  RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
3018  }
3019  return;
3020  }
3021 
3022  // If the predecessor is the `else`/`then` region, then branching into parent
3023  // op is valid.
3024  regions.push_back(RegionSuccessor(getResults()));
3025 }
3026 
3027 LogicalResult AffineIfOp::verify() {
3028  // Verify that we have a condition attribute.
3029  // FIXME: This should be specified in the arguments list in ODS.
3030  auto conditionAttr =
3031  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3032  if (!conditionAttr)
3033  return emitOpError("requires an integer set attribute named 'condition'");
3034 
3035  // Verify that there are enough operands for the condition.
3036  IntegerSet condition = conditionAttr.getValue();
3037  if (getNumOperands() != condition.getNumInputs())
3038  return emitOpError("operand count and condition integer set dimension and "
3039  "symbol count must match");
3040 
3041  // Verify that the operands are valid dimension/symbols.
3042  if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
3043  condition.getNumDims())))
3044  return failure();
3045 
3046  return success();
3047 }
3048 
3049 ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
3050  // Parse the condition attribute set.
3051  IntegerSetAttr conditionAttr;
3052  unsigned numDims;
3053  if (parser.parseAttribute(conditionAttr,
3054  AffineIfOp::getConditionAttrStrName(),
3055  result.attributes) ||
3056  parseDimAndSymbolList(parser, result.operands, numDims))
3057  return failure();
3058 
3059  // Verify the condition operands.
3060  auto set = conditionAttr.getValue();
3061  if (set.getNumDims() != numDims)
3062  return parser.emitError(
3063  parser.getNameLoc(),
3064  "dim operand count and integer set dim count must match");
3065  if (numDims + set.getNumSymbols() != result.operands.size())
3066  return parser.emitError(
3067  parser.getNameLoc(),
3068  "symbol operand count and integer set symbol count must match");
3069 
3070  if (parser.parseOptionalArrowTypeList(result.types))
3071  return failure();
3072 
3073  // Create the regions for 'then' and 'else'. The latter must be created even
3074  // if it remains empty for the validity of the operation.
3075  result.regions.reserve(2);
3076  Region *thenRegion = result.addRegion();
3077  Region *elseRegion = result.addRegion();
3078 
3079  // Parse the 'then' region.
3080  if (parser.parseRegion(*thenRegion, {}, {}))
3081  return failure();
3082  AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
3083  result.location);
3084 
3085  // If we find an 'else' keyword then parse the 'else' region.
3086  if (!parser.parseOptionalKeyword("else")) {
3087  if (parser.parseRegion(*elseRegion, {}, {}))
3088  return failure();
3089  AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
3090  result.location);
3091  }
3092 
3093  // Parse the optional attribute list.
3094  if (parser.parseOptionalAttrDict(result.attributes))
3095  return failure();
3096 
3097  return success();
3098 }
3099 
3101  auto conditionAttr =
3102  (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3103  p << " " << conditionAttr;
3104  printDimAndSymbolList(operand_begin(), operand_end(),
3105  conditionAttr.getValue().getNumDims(), p);
3106  p.printOptionalArrowTypeList(getResultTypes());
3107  p << ' ';
3108  p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
3109  /*printBlockTerminators=*/getNumResults());
3110 
3111  // Print the 'else' regions if it has any blocks.
3112  auto &elseRegion = this->getElseRegion();
3113  if (!elseRegion.empty()) {
3114  p << " else ";
3115  p.printRegion(elseRegion,
3116  /*printEntryBlockArgs=*/false,
3117  /*printBlockTerminators=*/getNumResults());
3118  }
3119 
3120  // Print the attribute list.
3121  p.printOptionalAttrDict((*this)->getAttrs(),
3122  /*elidedAttrs=*/getConditionAttrStrName());
3123 }
3124 
3125 IntegerSet AffineIfOp::getIntegerSet() {
3126  return (*this)
3127  ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
3128  .getValue();
3129 }
3130 
3131 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
3132  (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
3133 }
3134 
3135 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
3136  setIntegerSet(set);
3137  (*this)->setOperands(operands);
3138 }
3139 
3140 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
3141  TypeRange resultTypes, IntegerSet set, ValueRange args,
3142  bool withElseRegion) {
3143  assert(resultTypes.empty() || withElseRegion);
3144  OpBuilder::InsertionGuard guard(builder);
3145 
3146  result.addTypes(resultTypes);
3147  result.addOperands(args);
3148  result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
3149 
3150  Region *thenRegion = result.addRegion();
3151  builder.createBlock(thenRegion);
3152  if (resultTypes.empty())
3153  AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
3154 
3155  Region *elseRegion = result.addRegion();
3156  if (withElseRegion) {
3157  builder.createBlock(elseRegion);
3158  if (resultTypes.empty())
3159  AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
3160  }
3161 }
3162 
3163 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
3164  IntegerSet set, ValueRange args, bool withElseRegion) {
3165  AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
3166  withElseRegion);
3167 }
3168 
3169 /// Compose any affine.apply ops feeding into `operands` of the integer set
3170 /// `set` by composing the maps of such affine.apply ops with the integer
3171 /// set constraints.
3173  SmallVectorImpl<Value> &operands,
3174  bool composeAffineMin = false) {
3175  // We will simply reuse the API of the map composition by viewing the LHSs of
3176  // the equalities and inequalities of `set` as the affine exprs of an affine
3177  // map. Convert to equivalent map, compose, and convert back to set.
3178  auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(),
3179  set.getConstraints(), set.getContext());
3180  // Check if any composition is possible.
3181  if (llvm::none_of(operands,
3182  [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
3183  return;
3184 
3185  composeAffineMapAndOperands(&map, &operands, composeAffineMin);
3186  set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
3187  set.getEqFlags());
3188 }
3189 
3190 /// Canonicalize an affine if op's conditional (integer set + operands).
3191 LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3192  auto set = getIntegerSet();
3193  SmallVector<Value, 4> operands(getOperands());
3194  composeSetAndOperands(set, operands);
3195  canonicalizeSetAndOperands(&set, &operands);
3196 
3197  // Check if the canonicalization or composition led to any change.
3198  if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
3199  return failure();
3200 
3201  setConditional(set, operands);
3202  return success();
3203 }
3204 
3205 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
3206  MLIRContext *context) {
3207  results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
3208 }
3209 
3210 //===----------------------------------------------------------------------===//
3211 // AffineLoadOp
3212 //===----------------------------------------------------------------------===//
3213 
3214 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3215  AffineMap map, ValueRange operands) {
3216  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
3217  result.addOperands(operands);
3218  if (map)
3219  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3220  auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
3221  result.types.push_back(memrefType.getElementType());
3222 }
3223 
3224 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3225  Value memref, AffineMap map, ValueRange mapOperands) {
3226  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3227  result.addOperands(memref);
3228  result.addOperands(mapOperands);
3229  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3230  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3231  result.types.push_back(memrefType.getElementType());
3232 }
3233 
3234 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3235  Value memref, ValueRange indices) {
3236  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3237  int64_t rank = memrefType.getRank();
3238  // Create identity map for memrefs with at least one dimension or () -> ()
3239  // for zero-dimensional memrefs.
3240  auto map =
3241  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3242  build(builder, result, memref, map, indices);
3243 }
3244 
3245 ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
3246  auto &builder = parser.getBuilder();
3247  auto indexTy = builder.getIndexType();
3248 
3249  MemRefType type;
3250  OpAsmParser::UnresolvedOperand memrefInfo;
3251  AffineMapAttr mapAttr;
3253  return failure(
3254  parser.parseOperand(memrefInfo) ||
3255  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3256  AffineLoadOp::getMapAttrStrName(),
3257  result.attributes) ||
3258  parser.parseOptionalAttrDict(result.attributes) ||
3259  parser.parseColonType(type) ||
3260  parser.resolveOperand(memrefInfo, type, result.operands) ||
3261  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3262  parser.addTypeToList(type.getElementType(), result.types));
3263 }
3264 
3266  p << " " << getMemRef() << '[';
3267  if (AffineMapAttr mapAttr =
3268  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3269  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3270  p << ']';
3271  p.printOptionalAttrDict((*this)->getAttrs(),
3272  /*elidedAttrs=*/{getMapAttrStrName()});
3273  p << " : " << getMemRefType();
3274 }
3275 
3276 /// Verify common indexing invariants of affine.load, affine.store,
3277 /// affine.vector_load and affine.vector_store.
3278 template <typename AffineMemOpTy>
3279 static LogicalResult
3280 verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr,
3281  Operation::operand_range mapOperands,
3282  MemRefType memrefType, unsigned numIndexOperands) {
3283  AffineMap map = mapAttr.getValue();
3284  if (map.getNumResults() != memrefType.getRank())
3285  return op->emitOpError("affine map num results must equal memref rank");
3286  if (map.getNumInputs() != numIndexOperands)
3287  return op->emitOpError("expects as many subscripts as affine map inputs");
3288 
3289  for (auto idx : mapOperands) {
3290  if (!idx.getType().isIndex())
3291  return op->emitOpError("index to load must have 'index' type");
3292  }
3293  if (failed(verifyDimAndSymbolIdentifiers(op, mapOperands, map.getNumDims())))
3294  return failure();
3295 
3296  return success();
3297 }
3298 
3299 LogicalResult AffineLoadOp::verify() {
3300  auto memrefType = getMemRefType();
3301  if (getType() != memrefType.getElementType())
3302  return emitOpError("result type must match element type of memref");
3303 
3304  if (failed(verifyMemoryOpIndexing(
3305  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3306  getMapOperands(), memrefType,
3307  /*numIndexOperands=*/getNumOperands() - 1)))
3308  return failure();
3309 
3310  return success();
3311 }
3312 
3313 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3314  MLIRContext *context) {
3315  results.add<SimplifyAffineOp<AffineLoadOp>>(context);
3316 }
3317 
3318 OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3319  /// load(memrefcast) -> load
3320  if (succeeded(memref::foldMemRefCast(*this)))
3321  return getResult();
3322 
3323  // Fold load from a global constant memref.
3324  auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3325  if (!getGlobalOp)
3326  return {};
3327  // Get to the memref.global defining the symbol.
3328  auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3329  if (!symbolTableOp)
3330  return {};
3331  auto global = dyn_cast_or_null<memref::GlobalOp>(
3332  SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
3333  if (!global)
3334  return {};
3335 
3336  // Check if the global memref is a constant.
3337  auto cstAttr =
3338  dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3339  if (!cstAttr)
3340  return {};
3341  // If it's a splat constant, we can fold irrespective of indices.
3342  if (auto splatAttr = dyn_cast<SplatElementsAttr>(cstAttr))
3343  return splatAttr.getSplatValue<Attribute>();
3344  // Otherwise, we can fold only if we know the indices.
3345  if (!getAffineMap().isConstant())
3346  return {};
3347  auto indices = llvm::to_vector<4>(
3348  llvm::map_range(getAffineMap().getConstantResults(),
3349  [](int64_t v) -> uint64_t { return v; }));
3350  return cstAttr.getValues<Attribute>()[indices];
3351 }
3352 
3353 //===----------------------------------------------------------------------===//
3354 // AffineStoreOp
3355 //===----------------------------------------------------------------------===//
3356 
3357 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3358  Value valueToStore, Value memref, AffineMap map,
3359  ValueRange mapOperands) {
3360  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3361  result.addOperands(valueToStore);
3362  result.addOperands(memref);
3363  result.addOperands(mapOperands);
3364  result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3365 }
3366 
3367 // Use identity map.
3368 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3369  Value valueToStore, Value memref,
3370  ValueRange indices) {
3371  auto memrefType = llvm::cast<MemRefType>(memref.getType());
3372  int64_t rank = memrefType.getRank();
3373  // Create identity map for memrefs with at least one dimension or () -> ()
3374  // for zero-dimensional memrefs.
3375  auto map =
3376  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3377  build(builder, result, valueToStore, memref, map, indices);
3378 }
3379 
3380 ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
3381  auto indexTy = parser.getBuilder().getIndexType();
3382 
3383  MemRefType type;
3384  OpAsmParser::UnresolvedOperand storeValueInfo;
3385  OpAsmParser::UnresolvedOperand memrefInfo;
3386  AffineMapAttr mapAttr;
3388  return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3389  parser.parseOperand(memrefInfo) ||
3390  parser.parseAffineMapOfSSAIds(
3391  mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3392  result.attributes) ||
3393  parser.parseOptionalAttrDict(result.attributes) ||
3394  parser.parseColonType(type) ||
3395  parser.resolveOperand(storeValueInfo, type.getElementType(),
3396  result.operands) ||
3397  parser.resolveOperand(memrefInfo, type, result.operands) ||
3398  parser.resolveOperands(mapOperands, indexTy, result.operands));
3399 }
3400 
3402  p << " " << getValueToStore();
3403  p << ", " << getMemRef() << '[';
3404  if (AffineMapAttr mapAttr =
3405  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3406  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3407  p << ']';
3408  p.printOptionalAttrDict((*this)->getAttrs(),
3409  /*elidedAttrs=*/{getMapAttrStrName()});
3410  p << " : " << getMemRefType();
3411 }
3412 
3413 LogicalResult AffineStoreOp::verify() {
3414  // The value to store must have the same type as memref element type.
3415  auto memrefType = getMemRefType();
3416  if (getValueToStore().getType() != memrefType.getElementType())
3417  return emitOpError(
3418  "value to store must have the same type as memref element type");
3419 
3420  if (failed(verifyMemoryOpIndexing(
3421  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3422  getMapOperands(), memrefType,
3423  /*numIndexOperands=*/getNumOperands() - 2)))
3424  return failure();
3425 
3426  return success();
3427 }
3428 
3429 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3430  MLIRContext *context) {
3431  results.add<SimplifyAffineOp<AffineStoreOp>>(context);
3432 }
3433 
3434 LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3435  SmallVectorImpl<OpFoldResult> &results) {
3436  /// store(memrefcast) -> store
3437  return memref::foldMemRefCast(*this, getValueToStore());
3438 }
3439 
3440 //===----------------------------------------------------------------------===//
3441 // AffineMinMaxOpBase
3442 //===----------------------------------------------------------------------===//
3443 
3444 template <typename T>
3445 static LogicalResult verifyAffineMinMaxOp(T op) {
3446  // Verify that operand count matches affine map dimension and symbol count.
3447  if (op.getNumOperands() !=
3448  op.getMap().getNumDims() + op.getMap().getNumSymbols())
3449  return op.emitOpError(
3450  "operand count and affine map dimension and symbol count must match");
3451 
3452  if (op.getMap().getNumResults() == 0)
3453  return op.emitOpError("affine map expect at least one result");
3454  return success();
3455 }
3456 
3457 template <typename T>
3458 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3459  p << ' ' << op->getAttr(T::getMapAttrStrName());
3460  auto operands = op.getOperands();
3461  unsigned numDims = op.getMap().getNumDims();
3462  p << '(' << operands.take_front(numDims) << ')';
3463 
3464  if (operands.size() != numDims)
3465  p << '[' << operands.drop_front(numDims) << ']';
3466  p.printOptionalAttrDict(op->getAttrs(),
3467  /*elidedAttrs=*/{T::getMapAttrStrName()});
3468 }
3469 
3470 template <typename T>
3471 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
3472  OperationState &result) {
3473  auto &builder = parser.getBuilder();
3474  auto indexType = builder.getIndexType();
3477  AffineMapAttr mapAttr;
3478  return failure(
3479  parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3480  result.attributes) ||
3481  parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) ||
3482  parser.parseOperandList(symInfos,
3484  parser.parseOptionalAttrDict(result.attributes) ||
3485  parser.resolveOperands(dimInfos, indexType, result.operands) ||
3486  parser.resolveOperands(symInfos, indexType, result.operands) ||
3487  parser.addTypeToList(indexType, result.types));
3488 }
3489 
3490 /// Fold an affine min or max operation with the given operands. The operand
3491 /// list may contain nulls, which are interpreted as the operand not being a
3492 /// constant.
3493 template <typename T>
3495  static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3496  "expected affine min or max op");
3497 
3498  // Fold the affine map.
3499  // TODO: Fold more cases:
3500  // min(some_affine, some_affine + constant, ...), etc.
3501  SmallVector<int64_t, 2> results;
3502  auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3503 
3504  if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3505  return op.getOperand(0);
3506 
3507  // If some of the map results are not constant, try changing the map in-place.
3508  if (results.empty()) {
3509  // If the map is the same, report that folding did not happen.
3510  if (foldedMap == op.getMap())
3511  return {};
3512  op->setAttr("map", AffineMapAttr::get(foldedMap));
3513  return op.getResult();
3514  }
3515 
3516  // Otherwise, completely fold the op into a constant.
3517  auto resultIt = std::is_same<T, AffineMinOp>::value
3518  ? llvm::min_element(results)
3519  : llvm::max_element(results);
3520  if (resultIt == results.end())
3521  return {};
3522  return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3523 }
3524 
3525 /// Remove duplicated expressions in affine min/max ops.
3526 template <typename T>
3529 
3530  LogicalResult matchAndRewrite(T affineOp,
3531  PatternRewriter &rewriter) const override {
3532  AffineMap oldMap = affineOp.getAffineMap();
3533 
3534  SmallVector<AffineExpr, 4> newExprs;
3535  for (AffineExpr expr : oldMap.getResults()) {
3536  // This is a linear scan over newExprs, but it should be fine given that
3537  // we typically just have a few expressions per op.
3538  if (!llvm::is_contained(newExprs, expr))
3539  newExprs.push_back(expr);
3540  }
3541 
3542  if (newExprs.size() == oldMap.getNumResults())
3543  return failure();
3544 
3545  auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
3546  newExprs, rewriter.getContext());
3547  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3548 
3549  return success();
3550  }
3551 };
3552 
3553 /// Merge an affine min/max op to its consumers if its consumer is also an
3554 /// affine min/max op.
3555 ///
3556 /// This pattern requires the producer affine min/max op is bound to a
3557 /// dimension/symbol that is used as a standalone expression in the consumer
3558 /// affine op's map.
3559 ///
3560 /// For example, a pattern like the following:
3561 ///
3562 /// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3563 /// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3564 ///
3565 /// Can be turned into:
3566 ///
3567 /// %1 = affine.min affine_map<
3568 /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3569 template <typename T>
3572 
3573  LogicalResult matchAndRewrite(T affineOp,
3574  PatternRewriter &rewriter) const override {
3575  AffineMap oldMap = affineOp.getAffineMap();
3576  ValueRange dimOperands =
3577  affineOp.getMapOperands().take_front(oldMap.getNumDims());
3578  ValueRange symOperands =
3579  affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3580 
3581  auto newDimOperands = llvm::to_vector<8>(dimOperands);
3582  auto newSymOperands = llvm::to_vector<8>(symOperands);
3583  SmallVector<AffineExpr, 4> newExprs;
3584  SmallVector<T, 4> producerOps;
3585 
3586  // Go over each expression to see whether it's a single dimension/symbol
3587  // with the corresponding operand which is the result of another affine
3588  // min/max op. If So it can be merged into this affine op.
3589  for (AffineExpr expr : oldMap.getResults()) {
3590  if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3591  Value symValue = symOperands[symExpr.getPosition()];
3592  if (auto producerOp = symValue.getDefiningOp<T>()) {
3593  producerOps.push_back(producerOp);
3594  continue;
3595  }
3596  } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3597  Value dimValue = dimOperands[dimExpr.getPosition()];
3598  if (auto producerOp = dimValue.getDefiningOp<T>()) {
3599  producerOps.push_back(producerOp);
3600  continue;
3601  }
3602  }
3603  // For the above cases we will remove the expression by merging the
3604  // producer affine min/max's affine expressions. Otherwise we need to
3605  // keep the existing expression.
3606  newExprs.push_back(expr);
3607  }
3608 
3609  if (producerOps.empty())
3610  return failure();
3611 
3612  unsigned numUsedDims = oldMap.getNumDims();
3613  unsigned numUsedSyms = oldMap.getNumSymbols();
3614 
3615  // Now go over all producer affine ops and merge their expressions.
3616  for (T producerOp : producerOps) {
3617  AffineMap producerMap = producerOp.getAffineMap();
3618  unsigned numProducerDims = producerMap.getNumDims();
3619  unsigned numProducerSyms = producerMap.getNumSymbols();
3620 
3621  // Collect all dimension/symbol values.
3622  ValueRange dimValues =
3623  producerOp.getMapOperands().take_front(numProducerDims);
3624  ValueRange symValues =
3625  producerOp.getMapOperands().take_back(numProducerSyms);
3626  newDimOperands.append(dimValues.begin(), dimValues.end());
3627  newSymOperands.append(symValues.begin(), symValues.end());
3628 
3629  // For expressions we need to shift to avoid overlap.
3630  for (AffineExpr expr : producerMap.getResults()) {
3631  newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3632  .shiftSymbols(numProducerSyms, numUsedSyms));
3633  }
3634 
3635  numUsedDims += numProducerDims;
3636  numUsedSyms += numProducerSyms;
3637  }
3638 
3639  auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
3640  rewriter.getContext());
3641  auto newOperands =
3642  llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3643  rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3644 
3645  return success();
3646  }
3647 };
3648 
3649 /// Canonicalize the result expression order of an affine map and return success
3650 /// if the order changed.
3651 ///
3652 /// The function flattens the map's affine expressions to coefficient arrays and
3653 /// sorts them in lexicographic order. A coefficient array contains a multiplier
3654 /// for every dimension/symbol and a constant term. The canonicalization fails
3655 /// if a result expression is not pure or if the flattening requires local
3656 /// variables that, unlike dimensions and symbols, have no global order.
3657 static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
3658  SmallVector<SmallVector<int64_t>> flattenedExprs;
3659  for (const AffineExpr &resultExpr : map.getResults()) {
3660  // Fail if the expression is not pure.
3661  if (!resultExpr.isPureAffine())
3662  return failure();
3663 
3664  SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3665  auto flattenResult = flattener.walkPostOrder(resultExpr);
3666  if (failed(flattenResult))
3667  return failure();
3668 
3669  // Fail if the flattened expression has local variables.
3670  if (flattener.operandExprStack.back().size() !=
3671  map.getNumDims() + map.getNumSymbols() + 1)
3672  return failure();
3673 
3674  flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3675  flattener.operandExprStack.back().end());
3676  }
3677 
3678  // Fail if sorting is not necessary.
3679  if (llvm::is_sorted(flattenedExprs))
3680  return failure();
3681 
3682  // Reorder the result expressions according to their flattened form.
3683  SmallVector<unsigned> resultPermutation =
3684  llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults()));
3685  llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3686  return flattenedExprs[lhs] < flattenedExprs[rhs];
3687  });
3688  SmallVector<AffineExpr> newExprs;
3689  for (unsigned idx : resultPermutation)
3690  newExprs.push_back(map.getResult(idx));
3691 
3692  map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs,
3693  map.getContext());
3694  return success();
3695 }
3696 
3697 /// Canonicalize the affine map result expression order of an affine min/max
3698 /// operation.
3699 ///
3700 /// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3701 /// expressions and replaces the operation if the order changed.
3702 ///
3703 /// For example, the following operation:
3704 ///
3705 /// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3706 ///
3707 /// Turns into:
3708 ///
3709 /// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3710 template <typename T>
3713 
3714  LogicalResult matchAndRewrite(T affineOp,
3715  PatternRewriter &rewriter) const override {
3716  AffineMap map = affineOp.getAffineMap();
3717  if (failed(canonicalizeMapExprAndTermOrder(map)))
3718  return failure();
3719  rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3720  return success();
3721  }
3722 };
3723 
3724 template <typename T>
3727 
3728  LogicalResult matchAndRewrite(T affineOp,
3729  PatternRewriter &rewriter) const override {
3730  if (affineOp.getMap().getNumResults() != 1)
3731  return failure();
3732  rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3733  affineOp.getOperands());
3734  return success();
3735  }
3736 };
3737 
3738 //===----------------------------------------------------------------------===//
3739 // AffineMinOp
3740 //===----------------------------------------------------------------------===//
3741 //
3742 // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3743 //
3744 
3745 OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3746  return foldMinMaxOp(*this, adaptor.getOperands());
3747 }
3748 
3749 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3750  MLIRContext *context) {
3753  MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3755  context);
3756 }
3757 
3758 LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); }
3759 
3760 ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3761  return parseAffineMinMaxOp<AffineMinOp>(parser, result);
3762 }
3763 
3764 void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3765 
3766 //===----------------------------------------------------------------------===//
3767 // AffineMaxOp
3768 //===----------------------------------------------------------------------===//
3769 //
3770 // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3771 //
3772 
3773 OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3774  return foldMinMaxOp(*this, adaptor.getOperands());
3775 }
3776 
3777 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3778  MLIRContext *context) {
3781  MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3783  context);
3784 }
3785 
3786 LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); }
3787 
3788 ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3789  return parseAffineMinMaxOp<AffineMaxOp>(parser, result);
3790 }
3791 
3792 void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3793 
3794 //===----------------------------------------------------------------------===//
3795 // AffinePrefetchOp
3796 //===----------------------------------------------------------------------===//
3797 
3798 //
3799 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3800 //
3801 ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3802  OperationState &result) {
3803  auto &builder = parser.getBuilder();
3804  auto indexTy = builder.getIndexType();
3805 
3806  MemRefType type;
3807  OpAsmParser::UnresolvedOperand memrefInfo;
3808  IntegerAttr hintInfo;
3809  auto i32Type = parser.getBuilder().getIntegerType(32);
3810  StringRef readOrWrite, cacheType;
3811 
3812  AffineMapAttr mapAttr;
3814  if (parser.parseOperand(memrefInfo) ||
3815  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3816  AffinePrefetchOp::getMapAttrStrName(),
3817  result.attributes) ||
3818  parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3819  parser.parseComma() || parser.parseKeyword("locality") ||
3820  parser.parseLess() ||
3821  parser.parseAttribute(hintInfo, i32Type,
3822  AffinePrefetchOp::getLocalityHintAttrStrName(),
3823  result.attributes) ||
3824  parser.parseGreater() || parser.parseComma() ||
3825  parser.parseKeyword(&cacheType) ||
3826  parser.parseOptionalAttrDict(result.attributes) ||
3827  parser.parseColonType(type) ||
3828  parser.resolveOperand(memrefInfo, type, result.operands) ||
3829  parser.resolveOperands(mapOperands, indexTy, result.operands))
3830  return failure();
3831 
3832  if (readOrWrite != "read" && readOrWrite != "write")
3833  return parser.emitError(parser.getNameLoc(),
3834  "rw specifier has to be 'read' or 'write'");
3835  result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3836  parser.getBuilder().getBoolAttr(readOrWrite == "write"));
3837 
3838  if (cacheType != "data" && cacheType != "instr")
3839  return parser.emitError(parser.getNameLoc(),
3840  "cache type has to be 'data' or 'instr'");
3841 
3842  result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3843  parser.getBuilder().getBoolAttr(cacheType == "data"));
3844 
3845  return success();
3846 }
3847 
3849  p << " " << getMemref() << '[';
3850  AffineMapAttr mapAttr =
3851  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3852  if (mapAttr)
3853  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3854  p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
3855  << "locality<" << getLocalityHint() << ">, "
3856  << (getIsDataCache() ? "data" : "instr");
3858  (*this)->getAttrs(),
3859  /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3860  getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3861  p << " : " << getMemRefType();
3862 }
3863 
3864 LogicalResult AffinePrefetchOp::verify() {
3865  auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3866  if (mapAttr) {
3867  AffineMap map = mapAttr.getValue();
3868  if (map.getNumResults() != getMemRefType().getRank())
3869  return emitOpError("affine.prefetch affine map num results must equal"
3870  " memref rank");
3871  if (map.getNumInputs() + 1 != getNumOperands())
3872  return emitOpError("too few operands");
3873  } else {
3874  if (getNumOperands() != 1)
3875  return emitOpError("too few operands");
3876  }
3877 
3878  Region *scope = getAffineScope(*this);
3879  for (auto idx : getMapOperands()) {
3880  if (!isValidAffineIndexOperand(idx, scope))
3881  return emitOpError(
3882  "index must be a valid dimension or symbol identifier");
3883  }
3884  return success();
3885 }
3886 
3887 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3888  MLIRContext *context) {
3889  // prefetch(memrefcast) -> prefetch
3890  results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
3891 }
3892 
3893 LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
3894  SmallVectorImpl<OpFoldResult> &results) {
3895  /// prefetch(memrefcast) -> prefetch
3896  return memref::foldMemRefCast(*this);
3897 }
3898 
3899 //===----------------------------------------------------------------------===//
3900 // AffineParallelOp
3901 //===----------------------------------------------------------------------===//
3902 
3903 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3904  TypeRange resultTypes,
3905  ArrayRef<arith::AtomicRMWKind> reductions,
3906  ArrayRef<int64_t> ranges) {
3907  SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
3908  auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
3909  return builder.getConstantAffineMap(value);
3910  }));
3911  SmallVector<int64_t> steps(ranges.size(), 1);
3912  build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
3913  /*ubArgs=*/{}, steps);
3914 }
3915 
3916 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
3917  TypeRange resultTypes,
3918  ArrayRef<arith::AtomicRMWKind> reductions,
3919  ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
3920  ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
3921  ArrayRef<int64_t> steps) {
3922  assert(llvm::all_of(lbMaps,
3923  [lbMaps](AffineMap m) {
3924  return m.getNumDims() == lbMaps[0].getNumDims() &&
3925  m.getNumSymbols() == lbMaps[0].getNumSymbols();
3926  }) &&
3927  "expected all lower bounds maps to have the same number of dimensions "
3928  "and symbols");
3929  assert(llvm::all_of(ubMaps,
3930  [ubMaps](AffineMap m) {
3931  return m.getNumDims() == ubMaps[0].getNumDims() &&
3932  m.getNumSymbols() == ubMaps[0].getNumSymbols();
3933  }) &&
3934  "expected all upper bounds maps to have the same number of dimensions "
3935  "and symbols");
3936  assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
3937  "expected lower bound maps to have as many inputs as lower bound "
3938  "operands");
3939  assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
3940  "expected upper bound maps to have as many inputs as upper bound "
3941  "operands");
3942 
3943  OpBuilder::InsertionGuard guard(builder);
3944  result.addTypes(resultTypes);
3945 
3946  // Convert the reductions to integer attributes.
3947  SmallVector<Attribute, 4> reductionAttrs;
3948  for (arith::AtomicRMWKind reduction : reductions)
3949  reductionAttrs.push_back(
3950  builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
3951  result.addAttribute(getReductionsAttrStrName(),
3952  builder.getArrayAttr(reductionAttrs));
3953 
3954  // Concatenates maps defined in the same input space (same dimensions and
3955  // symbols), assumes there is at least one map.
3956  auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
3957  SmallVectorImpl<int32_t> &groups) {
3958  if (maps.empty())
3959  return AffineMap::get(builder.getContext());
3961  groups.reserve(groups.size() + maps.size());
3962  exprs.reserve(maps.size());
3963  for (AffineMap m : maps) {
3964  llvm::append_range(exprs, m.getResults());
3965  groups.push_back(m.getNumResults());
3966  }
3967  return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
3968  maps[0].getContext());
3969  };
3970 
3971  // Set up the bounds.
3972  SmallVector<int32_t> lbGroups, ubGroups;
3973  AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
3974  AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
3975  result.addAttribute(getLowerBoundsMapAttrStrName(),
3976  AffineMapAttr::get(lbMap));
3977  result.addAttribute(getLowerBoundsGroupsAttrStrName(),
3978  builder.getI32TensorAttr(lbGroups));
3979  result.addAttribute(getUpperBoundsMapAttrStrName(),
3980  AffineMapAttr::get(ubMap));
3981  result.addAttribute(getUpperBoundsGroupsAttrStrName(),
3982  builder.getI32TensorAttr(ubGroups));
3983  result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
3984  result.addOperands(lbArgs);
3985  result.addOperands(ubArgs);
3986 
3987  // Create a region and a block for the body.
3988  auto *bodyRegion = result.addRegion();
3989  Block *body = builder.createBlock(bodyRegion);
3990 
3991  // Add all the block arguments.
3992  for (unsigned i = 0, e = steps.size(); i < e; ++i)
3993  body->addArgument(IndexType::get(builder.getContext()), result.location);
3994  if (resultTypes.empty())
3995  ensureTerminator(*bodyRegion, builder, result.location);
3996 }
3997 
3998 SmallVector<Region *> AffineParallelOp::getLoopRegions() {
3999  return {&getRegion()};
4000 }
4001 
4002 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
4003 
4004 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
4005  return getOperands().take_front(getLowerBoundsMap().getNumInputs());
4006 }
4007 
4008 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
4009  return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
4010 }
4011 
4012 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
4013  auto values = getLowerBoundsGroups().getValues<int32_t>();
4014  unsigned start = 0;
4015  for (unsigned i = 0; i < pos; ++i)
4016  start += values[i];
4017  return getLowerBoundsMap().getSliceMap(start, values[pos]);
4018 }
4019 
4020 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
4021  auto values = getUpperBoundsGroups().getValues<int32_t>();
4022  unsigned start = 0;
4023  for (unsigned i = 0; i < pos; ++i)
4024  start += values[i];
4025  return getUpperBoundsMap().getSliceMap(start, values[pos]);
4026 }
4027 
4028 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
4029  return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
4030 }
4031 
4032 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
4033  return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
4034 }
4035 
4036 std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
4037  if (hasMinMaxBounds())
4038  return std::nullopt;
4039 
4040  // Try to convert all the ranges to constant expressions.
4042  AffineValueMap rangesValueMap;
4043  AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
4044  &rangesValueMap);
4045  out.reserve(rangesValueMap.getNumResults());
4046  for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
4047  auto expr = rangesValueMap.getResult(i);
4048  auto cst = dyn_cast<AffineConstantExpr>(expr);
4049  if (!cst)
4050  return std::nullopt;
4051  out.push_back(cst.getValue());
4052  }
4053  return out;
4054 }
4055 
4056 Block *AffineParallelOp::getBody() { return &getRegion().front(); }
4057 
4058 OpBuilder AffineParallelOp::getBodyBuilder() {
4059  return OpBuilder(getBody(), std::prev(getBody()->end()));
4060 }
4061 
4062 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
4063  assert(lbOperands.size() == map.getNumInputs() &&
4064  "operands to map must match number of inputs");
4065 
4066  auto ubOperands = getUpperBoundsOperands();
4067 
4068  SmallVector<Value, 4> newOperands(lbOperands);
4069  newOperands.append(ubOperands.begin(), ubOperands.end());
4070  (*this)->setOperands(newOperands);
4071 
4072  setLowerBoundsMapAttr(AffineMapAttr::get(map));
4073 }
4074 
4075 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
4076  assert(ubOperands.size() == map.getNumInputs() &&
4077  "operands to map must match number of inputs");
4078 
4079  SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
4080  newOperands.append(ubOperands.begin(), ubOperands.end());
4081  (*this)->setOperands(newOperands);
4082 
4083  setUpperBoundsMapAttr(AffineMapAttr::get(map));
4084 }
4085 
4086 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
4087  setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
4088 }
4089 
4090 // check whether resultType match op or not in affine.parallel
4091 static bool isResultTypeMatchAtomicRMWKind(Type resultType,
4092  arith::AtomicRMWKind op) {
4093  switch (op) {
4094  case arith::AtomicRMWKind::addf:
4095  return isa<FloatType>(resultType);
4096  case arith::AtomicRMWKind::addi:
4097  return isa<IntegerType>(resultType);
4098  case arith::AtomicRMWKind::assign:
4099  return true;
4100  case arith::AtomicRMWKind::mulf:
4101  return isa<FloatType>(resultType);
4102  case arith::AtomicRMWKind::muli:
4103  return isa<IntegerType>(resultType);
4104  case arith::AtomicRMWKind::maximumf:
4105  return isa<FloatType>(resultType);
4106  case arith::AtomicRMWKind::minimumf:
4107  return isa<FloatType>(resultType);
4108  case arith::AtomicRMWKind::maxs: {
4109  auto intType = dyn_cast<IntegerType>(resultType);
4110  return intType && intType.isSigned();
4111  }
4112  case arith::AtomicRMWKind::mins: {
4113  auto intType = dyn_cast<IntegerType>(resultType);
4114  return intType && intType.isSigned();
4115  }
4116  case arith::AtomicRMWKind::maxu: {
4117  auto intType = dyn_cast<IntegerType>(resultType);
4118  return intType && intType.isUnsigned();
4119  }
4120  case arith::AtomicRMWKind::minu: {
4121  auto intType = dyn_cast<IntegerType>(resultType);
4122  return intType && intType.isUnsigned();
4123  }
4124  case arith::AtomicRMWKind::ori:
4125  return isa<IntegerType>(resultType);
4126  case arith::AtomicRMWKind::andi:
4127  return isa<IntegerType>(resultType);
4128  default:
4129  return false;
4130  }
4131 }
4132 
4133 LogicalResult AffineParallelOp::verify() {
4134  auto numDims = getNumDims();
4135  if (getLowerBoundsGroups().getNumElements() != numDims ||
4136  getUpperBoundsGroups().getNumElements() != numDims ||
4137  getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
4138  return emitOpError() << "the number of region arguments ("
4139  << getBody()->getNumArguments()
4140  << ") and the number of map groups for lower ("
4141  << getLowerBoundsGroups().getNumElements()
4142  << ") and upper bound ("
4143  << getUpperBoundsGroups().getNumElements()
4144  << "), and the number of steps (" << getSteps().size()
4145  << ") must all match";
4146  }
4147 
4148  unsigned expectedNumLBResults = 0;
4149  for (APInt v : getLowerBoundsGroups()) {
4150  unsigned results = v.getZExtValue();
4151  if (results == 0)
4152  return emitOpError()
4153  << "expected lower bound map to have at least one result";
4154  expectedNumLBResults += results;
4155  }
4156  if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
4157  return emitOpError() << "expected lower bounds map to have "
4158  << expectedNumLBResults << " results";
4159  unsigned expectedNumUBResults = 0;
4160  for (APInt v : getUpperBoundsGroups()) {
4161  unsigned results = v.getZExtValue();
4162  if (results == 0)
4163  return emitOpError()
4164  << "expected upper bound map to have at least one result";
4165  expectedNumUBResults += results;
4166  }
4167  if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
4168  return emitOpError() << "expected upper bounds map to have "
4169  << expectedNumUBResults << " results";
4170 
4171  if (getReductions().size() != getNumResults())
4172  return emitOpError("a reduction must be specified for each output");
4173 
4174  // Verify reduction ops are all valid and each result type matches reduction
4175  // ops
4176  for (auto it : llvm::enumerate((getReductions()))) {
4177  Attribute attr = it.value();
4178  auto intAttr = dyn_cast<IntegerAttr>(attr);
4179  if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
4180  return emitOpError("invalid reduction attribute");
4181  auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
4182  if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
4183  return emitOpError("result type cannot match reduction attribute");
4184  }
4185 
4186  // Verify that the bound operands are valid dimension/symbols.
4187  /// Lower bounds.
4188  if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
4189  getLowerBoundsMap().getNumDims())))
4190  return failure();
4191  /// Upper bounds.
4192  if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
4193  getUpperBoundsMap().getNumDims())))
4194  return failure();
4195  return success();
4196 }
4197 
4198 LogicalResult AffineValueMap::canonicalize() {
4199  SmallVector<Value, 4> newOperands{operands};
4200  auto newMap = getAffineMap();
4201  composeAffineMapAndOperands(&newMap, &newOperands);
4202  if (newMap == getAffineMap() && newOperands == operands)
4203  return failure();
4204  reset(newMap, newOperands);
4205  return success();
4206 }
4207 
4208 /// Canonicalize the bounds of the given loop.
4209 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
4210  AffineValueMap lb = op.getLowerBoundsValueMap();
4211  bool lbCanonicalized = succeeded(lb.canonicalize());
4212 
4213  AffineValueMap ub = op.getUpperBoundsValueMap();
4214  bool ubCanonicalized = succeeded(ub.canonicalize());
4215 
4216  // Any canonicalization change always leads to updated map(s).
4217  if (!lbCanonicalized && !ubCanonicalized)
4218  return failure();
4219 
4220  if (lbCanonicalized)
4221  op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
4222  if (ubCanonicalized)
4223  op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
4224 
4225  return success();
4226 }
4227 
4228 LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4229  SmallVectorImpl<OpFoldResult> &results) {
4230  return canonicalizeLoopBounds(*this);
4231 }
4232 
4233 /// Prints a lower(upper) bound of an affine parallel loop with max(min)
4234 /// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
4235 /// identifies which of the those expressions form max/min groups. `operands`
4236 /// are the SSA values of dimensions and symbols and `keyword` is either "min"
4237 /// or "max".
4238 static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
4239  DenseIntElementsAttr group, ValueRange operands,
4240  StringRef keyword) {
4241  AffineMap map = mapAttr.getValue();
4242  unsigned numDims = map.getNumDims();
4243  ValueRange dimOperands = operands.take_front(numDims);
4244  ValueRange symOperands = operands.drop_front(numDims);
4245  unsigned start = 0;
4246  for (llvm::APInt groupSize : group) {
4247  if (start != 0)
4248  p << ", ";
4249 
4250  unsigned size = groupSize.getZExtValue();
4251  if (size == 1) {
4252  p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
4253  ++start;
4254  } else {
4255  p << keyword << '(';
4256  AffineMap submap = map.getSliceMap(start, size);
4257  p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
4258  p << ')';
4259  start += size;
4260  }
4261  }
4262 }
4263 
4265  p << " (" << getBody()->getArguments() << ") = (";
4266  printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
4267  getLowerBoundsOperands(), "max");
4268  p << ") to (";
4269  printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
4270  getUpperBoundsOperands(), "min");
4271  p << ')';
4272  SmallVector<int64_t, 8> steps = getSteps();
4273  bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4274  if (!elideSteps) {
4275  p << " step (";
4276  llvm::interleaveComma(steps, p);
4277  p << ')';
4278  }
4279  if (getNumResults()) {
4280  p << " reduce (";
4281  llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4282  arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4283  llvm::cast<IntegerAttr>(attr).getInt());
4284  p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4285  });
4286  p << ") -> (" << getResultTypes() << ")";
4287  }
4288 
4289  p << ' ';
4290  p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4291  /*printBlockTerminators=*/getNumResults());
4293  (*this)->getAttrs(),
4294  /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4295  AffineParallelOp::getLowerBoundsMapAttrStrName(),
4296  AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4297  AffineParallelOp::getUpperBoundsMapAttrStrName(),
4298  AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4299  AffineParallelOp::getStepsAttrStrName()});
4300 }
4301 
4302 /// Given a list of lists of parsed operands, populates `uniqueOperands` with
4303 /// unique operands. Also populates `replacements with affine expressions of
4304 /// `kind` that can be used to update affine maps previously accepting a
4305 /// `operands` to accept `uniqueOperands` instead.
4307  OpAsmParser &parser,
4309  SmallVectorImpl<Value> &uniqueOperands,
4312  "expected operands to be dim or symbol expression");
4313 
4314  Type indexType = parser.getBuilder().getIndexType();
4315  for (const auto &list : operands) {
4316  SmallVector<Value> valueOperands;
4317  if (parser.resolveOperands(list, indexType, valueOperands))
4318  return failure();
4319  for (Value operand : valueOperands) {
4320  unsigned pos = std::distance(uniqueOperands.begin(),
4321  llvm::find(uniqueOperands, operand));
4322  if (pos == uniqueOperands.size())
4323  uniqueOperands.push_back(operand);
4324  replacements.push_back(
4326  ? getAffineDimExpr(pos, parser.getContext())
4327  : getAffineSymbolExpr(pos, parser.getContext()));
4328  }
4329  }
4330  return success();
4331 }
4332 
4333 namespace {
4334 enum class MinMaxKind { Min, Max };
4335 } // namespace
4336 
4337 /// Parses an affine map that can contain a min/max for groups of its results,
4338 /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4339 /// `result` attributes with the map (flat list of expressions) and the grouping
4340 /// (list of integers that specify how many expressions to put into each
4341 /// min/max) attributes. Deduplicates repeated operands.
4342 ///
4343 /// parallel-bound ::= `(` parallel-group-list `)`
4344 /// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
4345 /// parallel-group ::= simple-group | min-max-group
4346 /// simple-group ::= expr-of-ssa-ids
4347 /// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4348 /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4349 ///
4350 /// Examples:
4351 /// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4352 /// (%0, max(%1 - 2 * %2))
4353 static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
4354  OperationState &result,
4355  MinMaxKind kind) {
4356  // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4357  // see: https://reviews.llvm.org/D134227#3821753
4358  const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4359 
4360  StringRef mapName = kind == MinMaxKind::Min
4361  ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4362  : AffineParallelOp::getLowerBoundsMapAttrStrName();
4363  StringRef groupsName =
4364  kind == MinMaxKind::Min
4365  ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4366  : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4367 
4368  if (failed(parser.parseLParen()))
4369  return failure();
4370 
4371  if (succeeded(parser.parseOptionalRParen())) {
4372  result.addAttribute(
4373  mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4374  result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
4375  return success();
4376  }
4377 
4378  SmallVector<AffineExpr> flatExprs;
4381  SmallVector<int32_t> numMapsPerGroup;
4383  auto parseOperands = [&]() {
4384  if (succeeded(parser.parseOptionalKeyword(
4385  kind == MinMaxKind::Min ? "min" : "max"))) {
4386  mapOperands.clear();
4387  AffineMapAttr map;
4388  if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName,
4389  result.attributes,
4391  return failure();
4392  result.attributes.erase(tmpAttrStrName);
4393  llvm::append_range(flatExprs, map.getValue().getResults());
4394  auto operandsRef = llvm::ArrayRef(mapOperands);
4395  auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4397  auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4399  flatDimOperands.append(map.getValue().getNumResults(), dims);
4400  flatSymOperands.append(map.getValue().getNumResults(), syms);
4401  numMapsPerGroup.push_back(map.getValue().getNumResults());
4402  } else {
4403  if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
4404  flatSymOperands.emplace_back(),
4405  flatExprs.emplace_back())))
4406  return failure();
4407  numMapsPerGroup.push_back(1);
4408  }
4409  return success();
4410  };
4411  if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
4412  return failure();
4413 
4414  unsigned totalNumDims = 0;
4415  unsigned totalNumSyms = 0;
4416  for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4417  unsigned numDims = flatDimOperands[i].size();
4418  unsigned numSyms = flatSymOperands[i].size();
4419  flatExprs[i] = flatExprs[i]
4420  .shiftDims(numDims, totalNumDims)
4421  .shiftSymbols(numSyms, totalNumSyms);
4422  totalNumDims += numDims;
4423  totalNumSyms += numSyms;
4424  }
4425 
4426  // Deduplicate map operands.
4427  SmallVector<Value> dimOperands, symOperands;
4428  SmallVector<AffineExpr> dimRplacements, symRepacements;
4429  if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
4430  dimRplacements, AffineExprKind::DimId) ||
4431  deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
4432  symRepacements, AffineExprKind::SymbolId))
4433  return failure();
4434 
4435  result.operands.append(dimOperands.begin(), dimOperands.end());
4436  result.operands.append(symOperands.begin(), symOperands.end());
4437 
4438  Builder &builder = parser.getBuilder();
4439  auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4440  parser.getContext());
4441  flatMap = flatMap.replaceDimsAndSymbols(
4442  dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4443 
4444  result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4445  result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
4446  return success();
4447 }
4448 
4449 //
4450 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4451 // `to` parallel-bound steps? region attr-dict?
4452 // steps ::= `steps` `(` integer-literals `)`
4453 //
4454 ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4455  OperationState &result) {
4456  auto &builder = parser.getBuilder();
4457  auto indexType = builder.getIndexType();
4460  parser.parseEqual() ||
4461  parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
4462  parser.parseKeyword("to") ||
4463  parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
4464  return failure();
4465 
4466  AffineMapAttr stepsMapAttr;
4467  NamedAttrList stepsAttrs;
4469  if (failed(parser.parseOptionalKeyword("step"))) {
4470  SmallVector<int64_t, 4> steps(ivs.size(), 1);
4471  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4472  builder.getI64ArrayAttr(steps));
4473  } else {
4474  if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
4475  AffineParallelOp::getStepsAttrStrName(),
4476  stepsAttrs,
4478  return failure();
4479 
4480  // Convert steps from an AffineMap into an I64ArrayAttr.
4482  auto stepsMap = stepsMapAttr.getValue();
4483  for (const auto &result : stepsMap.getResults()) {
4484  auto constExpr = dyn_cast<AffineConstantExpr>(result);
4485  if (!constExpr)
4486  return parser.emitError(parser.getNameLoc(),
4487  "steps must be constant integers");
4488  steps.push_back(constExpr.getValue());
4489  }
4490  result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4491  builder.getI64ArrayAttr(steps));
4492  }
4493 
4494  // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4495  // quoted strings are a member of the enum AtomicRMWKind.
4496  SmallVector<Attribute, 4> reductions;
4497  if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4498  if (parser.parseLParen())
4499  return failure();
4500  auto parseAttributes = [&]() -> ParseResult {
4501  // Parse a single quoted string via the attribute parsing, and then
4502  // verify it is a member of the enum and convert to it's integer
4503  // representation.
4504  StringAttr attrVal;
4505  NamedAttrList attrStorage;
4506  auto loc = parser.getCurrentLocation();
4507  if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
4508  attrStorage))
4509  return failure();
4510  std::optional<arith::AtomicRMWKind> reduction =
4511  arith::symbolizeAtomicRMWKind(attrVal.getValue());
4512  if (!reduction)
4513  return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4514  reductions.push_back(
4515  builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4516  // While we keep getting commas, keep parsing.
4517  return success();
4518  };
4519  if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
4520  return failure();
4521  }
4522  result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4523  builder.getArrayAttr(reductions));
4524 
4525  // Parse return types of reductions (if any)
4526  if (parser.parseOptionalArrowTypeList(result.types))
4527  return failure();
4528 
4529  // Now parse the body.
4530  Region *body = result.addRegion();
4531  for (auto &iv : ivs)
4532  iv.type = indexType;
4533  if (parser.parseRegion(*body, ivs) ||
4534  parser.parseOptionalAttrDict(result.attributes))
4535  return failure();
4536 
4537  // Add a terminator if none was parsed.
4538  AffineParallelOp::ensureTerminator(*body, builder, result.location);
4539  return success();
4540 }
4541 
4542 //===----------------------------------------------------------------------===//
4543 // AffineYieldOp
4544 //===----------------------------------------------------------------------===//
4545 
4546 LogicalResult AffineYieldOp::verify() {
4547  auto *parentOp = (*this)->getParentOp();
4548  auto results = parentOp->getResults();
4549  auto operands = getOperands();
4550 
4551  if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4552  return emitOpError() << "only terminates affine.if/for/parallel regions";
4553  if (parentOp->getNumResults() != getNumOperands())
4554  return emitOpError() << "parent of yield must have same number of "
4555  "results as the yield operands";
4556  for (auto it : llvm::zip(results, operands)) {
4557  if (std::get<0>(it).getType() != std::get<1>(it).getType())
4558  return emitOpError() << "types mismatch between yield op and its parent";
4559  }
4560 
4561  return success();
4562 }
4563 
4564 //===----------------------------------------------------------------------===//
4565 // AffineVectorLoadOp
4566 //===----------------------------------------------------------------------===//
4567 
4568 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4569  VectorType resultType, AffineMap map,
4570  ValueRange operands) {
4571  assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4572  result.addOperands(operands);
4573  if (map)
4574  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4575  result.types.push_back(resultType);
4576 }
4577 
4578 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4579  VectorType resultType, Value memref,
4580  AffineMap map, ValueRange mapOperands) {
4581  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4582  result.addOperands(memref);
4583  result.addOperands(mapOperands);
4584  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4585  result.types.push_back(resultType);
4586 }
4587 
4588 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4589  VectorType resultType, Value memref,
4590  ValueRange indices) {
4591  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4592  int64_t rank = memrefType.getRank();
4593  // Create identity map for memrefs with at least one dimension or () -> ()
4594  // for zero-dimensional memrefs.
4595  auto map =
4596  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4597  build(builder, result, resultType, memref, map, indices);
4598 }
4599 
4600 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4601  MLIRContext *context) {
4602  results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4603 }
4604 
4605 ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4606  OperationState &result) {
4607  auto &builder = parser.getBuilder();
4608  auto indexTy = builder.getIndexType();
4609 
4610  MemRefType memrefType;
4611  VectorType resultType;
4612  OpAsmParser::UnresolvedOperand memrefInfo;
4613  AffineMapAttr mapAttr;
4615  return failure(
4616  parser.parseOperand(memrefInfo) ||
4617  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4618  AffineVectorLoadOp::getMapAttrStrName(),
4619  result.attributes) ||
4620  parser.parseOptionalAttrDict(result.attributes) ||
4621  parser.parseColonType(memrefType) || parser.parseComma() ||
4622  parser.parseType(resultType) ||
4623  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4624  parser.resolveOperands(mapOperands, indexTy, result.operands) ||
4625  parser.addTypeToList(resultType, result.types));
4626 }
4627 
4629  p << " " << getMemRef() << '[';
4630  if (AffineMapAttr mapAttr =
4631  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4632  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4633  p << ']';
4634  p.printOptionalAttrDict((*this)->getAttrs(),
4635  /*elidedAttrs=*/{getMapAttrStrName()});
4636  p << " : " << getMemRefType() << ", " << getType();
4637 }
4638 
4639 /// Verify common invariants of affine.vector_load and affine.vector_store.
4640 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4641  VectorType vectorType) {
4642  // Check that memref and vector element types match.
4643  if (memrefType.getElementType() != vectorType.getElementType())
4644  return op->emitOpError(
4645  "requires memref and vector types of the same elemental type");
4646  return success();
4647 }
4648 
4649 LogicalResult AffineVectorLoadOp::verify() {
4650  MemRefType memrefType = getMemRefType();
4651  if (failed(verifyMemoryOpIndexing(
4652  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4653  getMapOperands(), memrefType,
4654  /*numIndexOperands=*/getNumOperands() - 1)))
4655  return failure();
4656 
4657  if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
4658  return failure();
4659 
4660  return success();
4661 }
4662 
4663 //===----------------------------------------------------------------------===//
4664 // AffineVectorStoreOp
4665 //===----------------------------------------------------------------------===//
4666 
4667 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4668  Value valueToStore, Value memref, AffineMap map,
4669  ValueRange mapOperands) {
4670  assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4671  result.addOperands(valueToStore);
4672  result.addOperands(memref);
4673  result.addOperands(mapOperands);
4674  result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4675 }
4676 
4677 // Use identity map.
4678 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4679  Value valueToStore, Value memref,
4680  ValueRange indices) {
4681  auto memrefType = llvm::cast<MemRefType>(memref.getType());
4682  int64_t rank = memrefType.getRank();
4683  // Create identity map for memrefs with at least one dimension or () -> ()
4684  // for zero-dimensional memrefs.
4685  auto map =
4686  rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4687  build(builder, result, valueToStore, memref, map, indices);
4688 }
4689 void AffineVectorStoreOp::getCanonicalizationPatterns(
4690  RewritePatternSet &results, MLIRContext *context) {
4691  results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4692 }
4693 
4694 ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4695  OperationState &result) {
4696  auto indexTy = parser.getBuilder().getIndexType();
4697 
4698  MemRefType memrefType;
4699  VectorType resultType;
4700  OpAsmParser::UnresolvedOperand storeValueInfo;
4701  OpAsmParser::UnresolvedOperand memrefInfo;
4702  AffineMapAttr mapAttr;
4704  return failure(
4705  parser.parseOperand(storeValueInfo) || parser.parseComma() ||
4706  parser.parseOperand(memrefInfo) ||
4707  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4708  AffineVectorStoreOp::getMapAttrStrName(),
4709  result.attributes) ||
4710  parser.parseOptionalAttrDict(result.attributes) ||
4711  parser.parseColonType(memrefType) || parser.parseComma() ||
4712  parser.parseType(resultType) ||
4713  parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
4714  parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4715  parser.resolveOperands(mapOperands, indexTy, result.operands));
4716 }
4717 
4719  p << " " << getValueToStore();
4720  p << ", " << getMemRef() << '[';
4721  if (AffineMapAttr mapAttr =
4722  (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4723  p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4724  p << ']';
4725  p.printOptionalAttrDict((*this)->getAttrs(),
4726  /*elidedAttrs=*/{getMapAttrStrName()});
4727  p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4728 }
4729 
4730 LogicalResult AffineVectorStoreOp::verify() {
4731  MemRefType memrefType = getMemRefType();
4732  if (failed(verifyMemoryOpIndexing(
4733  *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4734  getMapOperands(), memrefType,
4735  /*numIndexOperands=*/getNumOperands() - 2)))
4736  return failure();
4737 
4738  if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4739  return failure();
4740 
4741  return success();
4742 }
4743 
4744 //===----------------------------------------------------------------------===//
4745 // DelinearizeIndexOp
4746 //===----------------------------------------------------------------------===//
4747 
4748 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4749  OperationState &odsState,
4750  Value linearIndex, ValueRange dynamicBasis,
4751  ArrayRef<int64_t> staticBasis,
4752  bool hasOuterBound) {
4753  SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
4754  : staticBasis.size() + 1,
4755  linearIndex.getType());
4756  build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4757  staticBasis);
4758 }
4759 
4760 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4761  OperationState &odsState,
4762  Value linearIndex, ValueRange basis,
4763  bool hasOuterBound) {
4764  if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
4765  hasOuterBound = false;
4766  basis = basis.drop_front();
4767  }
4768  SmallVector<Value> dynamicBasis;
4769  SmallVector<int64_t> staticBasis;
4770  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4771  staticBasis);
4772  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4773  hasOuterBound);
4774 }
4775 
4776 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4777  OperationState &odsState,
4778  Value linearIndex,
4779  ArrayRef<OpFoldResult> basis,
4780  bool hasOuterBound) {
4781  if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
4782  hasOuterBound = false;
4783  basis = basis.drop_front();
4784  }
4785  SmallVector<Value> dynamicBasis;
4786  SmallVector<int64_t> staticBasis;
4787  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
4788  build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4789  hasOuterBound);
4790 }
4791 
4792 void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4793  OperationState &odsState,
4794  Value linearIndex, ArrayRef<int64_t> basis,
4795  bool hasOuterBound) {
4796  build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound);
4797 }
4798 
4799 LogicalResult AffineDelinearizeIndexOp::verify() {
4800  ArrayRef<int64_t> staticBasis = getStaticBasis();
4801  if (getNumResults() != staticBasis.size() &&
4802  getNumResults() != staticBasis.size() + 1)
4803  return emitOpError("should return an index for each basis element and up "
4804  "to one extra index");
4805 
4806  auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4807  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4808  return emitOpError(
4809  "mismatch between dynamic and static basis (kDynamic marker but no "
4810  "corresponding dynamic basis entry) -- this can only happen due to an "
4811  "incorrect fold/rewrite");
4812 
4813  if (!llvm::all_of(staticBasis, [](int64_t v) {
4814  return v > 0 || ShapedType::isDynamic(v);
4815  }))
4816  return emitOpError("no basis element may be statically non-positive");
4817 
4818  return success();
4819 }
4820 
4821 /// Given mixed basis of affine.delinearize_index/linearize_index replace
4822 /// constant SSA values with the constant integer value and return the new
4823 /// static basis. In case no such candidate for replacement exists, this utility
4824 /// returns std::nullopt.
4825 static std::optional<SmallVector<int64_t>>
4827  MutableOperandRange mutableDynamicBasis,
4828  ArrayRef<Attribute> dynamicBasis) {
4829  uint64_t dynamicBasisIndex = 0;
4830  for (OpFoldResult basis : dynamicBasis) {
4831  if (basis) {
4832  mutableDynamicBasis.erase(dynamicBasisIndex);
4833  } else {
4834  ++dynamicBasisIndex;
4835  }
4836  }
4837 
4838  // No constant SSA value exists.
4839  if (dynamicBasisIndex == dynamicBasis.size())
4840  return std::nullopt;
4841 
4842  SmallVector<int64_t> staticBasis;
4843  for (OpFoldResult basis : mixedBasis) {
4844  std::optional<int64_t> basisVal = getConstantIntValue(basis);
4845  if (!basisVal)
4846  staticBasis.push_back(ShapedType::kDynamic);
4847  else
4848  staticBasis.push_back(*basisVal);
4849  }
4850 
4851  return staticBasis;
4852 }
4853 
4854 LogicalResult
4855 AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4857  std::optional<SmallVector<int64_t>> maybeStaticBasis =
4858  foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
4859  adaptor.getDynamicBasis());
4860  if (maybeStaticBasis) {
4861  setStaticBasis(*maybeStaticBasis);
4862  return success();
4863  }
4864  // If we won't be doing any division or modulo (no basis or the one basis
4865  // element is purely advisory), simply return the input value.
4866  if (getNumResults() == 1) {
4867  result.push_back(getLinearIndex());
4868  return success();
4869  }
4870 
4871  if (adaptor.getLinearIndex() == nullptr)
4872  return failure();
4873 
4874  if (!adaptor.getDynamicBasis().empty())
4875  return failure();
4876 
4877  int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
4878  Type attrType = getLinearIndex().getType();
4879 
4880  ArrayRef<int64_t> staticBasis = getStaticBasis();
4881  if (hasOuterBound())
4882  staticBasis = staticBasis.drop_front();
4883  for (int64_t modulus : llvm::reverse(staticBasis)) {
4884  result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
4885  highPart = llvm::divideFloorSigned(highPart, modulus);
4886  }
4887  result.push_back(IntegerAttr::get(attrType, highPart));
4888  std::reverse(result.begin(), result.end());
4889  return success();
4890 }
4891 
4892 SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
4893  OpBuilder builder(getContext());
4894  if (hasOuterBound()) {
4895  if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
4896  return getMixedValues(getStaticBasis().drop_front(),
4897  getDynamicBasis().drop_front(), builder);
4898 
4899  return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
4900  builder);
4901  }
4902 
4903  return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
4904 }
4905 
4906 SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
4907  SmallVector<OpFoldResult> ret = getMixedBasis();
4908  if (!hasOuterBound())
4909  ret.insert(ret.begin(), OpFoldResult());
4910  return ret;
4911 }
4912 
4913 namespace {
4914 
4915 // Drops delinearization indices that correspond to unit-extent basis
4916 struct DropUnitExtentBasis
4917  : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4919 
4920  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4921  PatternRewriter &rewriter) const override {
4922  SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
4923  std::optional<Value> zero = std::nullopt;
4924  Location loc = delinearizeOp->getLoc();
4925  auto getZero = [&]() -> Value {
4926  if (!zero)
4927  zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
4928  return zero.value();
4929  };
4930 
4931  // Replace all indices corresponding to unit-extent basis with 0.
4932  // Remaining basis can be used to get a new `affine.delinearize_index` op.
4933  SmallVector<OpFoldResult> newBasis;
4934  for (auto [index, basis] :
4935  llvm::enumerate(delinearizeOp.getPaddedBasis())) {
4936  std::optional<int64_t> basisVal =
4937  basis ? getConstantIntValue(basis) : std::nullopt;
4938  if (basisVal == 1)
4939  replacements[index] = getZero();
4940  else
4941  newBasis.push_back(basis);
4942  }
4943 
4944  if (newBasis.size() == delinearizeOp.getNumResults())
4945  return rewriter.notifyMatchFailure(delinearizeOp,
4946  "no unit basis elements");
4947 
4948  if (!newBasis.empty()) {
4949  // Will drop the leading nullptr from `basis` if there was no outer bound.
4950  auto newDelinearizeOp = affine::AffineDelinearizeIndexOp::create(
4951  rewriter, loc, delinearizeOp.getLinearIndex(), newBasis);
4952  int newIndex = 0;
4953  // Map back the new delinearized indices to the values they replace.
4954  for (auto &replacement : replacements) {
4955  if (replacement)
4956  continue;
4957  replacement = newDelinearizeOp->getResult(newIndex++);
4958  }
4959  }
4960 
4961  rewriter.replaceOp(delinearizeOp, replacements);
4962  return success();
4963  }
4964 };
4965 
4966 /// If a `affine.delinearize_index`'s input is a `affine.linearize_index
4967 /// disjoint` and the two operations end with the same basis elements,
4968 /// cancel those parts of the operations out because they are inverses
4969 /// of each other.
4970 ///
4971 /// If the operations have the same basis, cancel them entirely.
4972 ///
4973 /// The `disjoint` flag is needed on the `affine.linearize_index` because
4974 /// otherwise, there is no guarantee that the inputs to the linearization are
4975 /// in-bounds the way the outputs of the delinearization would be.
4976 struct CancelDelinearizeOfLinearizeDisjointExactTail
4977  : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
4979 
4980  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
4981  PatternRewriter &rewriter) const override {
4982  auto linearizeOp = delinearizeOp.getLinearIndex()
4983  .getDefiningOp<affine::AffineLinearizeIndexOp>();
4984  if (!linearizeOp)
4985  return rewriter.notifyMatchFailure(delinearizeOp,
4986  "index doesn't come from linearize");
4987 
4988  if (!linearizeOp.getDisjoint())
4989  return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
4990 
4991  ValueRange linearizeIns = linearizeOp.getMultiIndex();
4992  // Note: we use the full basis so we don't lose outer bounds later.
4993  SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
4994  SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
4995  size_t numMatches = 0;
4996  for (auto [linSize, delinSize] : llvm::zip(
4997  llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
4998  if (linSize != delinSize)
4999  break;
5000  ++numMatches;
5001  }
5002 
5003  if (numMatches == 0)
5004  return rewriter.notifyMatchFailure(
5005  delinearizeOp, "final basis element doesn't match linearize");
5006 
5007  // The easy case: everything lines up and the basis match sup completely.
5008  if (numMatches == linearizeBasis.size() &&
5009  numMatches == delinearizeBasis.size() &&
5010  linearizeIns.size() == delinearizeOp.getNumResults()) {
5011  rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
5012  return success();
5013  }
5014 
5015  Value newLinearize = affine::AffineLinearizeIndexOp::create(
5016  rewriter, linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
5017  ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
5018  linearizeOp.getDisjoint());
5019  auto newDelinearize = affine::AffineDelinearizeIndexOp::create(
5020  rewriter, delinearizeOp.getLoc(), newLinearize,
5021  ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
5022  delinearizeOp.hasOuterBound());
5023  SmallVector<Value> mergedResults(newDelinearize.getResults());
5024  mergedResults.append(linearizeIns.take_back(numMatches).begin(),
5025  linearizeIns.take_back(numMatches).end());
5026  rewriter.replaceOp(delinearizeOp, mergedResults);
5027  return success();
5028  }
5029 };
5030 
5031 /// If the input to a delinearization is a disjoint linearization, and the
5032 /// last k > 1 components of the delinearization basis multiply to the
5033 /// last component of the linearization basis, break the linearization and
5034 /// delinearization into two parts, peeling off the last input to linearization.
5035 ///
5036 /// For example:
5037 /// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
5038 /// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
5039 /// becomes
5040 /// %0 = affine.linearize_index [%z, %y] by (3, 2) : index
5041 /// %1:2 = affine.delinearize_index %0 by (2, 3) : index
5042 /// %2:2 = affine.delinearize_index %x by (8, 4) : index
5043 /// where the original %1:4 is replaced by %1:2 ++ %2:2
5044 struct SplitDelinearizeSpanningLastLinearizeArg final
5045  : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5047 
5048  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5049  PatternRewriter &rewriter) const override {
5050  auto linearizeOp = delinearizeOp.getLinearIndex()
5051  .getDefiningOp<affine::AffineLinearizeIndexOp>();
5052  if (!linearizeOp)
5053  return rewriter.notifyMatchFailure(delinearizeOp,
5054  "index doesn't come from linearize");
5055 
5056  if (!linearizeOp.getDisjoint())
5057  return rewriter.notifyMatchFailure(linearizeOp,
5058  "linearize isn't disjoint");
5059 
5060  int64_t target = linearizeOp.getStaticBasis().back();
5061  if (ShapedType::isDynamic(target))
5062  return rewriter.notifyMatchFailure(
5063  linearizeOp, "linearize ends with dynamic basis value");
5064 
5065  int64_t sizeToSplit = 1;
5066  size_t elemsToSplit = 0;
5067  ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
5068  for (int64_t basisElem : llvm::reverse(basis)) {
5069  if (ShapedType::isDynamic(basisElem))
5070  return rewriter.notifyMatchFailure(
5071  delinearizeOp, "dynamic basis element while scanning for split");
5072  sizeToSplit *= basisElem;
5073  elemsToSplit += 1;
5074 
5075  if (sizeToSplit > target)
5076  return rewriter.notifyMatchFailure(delinearizeOp,
5077  "overshot last argument size");
5078  if (sizeToSplit == target)
5079  break;
5080  }
5081 
5082  if (sizeToSplit < target)
5083  return rewriter.notifyMatchFailure(
5084  delinearizeOp, "product of known basis elements doesn't exceed last "
5085  "linearize argument");
5086 
5087  if (elemsToSplit < 2)
5088  return rewriter.notifyMatchFailure(
5089  delinearizeOp,
5090  "need at least two elements to form the basis product");
5091 
5092  Value linearizeWithoutBack = affine::AffineLinearizeIndexOp::create(
5093  rewriter, linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
5094  linearizeOp.getDynamicBasis(), linearizeOp.getStaticBasis().drop_back(),
5095  linearizeOp.getDisjoint());
5096  auto delinearizeWithoutSplitPart = affine::AffineDelinearizeIndexOp::create(
5097  rewriter, delinearizeOp.getLoc(), linearizeWithoutBack,
5098  delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
5099  delinearizeOp.hasOuterBound());
5100  auto delinearizeBack = affine::AffineDelinearizeIndexOp::create(
5101  rewriter, delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
5102  basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
5103  SmallVector<Value> results = llvm::to_vector(
5104  llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
5105  delinearizeBack.getResults()));
5106  rewriter.replaceOp(delinearizeOp, results);
5107 
5108  return success();
5109  }
5110 };
5111 } // namespace
5112 
5113 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
5114  RewritePatternSet &patterns, MLIRContext *context) {
5115  patterns
5116  .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
5117  DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
5118  context);
5119 }
5120 
5121 //===----------------------------------------------------------------------===//
5122 // LinearizeIndexOp
5123 //===----------------------------------------------------------------------===//
5124 
5125 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5126  OperationState &odsState,
5127  ValueRange multiIndex, ValueRange basis,
5128  bool disjoint) {
5129  if (!basis.empty() && basis.front() == Value())
5130  basis = basis.drop_front();
5131  SmallVector<Value> dynamicBasis;
5132  SmallVector<int64_t> staticBasis;
5133  dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
5134  staticBasis);
5135  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5136 }
5137 
5138 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5139  OperationState &odsState,
5140  ValueRange multiIndex,
5141  ArrayRef<OpFoldResult> basis,
5142  bool disjoint) {
5143  if (!basis.empty() && basis.front() == OpFoldResult())
5144  basis = basis.drop_front();
5145  SmallVector<Value> dynamicBasis;
5146  SmallVector<int64_t> staticBasis;
5147  dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
5148  build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5149 }
5150 
5151 void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5152  OperationState &odsState,
5153  ValueRange multiIndex,
5154  ArrayRef<int64_t> basis, bool disjoint) {
5155  build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
5156 }
5157 
5158 LogicalResult AffineLinearizeIndexOp::verify() {
5159  size_t numIndexes = getMultiIndex().size();
5160  size_t numBasisElems = getStaticBasis().size();
5161  if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
5162  return emitOpError("should be passed a basis element for each index except "
5163  "possibly the first");
5164 
5165  auto dynamicMarkersCount =
5166  llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
5167  if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
5168  return emitOpError(
5169  "mismatch between dynamic and static basis (kDynamic marker but no "
5170  "corresponding dynamic basis entry) -- this can only happen due to an "
5171  "incorrect fold/rewrite");
5172 
5173  return success();
5174 }
5175 
5176 OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
5177  std::optional<SmallVector<int64_t>> maybeStaticBasis =
5178  foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
5179  adaptor.getDynamicBasis());
5180  if (maybeStaticBasis) {
5181  setStaticBasis(*maybeStaticBasis);
5182  return getResult();
5183  }
5184  // No indices linearizes to zero.
5185  if (getMultiIndex().empty())
5186  return IntegerAttr::get(getResult().getType(), 0);
5187 
5188  // One single index linearizes to itself.
5189  if (getMultiIndex().size() == 1)
5190  return getMultiIndex().front();
5191 
5192  if (llvm::is_contained(adaptor.getMultiIndex(), nullptr))
5193  return nullptr;
5194 
5195  if (!adaptor.getDynamicBasis().empty())
5196  return nullptr;
5197 
5198  int64_t result = 0;
5199  int64_t stride = 1;
5200  for (auto [length, indexAttr] :
5201  llvm::zip_first(llvm::reverse(getStaticBasis()),
5202  llvm::reverse(adaptor.getMultiIndex()))) {
5203  result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
5204  stride = stride * length;
5205  }
5206  // Handle the index element with no basis element.
5207  if (!hasOuterBound())
5208  result =
5209  result +
5210  cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
5211 
5212  return IntegerAttr::get(getResult().getType(), result);
5213 }
5214 
5215 SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
5216  OpBuilder builder(getContext());
5217  if (hasOuterBound()) {
5218  if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5219  return getMixedValues(getStaticBasis().drop_front(),
5220  getDynamicBasis().drop_front(), builder);
5221 
5222  return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5223  builder);
5224  }
5225 
5226  return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5227 }
5228 
5229 SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
5230  SmallVector<OpFoldResult> ret = getMixedBasis();
5231  if (!hasOuterBound())
5232  ret.insert(ret.begin(), OpFoldResult());
5233  return ret;
5234 }
5235 
5236 namespace {
5237 /// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
5238 /// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
5239 /// %...d)`.
5240 
5241 /// Note that `disjoint` is required here, because, without it, we could have
5242 /// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
5243 /// is a valid operation where the `%c64` cannot be trivially dropped.
5244 ///
5245 /// Alternatively, if `%x` in the above is a known constant 0, remove it even if
5246 /// the operation isn't asserted to be `disjoint`.
5247 struct DropLinearizeUnitComponentsIfDisjointOrZero final
5248  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5250 
5251  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5252  PatternRewriter &rewriter) const override {
5253  ValueRange multiIndex = op.getMultiIndex();
5254  size_t numIndices = multiIndex.size();
5255  SmallVector<Value> newIndices;
5256  newIndices.reserve(numIndices);
5257  SmallVector<OpFoldResult> newBasis;
5258  newBasis.reserve(numIndices);
5259 
5260  if (!op.hasOuterBound()) {
5261  newIndices.push_back(multiIndex.front());
5262  multiIndex = multiIndex.drop_front();
5263  }
5264 
5265  SmallVector<OpFoldResult> basis = op.getMixedBasis();
5266  for (auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5267  std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
5268  if (!basisEntry || *basisEntry != 1) {
5269  newIndices.push_back(index);
5270  newBasis.push_back(basisElem);
5271  continue;
5272  }
5273 
5274  std::optional<int64_t> indexValue = getConstantIntValue(index);
5275  if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5276  newIndices.push_back(index);
5277  newBasis.push_back(basisElem);
5278  continue;
5279  }
5280  }
5281  if (newIndices.size() == numIndices)
5282  return rewriter.notifyMatchFailure(op,
5283  "no unit basis entries to replace");
5284 
5285  if (newIndices.size() == 0) {
5286  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
5287  return success();
5288  }
5289  rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5290  op, newIndices, newBasis, op.getDisjoint());
5291  return success();
5292  }
5293 };
5294 
5296  ArrayRef<OpFoldResult> terms) {
5297  int64_t nDynamic = 0;
5298  SmallVector<Value> dynamicPart;
5299  AffineExpr result = builder.getAffineConstantExpr(1);
5300  for (OpFoldResult term : terms) {
5301  if (!term)
5302  return term;
5303  std::optional<int64_t> maybeConst = getConstantIntValue(term);
5304  if (maybeConst) {
5305  result = result * builder.getAffineConstantExpr(*maybeConst);
5306  } else {
5307  dynamicPart.push_back(cast<Value>(term));
5308  result = result * builder.getAffineSymbolExpr(nDynamic++);
5309  }
5310  }
5311  if (auto constant = dyn_cast<AffineConstantExpr>(result))
5312  return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
5313  return AffineApplyOp::create(builder, loc, result, dynamicPart).getResult();
5314 }
5315 
5316 /// If conseceutive outputs of a delinearize_index are linearized with the same
5317 /// bounds, canonicalize away the redundant arithmetic.
5318 ///
5319 /// That is, if we have
5320 /// ```
5321 /// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
5322 /// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
5323 /// by (...e, B1, B2, ..., BK, ...f)
5324 /// ```
5325 ///
5326 /// We can rewrite this to
5327 /// ```
5328 /// B = B1 * B2 ... BK
5329 /// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
5330 /// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
5331 /// ```
5332 /// where we replace all results of %s unaffected by the change with results
5333 /// from %sMerged.
5334 ///
5335 /// As a special case, if all results of the delinearize are merged in this way
5336 /// we can replace those usages with %x, thus cancelling the delinearization
5337 /// entirely, as in
5338 /// ```
5339 /// %s:3 = affine.delinearize_index %x into (2, 4, 8)
5340 /// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
5341 /// ```
5342 /// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
5343 struct CancelLinearizeOfDelinearizePortion final
5344  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5346 
5347 private:
5348  // Struct representing a case where the cancellation pattern
5349  // applies. A `Match` means that `length` inputs to the linearize operation
5350  // starting at `linStart` can be cancelled with `length` outputs of
5351  // `delinearize`, starting from `delinStart`.
5352  struct Match {
5353  AffineDelinearizeIndexOp delinearize;
5354  unsigned linStart = 0;
5355  unsigned delinStart = 0;
5356  unsigned length = 0;
5357  };
5358 
5359 public:
5360  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5361  PatternRewriter &rewriter) const override {
5362  SmallVector<Match> matches;
5363 
5364  const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
5365  ArrayRef<OpFoldResult> linBasisRef = linBasis;
5366 
5367  ValueRange multiIndex = linearizeOp.getMultiIndex();
5368  unsigned numLinArgs = multiIndex.size();
5369  unsigned linArgIdx = 0;
5370  // We only want to replace one run from the same delinearize op per
5371  // pattern invocation lest we run into invalidation issues.
5372  llvm::SmallPtrSet<Operation *, 2> alreadyMatchedDelinearize;
5373  while (linArgIdx < numLinArgs) {
5374  auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5375  if (!asResult) {
5376  linArgIdx++;
5377  continue;
5378  }
5379 
5380  auto delinearizeOp =
5381  dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5382  if (!delinearizeOp) {
5383  linArgIdx++;
5384  continue;
5385  }
5386 
5387  /// Result 0 of the delinearize and argument 0 of the linearize can
5388  /// leave their maximum value unspecified. However, even if this happens
5389  /// we can still sometimes start the match process. Specifically, if
5390  /// - The argument we're matching is result 0 and argument 0 (so the
5391  /// bounds don't matter). For example,
5392  ///
5393  /// %0:2 = affine.delinearize_index %x into (8) : index, index
5394  /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
5395  /// allows cancellation
5396  /// - The delinearization doesn't specify a bound, but the linearization
5397  /// is `disjoint`, which asserts that the bound on the linearization is
5398  /// correct.
5399  unsigned delinArgIdx = asResult.getResultNumber();
5400  SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
5401  OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5402  OpFoldResult firstLinBound = linBasis[linArgIdx];
5403  bool boundsMatch = firstDelinBound == firstLinBound;
5404  bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5405  bool knownByDisjoint =
5406  linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5407  if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5408  linArgIdx++;
5409  continue;
5410  }
5411 
5412  unsigned j = 1;
5413  unsigned numDelinOuts = delinearizeOp.getNumResults();
5414  for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
5415  ++j) {
5416  if (multiIndex[linArgIdx + j] !=
5417  delinearizeOp.getResult(delinArgIdx + j))
5418  break;
5419  if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
5420  break;
5421  }
5422  // If there're multiple matches against the same delinearize_index,
5423  // only rewrite the first one we find to prevent invalidations. The next
5424  // ones will be taken care of by subsequent pattern invocations.
5425  if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5426  linArgIdx++;
5427  continue;
5428  }
5429  matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx, j});
5430  linArgIdx += j;
5431  }
5432 
5433  if (matches.empty())
5434  return rewriter.notifyMatchFailure(
5435  linearizeOp, "no run of delinearize outputs to deal with");
5436 
5437  // Record all the delinearize replacements so we can do them after creating
5438  // the new linearization operation, since the new operation might use
5439  // outputs of something we're replacing.
5440  SmallVector<SmallVector<Value>> delinearizeReplacements;
5441 
5442  SmallVector<Value> newIndex;
5443  newIndex.reserve(numLinArgs);
5444  SmallVector<OpFoldResult> newBasis;
5445  newBasis.reserve(numLinArgs);
5446  unsigned prevMatchEnd = 0;
5447  for (Match m : matches) {
5448  unsigned gap = m.linStart - prevMatchEnd;
5449  llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5450  llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5451  // Update here so we don't forget this during early continues
5452  prevMatchEnd = m.linStart + m.length;
5453 
5454  PatternRewriter::InsertionGuard g(rewriter);
5455  rewriter.setInsertionPoint(m.delinearize);
5456 
5457  ArrayRef<OpFoldResult> basisToMerge =
5458  linBasisRef.slice(m.linStart, m.length);
5459  // We use the slice from the linearize's basis above because of the
5460  // "bounds inferred from `disjoint`" case above.
5461  OpFoldResult newSize =
5462  computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
5463 
5464  // Trivial case where we can just skip past the delinearize all together
5465  if (m.length == m.delinearize.getNumResults()) {
5466  newIndex.push_back(m.delinearize.getLinearIndex());
5467  newBasis.push_back(newSize);
5468  // Pad out set of replacements so we don't do anything with this one.
5469  delinearizeReplacements.push_back(SmallVector<Value>());
5470  continue;
5471  }
5472 
5473  SmallVector<Value> newDelinResults;
5474  SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
5475  newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5476  newDelinBasis.begin() + m.delinStart + m.length);
5477  newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5478  auto newDelinearize = AffineDelinearizeIndexOp::create(
5479  rewriter, m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5480  newDelinBasis);
5481 
5482  // Since there may be other uses of the indices we just merged together,
5483  // create a residual affine.delinearize_index that delinearizes the
5484  // merged output into its component parts.
5485  Value combinedElem = newDelinearize.getResult(m.delinStart);
5486  auto residualDelinearize = AffineDelinearizeIndexOp::create(
5487  rewriter, m.delinearize.getLoc(), combinedElem, basisToMerge);
5488 
5489  // Swap all the uses of the unaffected delinearize outputs to the new
5490  // delinearization so that the old code can be removed if this
5491  // linearize_index is the only user of the merged results.
5492  llvm::append_range(newDelinResults,
5493  newDelinearize.getResults().take_front(m.delinStart));
5494  llvm::append_range(newDelinResults, residualDelinearize.getResults());
5495  llvm::append_range(
5496  newDelinResults,
5497  newDelinearize.getResults().drop_front(m.delinStart + 1));
5498 
5499  delinearizeReplacements.push_back(newDelinResults);
5500  newIndex.push_back(combinedElem);
5501  newBasis.push_back(newSize);
5502  }
5503  llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5504  llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5505  rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
5506  linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5507 
5508  for (auto [m, newResults] :
5509  llvm::zip_equal(matches, delinearizeReplacements)) {
5510  if (newResults.empty())
5511  continue;
5512  rewriter.replaceOp(m.delinearize, newResults);
5513  }
5514 
5515  return success();
5516  }
5517 };
5518 
5519 /// Strip leading zero from affine.linearize_index.
5520 ///
5521 /// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
5522 /// to `affine.linearize_index [...a] by (...b)` in all cases.
5523 struct DropLinearizeLeadingZero final
5524  : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5526 
5527  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5528  PatternRewriter &rewriter) const override {
5529  Value leadingIdx = op.getMultiIndex().front();
5530  if (!matchPattern(leadingIdx, m_Zero()))
5531  return failure();
5532 
5533  if (op.getMultiIndex().size() == 1) {
5534  rewriter.replaceOp(op, leadingIdx);
5535  return success();
5536  }
5537 
5538  SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
5539  ArrayRef<OpFoldResult> newMixedBasis = mixedBasis;
5540  if (op.hasOuterBound())
5541  newMixedBasis = newMixedBasis.drop_front();
5542 
5543  rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5544  op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5545  return success();
5546  }
5547 };
5548 } // namespace
5549 
5550 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5551  RewritePatternSet &patterns, MLIRContext *context) {
5552  patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5553  DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5554 }
5555 
5556 //===----------------------------------------------------------------------===//
5557 // TableGen'd op method definitions
5558 //===----------------------------------------------------------------------===//
5559 
5560 #define GET_OP_CLASSES
5561 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
static Value getStride(Location loc, MemRefType mType, Value base, RewriterBase &rewriter)
Maps the 2-dim memref shape to the 64-bit stride.
Definition: AMXDialect.cpp:85
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
Definition: AffineOps.cpp:2887
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
Definition: AffineOps.cpp:2608
static LogicalResult verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
Definition: AffineOps.cpp:3280
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
Definition: AffineOps.cpp:3458
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
Definition: AffineOps.cpp:4091
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
Definition: AffineOps.cpp:62
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
Definition: AffineOps.cpp:4238
static void LLVM_ATTRIBUTE_UNUSED simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
Definition: AffineOps.cpp:1034
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
Definition: AffineOps.cpp:3494
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
Definition: AffineOps.cpp:2462
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
Definition: AffineOps.cpp:820
static bool isValidAffineIndexOperand(Value value, Region *region)
Definition: AffineOps.cpp:490
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1540
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:755
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
Definition: AffineOps.cpp:2151
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:747
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
Definition: AffineOps.cpp:1446
static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType)
Verify common invariants of affine.vector_load and affine.vector_store.
Definition: AffineOps.cpp:4640
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
Definition: AffineOps.cpp:923
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
Definition: AffineOps.cpp:3471
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
Definition: AffineOps.cpp:3172
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
Definition: AffineOps.cpp:349
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
Definition: AffineOps.cpp:695
static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser, OperationState &result, MinMaxKind kind)
Parses an affine map that can contain a min/max for groups of its results, e.g., max(expr-1,...
Definition: AffineOps.cpp:4353
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
Definition: AffineOps.cpp:2896
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
Definition: AffineOps.cpp:495
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
Definition: AffineOps.cpp:657
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
Definition: AffineOps.cpp:1194
static void legalizeDemotedDims(MapOrSet &mapOrSet, SmallVectorImpl< Value > &operands)
A valid affine dimension may appear as a symbol in affine.apply operations.
Definition: AffineOps.cpp:1494
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1383
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
Definition: AffineOps.cpp:2846
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
Definition: AffineOps.cpp:2416
static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, AffineExpr dimOrSym, AffineMap *map, ValueRange dims, ValueRange syms)
Assuming dimOrSym is a quantity in the apply op map map and defined by minOp = affine_min(x_1,...
Definition: AffineOps.cpp:1059
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
Definition: AffineOps.cpp:527
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Definition: AffineOps.cpp:1398
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
Definition: AffineOps.cpp:368
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr &quotientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
Definition: AffineOps.cpp:723
static ParseResult deduplicateAndResolveOperands(OpAsmParser &parser, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> operands, SmallVectorImpl< Value > &uniqueOperands, SmallVectorImpl< AffineExpr > &replacements, AffineExprKind kind)
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
Definition: AffineOps.cpp:4306
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms, bool replaceAffineMin)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
Definition: AffineOps.cpp:1138
static LogicalResult verifyAffineMinMaxOp(T op)
Definition: AffineOps.cpp:3445
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
Definition: AffineOps.cpp:2326
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
Definition: AffineOps.cpp:1297
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
Definition: AffineOps.cpp:4826
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
Definition: AffineOps.cpp:3657
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
union mlir::linalg::@1227::ArityGroupAndKind::Kind kind
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:73
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
Definition: SCFToGPU.cpp:78
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:959
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:33
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
Definition: AffineExpr.cpp:241
MLIRContext * getContext() const
Definition: AffineExpr.cpp:31
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
Definition: AffineExpr.cpp:179
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
Definition: AffineMap.cpp:655
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
Definition: AffineMap.h:221
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:267
unsigned getNumSymbols() const
Definition: AffineMap.cpp:394
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
Definition: AffineMap.h:228
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
Definition: AffineMap.cpp:496
unsigned getNumInputs() const
Definition: AffineMap.cpp:399
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ...
Definition: AffineMap.h:280
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:511
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
Definition: AffineMap.cpp:124
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
Definition: AffineMap.cpp:647
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
Definition: AffineMap.cpp:430
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:158
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
AffineMap getDimIdentityMap()
Definition: Builders.cpp:378
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:382
AffineExpr getAffineSymbolExpr(unsigned position)
Definition: Builders.cpp:363
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:367
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
Definition: Builders.cpp:174
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
NoneType getNoneType()
Definition: Builders.cpp:83
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:95
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
Definition: Builders.cpp:371
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition: Builders.cpp:373
MLIRContext * getContext() const
Definition: Builders.h:55
AffineMap getSymbolIdentityMap()
Definition: Builders.cpp:391
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:276
IndexType getIndexType()
Definition: Builders.cpp:50
An attribute that represents a reference to a dense integer vector or tensor object.
This is the interface that must be implemented by the dialects of operations to be inlined.
Definition: InliningUtils.h:44
DialectInlinerInterface(Dialect *dialect)
Definition: InliningUtils.h:46
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:72
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:621
Location getLoc() const
Accessors for the implied location.
Definition: Builders.h:654
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition: IntegerSet.h:44
unsigned getNumDims() const
Definition: IntegerSet.cpp:15
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
Definition: IntegerSet.cpp:57
unsigned getNumInputs() const
Definition: IntegerSet.cpp:17
ArrayRef< AffineExpr > getConstraints() const
Definition: IntegerSet.cpp:41
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
Definition: IntegerSet.cpp:51
unsigned getNumSymbols() const
Definition: IntegerSet.cpp:16
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void pop_back()
Pop last element from list.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition: Operation.h:230
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:218
operand_range::iterator operand_iterator
Definition: Operation.h:372
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
bool hasOneBlock()
Return true if this region has exactly one block.
Definition: Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:845
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:612
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
A variable that can be added to the constraint set as a "column".
static bool compare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Return "true" if "lhs cmp rhs" was proven to hold.
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
AffineBound represents a lower or upper bound in the for operation.
Definition: AffineOps.h:550
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
Definition: AffineOps.h:106
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
Definition: AffineOps.h:330
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
Definition: AffineOps.cpp:4198
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
unsigned getNumResults() const
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Operation * getOwner() const
Return the owner of this operand.
Definition: UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
Definition: AffineOps.cpp:2909
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1274
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
Definition: AffineOps.cpp:2823
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
Definition: AffineOps.cpp:290
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
Definition: AffineOps.cpp:2795
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
Definition: AffineOps.cpp:1372
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
Definition: AffineOps.cpp:2799
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
Definition: AffineOps.cpp:1617
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
Definition: AffineOps.cpp:1437
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
Definition: AffineOps.cpp:2787
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1327
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
Definition: AffineOps.cpp:1430
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Definition: AffineOps.cpp:250
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
Definition: AffineOps.cpp:275
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
Definition: AffineOps.cpp:1258
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
Definition: AffineOps.cpp:1622
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
Definition: AffineOps.cpp:2830
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
Definition: AffineOps.cpp:412
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Definition: AffineOps.cpp:2810
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
Definition: AffineOps.cpp:265
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
Definition: AffineOps.cpp:505
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
Definition: AffineOps.cpp:2791
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
Definition: AffineOps.cpp:1392
BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout={}, Attribute memorySpace=nullptr)
Return a MemRefType to which the TensorType can be bufferized.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
Definition: AffineMap.cpp:766
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
Definition: AffineMap.cpp:776
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t >> constLowerBounds, ArrayRef< std::optional< int64_t >> constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
AffineExprKind
Definition: AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ SymbolId
Symbolic identifier.
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
Definition: AffineExpr.cpp:68
std::function< SmallVector< Value >(OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition: Matchers.h:442
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
Definition: AffineMap.cpp:738
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Definition: AffineExpr.cpp:629
Canonicalize the affine map result expression order of an affine min/max operation.
Definition: AffineOps.cpp:3711
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3714
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3728
Remove duplicated expressions in affine min/max ops.
Definition: AffineOps.cpp:3527
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3530
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
Definition: AffineOps.cpp:3570
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Definition: AffineOps.cpp:3573
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition: Builders.h:283
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.