MLIR 23.0.0git
AffineOps.cpp
Go to the documentation of this file.
1//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/AffineExpr.h"
16#include "mlir/IR/IRMapping.h"
17#include "mlir/IR/IntegerSet.h"
18#include "mlir/IR/Matchers.h"
21#include "mlir/IR/Value.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/SmallVectorExtras.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Support/DebugLog.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/MathExtras.h"
32#include <numeric>
33#include <optional>
34
35using namespace mlir;
36using namespace mlir::affine;
37
38using llvm::divideCeilSigned;
39using llvm::divideFloorSigned;
40using llvm::mod;
41
42#define DEBUG_TYPE "affine-ops"
43
44#include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
45
46/// A utility function to check if a value is defined at the top level of
47/// `region` or is an argument of `region`. A value of index type defined at the
48/// top level of a `AffineScope` region is always a valid symbol for all
49/// uses in that region.
51 if (auto arg = dyn_cast<BlockArgument>(value))
52 return arg.getParentRegion() == region;
53 return value.getDefiningOp()->getParentRegion() == region;
54}
55
56/// Checks if `value` known to be a legal affine dimension or symbol in `src`
57/// region remains legal if the operation that uses it is inlined into `dest`
58/// with the given value mapping. `legalityCheck` is either `isValidDim` or
59/// `isValidSymbol`, depending on the value being required to remain a valid
60/// dimension or symbol.
61static bool
63 const IRMapping &mapping,
64 function_ref<bool(Value, Region *)> legalityCheck) {
65 // If the value is a valid dimension for any other reason than being
66 // a top-level value, it will remain valid: constants get inlined
67 // with the function, transitive affine applies also get inlined and
68 // will be checked themselves, etc.
69 if (!isTopLevelValue(value, src))
70 return true;
71
72 // If it's a top-level value because it's a block operand, i.e. a
73 // function argument, check whether the value replacing it after
74 // inlining is a valid dimension in the new region.
75 if (llvm::isa<BlockArgument>(value))
76 return legalityCheck(mapping.lookup(value), dest);
77
78 // If it's a top-level value because it's defined in the region,
79 // it can only be inlined if the defining op is a constant or a
80 // `dim`, which can appear anywhere and be valid, since the defining
81 // op won't be top-level anymore after inlining.
82 Attribute operandCst;
83 bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp());
84 return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
85 isDimLikeOp;
86}
87
88/// Checks if all values known to be legal affine dimensions or symbols in `src`
89/// remain so if their respective users are inlined into `dest`.
90static bool
92 const IRMapping &mapping,
93 function_ref<bool(Value, Region *)> legalityCheck) {
94 return llvm::all_of(values, [&](Value v) {
95 return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
96 });
97}
98
99/// Checks if an affine read or write operation remains legal after inlining
100/// from `src` to `dest`.
101template <typename OpTy>
102static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
103 const IRMapping &mapping) {
104 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
105 AffineWriteOpInterface>::value,
106 "only ops with affine read/write interface are supported");
107
108 AffineMap map = op.getAffineMap();
109 ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
110 ValueRange symbolOperands =
111 op.getMapOperands().take_back(map.getNumSymbols());
113 dimOperands, src, dest, mapping,
114 static_cast<bool (*)(Value, Region *)>(isValidDim)))
115 return false;
117 symbolOperands, src, dest, mapping,
118 static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
119 return false;
120 return true;
121}
122
123/// Checks if an affine apply operation remains legal after inlining from `src`
124/// to `dest`.
125// Use "unused attribute" marker to silence clang-tidy warning stemming from
126// the inability to see through "llvm::TypeSwitch".
127template <>
128[[maybe_unused]] bool remainsLegalAfterInline(AffineApplyOp op, Region *src,
129 Region *dest,
130 const IRMapping &mapping) {
131 // If it's a valid dimension, we need to check that it remains so.
132 if (isValidDim(op.getResult(), src))
134 op.getMapOperands(), src, dest, mapping,
135 static_cast<bool (*)(Value, Region *)>(isValidDim));
136
137 // Otherwise it must be a valid symbol, check that it remains so.
139 op.getMapOperands(), src, dest, mapping,
140 static_cast<bool (*)(Value, Region *)>(isValidSymbol));
141}
142
143//===----------------------------------------------------------------------===//
144// AffineDialect Interfaces
145//===----------------------------------------------------------------------===//
146
147namespace {
148/// This class defines the interface for handling inlining with affine
149/// operations.
150struct AffineInlinerInterface : public DialectInlinerInterface {
151 using DialectInlinerInterface::DialectInlinerInterface;
152
153 //===--------------------------------------------------------------------===//
154 // Analysis Hooks
155 //===--------------------------------------------------------------------===//
156
157 /// Returns true if the given region 'src' can be inlined into the region
158 /// 'dest' that is attached to an operation registered to the current dialect.
159 /// 'wouldBeCloned' is set if the region is cloned into its new location
160 /// rather than moved, indicating there may be other users.
161 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
162 IRMapping &valueMapping) const final {
163 // We can inline into affine loops and conditionals if this doesn't break
164 // affine value categorization rules.
165 Operation *destOp = dest->getParentOp();
166 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
167 return false;
168
169 // Multi-block regions cannot be inlined into affine constructs, all of
170 // which require single-block regions.
171 if (!src->hasOneBlock())
172 return false;
173
174 // Side-effecting operations that the affine dialect cannot understand
175 // should not be inlined.
176 Block &srcBlock = src->front();
177 for (Operation &op : srcBlock) {
178 // Ops with no side effects are fine,
179 if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
180 if (iface.hasNoEffect())
181 continue;
182 }
183
184 // Assuming the inlined region is valid, we only need to check if the
185 // inlining would change it.
186 bool remainsValid =
187 llvm::TypeSwitch<Operation *, bool>(&op)
188 .Case<AffineApplyOp, AffineReadOpInterface,
189 AffineWriteOpInterface>([&](auto op) {
190 return remainsLegalAfterInline(op, src, dest, valueMapping);
191 })
192 .Default([](Operation *) {
193 // Conservatively disallow inlining ops we cannot reason about.
194 return false;
195 });
196
197 if (!remainsValid)
198 return false;
199 }
200
201 return true;
202 }
203
204 /// Returns true if the given operation 'op', that is registered to this
205 /// dialect, can be inlined into the given region, false otherwise.
206 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
207 IRMapping &valueMapping) const final {
208 // Always allow inlining affine operations into a region that is marked as
209 // affine scope, or into affine loops and conditionals. There are some edge
210 // cases when inlining *into* affine structures, but that is handled in the
211 // other 'isLegalToInline' hook above.
212 Operation *parentOp = region->getParentOp();
213 return parentOp->hasTrait<OpTrait::AffineScope>() ||
214 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
215 }
216
217 /// Affine regions should be analyzed recursively.
218 bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
219};
220} // namespace
221
222//===----------------------------------------------------------------------===//
223// AffineDialect
224//===----------------------------------------------------------------------===//
225
226void AffineDialect::initialize() {
227 addOperations<AffineDmaStartOp, AffineDmaWaitOp,
228#define GET_OP_LIST
229#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
230 >();
231 addInterfaces<AffineInlinerInterface>();
232 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
233 AffineMinOp>();
234}
235
236/// Materialize a single constant operation from a given attribute value with
237/// the desired resultant type.
238Operation *AffineDialect::materializeConstant(OpBuilder &builder,
239 Attribute value, Type type,
240 Location loc) {
241 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
242 return ub::PoisonOp::create(builder, loc, type, poison);
243 return arith::ConstantOp::materialize(builder, value, type, loc);
244}
245
246/// A utility function to check if a value is defined at the top level of an
247/// op with trait `AffineScope`. If the value is defined in an unlinked region,
248/// conservatively assume it is not top-level. A value of index type defined at
249/// the top level is always a valid symbol.
251 if (auto arg = dyn_cast<BlockArgument>(value)) {
252 // The block owning the argument may be unlinked, e.g. when the surrounding
253 // region has not yet been attached to an Op, at which point the parent Op
254 // is null.
255 Operation *parentOp = arg.getOwner()->getParentOp();
256 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
257 }
258 // The defining Op may live in an unlinked block so its parent Op may be null.
259 Operation *parentOp = value.getDefiningOp()->getParentOp();
260 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
261}
262
263/// Returns the closest region enclosing `op` that is held by an operation with
264/// trait `AffineScope`; `nullptr` if there is no such region.
266 auto *curOp = op;
267 while (auto *parentOp = curOp->getParentOp()) {
268 if (parentOp->hasTrait<OpTrait::AffineScope>())
269 return curOp->getParentRegion();
270 curOp = parentOp;
271 }
272 return nullptr;
273}
274
276 Operation *curOp = op;
277 while (auto *parentOp = curOp->getParentOp()) {
278 if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))
279 return curOp->getParentRegion();
280 curOp = parentOp;
281 }
282 return nullptr;
283}
284
285// A Value can be used as a dimension id iff it meets one of the following
286// conditions:
287// *) It is valid as a symbol.
288// *) It is an induction variable.
289// *) It is the result of affine apply operation with dimension id arguments.
291 // The value must be an index type.
292 if (!value.getType().isIndex())
293 return false;
294
295 if (auto *defOp = value.getDefiningOp())
296 return isValidDim(value, getAffineScope(defOp));
297
298 // This value has to be a block argument for an op that has the
299 // `AffineScope` trait or an induction var of an affine.for or
300 // affine.parallel.
301 if (isAffineInductionVar(value))
302 return true;
303 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
304 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
305}
306
307// Value can be used as a dimension id iff it meets one of the following
308// conditions:
309// *) It is valid as a symbol.
310// *) It is an induction variable.
311// *) It is the result of an affine apply operation with dimension id operands.
312// *) It is the result of a more specialized index transformation (ex.
313// delinearize_index or linearize_index) with dimension id operands.
315 // The value must be an index type.
316 if (!value.getType().isIndex())
317 return false;
318
319 // All valid symbols are okay.
320 if (isValidSymbol(value, region))
321 return true;
322
323 auto *op = value.getDefiningOp();
324 if (!op) {
325 // This value has to be an induction var for an affine.for or an
326 // affine.parallel.
327 return isAffineInductionVar(value);
328 }
329
330 // Affine apply operation is ok if all of its operands are ok.
331 if (auto applyOp = dyn_cast<AffineApplyOp>(op))
332 return applyOp.isValidDim(region);
333 // delinearize_index and linearize_index are special forms of apply
334 // and so are valid dimensions if all their arguments are valid dimensions.
335 if (isa<AffineDelinearizeIndexOp, AffineLinearizeIndexOp>(op))
336 return llvm::all_of(op->getOperands(),
337 [&](Value arg) { return ::isValidDim(arg, region); });
338 // The dim op is okay if its operand memref/tensor is defined at the top
339 // level.
340 if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
341 return isTopLevelValue(dimOp.getShapedValue());
342 return false;
343}
344
345/// Returns true if the 'index' dimension of the `memref` defined by
346/// `memrefDefOp` is a statically shaped one or defined using a valid symbol
347/// for `region`.
348template <typename AnyMemRefDefOp>
349static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
350 Region *region) {
351 MemRefType memRefType = memrefDefOp.getType();
352
353 // Dimension index is out of bounds.
354 if (index >= memRefType.getRank()) {
355 return false;
356 }
357
358 // Statically shaped.
359 if (!memRefType.isDynamicDim(index))
360 return true;
361 // Get the position of the dimension among dynamic dimensions;
362 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
363 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
364 region);
365}
366
367/// Returns true if the result of the dim op is a valid symbol for `region`.
368static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
369 // The dim op is okay if its source is defined at the top level.
370 if (isTopLevelValue(dimOp.getShapedValue()))
371 return true;
372
373 // Conservatively handle remaining BlockArguments as non-valid symbols.
374 // E.g. scf.for iterArgs.
375 if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
376 return false;
377
378 // The dim op is also okay if its operand memref is a view/subview whose
379 // corresponding size is a valid symbol.
380 std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
381
382 // Be conservative if we can't understand the dimension.
383 if (!index.has_value())
384 return false;
385
386 // Skip over all memref.cast ops (if any).
387 Operation *op = dimOp.getShapedValue().getDefiningOp();
388 while (auto castOp = dyn_cast<memref::CastOp>(op)) {
389 // Bail on unranked memrefs.
390 if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
391 return false;
392 op = castOp.getSource().getDefiningOp();
393 if (!op)
394 return false;
395 }
396
397 int64_t i = index.value();
399 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
400 [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
401 .Default([](Operation *) { return false; });
402}
403
404// A value can be used as a symbol (at all its use sites) iff it meets one of
405// the following conditions:
406// *) It is a constant.
407// *) Its defining op or block arg appearance is immediately enclosed by an op
408// with `AffineScope` trait.
409// *) It is the result of an affine.apply operation with symbol operands.
410// *) It is a result of the dim op on a memref whose corresponding size is a
411// valid symbol.
413 if (!value)
414 return false;
415
416 // The value must be an index type.
417 if (!value.getType().isIndex())
418 return false;
419
420 // Check that the value is a top level value.
421 if (isTopLevelValue(value))
422 return true;
423
424 if (auto *defOp = value.getDefiningOp())
425 return isValidSymbol(value, getAffineScope(defOp));
426
427 return false;
428}
429
430/// A utility function to check if a value is defined at the top level of
431/// `region` or is an argument of `region` or is defined above the region.
432static bool isTopLevelValueOrAbove(Value value, Region *region) {
433 Region *parentRegion = value.getParentRegion();
434 do {
435 if (parentRegion == region)
436 return true;
437 Operation *regionOp = region->getParentOp();
438 if (regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
439 break;
440 region = region->getParentOp()->getParentRegion();
441 } while (region);
442 return false;
443}
444
445/// A value can be used as a symbol for `region` iff it meets one of the
446/// following conditions:
447/// *) It is a constant.
448/// *) It is a result of a `Pure` operation whose operands are valid symbolic
449/// *) identifiers.
450/// *) It is a result of the dim op on a memref whose corresponding size is
451/// a valid symbol.
452/// *) It is defined at the top level of 'region' or is its argument.
453/// *) It dominates `region`'s parent op.
454/// If `region` is null, conservatively assume the symbol definition scope does
455/// not exist and only accept the values that would be symbols regardless of
456/// the surrounding region structure, i.e. the first three cases above.
458 // The value must be an index type.
459 if (!value.getType().isIndex())
460 return false;
461
462 // A top-level value is a valid symbol.
463 if (region && isTopLevelValueOrAbove(value, region))
464 return true;
465
466 auto *defOp = value.getDefiningOp();
467 if (!defOp)
468 return false;
469
470 // Constant operation is ok.
471 Attribute operandCst;
472 if (matchPattern(defOp, m_Constant(&operandCst)))
473 return true;
474
475 // `Pure` operation that whose operands are valid symbolic identifiers.
476 if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
477 return affine::isValidSymbol(operand, region);
478 })) {
479 return true;
480 }
481
482 // Dim op results could be valid symbols at any level.
483 if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
484 return isDimOpValidSymbol(dimOp, region);
485
486 return false;
487}
488
489// Returns true if 'value' is a valid index to an affine operation (e.g.
490// affine.load, affine.store, affine.dma_start, affine.dma_wait) where
491// `region` provides the polyhedral symbol scope. Returns false otherwise.
492static bool isValidAffineIndexOperand(Value value, Region *region) {
493 return isValidDim(value, region) || isValidSymbol(value, region);
494}
495
496/// Prints dimension and symbol list.
499 unsigned numDims, OpAsmPrinter &printer) {
500 OperandRange operands(begin, end);
501 printer << '(' << operands.take_front(numDims) << ')';
502 if (operands.size() > numDims)
503 printer << '[' << operands.drop_front(numDims) << ']';
504}
505
506/// Parses dimension and symbol list and returns true if parsing failed.
508 OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
511 return failure();
512 // Store number of dimensions for validation by caller.
513 numDims = opInfos.size();
514
515 // Parse the optional symbol operands.
516 auto indexTy = parser.getBuilder().getIndexType();
517 return failure(parser.parseOperandList(
519 parser.resolveOperands(opInfos, indexTy, operands));
520}
521
522/// Utility function to verify that a set of operands are valid dimension and
523/// symbol identifiers. The operands should be laid out such that the dimension
524/// operands are before the symbol operands. This function returns failure if
525/// there was an invalid operand. An operation is provided to emit any necessary
526/// errors.
527template <typename OpTy>
528static LogicalResult
530 unsigned numDims) {
531 unsigned opIt = 0;
532 for (auto operand : operands) {
533 if (opIt++ < numDims) {
534 if (!isValidDim(operand, getAffineScope(op)))
535 return op.emitOpError("operand cannot be used as a dimension id");
536 } else if (!isValidSymbol(operand, getAffineScope(op))) {
537 return op.emitOpError("operand cannot be used as a symbol");
538 }
539 }
540 return success();
541}
542
543//===----------------------------------------------------------------------===//
544// AffineApplyOp
545//===----------------------------------------------------------------------===//
546
547AffineValueMap AffineApplyOp::getAffineValueMap() {
548 return AffineValueMap(getAffineMap(), getOperands(), getResult());
549}
550
551ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
552 auto &builder = parser.getBuilder();
553 auto indexTy = builder.getIndexType();
554
555 AffineMapAttr mapAttr;
556 unsigned numDims;
557 if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
558 parseDimAndSymbolList(parser, result.operands, numDims) ||
559 parser.parseOptionalAttrDict(result.attributes))
560 return failure();
561 auto map = mapAttr.getValue();
562
563 if (map.getNumDims() != numDims ||
564 numDims + map.getNumSymbols() != result.operands.size()) {
565 return parser.emitError(parser.getNameLoc(),
566 "dimension or symbol index mismatch");
567 }
568
569 result.types.append(map.getNumResults(), indexTy);
570 return success();
571}
572
573void AffineApplyOp::print(OpAsmPrinter &p) {
574 p << " " << getMapAttr();
575 printDimAndSymbolList(operand_begin(), operand_end(),
576 getAffineMap().getNumDims(), p);
577 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
578}
579
580LogicalResult AffineApplyOp::verify() {
581 // Check input and output dimensions match.
582 AffineMap affineMap = getMap();
583
584 // Verify that operand count matches affine map dimension and symbol count.
585 if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
586 return emitOpError(
587 "operand count and affine map dimension and symbol count must match");
588
589 // Verify that the map only produces one result.
590 if (affineMap.getNumResults() != 1)
591 return emitOpError("mapping must produce one value");
592
593 // Do not allow valid dims to be used in symbol positions. We do allow
594 // affine.apply to use operands for values that may neither qualify as affine
595 // dims or affine symbols due to usage outside of affine ops, analyses, etc.
596 Region *region = getAffineScope(*this);
597 for (Value operand : getMapOperands().drop_front(affineMap.getNumDims())) {
598 if (::isValidDim(operand, region) && !::isValidSymbol(operand, region))
599 return emitError("dimensional operand cannot be used as a symbol");
600 }
601
602 return success();
603}
604
605// The result of the affine apply operation can be used as a dimension id if all
606// its operands are valid dimension ids.
607bool AffineApplyOp::isValidDim() {
608 return llvm::all_of(getOperands(),
609 [](Value op) { return affine::isValidDim(op); });
610}
611
612// The result of the affine apply operation can be used as a dimension id if all
613// its operands are valid dimension ids with the parent operation of `region`
614// defining the polyhedral scope for symbols.
615bool AffineApplyOp::isValidDim(Region *region) {
616 return llvm::all_of(getOperands(),
617 [&](Value op) { return ::isValidDim(op, region); });
618}
619
620// The result of the affine apply operation can be used as a symbol if all its
621// operands are symbols.
622bool AffineApplyOp::isValidSymbol() {
623 return llvm::all_of(getOperands(),
624 [](Value op) { return affine::isValidSymbol(op); });
625}
626
627// The result of the affine apply operation can be used as a symbol in `region`
628// if all its operands are symbols in `region`.
629bool AffineApplyOp::isValidSymbol(Region *region) {
630 return llvm::all_of(getOperands(), [&](Value operand) {
631 return affine::isValidSymbol(operand, region);
632 });
633}
634
635OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
636 auto map = getAffineMap();
637
638 // Fold dims and symbols to existing values.
639 auto expr = map.getResult(0);
640 if (auto dim = dyn_cast<AffineDimExpr>(expr))
641 return getOperand(dim.getPosition());
642 if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
643 return getOperand(map.getNumDims() + sym.getPosition());
644
645 // Otherwise, default to folding the map.
647 bool hasPoison = false;
648 auto foldResult =
649 map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
650 if (hasPoison)
651 return ub::PoisonAttr::get(getContext());
652 if (failed(foldResult))
653 return {};
654 return result[0];
655}
656
657/// Returns the largest known divisor of `e`. Exploits information from the
658/// values in `operands`.
660 // This method isn't aware of `operands`.
662
663 // We now make use of operands for the case `e` is a dim expression.
664 // TODO: More powerful simplification would have to modify
665 // getLargestKnownDivisor to take `operands` and exploit that information as
666 // well for dim/sym expressions, but in that case, getLargestKnownDivisor
667 // can't be part of the IR library but of the `Analysis` library. The IR
668 // library can only really depend on simple O(1) checks.
669 auto dimExpr = dyn_cast<AffineDimExpr>(e);
670 // If it's not a dim expr, `div` is the best we have.
671 if (!dimExpr)
672 return div;
673
674 // We simply exploit information from loop IVs.
675 // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
676 // desired simplifications are expected to be part of other
677 // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
678 // LoopAnalysis library.
679 Value operand = operands[dimExpr.getPosition()];
680 int64_t operandDivisor = 1;
681 // TODO: With the right accessors, this can be extended to
682 // LoopLikeOpInterface.
683 if (AffineForOp forOp = getForInductionVarOwner(operand)) {
684 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
685 operandDivisor = forOp.getStepAsInt();
686 } else {
687 uint64_t lbLargestKnownDivisor =
688 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
689 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
690 }
691 }
692 return operandDivisor;
693}
694
695/// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
696/// being an affine dim expression or a constant.
698 int64_t k) {
699 if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
700 int64_t constVal = constExpr.getValue();
701 return constVal >= 0 && constVal < k;
702 }
703 auto dimExpr = dyn_cast<AffineDimExpr>(e);
704 if (!dimExpr)
705 return false;
706 Value operand = operands[dimExpr.getPosition()];
707 // TODO: With the right accessors, this can be extended to
708 // LoopLikeOpInterface.
709 if (AffineForOp forOp = getForInductionVarOwner(operand)) {
710 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
711 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
712 return true;
713 }
714 }
715
716 // We don't consider other cases like `operand` being defined by a constant or
717 // an affine.apply op since such cases will already be handled by other
718 // patterns and propagation of loop IVs or constant would happen.
719 return false;
720}
721
722/// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
723/// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
724/// expression is in that form.
726 AffineExpr &quotientTimesDiv, AffineExpr &rem) {
727 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
728 if (!bin || bin.getKind() != AffineExprKind::Add)
729 return false;
730
731 AffineExpr llhs = bin.getLHS();
732 AffineExpr rlhs = bin.getRHS();
733 div = getLargestKnownDivisor(llhs, operands);
734 if (isNonNegativeBoundedBy(rlhs, operands, div)) {
735 quotientTimesDiv = llhs;
736 rem = rlhs;
737 return true;
738 }
739 div = getLargestKnownDivisor(rlhs, operands);
740 if (isNonNegativeBoundedBy(llhs, operands, div)) {
741 quotientTimesDiv = rlhs;
742 rem = llhs;
743 return true;
744 }
745 return false;
746}
747
748/// Gets the constant lower bound on an `iv`.
749static std::optional<int64_t> getLowerBound(Value iv) {
750 AffineForOp forOp = getForInductionVarOwner(iv);
751 if (forOp && forOp.hasConstantLowerBound())
752 return forOp.getConstantLowerBound();
753 return std::nullopt;
754}
755
756/// Gets the constant upper bound on an affine.for `iv`.
757static std::optional<int64_t> getUpperBound(Value iv) {
758 AffineForOp forOp = getForInductionVarOwner(iv);
759 if (!forOp || !forOp.hasConstantUpperBound())
760 return std::nullopt;
761
762 // If its lower bound is also known, we can get a more precise bound
763 // whenever the step is not one.
764 if (forOp.hasConstantLowerBound()) {
765 return forOp.getConstantUpperBound() - 1 -
766 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
767 forOp.getStepAsInt();
768 }
769 return forOp.getConstantUpperBound() - 1;
770}
771
772/// Determine a constant upper bound for `expr` if one exists while exploiting
773/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
774/// is guaranteed to be less than or equal to it.
775static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
776 unsigned numSymbols,
777 ArrayRef<Value> operands) {
778 // Get the constant lower or upper bounds on the operands.
779 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
780 constLowerBounds.reserve(operands.size());
781 constUpperBounds.reserve(operands.size());
782 for (Value operand : operands) {
783 constLowerBounds.push_back(getLowerBound(operand));
784 constUpperBounds.push_back(getUpperBound(operand));
785 }
786
787 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
788 return constExpr.getValue();
789
790 return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
791 constUpperBounds,
792 /*isUpper=*/true);
793}
794
795/// Determine a constant lower bound for `expr` if one exists while exploiting
796/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
797/// is guaranteed to be less than or equal to it.
798static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
799 unsigned numSymbols,
800 ArrayRef<Value> operands) {
801 // Get the constant lower or upper bounds on the operands.
802 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
803 constLowerBounds.reserve(operands.size());
804 constUpperBounds.reserve(operands.size());
805 for (Value operand : operands) {
806 constLowerBounds.push_back(getLowerBound(operand));
807 constUpperBounds.push_back(getUpperBound(operand));
808 }
809
810 std::optional<int64_t> lowerBound;
811 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
812 lowerBound = constExpr.getValue();
813 } else {
814 lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
815 constLowerBounds, constUpperBounds,
816 /*isUpper=*/false);
817 }
818 return lowerBound;
819}
820
821/// Simplify `expr` while exploiting information from the values in `operands`.
822static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
823 unsigned numSymbols,
824 ArrayRef<Value> operands) {
825 // We do this only for certain floordiv/mod expressions.
826 auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
827 if (!binExpr)
828 return;
829
830 // Simplify the child expressions first.
831 AffineExpr lhs = binExpr.getLHS();
832 AffineExpr rhs = binExpr.getRHS();
833 simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
834 simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
835 expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
836
837 binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
838 if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
840 expr.getKind() != AffineExprKind::Mod)) {
841 return;
842 }
843
844 // The `lhs` and `rhs` may be different post construction of simplified expr.
845 lhs = binExpr.getLHS();
846 rhs = binExpr.getRHS();
847 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
848 if (!rhsConst)
849 return;
850
851 int64_t rhsConstVal = rhsConst.getValue();
852 // Undefined exprsessions aren't touched; IR can still be valid with them.
853 if (rhsConstVal <= 0)
854 return;
855
856 // Exploit constant lower/upper bounds to simplify a floordiv or mod.
857 MLIRContext *context = expr.getContext();
858 std::optional<int64_t> lhsLbConst =
859 getLowerBound(lhs, numDims, numSymbols, operands);
860 std::optional<int64_t> lhsUbConst =
861 getUpperBound(lhs, numDims, numSymbols, operands);
862 if (lhsLbConst && lhsUbConst) {
863 int64_t lhsLbConstVal = *lhsLbConst;
864 int64_t lhsUbConstVal = *lhsUbConst;
865 // lhs floordiv c is a single value lhs is bounded in a range `c` that has
866 // the same quotient.
867 if (binExpr.getKind() == AffineExprKind::FloorDiv &&
868 divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
869 divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
871 divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
872 return;
873 }
874 // lhs ceildiv c is a single value if the entire range has the same ceil
875 // quotient.
876 if (binExpr.getKind() == AffineExprKind::CeilDiv &&
877 divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
878 divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
879 expr = getAffineConstantExpr(divideCeilSigned(lhsLbConstVal, rhsConstVal),
880 context);
881 return;
882 }
883 // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
884 if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
885 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
886 expr = lhs;
887 return;
888 }
889 }
890
891 // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
892 // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
893 // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
894 // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
895 AffineExpr quotientTimesDiv, rem;
896 int64_t divisor;
897 if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
898 if (rhsConstVal % divisor == 0 &&
899 binExpr.getKind() == AffineExprKind::FloorDiv) {
900 expr = quotientTimesDiv.floorDiv(rhsConst);
901 } else if (divisor % rhsConstVal == 0 &&
902 binExpr.getKind() == AffineExprKind::Mod) {
903 expr = rem % rhsConst;
904 }
905 return;
906 }
907
908 // Handle the simple case when the LHS expression can be either upper
909 // bounded or is a known multiple of RHS constant.
910 // lhs floordiv c -> 0 if 0 <= lhs < c,
911 // lhs mod c -> 0 if lhs % c = 0.
912 if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) &&
913 binExpr.getKind() == AffineExprKind::FloorDiv) ||
914 (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 &&
915 binExpr.getKind() == AffineExprKind::Mod)) {
916 expr = getAffineConstantExpr(0, expr.getContext());
917 }
918}
919
920/// Simplify the expressions in `map` while making use of lower or upper bounds
921/// of its operands. If `isMax` is true, the map is to be treated as a max of
922/// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
923/// d1) can be simplified to (8) if the operands are respectively lower bounded
924/// by 2 and 0 (the second expression can't be lower than 8).
926 ArrayRef<Value> operands,
927 bool isMax) {
928 // Can't simplify.
929 if (operands.empty())
930 return;
931
932 // Get the upper or lower bound on an affine.for op IV using its range.
933 // Get the constant lower or upper bounds on the operands.
934 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
935 constLowerBounds.reserve(operands.size());
936 constUpperBounds.reserve(operands.size());
937 for (Value operand : operands) {
938 constLowerBounds.push_back(getLowerBound(operand));
939 constUpperBounds.push_back(getUpperBound(operand));
940 }
941
942 // We will compute the lower and upper bounds on each of the expressions
943 // Then, we will check (depending on max or min) as to whether a specific
944 // bound is redundant by checking if its highest (in case of max) and its
945 // lowest (in the case of min) value is already lower than (or higher than)
946 // the lower bound (or upper bound in the case of min) of another bound.
947 SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
948 lowerBounds.reserve(map.getNumResults());
949 upperBounds.reserve(map.getNumResults());
950 for (AffineExpr e : map.getResults()) {
951 if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
952 lowerBounds.push_back(constExpr.getValue());
953 upperBounds.push_back(constExpr.getValue());
954 } else {
955 lowerBounds.push_back(
957 constLowerBounds, constUpperBounds,
958 /*isUpper=*/false));
959 upperBounds.push_back(
961 constLowerBounds, constUpperBounds,
962 /*isUpper=*/true));
963 }
964 }
965
966 // Collect expressions that are not redundant.
967 SmallVector<AffineExpr, 4> irredundantExprs;
968 for (auto exprEn : llvm::enumerate(map.getResults())) {
969 AffineExpr e = exprEn.value();
970 unsigned i = exprEn.index();
971 // Some expressions can be turned into constants.
972 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
973 e = getAffineConstantExpr(*lowerBounds[i], e.getContext());
974
975 // Check if the expression is redundant.
976 if (isMax) {
977 if (!upperBounds[i]) {
978 irredundantExprs.push_back(e);
979 continue;
980 }
981 // If there exists another expression such that its lower bound is greater
982 // than this expression's upper bound, it's redundant.
983 if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
984 auto otherLowerBound = en.value();
985 unsigned pos = en.index();
986 if (pos == i || !otherLowerBound)
987 return false;
988 if (*otherLowerBound > *upperBounds[i])
989 return true;
990 if (*otherLowerBound < *upperBounds[i])
991 return false;
992 // Equality case. When both expressions are considered redundant, we
993 // don't want to get both of them. We keep the one that appears
994 // first.
995 if (upperBounds[pos] && lowerBounds[i] &&
996 lowerBounds[i] == upperBounds[i] &&
997 otherLowerBound == *upperBounds[pos] && i < pos)
998 return false;
999 return true;
1000 }))
1001 irredundantExprs.push_back(e);
1002 } else {
1003 if (!lowerBounds[i]) {
1004 irredundantExprs.push_back(e);
1005 continue;
1006 }
1007 // Likewise for the `min` case. Use the complement of the condition above.
1008 if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
1009 auto otherUpperBound = en.value();
1010 unsigned pos = en.index();
1011 if (pos == i || !otherUpperBound)
1012 return false;
1013 if (*otherUpperBound < *lowerBounds[i])
1014 return true;
1015 if (*otherUpperBound > *lowerBounds[i])
1016 return false;
1017 if (lowerBounds[pos] && upperBounds[i] &&
1018 lowerBounds[i] == upperBounds[i] &&
1019 otherUpperBound == lowerBounds[pos] && i < pos)
1020 return false;
1021 return true;
1022 }))
1023 irredundantExprs.push_back(e);
1024 }
1025 }
1026
1027 // Create the map without the redundant expressions.
1028 map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs,
1029 map.getContext());
1030}
1031
1032/// Simplify the map while exploiting information on the values in `operands`.
1033// Use "unused attribute" marker to silence warning stemming from the inability
1034// to see through the template expansion.
1035[[maybe_unused]] static void simplifyMapWithOperands(AffineMap &map,
1036 ArrayRef<Value> operands) {
1037 assert(map.getNumInputs() == operands.size() && "invalid operands for map");
1038 SmallVector<AffineExpr> newResults;
1039 newResults.reserve(map.getNumResults());
1040 for (AffineExpr expr : map.getResults()) {
1042 operands);
1043 newResults.push_back(expr);
1044 }
1045 map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
1046 map.getContext());
1047}
1048
1049/// Assuming `dimOrSym` is a quantity in the apply op map `map` and defined by
1050/// `minOp = affine_min(x_1, ..., x_n)`. This function checks that:
1051/// `0 < affine_min(x_1, ..., x_n)` and proceeds with replacing the patterns:
1052/// ```
1053/// dimOrSym.ceildiv(x_k)
1054/// (dimOrSym + x_k - 1).floordiv(x_k)
1055/// ```
1056/// by `1` for all `k` in `1, ..., n`. This is possible because `x / x_k <= 1`.
1057///
1058///
1059/// Warning: ValueBoundsConstraintSet::computeConstantBound is needed to check
1060/// `minOp` is positive.
1061static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
1062 AffineExpr dimOrSym,
1063 AffineMap *map,
1064 ValueRange dims,
1065 ValueRange syms) {
1066 LDBG() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`";
1067 AffineMap affineMinMap = minOp.getAffineMap();
1068
1069 // Check the value is positive.
1070 for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
1071 // Compare each expression in the minimum against 0.
1073 getAsIndexOpFoldResult(minOp.getContext(), 0),
1076 minOp.getOperands())))
1077 return failure();
1078 }
1079
1080 /// Convert affine symbols and dimensions in minOp to symbols or dimensions in
1081 /// the apply op affine map.
1082 DenseMap<AffineExpr, AffineExpr> dimSymConversionTable;
1083 SmallVector<unsigned> unmappedDims, unmappedSyms;
1084 for (auto [i, dim] : llvm::enumerate(minOp.getDimOperands())) {
1085 auto it = llvm::find(dims, dim);
1086 if (it == dims.end()) {
1087 unmappedDims.push_back(i);
1088 continue;
1089 }
1090 dimSymConversionTable[getAffineDimExpr(i, minOp.getContext())] =
1091 getAffineDimExpr(it.getIndex(), minOp.getContext());
1092 }
1093 for (auto [i, sym] : llvm::enumerate(minOp.getSymbolOperands())) {
1094 auto it = llvm::find(syms, sym);
1095 if (it == syms.end()) {
1096 unmappedSyms.push_back(i);
1097 continue;
1098 }
1099 dimSymConversionTable[getAffineSymbolExpr(i, minOp.getContext())] =
1100 getAffineSymbolExpr(it.getIndex(), minOp.getContext());
1101 }
1102
1103 // Create the replacement map.
1105 AffineExpr c1 = getAffineConstantExpr(1, minOp.getContext());
1106 for (AffineExpr expr : affineMinMap.getResults()) {
1107 // If we cannot express the result in terms of the apply map symbols and
1108 // sims then continue.
1109 if (llvm::any_of(unmappedDims,
1110 [&](unsigned i) { return expr.isFunctionOfDim(i); }) ||
1111 llvm::any_of(unmappedSyms,
1112 [&](unsigned i) { return expr.isFunctionOfSymbol(i); }))
1113 continue;
1114
1115 AffineExpr convertedExpr = expr.replace(dimSymConversionTable);
1116
1117 // dimOrSym.ceilDiv(expr) -> 1
1118 repl[dimOrSym.ceilDiv(convertedExpr)] = c1;
1119 // (dimOrSym + expr - 1).floorDiv(expr) -> 1
1120 repl[(dimOrSym + convertedExpr - 1).floorDiv(convertedExpr)] = c1;
1121 }
1122 AffineMap initialMap = *map;
1123 *map = initialMap.replace(repl, initialMap.getNumDims(),
1124 initialMap.getNumSymbols());
1125 return success(*map != initialMap);
1126}
1127
1128/// Recursively traverse `e`. If `e` or one of its sub-expressions has the form
1129/// e1 + e2 + ... + eK, where the e_i are a super(multi)set of `exprsToRemove`,
1130/// place a map between e and `newVal` + sum({e1, e2, .. eK} - exprsToRemove)
1131/// into `replacementsMap`. If no entries were added to `replacementsMap`,
1132/// nothing was found.
1134 AffineExpr e, const llvm::SmallDenseSet<AffineExpr, 4> &exprsToRemove,
1135 AffineExpr newVal, DenseMap<AffineExpr, AffineExpr> &replacementsMap) {
1136 auto binOp = dyn_cast<AffineBinaryOpExpr>(e);
1137 if (!binOp)
1138 return;
1139 AffineExpr lhs = binOp.getLHS();
1140 AffineExpr rhs = binOp.getRHS();
1141 if (binOp.getKind() != AffineExprKind::Add) {
1142 shortenAddChainsContainingAll(lhs, exprsToRemove, newVal, replacementsMap);
1143 shortenAddChainsContainingAll(rhs, exprsToRemove, newVal, replacementsMap);
1144 return;
1145 }
1146 SmallVector<AffineExpr> toPreserve;
1147 llvm::SmallDenseSet<AffineExpr, 4> ourTracker(exprsToRemove);
1148 AffineExpr thisTerm = rhs;
1149 AffineExpr nextTerm = lhs;
1150
1151 while (thisTerm) {
1152 if (!ourTracker.erase(thisTerm)) {
1153 toPreserve.push_back(thisTerm);
1154 shortenAddChainsContainingAll(thisTerm, exprsToRemove, newVal,
1155 replacementsMap);
1156 }
1157 auto nextBinOp = dyn_cast_if_present<AffineBinaryOpExpr>(nextTerm);
1158 if (!nextBinOp || nextBinOp.getKind() != AffineExprKind::Add) {
1159 thisTerm = nextTerm;
1160 nextTerm = AffineExpr();
1161 } else {
1162 thisTerm = nextBinOp.getRHS();
1163 nextTerm = nextBinOp.getLHS();
1164 }
1165 }
1166 if (!ourTracker.empty())
1167 return;
1168 // We reverse the terms to be preserved here in order to preserve
1169 // associativity between them.
1170 AffineExpr newExpr = newVal;
1171 for (AffineExpr preserved : llvm::reverse(toPreserve))
1172 newExpr = newExpr + preserved;
1173 replacementsMap.insert({e, newExpr});
1174}
1175
1176/// If this map contains of the expression `x_1 + x_1 * C_1 + ... x_n * C_N +
1177/// ...` (not necessarily in order) where the set of the `x_i` is the set of
1178/// outputs of an `affine.delinearize_index` whos inverse is that expression,
1179/// replace that expression with the input of that delinearize_index op.
1180///
1181/// `unitDimInput` is the input that was detected as the potential start to this
1182/// replacement chain - if it isn't the rightmost result of the delinearization,
1183/// this method fails. (This is intended to ensure we don't have redundant scans
1184/// over the same expression).
1185///
1186/// While this currently only handles delinearizations with a constant basis,
1187/// that isn't a fundamental limitation.
1188///
1189/// This is a utility function for `replaceDimOrSym` below.
1191 AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map,
1193 if (!delinOp.getDynamicBasis().empty())
1194 return failure();
1195 if (resultToReplace != delinOp.getMultiIndex().back())
1196 return failure();
1197
1198 MLIRContext *ctx = delinOp.getContext();
1199 SmallVector<AffineExpr> resToExpr(delinOp.getNumResults(), AffineExpr());
1200 for (auto [pos, dim] : llvm::enumerate(dims)) {
1201 auto asResult = dyn_cast_if_present<OpResult>(dim);
1202 if (!asResult)
1203 continue;
1204 if (asResult.getOwner() == delinOp.getOperation())
1205 resToExpr[asResult.getResultNumber()] = getAffineDimExpr(pos, ctx);
1206 }
1207 for (auto [pos, sym] : llvm::enumerate(syms)) {
1208 auto asResult = dyn_cast_if_present<OpResult>(sym);
1209 if (!asResult)
1210 continue;
1211 if (asResult.getOwner() == delinOp.getOperation())
1212 resToExpr[asResult.getResultNumber()] = getAffineSymbolExpr(pos, ctx);
1213 }
1214 if (llvm::is_contained(resToExpr, AffineExpr()))
1215 return failure();
1216
1217 bool isDimReplacement = llvm::all_of(resToExpr, llvm::IsaPred<AffineDimExpr>);
1218 int64_t stride = 1;
1219 llvm::SmallDenseSet<AffineExpr, 4> expectedExprs;
1220 // This isn't zip_equal since sometimes the delinearize basis is missing a
1221 // size for the first result.
1222 for (auto [binding, size] : llvm::zip(
1223 llvm::reverse(resToExpr), llvm::reverse(delinOp.getStaticBasis()))) {
1224 expectedExprs.insert(binding * getAffineConstantExpr(stride, ctx));
1225 stride *= size;
1226 }
1227 if (resToExpr.size() != delinOp.getStaticBasis().size())
1228 expectedExprs.insert(resToExpr[0] * stride);
1229
1231 AffineExpr delinInExpr = isDimReplacement
1232 ? getAffineDimExpr(dims.size(), ctx)
1233 : getAffineSymbolExpr(syms.size(), ctx);
1234
1235 for (AffineExpr e : map->getResults())
1236 shortenAddChainsContainingAll(e, expectedExprs, delinInExpr, replacements);
1237 if (replacements.empty())
1238 return failure();
1239
1240 AffineMap origMap = *map;
1241 if (isDimReplacement)
1242 dims.push_back(delinOp.getLinearIndex());
1243 else
1244 syms.push_back(delinOp.getLinearIndex());
1245 *map = origMap.replace(replacements, dims.size(), syms.size());
1246
1247 // Blank out dead dimensions and symbols
1248 for (AffineExpr e : resToExpr) {
1249 if (auto d = dyn_cast<AffineDimExpr>(e)) {
1250 unsigned pos = d.getPosition();
1251 if (!map->isFunctionOfDim(pos))
1252 dims[pos] = nullptr;
1253 }
1254 if (auto s = dyn_cast<AffineSymbolExpr>(e)) {
1255 unsigned pos = s.getPosition();
1256 if (!map->isFunctionOfSymbol(pos))
1257 syms[pos] = nullptr;
1258 }
1259 }
1260 return success();
1261}
1262
1263/// Replace all occurrences of AffineExpr at position `pos` in `map` by the
1264/// defining AffineApplyOp expression and operands.
1265/// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
1266/// When `dimOrSymbolPosition >= dims.size()`,
1267/// AffineSymbolExpr@[pos - dims.size()] is replaced.
1268/// Mutate `map`,`dims` and `syms` in place as follows:
1269/// 1. `dims` and `syms` are only appended to.
1270/// 2. `map` dim and symbols are gradually shifted to higher positions.
1271/// 3. Old `dim` and `sym` entries are replaced by nullptr
1272/// This avoids the need for any bookkeeping.
1273/// If `replaceAffineMin` is set to true, additionally triggers more expensive
1274/// replacements involving affine_min operations.
1275static LogicalResult replaceDimOrSym(AffineMap *map,
1276 unsigned dimOrSymbolPosition,
1279 bool replaceAffineMin) {
1280 MLIRContext *ctx = map->getContext();
1281 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1282 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1283 : dimOrSymbolPosition - dims.size();
1284 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1285 if (!v)
1286 return failure();
1287
1288 if (auto minOp = v.getDefiningOp<AffineMinOp>(); minOp && replaceAffineMin) {
1289 AffineExpr dimOrSym = isDimReplacement ? getAffineDimExpr(pos, ctx)
1290 : getAffineSymbolExpr(pos, ctx);
1291 return replaceAffineMinBoundingBoxExpression(minOp, dimOrSym, map, dims,
1292 syms);
1293 }
1294
1295 if (auto delinOp = v.getDefiningOp<affine::AffineDelinearizeIndexOp>()) {
1296 return replaceAffineDelinearizeIndexInverseExpression(delinOp, v, map, dims,
1297 syms);
1298 }
1299
1300 auto affineApply = v.getDefiningOp<AffineApplyOp>();
1301 if (!affineApply)
1302 return failure();
1303
1304 // At this point we will perform a replacement of `v`, set the entry in `dim`
1305 // or `sym` to nullptr immediately.
1306 v = nullptr;
1307
1308 // Compute the map, dims and symbols coming from the AffineApplyOp.
1309 AffineMap composeMap = affineApply.getAffineMap();
1310 assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1311 SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1312 affineApply.getMapOperands().end());
1313 // Canonicalize the map to promote dims to symbols when possible. This is to
1314 // avoid generating invalid maps.
1315 canonicalizeMapAndOperands(&composeMap, &composeOperands);
1316 AffineExpr replacementExpr =
1317 composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
1318 ValueRange composeDims =
1319 ArrayRef<Value>(composeOperands).take_front(composeMap.getNumDims());
1320 ValueRange composeSyms =
1321 ArrayRef<Value>(composeOperands).take_back(composeMap.getNumSymbols());
1322 AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
1323 : getAffineSymbolExpr(pos, ctx);
1324
1325 // Append the dims and symbols where relevant and perform the replacement.
1326 dims.append(composeDims.begin(), composeDims.end());
1327 syms.append(composeSyms.begin(), composeSyms.end());
1328 *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());
1329
1330 return success();
1331}
1332
1333/// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1334/// iteratively. Perform canonicalization of map and operands as well as
1335/// AffineMap simplification. `map` and `operands` are mutated in place.
1337 SmallVectorImpl<Value> *operands,
1338 bool composeAffineMin = false) {
1339 if (map->getNumResults() == 0) {
1340 canonicalizeMapAndOperands(map, operands);
1341 *map = simplifyAffineMap(*map);
1342 return;
1343 }
1344
1345 MLIRContext *ctx = map->getContext();
1346 SmallVector<Value, 4> dims(operands->begin(),
1347 operands->begin() + map->getNumDims());
1348 SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1349 operands->end());
1350
1351 // Iterate over dims and symbols coming from AffineApplyOp and replace until
1352 // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1353 // and `syms` can only increase by construction.
1354 // The implementation uses a `while` loop to support the case of symbols
1355 // that may be constructed from dims ;this may be overkill.
1356 while (true) {
1357 bool changed = false;
1358 for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1359 if ((changed |=
1360 succeeded(replaceDimOrSym(map, pos, dims, syms, composeAffineMin))))
1361 break;
1362 if (!changed)
1363 break;
1364 }
1365
1366 // Clear operands so we can fill them anew.
1367 operands->clear();
1368
1369 // At this point we may have introduced null operands, prune them out before
1370 // canonicalizing map and operands.
1371 unsigned nDims = 0, nSyms = 0;
1372 SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1373 dimReplacements.reserve(dims.size());
1374 symReplacements.reserve(syms.size());
1375 for (auto *container : {&dims, &syms}) {
1376 bool isDim = (container == &dims);
1377 auto &repls = isDim ? dimReplacements : symReplacements;
1378 for (const auto &en : llvm::enumerate(*container)) {
1379 Value v = en.value();
1380 if (!v) {
1381 assert(isDim ? !map->isFunctionOfDim(en.index())
1382 : !map->isFunctionOfSymbol(en.index()) &&
1383 "map is function of unexpected expr@pos");
1384 repls.push_back(getAffineConstantExpr(0, ctx));
1385 continue;
1386 }
1387 repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
1388 : getAffineSymbolExpr(nSyms++, ctx));
1389 operands->push_back(v);
1390 }
1391 }
1392 *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
1393 nSyms);
1394
1395 // Canonicalize and simplify before returning.
1396 canonicalizeMapAndOperands(map, operands);
1397 *map = simplifyAffineMap(*map);
1398}
1399
1401 AffineMap *map, SmallVectorImpl<Value> *operands, bool composeAffineMin) {
1402 while (llvm::any_of(*operands, [](Value v) {
1403 return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
1404 })) {
1405 composeAffineMapAndOperands(map, operands, composeAffineMin);
1406 }
1407 // Additional trailing step for AffineMinOps in case no chains of AffineApply.
1408 if (composeAffineMin && llvm::any_of(*operands, [](Value v) {
1409 return isa_and_nonnull<AffineMinOp>(v.getDefiningOp());
1410 })) {
1411 composeAffineMapAndOperands(map, operands, composeAffineMin);
1412 }
1413}
1414
1415AffineApplyOp
1417 ArrayRef<OpFoldResult> operands,
1418 bool composeAffineMin) {
1419 SmallVector<Value> valueOperands;
1420 map = foldAttributesIntoMap(b, map, operands, valueOperands);
1421 composeAffineMapAndOperands(&map, &valueOperands, composeAffineMin);
1422 assert(map);
1423 return AffineApplyOp::create(b, loc, map, valueOperands);
1424}
1425
1426AffineApplyOp
1428 ArrayRef<OpFoldResult> operands,
1429 bool composeAffineMin) {
1431 b, loc,
1433 .front(),
1434 operands, composeAffineMin);
1435}
1436
1437/// Composes the given affine map with the given list of operands, pulling in
1438/// the maps from any affine.apply operations that supply the operands.
1440 SmallVectorImpl<Value> &operands,
1441 bool composeAffineMin = false) {
1442 // Compose and canonicalize each expression in the map individually because
1443 // composition only applies to single-result maps, collecting potentially
1444 // duplicate operands in a single list with shifted dimensions and symbols.
1445 SmallVector<Value> dims, symbols;
1447 for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
1448 SmallVector<Value> submapOperands(operands.begin(), operands.end());
1449 AffineMap submap = map.getSubMap({i});
1450 fullyComposeAffineMapAndOperands(&submap, &submapOperands,
1451 composeAffineMin);
1452 canonicalizeMapAndOperands(&submap, &submapOperands);
1453 unsigned numNewDims = submap.getNumDims();
1454 submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
1455 llvm::append_range(dims,
1456 ArrayRef<Value>(submapOperands).take_front(numNewDims));
1457 llvm::append_range(symbols,
1458 ArrayRef<Value>(submapOperands).drop_front(numNewDims));
1459 exprs.push_back(submap.getResult(0));
1460 }
1461
1462 // Canonicalize the map created from composed expressions to deduplicate the
1463 // dimension and symbol operands.
1464 operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1465 map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
1466 canonicalizeMapAndOperands(&map, &operands);
1467}
1468
1471 bool composeAffineMin) {
1472 assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1473
1474 // Create new builder without a listener, so that no notification is
1475 // triggered if the op is folded.
1476 // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1477 // workaround is no longer needed.
1478 OpBuilder newBuilder(b.getContext());
1479 newBuilder.setInsertionPoint(b.getInsertionBlock(), b.getInsertionPoint());
1480
1481 // Create op.
1482 AffineApplyOp applyOp =
1483 makeComposedAffineApply(newBuilder, loc, map, operands, composeAffineMin);
1484
1485 // Get constant operands.
1486 SmallVector<Attribute> constOperands(applyOp->getNumOperands());
1487 for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1488 matchPattern(applyOp->getOperand(i), m_Constant(&constOperands[i]));
1489
1490 // Try to fold the operation.
1491 SmallVector<OpFoldResult> foldResults;
1492 if (failed(applyOp->fold(constOperands, foldResults)) ||
1493 foldResults.empty()) {
1494 if (OpBuilder::Listener *listener = b.getListener())
1495 listener->notifyOperationInserted(applyOp, /*previous=*/{});
1496 return applyOp.getResult();
1497 }
1498
1499 applyOp->erase();
1500 return llvm::getSingleElement(foldResults);
1501}
1502
1504 OpBuilder &b, Location loc, AffineExpr expr,
1505 ArrayRef<OpFoldResult> operands, bool composeAffineMin) {
1507 b, loc,
1509 .front(),
1510 operands, composeAffineMin);
1511}
1512
1516 bool composeAffineMin) {
1517 return llvm::map_to_vector(
1518 llvm::seq<unsigned>(0, map.getNumResults()), [&](unsigned i) {
1519 return makeComposedFoldedAffineApply(b, loc, map.getSubMap({i}),
1520 operands, composeAffineMin);
1521 });
1522}
1523
1524template <typename OpTy>
1526 ArrayRef<OpFoldResult> operands) {
1527 SmallVector<Value> valueOperands;
1528 map = foldAttributesIntoMap(b, map, operands, valueOperands);
1529 composeMultiResultAffineMap(map, valueOperands);
1530 return OpTy::create(b, loc, b.getIndexType(), map, valueOperands);
1531}
1532
1533AffineMinOp
1538
1539template <typename OpTy>
1541 AffineMap map,
1542 ArrayRef<OpFoldResult> operands) {
1543 // Create new builder without a listener, so that no notification is
1544 // triggered if the op is folded.
1545 // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1546 // workaround is no longer needed.
1547 OpBuilder newBuilder(b.getContext());
1548 newBuilder.setInsertionPoint(b.getInsertionBlock(), b.getInsertionPoint());
1549
1550 // Create op.
1551 auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1552
1553 // Get constant operands.
1554 SmallVector<Attribute> constOperands(minMaxOp->getNumOperands());
1555 for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1556 matchPattern(minMaxOp->getOperand(i), m_Constant(&constOperands[i]));
1557
1558 // Try to fold the operation.
1559 SmallVector<OpFoldResult> foldResults;
1560 if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1561 foldResults.empty()) {
1562 if (OpBuilder::Listener *listener = b.getListener())
1563 listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
1564 return minMaxOp.getResult();
1565 }
1566
1567 minMaxOp->erase();
1568 return llvm::getSingleElement(foldResults);
1569}
1570
1577
1584
1585// A symbol may appear as a dim in affine.apply operations. This function
1586// canonicalizes dims that are valid symbols into actual symbols.
1587template <class MapOrSet>
1588static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1589 SmallVectorImpl<Value> *operands) {
1590 if (!mapOrSet || operands->empty())
1591 return;
1592
1593 assert(mapOrSet->getNumInputs() == operands->size() &&
1594 "map/set inputs must match number of operands");
1595
1596 auto *context = mapOrSet->getContext();
1597 SmallVector<Value, 8> resultOperands;
1598 resultOperands.reserve(operands->size());
1599 SmallVector<Value, 8> remappedSymbols;
1600 remappedSymbols.reserve(operands->size());
1601 unsigned nextDim = 0;
1602 unsigned nextSym = 0;
1603 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1604 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1605 for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1606 if (i < mapOrSet->getNumDims()) {
1607 if (isValidSymbol((*operands)[i])) {
1608 // This is a valid symbol that appears as a dim, canonicalize it.
1609 dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1610 remappedSymbols.push_back((*operands)[i]);
1611 } else {
1612 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1613 resultOperands.push_back((*operands)[i]);
1614 }
1615 } else {
1616 resultOperands.push_back((*operands)[i]);
1617 }
1618 }
1619
1620 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1621 *operands = resultOperands;
1622 *mapOrSet = mapOrSet->replaceDimsAndSymbols(
1623 dimRemapping, /*symReplacements=*/{}, nextDim, oldNumSyms + nextSym);
1624
1625 assert(mapOrSet->getNumInputs() == operands->size() &&
1626 "map/set inputs must match number of operands");
1627}
1628
1629/// A valid affine dimension may appear as a symbol in affine.apply operations.
1630/// Given an application of `operands` to an affine map or integer set
1631/// `mapOrSet`, this function canonicalizes symbols of `mapOrSet` that are valid
1632/// dims, but not valid symbols into actual dims. Without such a legalization,
1633/// the affine.apply will be invalid. This method is the exact inverse of
1634/// canonicalizePromotedSymbols.
1635template <class MapOrSet>
1636static void legalizeDemotedDims(MapOrSet &mapOrSet,
1637 SmallVectorImpl<Value> &operands) {
1638 if (!mapOrSet || operands.empty())
1639 return;
1640
1641 unsigned numOperands = operands.size();
1642
1643 assert(mapOrSet.getNumInputs() == numOperands &&
1644 "map/set inputs must match number of operands");
1645
1646 auto *context = mapOrSet.getContext();
1647 SmallVector<Value, 8> resultOperands;
1648 resultOperands.reserve(numOperands);
1649 SmallVector<Value, 8> remappedDims;
1650 remappedDims.reserve(numOperands);
1651 SmallVector<Value, 8> symOperands;
1652 symOperands.reserve(mapOrSet.getNumSymbols());
1653 unsigned nextSym = 0;
1654 unsigned nextDim = 0;
1655 unsigned oldNumDims = mapOrSet.getNumDims();
1656 SmallVector<AffineExpr, 8> symRemapping(mapOrSet.getNumSymbols());
1657 resultOperands.assign(operands.begin(), operands.begin() + oldNumDims);
1658 for (unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
1659 if (operands[i] && isValidDim(operands[i]) && !isValidSymbol(operands[i])) {
1660 // This is a valid dim that appears as a symbol, legalize it.
1661 symRemapping[i - oldNumDims] =
1662 getAffineDimExpr(oldNumDims + nextDim++, context);
1663 remappedDims.push_back(operands[i]);
1664 } else {
1665 symRemapping[i - oldNumDims] = getAffineSymbolExpr(nextSym++, context);
1666 symOperands.push_back(operands[i]);
1667 }
1668 }
1669
1670 append_range(resultOperands, remappedDims);
1671 append_range(resultOperands, symOperands);
1672 operands = resultOperands;
1673 mapOrSet = mapOrSet.replaceDimsAndSymbols(
1674 /*dimReplacements=*/{}, symRemapping, oldNumDims + nextDim, nextSym);
1675
1676 assert(mapOrSet.getNumInputs() == operands.size() &&
1677 "map/set inputs must match number of operands");
1678}
1679
1680// Works for either an affine map or an integer set.
1681template <class MapOrSet>
1682static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1683 SmallVectorImpl<Value> *operands) {
1684 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1685 "Argument must be either of AffineMap or IntegerSet type");
1686
1687 if (!mapOrSet || operands->empty())
1688 return;
1689
1690 assert(mapOrSet->getNumInputs() == operands->size() &&
1691 "map/set inputs must match number of operands");
1692
1693 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1694 legalizeDemotedDims<MapOrSet>(*mapOrSet, *operands);
1695
1696 // Check to see what dims are used.
1697 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1698 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1699 mapOrSet->walkExprs([&](AffineExpr expr) {
1700 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1701 usedDims[dimExpr.getPosition()] = true;
1702 else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1703 usedSyms[symExpr.getPosition()] = true;
1704 });
1705
1706 auto *context = mapOrSet->getContext();
1707
1708 SmallVector<Value, 8> resultOperands;
1709 resultOperands.reserve(operands->size());
1710
1711 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1712 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1713 unsigned nextDim = 0;
1714 for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1715 if (usedDims[i]) {
1716 // Remap dim positions for duplicate operands.
1717 auto it = seenDims.find((*operands)[i]);
1718 if (it == seenDims.end()) {
1719 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1720 resultOperands.push_back((*operands)[i]);
1721 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1722 } else {
1723 dimRemapping[i] = it->second;
1724 }
1725 }
1726 }
1727 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1728 SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1729 unsigned nextSym = 0;
1730 for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1731 if (!usedSyms[i])
1732 continue;
1733 // Handle constant operands (only needed for symbolic operands since
1734 // constant operands in dimensional positions would have already been
1735 // promoted to symbolic positions above).
1736 IntegerAttr operandCst;
1737 if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1738 m_Constant(&operandCst))) {
1739 symRemapping[i] =
1740 getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1741 continue;
1742 }
1743 // Remap symbol positions for duplicate operands.
1744 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1745 if (it == seenSymbols.end()) {
1746 symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1747 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1748 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1749 symRemapping[i]));
1750 } else {
1751 symRemapping[i] = it->second;
1752 }
1753 }
1754 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1755 nextDim, nextSym);
1756 *operands = resultOperands;
1757}
1758
1763
1768
1769namespace {
1770/// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1771/// maps that supply results into them.
1772///
1773template <typename AffineOpTy>
1774struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1775 using OpRewritePattern<AffineOpTy>::OpRewritePattern;
1776
1777 /// Replace the affine op with another instance of it with the supplied
1778 /// map and mapOperands.
1779 void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1780 AffineMap map, ArrayRef<Value> mapOperands) const;
1781
1782 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1783 PatternRewriter &rewriter) const override {
1784 static_assert(
1785 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1786 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1787 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1788 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1789 "expected");
1790 auto map = affineOp.getAffineMap();
1791 AffineMap oldMap = map;
1792 auto oldOperands = affineOp.getMapOperands();
1793 SmallVector<Value, 8> resultOperands(oldOperands);
1794 composeAffineMapAndOperands(&map, &resultOperands);
1795 canonicalizeMapAndOperands(&map, &resultOperands);
1796 simplifyMapWithOperands(map, resultOperands);
1797 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1798 resultOperands.begin()))
1799 return failure();
1800
1801 replaceAffineOp(rewriter, affineOp, map, resultOperands);
1802 return success();
1803 }
1804};
1805
1806// Specialize the template to account for the different build signatures for
1807// affine load, store, and apply ops.
1808template <>
1809void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1810 PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1811 ArrayRef<Value> mapOperands) const {
1812 rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1813 mapOperands);
1814}
1815template <>
1816void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1817 PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1818 ArrayRef<Value> mapOperands) const {
1819 rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1820 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1821 prefetch.getLocalityHint(), prefetch.getIsDataCache());
1822}
1823template <>
1824void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1825 PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1826 ArrayRef<Value> mapOperands) const {
1827 rewriter.replaceOpWithNewOp<AffineStoreOp>(
1828 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1829}
1830template <>
1831void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1832 PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1833 ArrayRef<Value> mapOperands) const {
1834 rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1835 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1836 mapOperands);
1837}
1838template <>
1839void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1840 PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1841 ArrayRef<Value> mapOperands) const {
1842 rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1843 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1844 mapOperands);
1845}
1846
1847// Generic version for ops that don't have extra operands.
1848template <typename AffineOpTy>
1849void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1850 PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1851 ArrayRef<Value> mapOperands) const {
1852 rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1853}
1854} // namespace
1855
1856void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1857 MLIRContext *context) {
1858 results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1859}
1860
1861//===----------------------------------------------------------------------===//
1862// AffineDmaStartOp
1863//===----------------------------------------------------------------------===//
1864
1865// TODO: Check that map operands are loop IVs or symbols.
1867 Value srcMemRef, AffineMap srcMap,
1868 ValueRange srcIndices, Value destMemRef,
1869 AffineMap dstMap, ValueRange destIndices,
1870 Value tagMemRef, AffineMap tagMap,
1871 ValueRange tagIndices, Value numElements,
1872 Value stride, Value elementsPerStride) {
1873 result.addOperands(srcMemRef);
1874 result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1875 result.addOperands(srcIndices);
1876 result.addOperands(destMemRef);
1877 result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1878 result.addOperands(destIndices);
1879 result.addOperands(tagMemRef);
1880 result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1881 result.addOperands(tagIndices);
1882 result.addOperands(numElements);
1883 if (stride) {
1884 result.addOperands({stride, elementsPerStride});
1885 }
1886}
1887
1889 OpBuilder &builder, Location location, Value srcMemRef, AffineMap srcMap,
1890 ValueRange srcIndices, Value destMemRef, AffineMap dstMap,
1891 ValueRange destIndices, Value tagMemRef, AffineMap tagMap,
1892 ValueRange tagIndices, Value numElements, Value stride,
1893 Value elementsPerStride) {
1894 mlir::OperationState state(location, getOperationName());
1895 build(builder, state, srcMemRef, srcMap, srcIndices, destMemRef, dstMap,
1896 destIndices, tagMemRef, tagMap, tagIndices, numElements, stride,
1897 elementsPerStride);
1898 auto result = dyn_cast<AffineDmaStartOp>(builder.create(state));
1899 assert(result && "builder didn't return the right type");
1900 return result;
1901}
1902
1904 ImplicitLocOpBuilder &builder, Value srcMemRef, AffineMap srcMap,
1905 ValueRange srcIndices, Value destMemRef, AffineMap dstMap,
1906 ValueRange destIndices, Value tagMemRef, AffineMap tagMap,
1907 ValueRange tagIndices, Value numElements, Value stride,
1908 Value elementsPerStride) {
1909 return create(builder, builder.getLoc(), srcMemRef, srcMap, srcIndices,
1910 destMemRef, dstMap, destIndices, tagMemRef, tagMap, tagIndices,
1911 numElements, stride, elementsPerStride);
1912}
1913
1915 p << " " << getSrcMemRef() << '[';
1917 p << "], " << getDstMemRef() << '[';
1919 p << "], " << getTagMemRef() << '[';
1921 p << "], " << getNumElements();
1922 if (isStrided()) {
1923 p << ", " << getStride();
1924 p << ", " << getNumElementsPerStride();
1925 }
1926 p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1927 << getTagMemRefType();
1928}
1929
1930// Parse AffineDmaStartOp.
1931// Ex:
1932// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1933// %stride, %num_elt_per_stride
1934// : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1935//
1938 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1939 AffineMapAttr srcMapAttr;
1941 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1942 AffineMapAttr dstMapAttr;
1944 OpAsmParser::UnresolvedOperand tagMemRefInfo;
1945 AffineMapAttr tagMapAttr;
1947 OpAsmParser::UnresolvedOperand numElementsInfo;
1949
1951 auto indexType = parser.getBuilder().getIndexType();
1952
1953 // Parse and resolve the following list of operands:
1954 // *) dst memref followed by its affine maps operands (in square brackets).
1955 // *) src memref followed by its affine map operands (in square brackets).
1956 // *) tag memref followed by its affine map operands (in square brackets).
1957 // *) number of elements transferred by DMA operation.
1958 if (parser.parseOperand(srcMemRefInfo) ||
1959 parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1961 result.attributes) ||
1962 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1963 parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1965 result.attributes) ||
1966 parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1967 parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1969 result.attributes) ||
1970 parser.parseComma() || parser.parseOperand(numElementsInfo))
1971 return failure();
1972
1973 // Parse optional stride and elements per stride.
1974 if (parser.parseTrailingOperandList(strideInfo))
1975 return failure();
1976
1977 if (!strideInfo.empty() && strideInfo.size() != 2) {
1978 return parser.emitError(parser.getNameLoc(),
1979 "expected two stride related operands");
1980 }
1981 bool isStrided = strideInfo.size() == 2;
1982
1983 if (parser.parseColonTypeList(types))
1984 return failure();
1985
1986 if (types.size() != 3)
1987 return parser.emitError(parser.getNameLoc(), "expected three types");
1988
1989 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1990 parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1991 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1992 parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1993 parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1994 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1995 parser.resolveOperand(numElementsInfo, indexType, result.operands))
1996 return failure();
1997
1998 if (isStrided) {
1999 if (parser.resolveOperands(strideInfo, indexType, result.operands))
2000 return failure();
2001 }
2002
2003 // Check that src/dst/tag operand counts match their map.numInputs.
2004 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
2005 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
2006 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
2007 return parser.emitError(parser.getNameLoc(),
2008 "memref operand count not equal to map.numInputs");
2009 return success();
2010}
2011
2013 if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
2014 return emitOpError("expected DMA source to be of memref type");
2015 if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
2016 return emitOpError("expected DMA destination to be of memref type");
2017 if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
2018 return emitOpError("expected DMA tag to be of memref type");
2019
2020 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
2023 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
2024 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
2025 return emitOpError("incorrect number of operands");
2026 }
2027
2028 Region *scope = getAffineScope(*this);
2029 for (auto idx : getSrcIndices()) {
2030 if (!idx.getType().isIndex())
2031 return emitOpError("src index to dma_start must have 'index' type");
2032 if (!isValidAffineIndexOperand(idx, scope))
2033 return emitOpError(
2034 "src index must be a valid dimension or symbol identifier");
2035 }
2036 for (auto idx : getDstIndices()) {
2037 if (!idx.getType().isIndex())
2038 return emitOpError("dst index to dma_start must have 'index' type");
2039 if (!isValidAffineIndexOperand(idx, scope))
2040 return emitOpError(
2041 "dst index must be a valid dimension or symbol identifier");
2042 }
2043 for (auto idx : getTagIndices()) {
2044 if (!idx.getType().isIndex())
2045 return emitOpError("tag index to dma_start must have 'index' type");
2046 if (!isValidAffineIndexOperand(idx, scope))
2047 return emitOpError(
2048 "tag index must be a valid dimension or symbol identifier");
2049 }
2050 return success();
2051}
2052
2055 /// dma_start(memrefcast) -> dma_start
2056 return memref::foldMemRefCast(*this);
2057}
2058
2069
2070//===----------------------------------------------------------------------===//
2071// AffineDmaWaitOp
2072//===----------------------------------------------------------------------===//
2073
2074// TODO: Check that map operands are loop IVs or symbols.
2076 Value tagMemRef, AffineMap tagMap,
2077 ValueRange tagIndices, Value numElements) {
2078 result.addOperands(tagMemRef);
2079 result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
2080 result.addOperands(tagIndices);
2081 result.addOperands(numElements);
2082}
2083
2085 Value tagMemRef, AffineMap tagMap,
2086 ValueRange tagIndices,
2087 Value numElements) {
2088 mlir::OperationState state(location, getOperationName());
2089 build(builder, state, tagMemRef, tagMap, tagIndices, numElements);
2090 auto result = dyn_cast<AffineDmaWaitOp>(builder.create(state));
2091 assert(result && "builder didn't return the right type");
2092 return result;
2093}
2094
2096 Value tagMemRef, AffineMap tagMap,
2097 ValueRange tagIndices,
2098 Value numElements) {
2099 return create(builder, builder.getLoc(), tagMemRef, tagMap, tagIndices,
2100 numElements);
2101}
2102
2104 p << " " << getTagMemRef() << '[';
2107 p << "], ";
2109 p << " : " << getTagMemRef().getType();
2110}
2111
2112// Parse AffineDmaWaitOp.
2113// Eg:
2114// affine.dma_wait %tag[%index], %num_elements
2115// : memref<1 x i32, (d0) -> (d0), 4>
2116//
2119 OpAsmParser::UnresolvedOperand tagMemRefInfo;
2120 AffineMapAttr tagMapAttr;
2122 Type type;
2123 auto indexType = parser.getBuilder().getIndexType();
2124 OpAsmParser::UnresolvedOperand numElementsInfo;
2125
2126 // Parse tag memref, its map operands, and dma size.
2127 if (parser.parseOperand(tagMemRefInfo) ||
2128 parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
2130 result.attributes) ||
2131 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
2132 parser.parseColonType(type) ||
2133 parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
2134 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
2135 parser.resolveOperand(numElementsInfo, indexType, result.operands))
2136 return failure();
2137
2138 if (!llvm::isa<MemRefType>(type))
2139 return parser.emitError(parser.getNameLoc(),
2140 "expected tag to be of memref type");
2141
2142 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
2143 return parser.emitError(parser.getNameLoc(),
2144 "tag memref operand count != to map.numInputs");
2145 return success();
2146}
2147
2149 if (!llvm::isa<MemRefType>(getOperand(0).getType()))
2150 return emitOpError("expected DMA tag to be of memref type");
2151 Region *scope = getAffineScope(*this);
2152 for (auto idx : getTagIndices()) {
2153 if (!idx.getType().isIndex())
2154 return emitOpError("index to dma_wait must have 'index' type");
2155 if (!isValidAffineIndexOperand(idx, scope))
2156 return emitOpError(
2157 "index must be a valid dimension or symbol identifier");
2158 }
2159 return success();
2160}
2161
2164 /// dma_wait(memrefcast) -> dma_wait
2165 return memref::foldMemRefCast(*this);
2166}
2167
2174
2175//===----------------------------------------------------------------------===//
2176// AffineForOp
2177//===----------------------------------------------------------------------===//
2178
2179/// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
2180/// bodyBuilder are empty/null, we include default terminator op.
2181void AffineForOp::build(OpBuilder &builder, OperationState &result,
2182 ValueRange lbOperands, AffineMap lbMap,
2183 ValueRange ubOperands, AffineMap ubMap, int64_t step,
2184 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
2185 assert(((!lbMap && lbOperands.empty()) ||
2186 lbOperands.size() == lbMap.getNumInputs()) &&
2187 "lower bound operand count does not match the affine map");
2188 assert(((!ubMap && ubOperands.empty()) ||
2189 ubOperands.size() == ubMap.getNumInputs()) &&
2190 "upper bound operand count does not match the affine map");
2191 assert(step > 0 && "step has to be a positive integer constant");
2192
2193 OpBuilder::InsertionGuard guard(builder);
2194
2195 // Set variadic segment sizes.
2196 result.addAttribute(
2197 getOperandSegmentSizeAttr(),
2198 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()),
2199 static_cast<int32_t>(ubOperands.size()),
2200 static_cast<int32_t>(iterArgs.size())}));
2201
2202 for (Value val : iterArgs)
2203 result.addTypes(val.getType());
2204
2205 // Add an attribute for the step.
2206 result.addAttribute(getStepAttrName(result.name),
2207 builder.getIntegerAttr(builder.getIndexType(), step));
2208
2209 // Add the lower bound.
2210 result.addAttribute(getLowerBoundMapAttrName(result.name),
2211 AffineMapAttr::get(lbMap));
2212 result.addOperands(lbOperands);
2213
2214 // Add the upper bound.
2215 result.addAttribute(getUpperBoundMapAttrName(result.name),
2216 AffineMapAttr::get(ubMap));
2217 result.addOperands(ubOperands);
2218
2219 result.addOperands(iterArgs);
2220 // Create a region and a block for the body. The argument of the region is
2221 // the loop induction variable.
2222 Region *bodyRegion = result.addRegion();
2223 Block *bodyBlock = builder.createBlock(bodyRegion);
2224 Value inductionVar =
2225 bodyBlock->addArgument(builder.getIndexType(), result.location);
2226 for (Value val : iterArgs)
2227 bodyBlock->addArgument(val.getType(), val.getLoc());
2228
2229 // Create the default terminator if the builder is not provided and if the
2230 // iteration arguments are not provided. Otherwise, leave this to the caller
2231 // because we don't know which values to return from the loop.
2232 if (iterArgs.empty() && !bodyBuilder) {
2233 ensureTerminator(*bodyRegion, builder, result.location);
2234 } else if (bodyBuilder) {
2235 OpBuilder::InsertionGuard guard(builder);
2236 builder.setInsertionPointToStart(bodyBlock);
2237 bodyBuilder(builder, result.location, inductionVar,
2238 bodyBlock->getArguments().drop_front());
2239 }
2240}
2241
2242void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
2243 int64_t ub, int64_t step, ValueRange iterArgs,
2244 BodyBuilderFn bodyBuilder) {
2245 auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
2246 auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
2247 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
2248 bodyBuilder);
2249}
2250
2251LogicalResult AffineForOp::verifyRegions() {
2252 // Check that the body defines as single block argument for the induction
2253 // variable.
2254 auto *body = getBody();
2255 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
2256 return emitOpError("expected body to have a single index argument for the "
2257 "induction variable");
2258
2259 // Verify that the bound operands are valid dimension/symbols.
2260 /// Lower bound.
2261 if (getLowerBoundMap().getNumInputs() > 0)
2263 getLowerBoundMap().getNumDims())))
2264 return failure();
2265 /// Upper bound.
2266 if (getUpperBoundMap().getNumInputs() > 0)
2268 getUpperBoundMap().getNumDims())))
2269 return failure();
2270 if (getLowerBoundMap().getNumResults() < 1)
2271 return emitOpError("expected lower bound map to have at least one result");
2272 if (getUpperBoundMap().getNumResults() < 1)
2273 return emitOpError("expected upper bound map to have at least one result");
2274
2275 unsigned opNumResults = getNumResults();
2276 if (opNumResults == 0)
2277 return success();
2278
2279 // If ForOp defines values, check that the number and types of the defined
2280 // values match ForOp initial iter operands and backedge basic block
2281 // arguments.
2282 if (getNumIterOperands() != opNumResults)
2283 return emitOpError(
2284 "mismatch between the number of loop-carried values and results");
2285 if (getNumRegionIterArgs() != opNumResults)
2286 return emitOpError(
2287 "mismatch between the number of basic block args and results");
2288
2289 return success();
2290}
2291
2292/// Parse a for operation loop bounds.
2293static ParseResult parseBound(bool isLower, OperationState &result,
2294 OpAsmParser &p) {
2295 // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
2296 // the map has multiple results.
2297 bool failedToParsedMinMax =
2298 failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
2299
2300 auto &builder = p.getBuilder();
2301 auto boundAttrStrName =
2302 isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)
2303 : AffineForOp::getUpperBoundMapAttrName(result.name);
2304
2305 // Parse ssa-id as identity map.
2307 if (p.parseOperandList(boundOpInfos))
2308 return failure();
2309
2310 if (!boundOpInfos.empty()) {
2311 // Check that only one operand was parsed.
2312 if (boundOpInfos.size() > 1)
2313 return p.emitError(p.getNameLoc(),
2314 "expected only one loop bound operand");
2315
2316 // TODO: improve error message when SSA value is not of index type.
2317 // Currently it is 'use of value ... expects different type than prior uses'
2318 if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
2319 result.operands))
2320 return failure();
2321
2322 // Create an identity map using symbol id. This representation is optimized
2323 // for storage. Analysis passes may expand it into a multi-dimensional map
2324 // if desired.
2325 AffineMap map = builder.getSymbolIdentityMap();
2326 result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
2327 return success();
2328 }
2329
2330 // Get the attribute location.
2331 SMLoc attrLoc = p.getCurrentLocation();
2332
2333 Attribute boundAttr;
2334 if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
2335 result.attributes))
2336 return failure();
2337
2338 // Parse full form - affine map followed by dim and symbol list.
2339 if (auto affineMapAttr = dyn_cast<AffineMapAttr>(boundAttr)) {
2340 unsigned currentNumOperands = result.operands.size();
2341 unsigned numDims;
2342 if (parseDimAndSymbolList(p, result.operands, numDims))
2343 return failure();
2344
2345 auto map = affineMapAttr.getValue();
2346 if (map.getNumDims() != numDims)
2347 return p.emitError(
2348 p.getNameLoc(),
2349 "dim operand count and affine map dim count must match");
2350
2351 unsigned numDimAndSymbolOperands =
2352 result.operands.size() - currentNumOperands;
2353 if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
2354 return p.emitError(
2355 p.getNameLoc(),
2356 "symbol operand count and affine map symbol count must match");
2357
2358 // If the map has multiple results, make sure that we parsed the min/max
2359 // prefix.
2360 if (map.getNumResults() > 1 && failedToParsedMinMax) {
2361 if (isLower) {
2362 return p.emitError(attrLoc, "lower loop bound affine map with "
2363 "multiple results requires 'max' prefix");
2364 }
2365 return p.emitError(attrLoc, "upper loop bound affine map with multiple "
2366 "results requires 'min' prefix");
2367 }
2368 return success();
2369 }
2370
2371 // Parse custom assembly form.
2372 if (auto integerAttr = dyn_cast<IntegerAttr>(boundAttr)) {
2373 result.attributes.pop_back();
2374 result.addAttribute(
2375 boundAttrStrName,
2376 AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
2377 return success();
2378 }
2379
2380 return p.emitError(
2381 p.getNameLoc(),
2382 "expected valid affine map representation for loop bounds");
2383}
2384
2385ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
2386 auto &builder = parser.getBuilder();
2387 OpAsmParser::Argument inductionVariable;
2388 inductionVariable.type = builder.getIndexType();
2389 // Parse the induction variable followed by '='.
2390 if (parser.parseArgument(inductionVariable) || parser.parseEqual())
2391 return failure();
2392
2393 // Parse loop bounds.
2394 int64_t numOperands = result.operands.size();
2395 if (parseBound(/*isLower=*/true, result, parser))
2396 return failure();
2397 int64_t numLbOperands = result.operands.size() - numOperands;
2398 if (parser.parseKeyword("to", " between bounds"))
2399 return failure();
2400 numOperands = result.operands.size();
2401 if (parseBound(/*isLower=*/false, result, parser))
2402 return failure();
2403 int64_t numUbOperands = result.operands.size() - numOperands;
2404
2405 // Parse the optional loop step, we default to 1 if one is not present.
2406 if (parser.parseOptionalKeyword("step")) {
2407 result.addAttribute(
2408 getStepAttrName(result.name),
2409 builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
2410 } else {
2411 SMLoc stepLoc = parser.getCurrentLocation();
2412 IntegerAttr stepAttr;
2413 if (parser.parseAttribute(stepAttr, builder.getIndexType(),
2414 getStepAttrName(result.name).data(),
2415 result.attributes))
2416 return failure();
2417
2418 if (stepAttr.getValue().isNegative())
2419 return parser.emitError(
2420 stepLoc,
2421 "expected step to be representable as a positive signed integer");
2422 }
2423
2424 // Parse the optional initial iteration arguments.
2425 SmallVector<OpAsmParser::Argument, 4> regionArgs;
2426 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2427
2428 // Induction variable.
2429 regionArgs.push_back(inductionVariable);
2430
2431 if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
2432 // Parse assignment list and results type list.
2433 if (parser.parseAssignmentList(regionArgs, operands) ||
2434 parser.parseArrowTypeList(result.types))
2435 return failure();
2436 // Resolve input operands.
2437 for (auto argOperandType :
2438 llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2439 Type type = std::get<2>(argOperandType);
2440 std::get<0>(argOperandType).type = type;
2441 if (parser.resolveOperand(std::get<1>(argOperandType), type,
2442 result.operands))
2443 return failure();
2444 }
2445 }
2446
2447 result.addAttribute(
2448 getOperandSegmentSizeAttr(),
2449 builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands),
2450 static_cast<int32_t>(numUbOperands),
2451 static_cast<int32_t>(operands.size())}));
2452
2453 // Parse the body region.
2454 Region *body = result.addRegion();
2455 if (regionArgs.size() != result.types.size() + 1)
2456 return parser.emitError(
2457 parser.getNameLoc(),
2458 "mismatch between the number of loop-carried values and results");
2459 if (parser.parseRegion(*body, regionArgs))
2460 return failure();
2461
2462 AffineForOp::ensureTerminator(*body, builder, result.location);
2463
2464 // Parse the optional attribute list.
2465 return parser.parseOptionalAttrDict(result.attributes);
2466}
2467
2468static void printBound(AffineMapAttr boundMap,
2469 Operation::operand_range boundOperands,
2470 const char *prefix, OpAsmPrinter &p) {
2471 AffineMap map = boundMap.getValue();
2472
2473 // Check if this bound should be printed using custom assembly form.
2474 // The decision to restrict printing custom assembly form to trivial cases
2475 // comes from the will to roundtrip MLIR binary -> text -> binary in a
2476 // lossless way.
2477 // Therefore, custom assembly form parsing and printing is only supported for
2478 // zero-operand constant maps and single symbol operand identity maps.
2479 if (map.getNumResults() == 1) {
2480 AffineExpr expr = map.getResult(0);
2481
2482 // Print constant bound.
2483 if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2484 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2485 p << constExpr.getValue();
2486 return;
2487 }
2488 }
2489
2490 // Print bound that consists of a single SSA symbol if the map is over a
2491 // single symbol.
2492 if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2493 if (isa<AffineSymbolExpr>(expr)) {
2494 p.printOperand(*boundOperands.begin());
2495 return;
2496 }
2497 }
2498 } else {
2499 // Map has multiple results. Print 'min' or 'max' prefix.
2500 p << prefix << ' ';
2501 }
2502
2503 // Print the map and its operands.
2504 p << boundMap;
2505 printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
2506 map.getNumDims(), p);
2507}
2508
2509unsigned AffineForOp::getNumIterOperands() {
2510 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2511 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2512
2513 return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2514}
2515
2516std::optional<MutableArrayRef<OpOperand>>
2517AffineForOp::getYieldedValuesMutable() {
2518 return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2519}
2520
2521void AffineForOp::print(OpAsmPrinter &p) {
2522 p << ' ';
2523 p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
2524 /*omitType=*/true);
2525 p << " = ";
2526 printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
2527 p << " to ";
2528 printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
2529
2530 if (getStepAsInt() != 1)
2531 p << " step " << getStepAsInt();
2532
2533 bool printBlockTerminators = false;
2534 if (getNumIterOperands() > 0) {
2535 p << " iter_args(";
2536 auto regionArgs = getRegionIterArgs();
2537 auto operands = getInits();
2538
2539 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2540 p << std::get<0>(it) << " = " << std::get<1>(it);
2541 });
2542 p << ") -> (" << getResultTypes() << ")";
2543 printBlockTerminators = true;
2544 }
2545
2546 p << ' ';
2547 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2548 printBlockTerminators);
2550 (*this)->getAttrs(),
2551 /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()),
2552 getUpperBoundMapAttrName(getOperation()->getName()),
2553 getStepAttrName(getOperation()->getName()),
2554 getOperandSegmentSizeAttr()});
2555}
2556
2557/// Fold the constant bounds of a loop.
2558static LogicalResult foldLoopBounds(AffineForOp forOp) {
2559 auto foldLowerOrUpperBound = [&forOp](bool lower) {
2560 // Check to see if each of the operands is the result of a constant. If
2561 // so, get the value. If not, ignore it.
2562 SmallVector<Attribute, 8> operandConstants;
2563 auto boundOperands =
2564 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2565 for (auto operand : boundOperands) {
2566 Attribute operandCst;
2567 matchPattern(operand, m_Constant(&operandCst));
2568 operandConstants.push_back(operandCst);
2569 }
2570
2571 AffineMap boundMap =
2572 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2573 assert(boundMap.getNumResults() >= 1 &&
2574 "bound maps should have at least one result");
2575 SmallVector<Attribute, 4> foldedResults;
2576 if (failed(boundMap.constantFold(operandConstants, foldedResults)))
2577 return failure();
2578
2579 // Compute the max or min as applicable over the results.
2580 assert(!foldedResults.empty() && "bounds should have at least one result");
2581 auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2582 for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2583 auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2584 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2585 : llvm::APIntOps::smin(maxOrMin, foldedResult);
2586 }
2587 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2588 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2589 return success();
2590 };
2591
2592 // Try to fold the lower bound.
2593 bool folded = false;
2594 if (!forOp.hasConstantLowerBound())
2595 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
2596
2597 // Try to fold the upper bound.
2598 if (!forOp.hasConstantUpperBound())
2599 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
2600 return success(folded);
2601}
2602
2603/// Returns constant trip count in trivial cases.
2604static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2605 int64_t step = forOp.getStepAsInt();
2606 if (!forOp.hasConstantBounds() || step <= 0)
2607 return std::nullopt;
2608 int64_t lb = forOp.getConstantLowerBound();
2609 int64_t ub = forOp.getConstantUpperBound();
2610 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2611}
2612
2613/// Fold the empty loop.
2615 if (!llvm::hasSingleElement(*forOp.getBody()))
2616 return {};
2617 if (forOp.getNumResults() == 0)
2618 return {};
2619 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2620 if (tripCount == 0) {
2621 // The initial values of the iteration arguments would be the op's
2622 // results.
2623 return forOp.getInits();
2624 }
2625 SmallVector<Value, 4> replacements;
2626 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2627 auto iterArgs = forOp.getRegionIterArgs();
2628 bool hasValDefinedOutsideLoop = false;
2629 bool iterArgsNotInOrder = false;
2630 for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2631 Value val = yieldOp.getOperand(i);
2632 BlockArgument *iterArgIt = llvm::find(iterArgs, val);
2633 // TODO: It should be possible to perform a replacement by computing the
2634 // last value of the IV based on the bounds and the step.
2635 if (val == forOp.getInductionVar())
2636 return {};
2637 if (iterArgIt == iterArgs.end()) {
2638 // `val` is defined outside of the loop.
2639 assert(forOp.isDefinedOutsideOfLoop(val) &&
2640 "must be defined outside of the loop");
2641 hasValDefinedOutsideLoop = true;
2642 replacements.push_back(val);
2643 } else {
2644 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2645 if (pos != i)
2646 iterArgsNotInOrder = true;
2647 replacements.push_back(forOp.getInits()[pos]);
2648 }
2649 }
2650 // Bail out when the trip count is unknown and the loop returns any value
2651 // defined outside of the loop or any iterArg out of order.
2652 if (!tripCount.has_value() &&
2653 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2654 return {};
2655 // Bail out when the loop iterates more than once and it returns any iterArg
2656 // out of order.
2657 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2658 return {};
2659 return llvm::to_vector_of<OpFoldResult>(replacements);
2660}
2661
2662/// Canonicalize the bounds of the given loop.
2663static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2664 SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2665 SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2666
2667 auto lbMap = forOp.getLowerBoundMap();
2668 auto ubMap = forOp.getUpperBoundMap();
2669 auto prevLbMap = lbMap;
2670 auto prevUbMap = ubMap;
2671
2672 composeAffineMapAndOperands(&lbMap, &lbOperands);
2673 canonicalizeMapAndOperands(&lbMap, &lbOperands);
2674 simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
2675 simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
2676 lbMap = removeDuplicateExprs(lbMap);
2677
2678 composeAffineMapAndOperands(&ubMap, &ubOperands);
2679 canonicalizeMapAndOperands(&ubMap, &ubOperands);
2680 ubMap = removeDuplicateExprs(ubMap);
2681
2682 // Any canonicalization change always leads to updated map(s).
2683 if (lbMap == prevLbMap && ubMap == prevUbMap)
2684 return failure();
2685
2686 if (lbMap != prevLbMap)
2687 forOp.setLowerBound(lbOperands, lbMap);
2688 if (ubMap != prevUbMap)
2689 forOp.setUpperBound(ubOperands, ubMap);
2690 return success();
2691}
2692
2693/// Returns true if the affine.for has zero iterations in trivial cases.
2694static bool hasTrivialZeroTripCount(AffineForOp op) {
2695 return getTrivialConstantTripCount(op) == 0;
2696}
2697
2698LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2699 SmallVectorImpl<OpFoldResult> &results) {
2700 bool folded = succeeded(foldLoopBounds(*this));
2701 folded |= succeeded(canonicalizeLoopBounds(*this));
2702 if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
2703 // The initial values of the loop-carried variables (iter_args) are the
2704 // results of the op. But this must be avoided for an affine.for op that
2705 // does not return any results. Since ops that do not return results cannot
2706 // be folded away, we would enter an infinite loop of folds on the same
2707 // affine.for op.
2708 results.assign(getInits().begin(), getInits().end());
2709 folded = true;
2710 }
2711 SmallVector<OpFoldResult> foldResults = AffineForEmptyLoopFolder(*this);
2712 if (!foldResults.empty()) {
2713 results.assign(foldResults);
2714 folded = true;
2715 }
2716 return success(folded);
2717}
2718
2719OperandRange AffineForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2720 assert((successor.isParent() || successor.getSuccessor() == &getRegion()) &&
2721 "invalid region point");
2722
2723 // The initial operands map to the loop arguments after the induction
2724 // variable or are forwarded to the results when the trip count is zero.
2725 return getInits();
2726}
2727
2728void AffineForOp::getSuccessorRegions(
2729 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2730 assert((point.isParent() ||
2732 &getRegion()) &&
2733 "expected loop region");
2734 // The loop may typically branch back to its body or to the parent operation.
2735 // If the predecessor is the parent op and the trip count is known to be at
2736 // least one, branch into the body using the iterator arguments. And in cases
2737 // we know the trip count is zero, it can only branch back to its parent.
2738 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2739 if (tripCount.has_value()) {
2740 if (!point.isParent()) {
2741 // From the loop body, if the trip count is one, we can only branch back
2742 // to the parent.
2743 if (tripCount == 1) {
2744 regions.push_back(RegionSuccessor::parent());
2745 return;
2746 }
2747 if (tripCount == 0)
2748 return;
2749 } else {
2750 if (tripCount.value() > 0) {
2751 regions.push_back(RegionSuccessor(&getRegion()));
2752 return;
2753 }
2754 if (tripCount.value() == 0) {
2755 regions.push_back(RegionSuccessor::parent());
2756 return;
2757 }
2758 }
2759 }
2760
2761 // In all other cases, the loop may branch back to itself or the parent
2762 // operation.
2763 regions.push_back(RegionSuccessor(&getRegion()));
2764 regions.push_back(RegionSuccessor::parent());
2765}
2766
2767ValueRange AffineForOp::getSuccessorInputs(RegionSuccessor successor) {
2768 if (successor.isParent())
2769 return getResults();
2770 return getRegionIterArgs();
2771}
2772
2773AffineBound AffineForOp::getLowerBound() {
2774 return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
2775}
2776
2777AffineBound AffineForOp::getUpperBound() {
2778 return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap());
2779}
2780
2781void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2782 assert(lbOperands.size() == map.getNumInputs());
2783 assert(map.getNumResults() >= 1 && "bound map has at least one result");
2784 getLowerBoundOperandsMutable().assign(lbOperands);
2785 setLowerBoundMap(map);
2786}
2787
2788void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2789 assert(ubOperands.size() == map.getNumInputs());
2790 assert(map.getNumResults() >= 1 && "bound map has at least one result");
2791 getUpperBoundOperandsMutable().assign(ubOperands);
2792 setUpperBoundMap(map);
2793}
2794
2795bool AffineForOp::hasConstantLowerBound() {
2796 return getLowerBoundMap().isSingleConstant();
2797}
2798
2799bool AffineForOp::hasConstantUpperBound() {
2800 return getUpperBoundMap().isSingleConstant();
2801}
2802
2803int64_t AffineForOp::getConstantLowerBound() {
2804 return getLowerBoundMap().getSingleConstantResult();
2805}
2806
2807int64_t AffineForOp::getConstantUpperBound() {
2808 return getUpperBoundMap().getSingleConstantResult();
2809}
2810
2811void AffineForOp::setConstantLowerBound(int64_t value) {
2812 setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2813}
2814
2815void AffineForOp::setConstantUpperBound(int64_t value) {
2816 setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2817}
2818
2819AffineForOp::operand_range AffineForOp::getControlOperands() {
2820 return {operand_begin(), operand_begin() + getLowerBoundOperands().size() +
2821 getUpperBoundOperands().size()};
2822}
2823
2824bool AffineForOp::matchingBoundOperandList() {
2825 auto lbMap = getLowerBoundMap();
2826 auto ubMap = getUpperBoundMap();
2827 if (lbMap.getNumDims() != ubMap.getNumDims() ||
2828 lbMap.getNumSymbols() != ubMap.getNumSymbols())
2829 return false;
2830
2831 unsigned numOperands = lbMap.getNumInputs();
2832 for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2833 // Compare Value 's.
2834 if (getOperand(i) != getOperand(numOperands + i))
2835 return false;
2836 }
2837 return true;
2838}
2839
2840SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
2841
2842std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2843 return SmallVector<Value>{getInductionVar()};
2844}
2845
2846std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2847 if (!hasConstantLowerBound())
2848 return std::nullopt;
2849 OpBuilder b(getContext());
2850 return SmallVector<OpFoldResult>{
2851 OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2852}
2853
2854std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2855 OpBuilder b(getContext());
2856 return SmallVector<OpFoldResult>{
2857 OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
2858}
2859
2860std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2861 if (!hasConstantUpperBound())
2862 return {};
2863 OpBuilder b(getContext());
2864 return SmallVector<OpFoldResult>{
2865 OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2866}
2867
2868FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2869 RewriterBase &rewriter, ValueRange newInitOperands,
2870 bool replaceInitOperandUsesInLoop,
2871 const NewYieldValuesFn &newYieldValuesFn) {
2872 // Create a new loop before the existing one, with the extra operands.
2873 OpBuilder::InsertionGuard g(rewriter);
2874 rewriter.setInsertionPoint(getOperation());
2875 auto inits = llvm::to_vector(getInits());
2876 inits.append(newInitOperands.begin(), newInitOperands.end());
2877 AffineForOp newLoop = AffineForOp::create(
2878 rewriter, getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2879 getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
2880
2881 // Generate the new yield values and append them to the scf.yield operation.
2882 auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2883 ArrayRef<BlockArgument> newIterArgs =
2884 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2885 {
2886 OpBuilder::InsertionGuard g(rewriter);
2887 rewriter.setInsertionPoint(yieldOp);
2888 SmallVector<Value> newYieldedValues =
2889 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2890 assert(newInitOperands.size() == newYieldedValues.size() &&
2891 "expected as many new yield values as new iter operands");
2892 rewriter.modifyOpInPlace(yieldOp, [&]() {
2893 yieldOp.getOperandsMutable().append(newYieldedValues);
2894 });
2895 }
2896
2897 // Move the loop body to the new op.
2898 rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2899 newLoop.getBody()->getArguments().take_front(
2900 getBody()->getNumArguments()));
2901
2902 if (replaceInitOperandUsesInLoop) {
2903 // Replace all uses of `newInitOperands` with the corresponding basic block
2904 // arguments.
2905 for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2906 rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
2907 [&](OpOperand &use) {
2908 Operation *user = use.getOwner();
2909 return newLoop->isProperAncestor(user);
2910 });
2911 }
2912 }
2913
2914 // Replace the old loop.
2915 rewriter.replaceOp(getOperation(),
2916 newLoop->getResults().take_front(getNumResults()));
2917 return cast<LoopLikeOpInterface>(newLoop.getOperation());
2918}
2919
2920Speculation::Speculatability AffineForOp::getSpeculatability() {
2921 // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2922 // Start and End.
2923 //
2924 // For Step != 1, the loop may not terminate. We can add more smarts here if
2925 // needed.
2926 return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable
2928}
2929
2930/// Returns true if the provided value is the induction variable of a
2931/// AffineForOp.
2933 return getForInductionVarOwner(val) != AffineForOp();
2934}
2935
2939
2943
2945 auto ivArg = dyn_cast<BlockArgument>(val);
2946 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2947 return AffineForOp();
2948 if (auto forOp =
2949 ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2950 // Check to make sure `val` is the induction variable, not an iter_arg.
2951 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2952 return AffineForOp();
2953}
2954
2956 auto ivArg = dyn_cast<BlockArgument>(val);
2957 if (!ivArg || !ivArg.getOwner())
2958 return nullptr;
2959 Operation *containingOp = ivArg.getOwner()->getParentOp();
2960 auto parallelOp = dyn_cast_if_present<AffineParallelOp>(containingOp);
2961 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2962 return parallelOp;
2963 return nullptr;
2964}
2965
2966/// Extracts the induction variables from a list of AffineForOps and returns
2967/// them.
2970 ivs->reserve(forInsts.size());
2971 for (auto forInst : forInsts)
2972 ivs->push_back(forInst.getInductionVar());
2973}
2974
2977 ivs.reserve(affineOps.size());
2978 for (Operation *op : affineOps) {
2979 // Add constraints from forOp's bounds.
2980 if (auto forOp = dyn_cast<AffineForOp>(op))
2981 ivs.push_back(forOp.getInductionVar());
2982 else if (auto parallelOp = dyn_cast<AffineParallelOp>(op))
2983 for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2984 ivs.push_back(parallelOp.getBody()->getArgument(i));
2985 }
2986}
2987
2988/// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2989/// operations.
2990template <typename BoundListTy, typename LoopCreatorTy>
2992 OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2993 ArrayRef<int64_t> steps,
2994 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2995 LoopCreatorTy &&loopCreatorFn) {
2996 assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2997 assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2998
2999 // If there are no loops to be constructed, construct the body anyway.
3000 OpBuilder::InsertionGuard guard(builder);
3001 if (lbs.empty()) {
3002 if (bodyBuilderFn)
3003 bodyBuilderFn(builder, loc, ValueRange());
3004 return;
3005 }
3006
3007 // Create the loops iteratively and store the induction variables.
3009 ivs.reserve(lbs.size());
3010 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
3011 // Callback for creating the loop body, always creates the terminator.
3012 auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
3013 ValueRange iterArgs) {
3014 ivs.push_back(iv);
3015 // In the innermost loop, call the body builder.
3016 if (i == e - 1 && bodyBuilderFn) {
3017 OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
3018 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
3019 }
3020 AffineYieldOp::create(nestedBuilder, nestedLoc);
3021 };
3022
3023 // Delegate actual loop creation to the callback in order to dispatch
3024 // between constant- and variable-bound loops.
3025 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
3026 builder.setInsertionPointToStart(loop.getBody());
3027 }
3028}
3029
3030/// Creates an affine loop from the bounds known to be constants.
3031static AffineForOp
3033 int64_t ub, int64_t step,
3034 AffineForOp::BodyBuilderFn bodyBuilderFn) {
3035 return AffineForOp::create(builder, loc, lb, ub, step,
3036 /*iterArgs=*/ValueRange(), bodyBuilderFn);
3037}
3038
3039/// Creates an affine loop from the bounds that may or may not be constants.
3040static AffineForOp
3042 int64_t step,
3043 AffineForOp::BodyBuilderFn bodyBuilderFn) {
3044 std::optional<int64_t> lbConst = getConstantIntValue(lb);
3045 std::optional<int64_t> ubConst = getConstantIntValue(ub);
3046 if (lbConst && ubConst)
3047 return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
3048 ubConst.value(), step, bodyBuilderFn);
3049 return AffineForOp::create(builder, loc, lb, builder.getDimIdentityMap(), ub,
3050 builder.getDimIdentityMap(), step,
3051 /*iterArgs=*/ValueRange(), bodyBuilderFn);
3052}
3053
3055 OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
3057 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
3058 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
3060}
3061
3063 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
3064 ArrayRef<int64_t> steps,
3065 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
3066 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
3068}
3069
3070//===----------------------------------------------------------------------===//
3071// AffineIfOp
3072//===----------------------------------------------------------------------===//
3073
3074namespace {
3075/// Remove else blocks that have nothing other than a zero value yield.
3076struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
3077 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
3078
3079 LogicalResult matchAndRewrite(AffineIfOp ifOp,
3080 PatternRewriter &rewriter) const override {
3081 if (ifOp.getElseRegion().empty() ||
3082 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
3083 return failure();
3084
3085 rewriter.startOpModification(ifOp);
3086 rewriter.eraseBlock(ifOp.getElseBlock());
3087 rewriter.finalizeOpModification(ifOp);
3088 return success();
3089 }
3090};
3091
3092/// Removes affine.if cond if the condition is always true or false in certain
3093/// trivial cases. Promotes the then/else block in the parent operation block.
3094struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
3095 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
3096
3097 LogicalResult matchAndRewrite(AffineIfOp op,
3098 PatternRewriter &rewriter) const override {
3099
3100 auto isTriviallyFalse = [](IntegerSet iSet) {
3101 return iSet.isEmptyIntegerSet();
3102 };
3103
3104 auto isTriviallyTrue = [](IntegerSet iSet) {
3105 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
3106 iSet.getConstraint(0) == 0);
3107 };
3108
3109 IntegerSet affineIfConditions = op.getIntegerSet();
3110 Block *blockToMove;
3111 if (isTriviallyFalse(affineIfConditions)) {
3112 // The absence, or equivalently, the emptiness of the else region need not
3113 // be checked when affine.if is returning results because if an affine.if
3114 // operation is returning results, it always has a non-empty else region.
3115 if (op.getNumResults() == 0 && !op.hasElse()) {
3116 // If the else region is absent, or equivalently, empty, remove the
3117 // affine.if operation (which is not returning any results).
3118 rewriter.eraseOp(op);
3119 return success();
3120 }
3121 blockToMove = op.getElseBlock();
3122 } else if (isTriviallyTrue(affineIfConditions)) {
3123 blockToMove = op.getThenBlock();
3124 } else {
3125 return failure();
3126 }
3127 Operation *blockToMoveTerminator = blockToMove->getTerminator();
3128 // Promote the "blockToMove" block to the parent operation block between the
3129 // prologue and epilogue of "op".
3130 rewriter.inlineBlockBefore(blockToMove, op);
3131 // Replace the "op" operation with the operands of the
3132 // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
3133 // the affine.yield operation present in the "blockToMove" block. It has no
3134 // operands when affine.if is not returning results and therefore, in that
3135 // case, replaceOp just erases "op". When affine.if is not returning
3136 // results, the affine.yield operation can be omitted. It gets inserted
3137 // implicitly.
3138 rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
3139 // Erase the "blockToMoveTerminator" operation since it is now in the parent
3140 // operation block, which already has its own terminator.
3141 rewriter.eraseOp(blockToMoveTerminator);
3142 return success();
3143 }
3144};
3145} // namespace
3146
3147/// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
3148/// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
3149void AffineIfOp::getSuccessorRegions(
3150 RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
3151 // If the predecessor is an AffineIfOp, then branching into both `then` and
3152 // `else` region is valid.
3153 if (point.isParent()) {
3154 regions.reserve(2);
3155 regions.push_back(RegionSuccessor(&getThenRegion()));
3156 // If the "else" region is empty, branch bach into parent.
3157 if (getElseRegion().empty()) {
3158 regions.push_back(RegionSuccessor::parent());
3159 } else {
3160 regions.push_back(RegionSuccessor(&getElseRegion()));
3161 }
3162 return;
3163 }
3164
3165 // If the predecessor is the `else`/`then` region, then branching into parent
3166 // op is valid.
3167 regions.push_back(RegionSuccessor::parent());
3168}
3169
3170ValueRange AffineIfOp::getSuccessorInputs(RegionSuccessor successor) {
3171 if (successor.isParent())
3172 return getResults();
3173 if (successor == &getThenRegion())
3174 return getThenRegion().getArguments();
3175 if (successor == &getElseRegion())
3176 return getElseRegion().getArguments();
3177 llvm_unreachable("invalid region successor");
3178}
3179
3180LogicalResult AffineIfOp::verify() {
3181 // Verify that we have a condition attribute.
3182 // FIXME: This should be specified in the arguments list in ODS.
3183 auto conditionAttr =
3184 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3185 if (!conditionAttr)
3186 return emitOpError("requires an integer set attribute named 'condition'");
3187
3188 // Verify that there are enough operands for the condition.
3189 IntegerSet condition = conditionAttr.getValue();
3190 if (getNumOperands() != condition.getNumInputs())
3191 return emitOpError("operand count and condition integer set dimension and "
3192 "symbol count must match");
3193
3194 // Verify that the operands are valid dimension/symbols.
3195 if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
3196 condition.getNumDims())))
3197 return failure();
3198
3199 return success();
3200}
3201
3202ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
3203 // Parse the condition attribute set.
3204 IntegerSetAttr conditionAttr;
3205 unsigned numDims;
3206 if (parser.parseAttribute(conditionAttr,
3207 AffineIfOp::getConditionAttrStrName(),
3208 result.attributes) ||
3209 parseDimAndSymbolList(parser, result.operands, numDims))
3210 return failure();
3211
3212 // Verify the condition operands.
3213 auto set = conditionAttr.getValue();
3214 if (set.getNumDims() != numDims)
3215 return parser.emitError(
3216 parser.getNameLoc(),
3217 "dim operand count and integer set dim count must match");
3218 if (numDims + set.getNumSymbols() != result.operands.size())
3219 return parser.emitError(
3220 parser.getNameLoc(),
3221 "symbol operand count and integer set symbol count must match");
3222
3223 if (parser.parseOptionalArrowTypeList(result.types))
3224 return failure();
3225
3226 // Create the regions for 'then' and 'else'. The latter must be created even
3227 // if it remains empty for the validity of the operation.
3228 result.regions.reserve(2);
3229 Region *thenRegion = result.addRegion();
3230 Region *elseRegion = result.addRegion();
3231
3232 // Parse the 'then' region.
3233 if (parser.parseRegion(*thenRegion, {}, {}))
3234 return failure();
3235 AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
3236 result.location);
3237
3238 // If we find an 'else' keyword then parse the 'else' region.
3239 if (!parser.parseOptionalKeyword("else")) {
3240 if (parser.parseRegion(*elseRegion, {}, {}))
3241 return failure();
3242 AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
3243 result.location);
3244 }
3245
3246 // Parse the optional attribute list.
3247 if (parser.parseOptionalAttrDict(result.attributes))
3248 return failure();
3249
3250 return success();
3251}
3252
3253void AffineIfOp::print(OpAsmPrinter &p) {
3254 auto conditionAttr =
3255 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3256 p << " " << conditionAttr;
3257 printDimAndSymbolList(operand_begin(), operand_end(),
3258 conditionAttr.getValue().getNumDims(), p);
3259 p.printOptionalArrowTypeList(getResultTypes());
3260 p << ' ';
3261 p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
3262 /*printBlockTerminators=*/getNumResults());
3263
3264 // Print the 'else' regions if it has any blocks.
3265 auto &elseRegion = this->getElseRegion();
3266 if (!elseRegion.empty()) {
3267 p << " else ";
3268 p.printRegion(elseRegion,
3269 /*printEntryBlockArgs=*/false,
3270 /*printBlockTerminators=*/getNumResults());
3271 }
3272
3273 // Print the attribute list.
3274 p.printOptionalAttrDict((*this)->getAttrs(),
3275 /*elidedAttrs=*/getConditionAttrStrName());
3276}
3277
3278IntegerSet AffineIfOp::getIntegerSet() {
3279 return (*this)
3280 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
3281 .getValue();
3282}
3283
3284void AffineIfOp::setIntegerSet(IntegerSet newSet) {
3285 (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
3286}
3287
3288void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
3289 setIntegerSet(set);
3290 (*this)->setOperands(operands);
3291}
3292
3293void AffineIfOp::build(OpBuilder &builder, OperationState &result,
3294 TypeRange resultTypes, IntegerSet set, ValueRange args,
3295 bool withElseRegion) {
3296 assert(resultTypes.empty() || withElseRegion);
3297 OpBuilder::InsertionGuard guard(builder);
3298
3299 result.addTypes(resultTypes);
3300 result.addOperands(args);
3301 result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
3302
3303 Region *thenRegion = result.addRegion();
3304 builder.createBlock(thenRegion);
3305 if (resultTypes.empty())
3306 AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
3307
3308 Region *elseRegion = result.addRegion();
3309 if (withElseRegion) {
3310 builder.createBlock(elseRegion);
3311 if (resultTypes.empty())
3312 AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
3313 }
3314}
3315
3316void AffineIfOp::build(OpBuilder &builder, OperationState &result,
3317 IntegerSet set, ValueRange args, bool withElseRegion) {
3318 AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
3319 withElseRegion);
3320}
3321
3322/// Compose any affine.apply ops feeding into `operands` of the integer set
3323/// `set` by composing the maps of such affine.apply ops with the integer
3324/// set constraints.
3326 SmallVectorImpl<Value> &operands,
3327 bool composeAffineMin = false) {
3328 // We will simply reuse the API of the map composition by viewing the LHSs of
3329 // the equalities and inequalities of `set` as the affine exprs of an affine
3330 // map. Convert to equivalent map, compose, and convert back to set.
3331 auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(),
3332 set.getConstraints(), set.getContext());
3333 // Check if any composition is possible.
3334 if (llvm::none_of(operands,
3335 [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
3336 return;
3337
3338 composeAffineMapAndOperands(&map, &operands, composeAffineMin);
3339 set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
3340 set.getEqFlags());
3341}
3342
3343/// Canonicalize an affine if op's conditional (integer set + operands).
3344LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3345 auto set = getIntegerSet();
3346 SmallVector<Value, 4> operands(getOperands());
3347 composeSetAndOperands(set, operands);
3348 canonicalizeSetAndOperands(&set, &operands);
3349
3350 // Check if the canonicalization or composition led to any change.
3351 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
3352 return failure();
3353
3354 setConditional(set, operands);
3355 return success();
3356}
3357
3358void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
3359 MLIRContext *context) {
3360 results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
3361}
3362
3363//===----------------------------------------------------------------------===//
3364// AffineLoadOp
3365//===----------------------------------------------------------------------===//
3366
3367void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3368 AffineMap map, ValueRange operands) {
3369 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
3370 result.addOperands(operands);
3371 if (map)
3372 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3373 auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
3374 result.types.push_back(memrefType.getElementType());
3375}
3376
3377void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3378 Value memref, AffineMap map, ValueRange mapOperands) {
3379 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3380 result.addOperands(memref);
3381 result.addOperands(mapOperands);
3382 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3383 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3384 result.types.push_back(memrefType.getElementType());
3385}
3386
3387void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3388 Value memref, ValueRange indices) {
3389 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3390 int64_t rank = memrefType.getRank();
3391 // Create identity map for memrefs with at least one dimension or () -> ()
3392 // for zero-dimensional memrefs.
3393 auto map =
3394 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3395 build(builder, result, memref, map, indices);
3396}
3397
3398ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
3399 auto &builder = parser.getBuilder();
3400 auto indexTy = builder.getIndexType();
3401
3402 MemRefType type;
3403 OpAsmParser::UnresolvedOperand memrefInfo;
3404 AffineMapAttr mapAttr;
3405 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3406 return failure(
3407 parser.parseOperand(memrefInfo) ||
3408 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3409 AffineLoadOp::getMapAttrStrName(),
3410 result.attributes) ||
3411 parser.parseOptionalAttrDict(result.attributes) ||
3412 parser.parseColonType(type) ||
3413 parser.resolveOperand(memrefInfo, type, result.operands) ||
3414 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3415 parser.addTypeToList(type.getElementType(), result.types));
3416}
3417
3418void AffineLoadOp::print(OpAsmPrinter &p) {
3419 p << " " << getMemRef() << '[';
3420 if (AffineMapAttr mapAttr =
3421 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3422 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3423 p << ']';
3424 p.printOptionalAttrDict((*this)->getAttrs(),
3425 /*elidedAttrs=*/{getMapAttrStrName()});
3426 p << " : " << getMemRefType();
3427}
3428
3429/// Verify common indexing invariants of affine.load, affine.store,
3430/// affine.vector_load and affine.vector_store.
3431template <typename AffineMemOpTy>
3432static LogicalResult
3433verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr,
3434 Operation::operand_range mapOperands,
3435 MemRefType memrefType, unsigned numIndexOperands) {
3436 AffineMap map = mapAttr.getValue();
3437 if (map.getNumResults() != memrefType.getRank())
3438 return op->emitOpError("affine map num results must equal memref rank");
3439 if (map.getNumInputs() != numIndexOperands)
3440 return op->emitOpError("expects as many subscripts as affine map inputs");
3441
3442 for (auto idx : mapOperands) {
3443 if (!idx.getType().isIndex())
3444 return op->emitOpError("index to load must have 'index' type");
3445 }
3446 if (failed(verifyDimAndSymbolIdentifiers(op, mapOperands, map.getNumDims())))
3447 return failure();
3448
3449 return success();
3450}
3451
3452LogicalResult AffineLoadOp::verify() {
3453 auto memrefType = getMemRefType();
3454 if (getType() != memrefType.getElementType())
3455 return emitOpError("result type must match element type of memref");
3456
3458 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3459 getMapOperands(), memrefType,
3460 /*numIndexOperands=*/getNumOperands() - 1)))
3461 return failure();
3462
3463 return success();
3464}
3465
3466void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3467 MLIRContext *context) {
3468 results.add<SimplifyAffineOp<AffineLoadOp>>(context);
3469}
3470
3471OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3472 /// load(memrefcast) -> load
3473 if (succeeded(memref::foldMemRefCast(*this)))
3474 return getResult();
3475
3476 // Fold load from a global constant memref.
3477 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3478 if (!getGlobalOp)
3479 return {};
3480 // Get to the memref.global defining the symbol.
3481 auto *symbolTableOp = getGlobalOp->getParentWithTrait<OpTrait::SymbolTable>();
3482 if (!symbolTableOp)
3483 return {};
3484 auto global = dyn_cast_or_null<memref::GlobalOp>(
3485 SymbolTable::lookupSymbolIn(symbolTableOp, getGlobalOp.getNameAttr()));
3486 if (!global)
3487 return {};
3488
3489 // Check if the global memref is a constant.
3490 auto cstAttr =
3491 dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3492 if (!cstAttr)
3493 return {};
3494 // If it's a splat constant, we can fold irrespective of indices.
3495 if (auto splatAttr = dyn_cast<SplatElementsAttr>(cstAttr))
3496 return splatAttr.getSplatValue<Attribute>();
3497 // Otherwise, we can fold only if we know the indices.
3498 if (!getAffineMap().isConstant())
3499 return {};
3500 auto indices = llvm::to_vector<4>(
3501 llvm::map_range(getAffineMap().getConstantResults(),
3502 [](int64_t v) -> uint64_t { return v; }));
3503 return cstAttr.getValues<Attribute>()[indices];
3504}
3505
3506//===----------------------------------------------------------------------===//
3507// AffineStoreOp
3508//===----------------------------------------------------------------------===//
3509
3510void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3511 Value valueToStore, Value memref, AffineMap map,
3512 ValueRange mapOperands) {
3513 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3514 result.addOperands(valueToStore);
3515 result.addOperands(memref);
3516 result.addOperands(mapOperands);
3517 result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3518}
3519
3520// Use identity map.
3521void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3522 Value valueToStore, Value memref,
3524 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3525 int64_t rank = memrefType.getRank();
3526 // Create identity map for memrefs with at least one dimension or () -> ()
3527 // for zero-dimensional memrefs.
3528 auto map =
3529 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3530 build(builder, result, valueToStore, memref, map, indices);
3531}
3532
3533ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
3534 auto indexTy = parser.getBuilder().getIndexType();
3535
3536 MemRefType type;
3537 OpAsmParser::UnresolvedOperand storeValueInfo;
3538 OpAsmParser::UnresolvedOperand memrefInfo;
3539 AffineMapAttr mapAttr;
3540 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3541 return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3542 parser.parseOperand(memrefInfo) ||
3544 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3545 result.attributes) ||
3546 parser.parseOptionalAttrDict(result.attributes) ||
3547 parser.parseColonType(type) ||
3548 parser.resolveOperand(storeValueInfo, type.getElementType(),
3549 result.operands) ||
3550 parser.resolveOperand(memrefInfo, type, result.operands) ||
3551 parser.resolveOperands(mapOperands, indexTy, result.operands));
3552}
3553
3554void AffineStoreOp::print(OpAsmPrinter &p) {
3555 p << " " << getValueToStore();
3556 p << ", " << getMemRef() << '[';
3557 if (AffineMapAttr mapAttr =
3558 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3559 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3560 p << ']';
3561 p.printOptionalAttrDict((*this)->getAttrs(),
3562 /*elidedAttrs=*/{getMapAttrStrName()});
3563 p << " : " << getMemRefType();
3564}
3565
3566LogicalResult AffineStoreOp::verify() {
3567 // The value to store must have the same type as memref element type.
3568 auto memrefType = getMemRefType();
3569 if (getValueToStore().getType() != memrefType.getElementType())
3570 return emitOpError(
3571 "value to store must have the same type as memref element type");
3572
3574 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3575 getMapOperands(), memrefType,
3576 /*numIndexOperands=*/getNumOperands() - 2)))
3577 return failure();
3578
3579 return success();
3580}
3581
3582void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3583 MLIRContext *context) {
3584 results.add<SimplifyAffineOp<AffineStoreOp>>(context);
3585}
3586
3587LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3588 SmallVectorImpl<OpFoldResult> &results) {
3589 /// store(memrefcast) -> store
3590 return memref::foldMemRefCast(*this, getValueToStore());
3591}
3592
3593//===----------------------------------------------------------------------===//
3594// AffineMinMaxOpBase
3595//===----------------------------------------------------------------------===//
3596
3597template <typename T>
3598static LogicalResult verifyAffineMinMaxOp(T op) {
3599 // Verify that operand count matches affine map dimension and symbol count.
3600 if (op.getNumOperands() !=
3601 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3602 return op.emitOpError(
3603 "operand count and affine map dimension and symbol count must match");
3604
3605 if (op.getMap().getNumResults() == 0)
3606 return op.emitOpError("affine map expect at least one result");
3607 return success();
3608}
3609
3610template <typename T>
3611static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3612 p << ' ' << op->getAttr(T::getMapAttrStrName());
3613 auto operands = op.getOperands();
3614 unsigned numDims = op.getMap().getNumDims();
3615 p << '(' << operands.take_front(numDims) << ')';
3616
3617 if (operands.size() != numDims)
3618 p << '[' << operands.drop_front(numDims) << ']';
3619 p.printOptionalAttrDict(op->getAttrs(),
3620 /*elidedAttrs=*/{T::getMapAttrStrName()});
3621}
3622
3623template <typename T>
3624static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
3626 auto &builder = parser.getBuilder();
3627 auto indexType = builder.getIndexType();
3630 AffineMapAttr mapAttr;
3631 return failure(
3632 parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3633 result.attributes) ||
3635 parser.parseOperandList(symInfos,
3637 parser.parseOptionalAttrDict(result.attributes) ||
3638 parser.resolveOperands(dimInfos, indexType, result.operands) ||
3639 parser.resolveOperands(symInfos, indexType, result.operands) ||
3640 parser.addTypeToList(indexType, result.types));
3641}
3642
3643/// Fold an affine min or max operation with the given operands. The operand
3644/// list may contain nulls, which are interpreted as the operand not being a
3645/// constant.
3646template <typename T>
3648 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3649 "expected affine min or max op");
3650
3651 // Fold the affine map.
3652 // TODO: Fold more cases:
3653 // min(some_affine, some_affine + constant, ...), etc.
3655 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3656
3657 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3658 return op.getOperand(0);
3659
3660 // If some of the map results are not constant, try changing the map in-place.
3661 if (results.empty()) {
3662 // If the map is the same, report that folding did not happen.
3663 if (foldedMap == op.getMap())
3664 return {};
3665 op->setAttr("map", AffineMapAttr::get(foldedMap));
3666 return op.getResult();
3667 }
3668
3669 // Otherwise, completely fold the op into a constant.
3670 auto resultIt = std::is_same<T, AffineMinOp>::value
3671 ? llvm::min_element(results)
3672 : llvm::max_element(results);
3673 if (resultIt == results.end())
3674 return {};
3675 return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3676}
3677
3678/// Remove duplicated expressions in affine min/max ops.
3679template <typename T>
3682
3683 LogicalResult matchAndRewrite(T affineOp,
3684 PatternRewriter &rewriter) const override {
3685 AffineMap oldMap = affineOp.getAffineMap();
3686
3688 for (AffineExpr expr : oldMap.getResults()) {
3689 // This is a linear scan over newExprs, but it should be fine given that
3690 // we typically just have a few expressions per op.
3691 if (!llvm::is_contained(newExprs, expr))
3692 newExprs.push_back(expr);
3693 }
3694
3695 if (newExprs.size() == oldMap.getNumResults())
3696 return failure();
3697
3698 auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
3699 newExprs, rewriter.getContext());
3700 rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3701
3702 return success();
3703 }
3704};
3705
3706/// Merge an affine min/max op to its consumers if its consumer is also an
3707/// affine min/max op.
3708///
3709/// This pattern requires the producer affine min/max op is bound to a
3710/// dimension/symbol that is used as a standalone expression in the consumer
3711/// affine op's map.
3712///
3713/// For example, a pattern like the following:
3714///
3715/// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3716/// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3717///
3718/// Can be turned into:
3719///
3720/// %1 = affine.min affine_map<
3721/// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3722template <typename T>
3725
3726 LogicalResult matchAndRewrite(T affineOp,
3727 PatternRewriter &rewriter) const override {
3728 AffineMap oldMap = affineOp.getAffineMap();
3729 ValueRange dimOperands =
3730 affineOp.getMapOperands().take_front(oldMap.getNumDims());
3731 ValueRange symOperands =
3732 affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3733
3734 auto newDimOperands = llvm::to_vector<8>(dimOperands);
3735 auto newSymOperands = llvm::to_vector<8>(symOperands);
3737 SmallVector<T, 4> producerOps;
3738
3739 // Go over each expression to see whether it's a single dimension/symbol
3740 // with the corresponding operand which is the result of another affine
3741 // min/max op. If So it can be merged into this affine op.
3742 for (AffineExpr expr : oldMap.getResults()) {
3743 if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3744 Value symValue = symOperands[symExpr.getPosition()];
3745 if (auto producerOp = symValue.getDefiningOp<T>()) {
3746 producerOps.push_back(producerOp);
3747 continue;
3748 }
3749 } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3750 Value dimValue = dimOperands[dimExpr.getPosition()];
3751 if (auto producerOp = dimValue.getDefiningOp<T>()) {
3752 producerOps.push_back(producerOp);
3753 continue;
3754 }
3755 }
3756 // For the above cases we will remove the expression by merging the
3757 // producer affine min/max's affine expressions. Otherwise we need to
3758 // keep the existing expression.
3759 newExprs.push_back(expr);
3760 }
3761
3762 if (producerOps.empty())
3763 return failure();
3764
3765 unsigned numUsedDims = oldMap.getNumDims();
3766 unsigned numUsedSyms = oldMap.getNumSymbols();
3767
3768 // Now go over all producer affine ops and merge their expressions.
3769 for (T producerOp : producerOps) {
3770 AffineMap producerMap = producerOp.getAffineMap();
3771 unsigned numProducerDims = producerMap.getNumDims();
3772 unsigned numProducerSyms = producerMap.getNumSymbols();
3773
3774 // Collect all dimension/symbol values.
3775 ValueRange dimValues =
3776 producerOp.getMapOperands().take_front(numProducerDims);
3777 ValueRange symValues =
3778 producerOp.getMapOperands().take_back(numProducerSyms);
3779 newDimOperands.append(dimValues.begin(), dimValues.end());
3780 newSymOperands.append(symValues.begin(), symValues.end());
3781
3782 // For expressions we need to shift to avoid overlap.
3783 for (AffineExpr expr : producerMap.getResults()) {
3784 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3785 .shiftSymbols(numProducerSyms, numUsedSyms));
3786 }
3787
3788 numUsedDims += numProducerDims;
3789 numUsedSyms += numProducerSyms;
3790 }
3791
3792 auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
3793 rewriter.getContext());
3794 auto newOperands =
3795 llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3796 rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3797
3798 return success();
3799 }
3800};
3801
3802/// Canonicalize the result expression order of an affine map and return success
3803/// if the order changed.
3804///
3805/// The function flattens the map's affine expressions to coefficient arrays and
3806/// sorts them in lexicographic order. A coefficient array contains a multiplier
3807/// for every dimension/symbol and a constant term. The canonicalization fails
3808/// if a result expression is not pure or if the flattening requires local
3809/// variables that, unlike dimensions and symbols, have no global order.
3810static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
3811 SmallVector<SmallVector<int64_t>> flattenedExprs;
3812 for (const AffineExpr &resultExpr : map.getResults()) {
3813 // Fail if the expression is not pure.
3814 if (!resultExpr.isPureAffine())
3815 return failure();
3816
3817 SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3818 auto flattenResult = flattener.walkPostOrder(resultExpr);
3819 if (failed(flattenResult))
3820 return failure();
3821
3822 // Fail if the flattened expression has local variables.
3823 if (flattener.operandExprStack.back().size() !=
3824 map.getNumDims() + map.getNumSymbols() + 1)
3825 return failure();
3826
3827 flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3828 flattener.operandExprStack.back().end());
3829 }
3830
3831 // Fail if sorting is not necessary.
3832 if (llvm::is_sorted(flattenedExprs))
3833 return failure();
3834
3835 // Reorder the result expressions according to their flattened form.
3836 SmallVector<unsigned> resultPermutation =
3837 llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults()));
3838 llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3839 return flattenedExprs[lhs] < flattenedExprs[rhs];
3840 });
3841 SmallVector<AffineExpr> newExprs;
3842 for (unsigned idx : resultPermutation)
3843 newExprs.push_back(map.getResult(idx));
3844
3845 map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs,
3846 map.getContext());
3847 return success();
3848}
3849
3850/// Canonicalize the affine map result expression order of an affine min/max
3851/// operation.
3852///
3853/// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3854/// expressions and replaces the operation if the order changed.
3855///
3856/// For example, the following operation:
3857///
3858/// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3859///
3860/// Turns into:
3861///
3862/// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3863template <typename T>
3866
3867 LogicalResult matchAndRewrite(T affineOp,
3868 PatternRewriter &rewriter) const override {
3869 AffineMap map = affineOp.getAffineMap();
3870 if (failed(canonicalizeMapExprAndTermOrder(map)))
3871 return failure();
3872 rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3873 return success();
3874 }
3875};
3876
3877template <typename T>
3880
3881 LogicalResult matchAndRewrite(T affineOp,
3882 PatternRewriter &rewriter) const override {
3883 if (affineOp.getMap().getNumResults() != 1)
3884 return failure();
3885 rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3886 affineOp.getOperands());
3887 return success();
3888 }
3889};
3890
3891//===----------------------------------------------------------------------===//
3892// AffineMinOp
3893//===----------------------------------------------------------------------===//
3894//
3895// %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3896//
3897
3898OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3899 return foldMinMaxOp(*this, adaptor.getOperands());
3900}
3901
3902void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3903 MLIRContext *context) {
3904 patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMinOp>,
3905 DeduplicateAffineMinMaxExpressions<AffineMinOp>,
3906 MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3907 CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMinOp>>(
3908 context);
3909}
3910
3911LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); }
3912
3913ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3915}
3916
3917void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3918
3919//===----------------------------------------------------------------------===//
3920// AffineMaxOp
3921//===----------------------------------------------------------------------===//
3922//
3923// %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3924//
3925
3926OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3927 return foldMinMaxOp(*this, adaptor.getOperands());
3928}
3929
3930void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3931 MLIRContext *context) {
3932 patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMaxOp>,
3933 DeduplicateAffineMinMaxExpressions<AffineMaxOp>,
3934 MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3935 CanonicalizeAffineMinMaxOpExprAndTermOrder<AffineMaxOp>>(
3936 context);
3937}
3938
3939LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); }
3940
3941ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3943}
3944
3945void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3946
3947//===----------------------------------------------------------------------===//
3948// AffinePrefetchOp
3949//===----------------------------------------------------------------------===//
3950
3951//
3952// affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3953//
3954ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3955 OperationState &result) {
3956 auto &builder = parser.getBuilder();
3957 auto indexTy = builder.getIndexType();
3958
3959 MemRefType type;
3960 OpAsmParser::UnresolvedOperand memrefInfo;
3961 IntegerAttr hintInfo;
3962 auto i32Type = parser.getBuilder().getIntegerType(32);
3963 StringRef readOrWrite, cacheType;
3964
3965 AffineMapAttr mapAttr;
3966 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
3967 if (parser.parseOperand(memrefInfo) ||
3968 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3969 AffinePrefetchOp::getMapAttrStrName(),
3970 result.attributes) ||
3971 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3972 parser.parseComma() || parser.parseKeyword("locality") ||
3973 parser.parseLess() ||
3974 parser.parseAttribute(hintInfo, i32Type,
3975 AffinePrefetchOp::getLocalityHintAttrStrName(),
3976 result.attributes) ||
3977 parser.parseGreater() || parser.parseComma() ||
3978 parser.parseKeyword(&cacheType) ||
3979 parser.parseOptionalAttrDict(result.attributes) ||
3980 parser.parseColonType(type) ||
3981 parser.resolveOperand(memrefInfo, type, result.operands) ||
3982 parser.resolveOperands(mapOperands, indexTy, result.operands))
3983 return failure();
3984
3985 if (readOrWrite != "read" && readOrWrite != "write")
3986 return parser.emitError(parser.getNameLoc(),
3987 "rw specifier has to be 'read' or 'write'");
3988 result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3989 parser.getBuilder().getBoolAttr(readOrWrite == "write"));
3990
3991 if (cacheType != "data" && cacheType != "instr")
3992 return parser.emitError(parser.getNameLoc(),
3993 "cache type has to be 'data' or 'instr'");
3994
3995 result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3996 parser.getBuilder().getBoolAttr(cacheType == "data"));
3997
3998 return success();
3999}
4000
4001void AffinePrefetchOp::print(OpAsmPrinter &p) {
4002 p << " " << getMemref() << '[';
4003 AffineMapAttr mapAttr =
4004 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
4005 if (mapAttr)
4006 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4007 p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", "
4008 << "locality<" << getLocalityHint() << ">, "
4009 << (getIsDataCache() ? "data" : "instr");
4011 (*this)->getAttrs(),
4012 /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
4013 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
4014 p << " : " << getMemRefType();
4015}
4016
4017LogicalResult AffinePrefetchOp::verify() {
4018 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
4019 if (mapAttr) {
4020 AffineMap map = mapAttr.getValue();
4021 if (map.getNumResults() != getMemRefType().getRank())
4022 return emitOpError("affine.prefetch affine map num results must equal"
4023 " memref rank");
4024 if (map.getNumInputs() + 1 != getNumOperands())
4025 return emitOpError("too few operands");
4026 } else {
4027 if (getNumOperands() != 1)
4028 return emitOpError("too few operands");
4029 }
4030
4031 Region *scope = getAffineScope(*this);
4032 for (auto idx : getMapOperands()) {
4033 if (!isValidAffineIndexOperand(idx, scope))
4034 return emitOpError(
4035 "index must be a valid dimension or symbol identifier");
4036 }
4037 return success();
4038}
4039
4040void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4041 MLIRContext *context) {
4042 // prefetch(memrefcast) -> prefetch
4043 results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
4044}
4045
4046LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
4047 SmallVectorImpl<OpFoldResult> &results) {
4048 /// prefetch(memrefcast) -> prefetch
4049 return memref::foldMemRefCast(*this);
4050}
4051
4052//===----------------------------------------------------------------------===//
4053// AffineParallelOp
4054//===----------------------------------------------------------------------===//
4055
4056void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
4057 TypeRange resultTypes,
4058 ArrayRef<arith::AtomicRMWKind> reductions,
4059 ArrayRef<int64_t> ranges) {
4060 SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
4061 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
4062 return builder.getConstantAffineMap(value);
4063 }));
4064 SmallVector<int64_t> steps(ranges.size(), 1);
4065 build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
4066 /*ubArgs=*/{}, steps);
4067}
4068
4069void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
4070 TypeRange resultTypes,
4071 ArrayRef<arith::AtomicRMWKind> reductions,
4072 ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
4073 ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
4074 ArrayRef<int64_t> steps) {
4075 assert(llvm::all_of(lbMaps,
4076 [lbMaps](AffineMap m) {
4077 return m.getNumDims() == lbMaps[0].getNumDims() &&
4078 m.getNumSymbols() == lbMaps[0].getNumSymbols();
4079 }) &&
4080 "expected all lower bounds maps to have the same number of dimensions "
4081 "and symbols");
4082 assert(llvm::all_of(ubMaps,
4083 [ubMaps](AffineMap m) {
4084 return m.getNumDims() == ubMaps[0].getNumDims() &&
4085 m.getNumSymbols() == ubMaps[0].getNumSymbols();
4086 }) &&
4087 "expected all upper bounds maps to have the same number of dimensions "
4088 "and symbols");
4089 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
4090 "expected lower bound maps to have as many inputs as lower bound "
4091 "operands");
4092 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
4093 "expected upper bound maps to have as many inputs as upper bound "
4094 "operands");
4095
4096 OpBuilder::InsertionGuard guard(builder);
4097 result.addTypes(resultTypes);
4098
4099 // Convert the reductions to integer attributes.
4100 SmallVector<Attribute, 4> reductionAttrs;
4101 for (arith::AtomicRMWKind reduction : reductions)
4102 reductionAttrs.push_back(
4103 builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
4104 result.addAttribute(getReductionsAttrStrName(),
4105 builder.getArrayAttr(reductionAttrs));
4106
4107 // Concatenates maps defined in the same input space (same dimensions and
4108 // symbols), assumes there is at least one map.
4109 auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
4110 SmallVectorImpl<int32_t> &groups) {
4111 if (maps.empty())
4112 return AffineMap::get(builder.getContext());
4113 SmallVector<AffineExpr> exprs;
4114 groups.reserve(groups.size() + maps.size());
4115 exprs.reserve(maps.size());
4116 for (AffineMap m : maps) {
4117 llvm::append_range(exprs, m.getResults());
4118 groups.push_back(m.getNumResults());
4119 }
4120 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
4121 maps[0].getContext());
4122 };
4123
4124 // Set up the bounds.
4125 SmallVector<int32_t> lbGroups, ubGroups;
4126 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
4127 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
4128 result.addAttribute(getLowerBoundsMapAttrStrName(),
4129 AffineMapAttr::get(lbMap));
4130 result.addAttribute(getLowerBoundsGroupsAttrStrName(),
4131 builder.getI32TensorAttr(lbGroups));
4132 result.addAttribute(getUpperBoundsMapAttrStrName(),
4133 AffineMapAttr::get(ubMap));
4134 result.addAttribute(getUpperBoundsGroupsAttrStrName(),
4135 builder.getI32TensorAttr(ubGroups));
4136 result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
4137 result.addOperands(lbArgs);
4138 result.addOperands(ubArgs);
4139
4140 // Create a region and a block for the body.
4141 auto *bodyRegion = result.addRegion();
4142 Block *body = builder.createBlock(bodyRegion);
4143
4144 // Add all the block arguments.
4145 for (unsigned i = 0, e = steps.size(); i < e; ++i)
4146 body->addArgument(IndexType::get(builder.getContext()), result.location);
4147 if (resultTypes.empty())
4148 ensureTerminator(*bodyRegion, builder, result.location);
4149}
4150
4151SmallVector<Region *> AffineParallelOp::getLoopRegions() {
4152 return {&getRegion()};
4153}
4154
4155unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
4156
4157AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
4158 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
4159}
4160
4161AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
4162 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
4163}
4164
4165AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
4166 auto values = getLowerBoundsGroups().getValues<int32_t>();
4167 unsigned start = 0;
4168 for (unsigned i = 0; i < pos; ++i)
4169 start += values[i];
4170 return getLowerBoundsMap().getSliceMap(start, values[pos]);
4171}
4172
4173AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
4174 auto values = getUpperBoundsGroups().getValues<int32_t>();
4175 unsigned start = 0;
4176 for (unsigned i = 0; i < pos; ++i)
4177 start += values[i];
4178 return getUpperBoundsMap().getSliceMap(start, values[pos]);
4179}
4180
4181AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
4182 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
4183}
4184
4185AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
4186 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
4187}
4188
4189std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
4190 if (hasMinMaxBounds())
4191 return std::nullopt;
4192
4193 // Try to convert all the ranges to constant expressions.
4194 SmallVector<int64_t, 8> out;
4195 AffineValueMap rangesValueMap;
4196 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
4197 &rangesValueMap);
4198 out.reserve(rangesValueMap.getNumResults());
4199 for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
4200 auto expr = rangesValueMap.getResult(i);
4201 auto cst = dyn_cast<AffineConstantExpr>(expr);
4202 if (!cst)
4203 return std::nullopt;
4204 out.push_back(cst.getValue());
4205 }
4206 return out;
4207}
4208
4209Block *AffineParallelOp::getBody() { return &getRegion().front(); }
4210
4211OpBuilder AffineParallelOp::getBodyBuilder() {
4212 return OpBuilder(getBody(), std::prev(getBody()->end()));
4213}
4214
4215void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
4216 assert(lbOperands.size() == map.getNumInputs() &&
4217 "operands to map must match number of inputs");
4218
4219 auto ubOperands = getUpperBoundsOperands();
4220
4221 SmallVector<Value, 4> newOperands(lbOperands);
4222 newOperands.append(ubOperands.begin(), ubOperands.end());
4223 (*this)->setOperands(newOperands);
4224
4225 setLowerBoundsMapAttr(AffineMapAttr::get(map));
4226}
4227
4228void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
4229 assert(ubOperands.size() == map.getNumInputs() &&
4230 "operands to map must match number of inputs");
4231
4232 SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
4233 newOperands.append(ubOperands.begin(), ubOperands.end());
4234 (*this)->setOperands(newOperands);
4235
4236 setUpperBoundsMapAttr(AffineMapAttr::get(map));
4237}
4238
4239void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
4240 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
4241}
4242
4243// check whether resultType match op or not in affine.parallel
4245 arith::AtomicRMWKind op) {
4246 switch (op) {
4247 case arith::AtomicRMWKind::addf:
4248 return isa<FloatType>(resultType);
4249 case arith::AtomicRMWKind::addi:
4250 return isa<IntegerType>(resultType);
4251 case arith::AtomicRMWKind::assign:
4252 return true;
4253 case arith::AtomicRMWKind::mulf:
4254 return isa<FloatType>(resultType);
4255 case arith::AtomicRMWKind::muli:
4256 return isa<IntegerType>(resultType);
4257 case arith::AtomicRMWKind::maximumf:
4258 return isa<FloatType>(resultType);
4259 case arith::AtomicRMWKind::minimumf:
4260 return isa<FloatType>(resultType);
4261 case arith::AtomicRMWKind::maxs: {
4262 auto intType = dyn_cast<IntegerType>(resultType);
4263 return intType && intType.isSigned();
4264 }
4265 case arith::AtomicRMWKind::mins: {
4266 auto intType = dyn_cast<IntegerType>(resultType);
4267 return intType && intType.isSigned();
4268 }
4269 case arith::AtomicRMWKind::maxu: {
4270 auto intType = dyn_cast<IntegerType>(resultType);
4271 return intType && intType.isUnsigned();
4272 }
4273 case arith::AtomicRMWKind::minu: {
4274 auto intType = dyn_cast<IntegerType>(resultType);
4275 return intType && intType.isUnsigned();
4276 }
4277 case arith::AtomicRMWKind::ori:
4278 return isa<IntegerType>(resultType);
4279 case arith::AtomicRMWKind::andi:
4280 return isa<IntegerType>(resultType);
4281 default:
4282 return false;
4283 }
4284}
4285
4286LogicalResult AffineParallelOp::verify() {
4287 auto numDims = getNumDims();
4288 if (getLowerBoundsGroups().getNumElements() != numDims ||
4289 getUpperBoundsGroups().getNumElements() != numDims ||
4290 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
4291 return emitOpError() << "the number of region arguments ("
4292 << getBody()->getNumArguments()
4293 << ") and the number of map groups for lower ("
4294 << getLowerBoundsGroups().getNumElements()
4295 << ") and upper bound ("
4296 << getUpperBoundsGroups().getNumElements()
4297 << "), and the number of steps (" << getSteps().size()
4298 << ") must all match";
4299 }
4300
4301 unsigned expectedNumLBResults = 0;
4302 for (APInt v : getLowerBoundsGroups()) {
4303 unsigned results = v.getZExtValue();
4304 if (results == 0)
4305 return emitOpError()
4306 << "expected lower bound map to have at least one result";
4307 expectedNumLBResults += results;
4308 }
4309 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
4310 return emitOpError() << "expected lower bounds map to have "
4311 << expectedNumLBResults << " results";
4312 unsigned expectedNumUBResults = 0;
4313 for (APInt v : getUpperBoundsGroups()) {
4314 unsigned results = v.getZExtValue();
4315 if (results == 0)
4316 return emitOpError()
4317 << "expected upper bound map to have at least one result";
4318 expectedNumUBResults += results;
4319 }
4320 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
4321 return emitOpError() << "expected upper bounds map to have "
4322 << expectedNumUBResults << " results";
4323
4324 if (getReductions().size() != getNumResults())
4325 return emitOpError("a reduction must be specified for each output");
4326
4327 // Verify reduction ops are all valid and each result type matches reduction
4328 // ops
4329 for (auto it : llvm::enumerate((getReductions()))) {
4330 Attribute attr = it.value();
4331 auto intAttr = dyn_cast<IntegerAttr>(attr);
4332 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
4333 return emitOpError("invalid reduction attribute");
4334 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
4335 if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
4336 return emitOpError("result type cannot match reduction attribute");
4337 }
4338
4339 // Verify that the bound operands are valid dimension/symbols.
4340 /// Lower bounds.
4341 if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
4342 getLowerBoundsMap().getNumDims())))
4343 return failure();
4344 /// Upper bounds.
4345 if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
4346 getUpperBoundsMap().getNumDims())))
4347 return failure();
4348 return success();
4349}
4350
4352 SmallVector<Value, 4> newOperands{operands};
4353 auto newMap = getAffineMap();
4354 composeAffineMapAndOperands(&newMap, &newOperands);
4355 if (newMap == getAffineMap() && newOperands == operands)
4356 return failure();
4357 reset(newMap, newOperands);
4358 return success();
4359}
4360
4361/// Canonicalize the bounds of the given loop.
4362static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
4363 AffineValueMap lb = op.getLowerBoundsValueMap();
4364 bool lbCanonicalized = succeeded(lb.canonicalize());
4365
4366 AffineValueMap ub = op.getUpperBoundsValueMap();
4367 bool ubCanonicalized = succeeded(ub.canonicalize());
4368
4369 // Any canonicalization change always leads to updated map(s).
4370 if (!lbCanonicalized && !ubCanonicalized)
4371 return failure();
4372
4373 if (lbCanonicalized)
4374 op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
4375 if (ubCanonicalized)
4376 op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
4377
4378 return success();
4379}
4380
4381LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4382 SmallVectorImpl<OpFoldResult> &results) {
4383 return canonicalizeLoopBounds(*this);
4384}
4385
4386/// Prints a lower(upper) bound of an affine parallel loop with max(min)
4387/// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
4388/// identifies which of the those expressions form max/min groups. `operands`
4389/// are the SSA values of dimensions and symbols and `keyword` is either "min"
4390/// or "max".
4391static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
4392 DenseIntElementsAttr group, ValueRange operands,
4393 StringRef keyword) {
4394 AffineMap map = mapAttr.getValue();
4395 unsigned numDims = map.getNumDims();
4396 ValueRange dimOperands = operands.take_front(numDims);
4397 ValueRange symOperands = operands.drop_front(numDims);
4398 unsigned start = 0;
4399 for (llvm::APInt groupSize : group) {
4400 if (start != 0)
4401 p << ", ";
4402
4403 unsigned size = groupSize.getZExtValue();
4404 if (size == 1) {
4405 p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
4406 ++start;
4407 } else {
4408 p << keyword << '(';
4409 AffineMap submap = map.getSliceMap(start, size);
4410 p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
4411 p << ')';
4412 start += size;
4413 }
4414 }
4415}
4416
4417void AffineParallelOp::print(OpAsmPrinter &p) {
4418 p << " (" << getBody()->getArguments() << ") = (";
4419 printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
4420 getLowerBoundsOperands(), "max");
4421 p << ") to (";
4422 printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
4423 getUpperBoundsOperands(), "min");
4424 p << ')';
4425 SmallVector<int64_t, 8> steps = getSteps();
4426 bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4427 if (!elideSteps) {
4428 p << " step (";
4429 llvm::interleaveComma(steps, p);
4430 p << ')';
4431 }
4432 if (getNumResults()) {
4433 p << " reduce (";
4434 llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4435 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4436 llvm::cast<IntegerAttr>(attr).getInt());
4437 p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4438 });
4439 p << ") -> (" << getResultTypes() << ")";
4440 }
4441
4442 p << ' ';
4443 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4444 /*printBlockTerminators=*/getNumResults());
4446 (*this)->getAttrs(),
4447 /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4448 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4449 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4450 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4451 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4452 AffineParallelOp::getStepsAttrStrName()});
4453}
4454
4455/// Given a list of lists of parsed operands, populates `uniqueOperands` with
4456/// unique operands. Also populates `replacements with affine expressions of
4457/// `kind` that can be used to update affine maps previously accepting a
4458/// `operands` to accept `uniqueOperands` instead.
4459static ParseResult deduplicateAndResolveOperands(
4460 OpAsmParser &parser,
4461 ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> operands,
4462 SmallVectorImpl<Value> &uniqueOperands,
4463 SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
4464 assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
4465 "expected operands to be dim or symbol expression");
4466
4467 Type indexType = parser.getBuilder().getIndexType();
4468 for (const auto &list : operands) {
4469 SmallVector<Value> valueOperands;
4470 if (parser.resolveOperands(list, indexType, valueOperands))
4471 return failure();
4472 for (Value operand : valueOperands) {
4473 unsigned pos = std::distance(uniqueOperands.begin(),
4474 llvm::find(uniqueOperands, operand));
4475 if (pos == uniqueOperands.size())
4476 uniqueOperands.push_back(operand);
4477 replacements.push_back(
4478 kind == AffineExprKind::DimId
4479 ? getAffineDimExpr(pos, parser.getContext())
4480 : getAffineSymbolExpr(pos, parser.getContext()));
4481 }
4482 }
4483 return success();
4484}
4485
4486namespace {
4487enum class MinMaxKind { Min, Max };
4488} // namespace
4489
4490/// Parses an affine map that can contain a min/max for groups of its results,
4491/// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4492/// `result` attributes with the map (flat list of expressions) and the grouping
4493/// (list of integers that specify how many expressions to put into each
4494/// min/max) attributes. Deduplicates repeated operands.
4495///
4496/// parallel-bound ::= `(` parallel-group-list `)`
4497/// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
4498/// parallel-group ::= simple-group | min-max-group
4499/// simple-group ::= expr-of-ssa-ids
4500/// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4501/// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4502///
4503/// Examples:
4504/// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4505/// (%0, max(%1 - 2 * %2))
4506static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
4507 OperationState &result,
4508 MinMaxKind kind) {
4509 // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4510 // see: https://reviews.llvm.org/D134227#3821753
4511 const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4512
4513 StringRef mapName = kind == MinMaxKind::Min
4514 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4515 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4516 StringRef groupsName =
4517 kind == MinMaxKind::Min
4518 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4519 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4520
4521 if (failed(parser.parseLParen()))
4522 return failure();
4523
4524 if (succeeded(parser.parseOptionalRParen())) {
4525 result.addAttribute(
4526 mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4527 result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
4528 return success();
4529 }
4530
4531 SmallVector<AffineExpr> flatExprs;
4532 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatDimOperands;
4533 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatSymOperands;
4534 SmallVector<int32_t> numMapsPerGroup;
4535 SmallVector<OpAsmParser::UnresolvedOperand> mapOperands;
4536 auto parseOperands = [&]() {
4537 if (succeeded(parser.parseOptionalKeyword(
4538 kind == MinMaxKind::Min ? "min" : "max"))) {
4539 mapOperands.clear();
4540 AffineMapAttr map;
4541 if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName,
4542 result.attributes,
4544 return failure();
4545 result.attributes.erase(tmpAttrStrName);
4546 llvm::append_range(flatExprs, map.getValue().getResults());
4547 auto operandsRef = llvm::ArrayRef(mapOperands);
4548 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4549 SmallVector<OpAsmParser::UnresolvedOperand> dims(dimsRef);
4550 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4551 SmallVector<OpAsmParser::UnresolvedOperand> syms(symsRef);
4552 flatDimOperands.append(map.getValue().getNumResults(), dims);
4553 flatSymOperands.append(map.getValue().getNumResults(), syms);
4554 numMapsPerGroup.push_back(map.getValue().getNumResults());
4555 } else {
4556 if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
4557 flatSymOperands.emplace_back(),
4558 flatExprs.emplace_back())))
4559 return failure();
4560 numMapsPerGroup.push_back(1);
4561 }
4562 return success();
4563 };
4564 if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
4565 return failure();
4566
4567 unsigned totalNumDims = 0;
4568 unsigned totalNumSyms = 0;
4569 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4570 unsigned numDims = flatDimOperands[i].size();
4571 unsigned numSyms = flatSymOperands[i].size();
4572 flatExprs[i] = flatExprs[i]
4573 .shiftDims(numDims, totalNumDims)
4574 .shiftSymbols(numSyms, totalNumSyms);
4575 totalNumDims += numDims;
4576 totalNumSyms += numSyms;
4577 }
4578
4579 // Deduplicate map operands.
4580 SmallVector<Value> dimOperands, symOperands;
4581 SmallVector<AffineExpr> dimRplacements, symRepacements;
4582 if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
4583 dimRplacements, AffineExprKind::DimId) ||
4584 deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
4585 symRepacements, AffineExprKind::SymbolId))
4586 return failure();
4587
4588 result.operands.append(dimOperands.begin(), dimOperands.end());
4589 result.operands.append(symOperands.begin(), symOperands.end());
4590
4591 Builder &builder = parser.getBuilder();
4592 auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4593 parser.getContext());
4594 flatMap = flatMap.replaceDimsAndSymbols(
4595 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4596
4597 result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4598 result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
4599 return success();
4600}
4601
4602//
4603// operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4604// `to` parallel-bound steps? region attr-dict?
4605// steps ::= `steps` `(` integer-literals `)`
4606//
4607ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4608 OperationState &result) {
4609 auto &builder = parser.getBuilder();
4610 auto indexType = builder.getIndexType();
4611 SmallVector<OpAsmParser::Argument, 4> ivs;
4613 parser.parseEqual() ||
4614 parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
4615 parser.parseKeyword("to") ||
4616 parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
4617 return failure();
4618
4619 AffineMapAttr stepsMapAttr;
4620 NamedAttrList stepsAttrs;
4621 SmallVector<OpAsmParser::UnresolvedOperand, 4> stepsMapOperands;
4622 if (failed(parser.parseOptionalKeyword("step"))) {
4623 SmallVector<int64_t, 4> steps(ivs.size(), 1);
4624 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4625 builder.getI64ArrayAttr(steps));
4626 } else {
4627 if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
4628 AffineParallelOp::getStepsAttrStrName(),
4629 stepsAttrs,
4631 return failure();
4632
4633 // Convert steps from an AffineMap into an I64ArrayAttr.
4634 SmallVector<int64_t, 4> steps;
4635 auto stepsMap = stepsMapAttr.getValue();
4636 for (const auto &result : stepsMap.getResults()) {
4637 auto constExpr = dyn_cast<AffineConstantExpr>(result);
4638 if (!constExpr)
4639 return parser.emitError(parser.getNameLoc(),
4640 "steps must be constant integers");
4641 steps.push_back(constExpr.getValue());
4642 }
4643 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4644 builder.getI64ArrayAttr(steps));
4645 }
4646
4647 // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4648 // quoted strings are a member of the enum AtomicRMWKind.
4649 SmallVector<Attribute, 4> reductions;
4650 if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4651 if (parser.parseLParen())
4652 return failure();
4653 auto parseAttributes = [&]() -> ParseResult {
4654 // Parse a single quoted string via the attribute parsing, and then
4655 // verify it is a member of the enum and convert to it's integer
4656 // representation.
4657 StringAttr attrVal;
4658 NamedAttrList attrStorage;
4659 auto loc = parser.getCurrentLocation();
4660 if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
4661 attrStorage))
4662 return failure();
4663 std::optional<arith::AtomicRMWKind> reduction =
4664 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4665 if (!reduction)
4666 return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4667 reductions.push_back(
4668 builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4669 // While we keep getting commas, keep parsing.
4670 return success();
4671 };
4672 if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
4673 return failure();
4674 }
4675 result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4676 builder.getArrayAttr(reductions));
4677
4678 // Parse return types of reductions (if any)
4679 if (parser.parseOptionalArrowTypeList(result.types))
4680 return failure();
4681
4682 // Now parse the body.
4683 Region *body = result.addRegion();
4684 for (auto &iv : ivs)
4685 iv.type = indexType;
4686 if (parser.parseRegion(*body, ivs) ||
4687 parser.parseOptionalAttrDict(result.attributes))
4688 return failure();
4689
4690 // Add a terminator if none was parsed.
4691 AffineParallelOp::ensureTerminator(*body, builder, result.location);
4692 return success();
4693}
4694
4695//===----------------------------------------------------------------------===//
4696// AffineYieldOp
4697//===----------------------------------------------------------------------===//
4698
4699LogicalResult AffineYieldOp::verify() {
4700 auto *parentOp = (*this)->getParentOp();
4701 auto results = parentOp->getResults();
4702 auto operands = getOperands();
4703
4704 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4705 return emitOpError() << "only terminates affine.if/for/parallel regions";
4706 if (parentOp->getNumResults() != getNumOperands())
4707 return emitOpError() << "parent of yield must have same number of "
4708 "results as the yield operands";
4709 for (auto it : llvm::zip(results, operands)) {
4710 if (std::get<0>(it).getType() != std::get<1>(it).getType())
4711 return emitOpError() << "types mismatch between yield op and its parent";
4712 }
4713
4714 return success();
4715}
4716
4717//===----------------------------------------------------------------------===//
4718// AffineVectorLoadOp
4719//===----------------------------------------------------------------------===//
4720
4721void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4722 VectorType resultType, AffineMap map,
4723 ValueRange operands) {
4724 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4725 result.addOperands(operands);
4726 if (map)
4727 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4728 result.types.push_back(resultType);
4729}
4730
4731void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4732 VectorType resultType, Value memref,
4733 AffineMap map, ValueRange mapOperands) {
4734 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4735 result.addOperands(memref);
4736 result.addOperands(mapOperands);
4737 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4738 result.types.push_back(resultType);
4739}
4740
4741void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4742 VectorType resultType, Value memref,
4744 auto memrefType = llvm::cast<MemRefType>(memref.getType());
4745 int64_t rank = memrefType.getRank();
4746 // Create identity map for memrefs with at least one dimension or () -> ()
4747 // for zero-dimensional memrefs.
4748 auto map =
4749 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4750 build(builder, result, resultType, memref, map, indices);
4751}
4752
4753void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4754 MLIRContext *context) {
4755 results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4756}
4757
4758ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4759 OperationState &result) {
4760 auto &builder = parser.getBuilder();
4761 auto indexTy = builder.getIndexType();
4762
4763 MemRefType memrefType;
4764 VectorType resultType;
4765 OpAsmParser::UnresolvedOperand memrefInfo;
4766 AffineMapAttr mapAttr;
4767 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
4768 return failure(
4769 parser.parseOperand(memrefInfo) ||
4770 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4771 AffineVectorLoadOp::getMapAttrStrName(),
4772 result.attributes) ||
4773 parser.parseOptionalAttrDict(result.attributes) ||
4774 parser.parseColonType(memrefType) || parser.parseComma() ||
4775 parser.parseType(resultType) ||
4776 parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4777 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
4778 parser.addTypeToList(resultType, result.types));
4779}
4780
4781void AffineVectorLoadOp::print(OpAsmPrinter &p) {
4782 p << " " << getMemRef() << '[';
4783 if (AffineMapAttr mapAttr =
4784 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4785 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4786 p << ']';
4787 p.printOptionalAttrDict((*this)->getAttrs(),
4788 /*elidedAttrs=*/{getMapAttrStrName()});
4789 p << " : " << getMemRefType() << ", " << getType();
4790}
4791
4792/// Verify common invariants of affine.vector_load and affine.vector_store.
4793static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4794 VectorType vectorType) {
4795 // Check that memref and vector element types match.
4796 if (memrefType.getElementType() != vectorType.getElementType())
4797 return op->emitOpError(
4798 "requires memref and vector types of the same elemental type");
4799 return success();
4800}
4801
4802LogicalResult AffineVectorLoadOp::verify() {
4803 MemRefType memrefType = getMemRefType();
4805 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4806 getMapOperands(), memrefType,
4807 /*numIndexOperands=*/getNumOperands() - 1)))
4808 return failure();
4809
4810 if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
4811 return failure();
4812
4813 return success();
4814}
4815
4816//===----------------------------------------------------------------------===//
4817// AffineVectorStoreOp
4818//===----------------------------------------------------------------------===//
4819
4820void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4821 Value valueToStore, Value memref, AffineMap map,
4822 ValueRange mapOperands) {
4823 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4824 result.addOperands(valueToStore);
4825 result.addOperands(memref);
4826 result.addOperands(mapOperands);
4827 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4828}
4829
4830// Use identity map.
4831void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4832 Value valueToStore, Value memref,
4834 auto memrefType = llvm::cast<MemRefType>(memref.getType());
4835 int64_t rank = memrefType.getRank();
4836 // Create identity map for memrefs with at least one dimension or () -> ()
4837 // for zero-dimensional memrefs.
4838 auto map =
4839 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4840 build(builder, result, valueToStore, memref, map, indices);
4841}
4842void AffineVectorStoreOp::getCanonicalizationPatterns(
4843 RewritePatternSet &results, MLIRContext *context) {
4844 results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4845}
4846
4847ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4848 OperationState &result) {
4849 auto indexTy = parser.getBuilder().getIndexType();
4850
4851 MemRefType memrefType;
4852 VectorType resultType;
4853 OpAsmParser::UnresolvedOperand storeValueInfo;
4854 OpAsmParser::UnresolvedOperand memrefInfo;
4855 AffineMapAttr mapAttr;
4856 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
4857 return failure(
4858 parser.parseOperand(storeValueInfo) || parser.parseComma() ||
4859 parser.parseOperand(memrefInfo) ||
4860 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4861 AffineVectorStoreOp::getMapAttrStrName(),
4862 result.attributes) ||
4863 parser.parseOptionalAttrDict(result.attributes) ||
4864 parser.parseColonType(memrefType) || parser.parseComma() ||
4865 parser.parseType(resultType) ||
4866 parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
4867 parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4868 parser.resolveOperands(mapOperands, indexTy, result.operands));
4869}
4870
4871void AffineVectorStoreOp::print(OpAsmPrinter &p) {
4872 p << " " << getValueToStore();
4873 p << ", " << getMemRef() << '[';
4874 if (AffineMapAttr mapAttr =
4875 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4876 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4877 p << ']';
4878 p.printOptionalAttrDict((*this)->getAttrs(),
4879 /*elidedAttrs=*/{getMapAttrStrName()});
4880 p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4881}
4882
4883LogicalResult AffineVectorStoreOp::verify() {
4884 MemRefType memrefType = getMemRefType();
4886 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4887 getMapOperands(), memrefType,
4888 /*numIndexOperands=*/getNumOperands() - 2)))
4889 return failure();
4890
4891 if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4892 return failure();
4893
4894 return success();
4895}
4896
4897//===----------------------------------------------------------------------===//
4898// DelinearizeIndexOp
4899//===----------------------------------------------------------------------===//
4900
4901void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4902 OperationState &odsState,
4903 Value linearIndex, ValueRange dynamicBasis,
4904 ArrayRef<int64_t> staticBasis,
4905 bool hasOuterBound) {
4906 SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
4907 : staticBasis.size() + 1,
4908 linearIndex.getType());
4909 build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4910 staticBasis);
4911}
4912
4913void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4914 OperationState &odsState,
4915 Value linearIndex, ValueRange basis,
4916 bool hasOuterBound) {
4917 if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
4918 hasOuterBound = false;
4919 basis = basis.drop_front();
4920 }
4921 SmallVector<Value> dynamicBasis;
4922 SmallVector<int64_t> staticBasis;
4923 dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4924 staticBasis);
4925 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4926 hasOuterBound);
4927}
4928
4929void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4930 OperationState &odsState,
4931 Value linearIndex,
4932 ArrayRef<OpFoldResult> basis,
4933 bool hasOuterBound) {
4934 if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
4935 hasOuterBound = false;
4936 basis = basis.drop_front();
4937 }
4938 SmallVector<Value> dynamicBasis;
4939 SmallVector<int64_t> staticBasis;
4940 dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
4941 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4942 hasOuterBound);
4943}
4944
4945void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4946 OperationState &odsState,
4947 Value linearIndex, ArrayRef<int64_t> basis,
4948 bool hasOuterBound) {
4949 build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound);
4950}
4951
4952LogicalResult AffineDelinearizeIndexOp::verify() {
4953 ArrayRef<int64_t> staticBasis = getStaticBasis();
4954 if (getNumResults() != staticBasis.size() &&
4955 getNumResults() != staticBasis.size() + 1)
4956 return emitOpError("should return an index for each basis element and up "
4957 "to one extra index");
4958
4959 auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4960 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4961 return emitOpError(
4962 "mismatch between dynamic and static basis (kDynamic marker but no "
4963 "corresponding dynamic basis entry) -- this can only happen due to an "
4964 "incorrect fold/rewrite");
4965
4966 if (!llvm::all_of(staticBasis, [](int64_t v) {
4967 return v > 0 || ShapedType::isDynamic(v);
4968 }))
4969 return emitOpError("no basis element may be statically non-positive");
4970
4971 return success();
4972}
4973
4974/// Given mixed basis of affine.delinearize_index/linearize_index replace
4975/// constant SSA values with the constant integer value and return the new
4976/// static basis. In case no such candidate for replacement exists, this utility
4977/// returns std::nullopt.
4978static std::optional<SmallVector<int64_t>>
4980 MutableOperandRange mutableDynamicBasis,
4981 ArrayRef<Attribute> dynamicBasis) {
4982 uint64_t dynamicBasisIndex = 0;
4983 for (OpFoldResult basis : dynamicBasis) {
4984 if (basis) {
4985 mutableDynamicBasis.erase(dynamicBasisIndex);
4986 } else {
4987 ++dynamicBasisIndex;
4988 }
4989 }
4990
4991 // No constant SSA value exists.
4992 if (dynamicBasisIndex == dynamicBasis.size())
4993 return std::nullopt;
4994
4995 SmallVector<int64_t> staticBasis;
4996 for (OpFoldResult basis : mixedBasis) {
4997 std::optional<int64_t> basisVal = getConstantIntValue(basis);
4998 if (!basisVal)
4999 staticBasis.push_back(ShapedType::kDynamic);
5000 else
5001 staticBasis.push_back(*basisVal);
5002 }
5003
5004 return staticBasis;
5005}
5006
5007LogicalResult
5008AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
5009 SmallVectorImpl<OpFoldResult> &result) {
5010 std::optional<SmallVector<int64_t>> maybeStaticBasis =
5011 foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
5012 adaptor.getDynamicBasis());
5013 if (maybeStaticBasis) {
5014 setStaticBasis(*maybeStaticBasis);
5015 return success();
5016 }
5017 // If we won't be doing any division or modulo (no basis or the one basis
5018 // element is purely advisory), simply return the input value.
5019 if (getNumResults() == 1) {
5020 result.push_back(getLinearIndex());
5021 return success();
5022 }
5023
5024 if (adaptor.getLinearIndex() == nullptr)
5025 return failure();
5026
5027 if (!adaptor.getDynamicBasis().empty())
5028 return failure();
5029
5030 int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
5031 Type attrType = getLinearIndex().getType();
5032
5033 ArrayRef<int64_t> staticBasis = getStaticBasis();
5034 if (hasOuterBound())
5035 staticBasis = staticBasis.drop_front();
5036 for (int64_t modulus : llvm::reverse(staticBasis)) {
5037 result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
5038 highPart = llvm::divideFloorSigned(highPart, modulus);
5039 }
5040 result.push_back(IntegerAttr::get(attrType, highPart));
5041 std::reverse(result.begin(), result.end());
5042 return success();
5043}
5044
5045SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
5046 OpBuilder builder(getContext());
5047 if (hasOuterBound()) {
5048 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5049 return getMixedValues(getStaticBasis().drop_front(),
5050 getDynamicBasis().drop_front(), builder);
5051
5052 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5053 builder);
5054 }
5055
5056 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5057}
5058
5059SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
5060 SmallVector<OpFoldResult> ret = getMixedBasis();
5061 if (!hasOuterBound())
5062 ret.insert(ret.begin(), OpFoldResult());
5063 return ret;
5064}
5065
5066namespace {
5067
5068// Drops delinearization indices that correspond to unit-extent basis
5069struct DropUnitExtentBasis
5070 : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5072
5073 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5074 PatternRewriter &rewriter) const override {
5075 SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
5076 std::optional<Value> zero = std::nullopt;
5077 Location loc = delinearizeOp->getLoc();
5078 auto getZero = [&]() -> Value {
5079 if (!zero)
5080 zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
5081 return zero.value();
5082 };
5083
5084 // Replace all indices corresponding to unit-extent basis with 0.
5085 // Remaining basis can be used to get a new `affine.delinearize_index` op.
5086 SmallVector<OpFoldResult> newBasis;
5087 for (auto [index, basis] :
5088 llvm::enumerate(delinearizeOp.getPaddedBasis())) {
5089 std::optional<int64_t> basisVal =
5090 basis ? getConstantIntValue(basis) : std::nullopt;
5091 if (basisVal == 1)
5092 replacements[index] = getZero();
5093 else
5094 newBasis.push_back(basis);
5095 }
5096
5097 if (newBasis.size() == delinearizeOp.getNumResults())
5098 return rewriter.notifyMatchFailure(delinearizeOp,
5099 "no unit basis elements");
5100
5101 if (!newBasis.empty()) {
5102 // Will drop the leading nullptr from `basis` if there was no outer bound.
5103 auto newDelinearizeOp = affine::AffineDelinearizeIndexOp::create(
5104 rewriter, loc, delinearizeOp.getLinearIndex(), newBasis);
5105 int newIndex = 0;
5106 // Map back the new delinearized indices to the values they replace.
5107 for (auto &replacement : replacements) {
5108 if (replacement)
5109 continue;
5110 replacement = newDelinearizeOp->getResult(newIndex++);
5111 }
5112 }
5113
5114 rewriter.replaceOp(delinearizeOp, replacements);
5115 return success();
5116 }
5117};
5118
5119/// If a `affine.delinearize_index`'s input is a `affine.linearize_index
5120/// disjoint` and the two operations end with the same basis elements,
5121/// cancel those parts of the operations out because they are inverses
5122/// of each other.
5123///
5124/// If the operations have the same basis, cancel them entirely.
5125///
5126/// The `disjoint` flag is needed on the `affine.linearize_index` because
5127/// otherwise, there is no guarantee that the inputs to the linearization are
5128/// in-bounds the way the outputs of the delinearization would be.
5129struct CancelDelinearizeOfLinearizeDisjointExactTail
5130 : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5132
5133 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5134 PatternRewriter &rewriter) const override {
5135 auto linearizeOp = delinearizeOp.getLinearIndex()
5136 .getDefiningOp<affine::AffineLinearizeIndexOp>();
5137 if (!linearizeOp)
5138 return rewriter.notifyMatchFailure(delinearizeOp,
5139 "index doesn't come from linearize");
5140
5141 if (!linearizeOp.getDisjoint())
5142 return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
5143
5144 ValueRange linearizeIns = linearizeOp.getMultiIndex();
5145 // Note: we use the full basis so we don't lose outer bounds later.
5146 SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
5147 SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
5148 size_t numMatches = 0;
5149 for (auto [linSize, delinSize] : llvm::zip(
5150 llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
5151 if (linSize != delinSize)
5152 break;
5153 ++numMatches;
5154 }
5155
5156 if (numMatches == 0)
5157 return rewriter.notifyMatchFailure(
5158 delinearizeOp, "final basis element doesn't match linearize");
5159
5160 // The easy case: everything lines up and the basis match sup completely.
5161 if (numMatches == linearizeBasis.size() &&
5162 numMatches == delinearizeBasis.size() &&
5163 linearizeIns.size() == delinearizeOp.getNumResults()) {
5164 rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
5165 return success();
5166 }
5167
5168 Value newLinearize = affine::AffineLinearizeIndexOp::create(
5169 rewriter, linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
5170 ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
5171 linearizeOp.getDisjoint());
5172 auto newDelinearize = affine::AffineDelinearizeIndexOp::create(
5173 rewriter, delinearizeOp.getLoc(), newLinearize,
5174 ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
5175 delinearizeOp.hasOuterBound());
5176 SmallVector<Value> mergedResults(newDelinearize.getResults());
5177 mergedResults.append(linearizeIns.take_back(numMatches).begin(),
5178 linearizeIns.take_back(numMatches).end());
5179 rewriter.replaceOp(delinearizeOp, mergedResults);
5180 return success();
5181 }
5182};
5183
5184/// If the input to a delinearization is a disjoint linearization, and the
5185/// last k > 1 components of the delinearization basis multiply to the
5186/// last component of the linearization basis, break the linearization and
5187/// delinearization into two parts, peeling off the last input to linearization.
5188///
5189/// For example:
5190/// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
5191/// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
5192/// becomes
5193/// %0 = affine.linearize_index [%z, %y] by (3, 2) : index
5194/// %1:2 = affine.delinearize_index %0 by (2, 3) : index
5195/// %2:2 = affine.delinearize_index %x by (8, 4) : index
5196/// where the original %1:4 is replaced by %1:2 ++ %2:2
5197struct SplitDelinearizeSpanningLastLinearizeArg final
5198 : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5200
5201 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5202 PatternRewriter &rewriter) const override {
5203 auto linearizeOp = delinearizeOp.getLinearIndex()
5204 .getDefiningOp<affine::AffineLinearizeIndexOp>();
5205 if (!linearizeOp)
5206 return rewriter.notifyMatchFailure(delinearizeOp,
5207 "index doesn't come from linearize");
5208
5209 if (!linearizeOp.getDisjoint())
5210 return rewriter.notifyMatchFailure(linearizeOp,
5211 "linearize isn't disjoint");
5212
5213 int64_t target = linearizeOp.getStaticBasis().back();
5214 if (ShapedType::isDynamic(target))
5215 return rewriter.notifyMatchFailure(
5216 linearizeOp, "linearize ends with dynamic basis value");
5217
5218 int64_t sizeToSplit = 1;
5219 size_t elemsToSplit = 0;
5220 ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
5221 for (int64_t basisElem : llvm::reverse(basis)) {
5222 if (ShapedType::isDynamic(basisElem))
5223 return rewriter.notifyMatchFailure(
5224 delinearizeOp, "dynamic basis element while scanning for split");
5225 sizeToSplit *= basisElem;
5226 elemsToSplit += 1;
5227
5228 if (sizeToSplit > target)
5229 return rewriter.notifyMatchFailure(delinearizeOp,
5230 "overshot last argument size");
5231 if (sizeToSplit == target)
5232 break;
5233 }
5234
5235 if (sizeToSplit < target)
5236 return rewriter.notifyMatchFailure(
5237 delinearizeOp, "product of known basis elements doesn't exceed last "
5238 "linearize argument");
5239
5240 if (elemsToSplit < 2)
5241 return rewriter.notifyMatchFailure(
5242 delinearizeOp,
5243 "need at least two elements to form the basis product");
5244
5245 Value linearizeWithoutBack = affine::AffineLinearizeIndexOp::create(
5246 rewriter, linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
5247 linearizeOp.getDynamicBasis(), linearizeOp.getStaticBasis().drop_back(),
5248 linearizeOp.getDisjoint());
5249 auto delinearizeWithoutSplitPart = affine::AffineDelinearizeIndexOp::create(
5250 rewriter, delinearizeOp.getLoc(), linearizeWithoutBack,
5251 delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
5252 delinearizeOp.hasOuterBound());
5253 auto delinearizeBack = affine::AffineDelinearizeIndexOp::create(
5254 rewriter, delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
5255 basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
5256 SmallVector<Value> results = llvm::to_vector(
5257 llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
5258 delinearizeBack.getResults()));
5259 rewriter.replaceOp(delinearizeOp, results);
5260
5261 return success();
5262 }
5263};
5264} // namespace
5265
5266void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
5267 RewritePatternSet &patterns, MLIRContext *context) {
5268 patterns
5269 .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
5270 DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
5271 context);
5272}
5273
5274//===----------------------------------------------------------------------===//
5275// LinearizeIndexOp
5276//===----------------------------------------------------------------------===//
5277
5278void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5279 OperationState &odsState,
5280 ValueRange multiIndex, ValueRange basis,
5281 bool disjoint) {
5282 if (!basis.empty() && basis.front() == Value())
5283 basis = basis.drop_front();
5284 SmallVector<Value> dynamicBasis;
5285 SmallVector<int64_t> staticBasis;
5286 dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
5287 staticBasis);
5288 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5289}
5290
5291void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5292 OperationState &odsState,
5293 ValueRange multiIndex,
5294 ArrayRef<OpFoldResult> basis,
5295 bool disjoint) {
5296 if (!basis.empty() && basis.front() == OpFoldResult())
5297 basis = basis.drop_front();
5298 SmallVector<Value> dynamicBasis;
5299 SmallVector<int64_t> staticBasis;
5300 dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
5301 build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint);
5302}
5303
5304void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5305 OperationState &odsState,
5306 ValueRange multiIndex,
5307 ArrayRef<int64_t> basis, bool disjoint) {
5308 build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint);
5309}
5310
5311LogicalResult AffineLinearizeIndexOp::verify() {
5312 size_t numIndexes = getMultiIndex().size();
5313 size_t numBasisElems = getStaticBasis().size();
5314 if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
5315 return emitOpError("should be passed a basis element for each index except "
5316 "possibly the first");
5317
5318 auto dynamicMarkersCount =
5319 llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
5320 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
5321 return emitOpError(
5322 "mismatch between dynamic and static basis (kDynamic marker but no "
5323 "corresponding dynamic basis entry) -- this can only happen due to an "
5324 "incorrect fold/rewrite");
5325
5326 return success();
5327}
5328
5329OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
5330 std::optional<SmallVector<int64_t>> maybeStaticBasis =
5331 foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
5332 adaptor.getDynamicBasis());
5333 if (maybeStaticBasis) {
5334 setStaticBasis(*maybeStaticBasis);
5335 return getResult();
5336 }
5337 // No indices linearizes to zero.
5338 if (getMultiIndex().empty())
5339 return IntegerAttr::get(getResult().getType(), 0);
5340
5341 // One single index linearizes to itself.
5342 if (getMultiIndex().size() == 1)
5343 return getMultiIndex().front();
5344
5345 if (llvm::is_contained(adaptor.getMultiIndex(), nullptr))
5346 return nullptr;
5347
5348 if (!adaptor.getDynamicBasis().empty())
5349 return nullptr;
5350
5351 int64_t result = 0;
5352 int64_t stride = 1;
5353 for (auto [length, indexAttr] :
5354 llvm::zip_first(llvm::reverse(getStaticBasis()),
5355 llvm::reverse(adaptor.getMultiIndex()))) {
5356 result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
5357 stride = stride * length;
5358 }
5359 // Handle the index element with no basis element.
5360 if (!hasOuterBound())
5361 result =
5362 result +
5363 cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
5364
5365 return IntegerAttr::get(getResult().getType(), result);
5366}
5367
5368SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
5369 OpBuilder builder(getContext());
5370 if (hasOuterBound()) {
5371 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5372 return getMixedValues(getStaticBasis().drop_front(),
5373 getDynamicBasis().drop_front(), builder);
5374
5375 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5376 builder);
5377 }
5378
5379 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5380}
5381
5382SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
5383 SmallVector<OpFoldResult> ret = getMixedBasis();
5384 if (!hasOuterBound())
5385 ret.insert(ret.begin(), OpFoldResult());
5386 return ret;
5387}
5388
5389namespace {
5390/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
5391/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
5392/// %...d)`.
5393
5394/// Note that `disjoint` is required here, because, without it, we could have
5395/// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
5396/// is a valid operation where the `%c64` cannot be trivially dropped.
5397///
5398/// Alternatively, if `%x` in the above is a known constant 0, remove it even if
5399/// the operation isn't asserted to be `disjoint`.
5400struct DropLinearizeUnitComponentsIfDisjointOrZero final
5401 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5403
5404 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5405 PatternRewriter &rewriter) const override {
5406 ValueRange multiIndex = op.getMultiIndex();
5407 size_t numIndices = multiIndex.size();
5408 SmallVector<Value> newIndices;
5409 newIndices.reserve(numIndices);
5410 SmallVector<OpFoldResult> newBasis;
5411 newBasis.reserve(numIndices);
5412
5413 if (!op.hasOuterBound()) {
5414 newIndices.push_back(multiIndex.front());
5415 multiIndex = multiIndex.drop_front();
5416 }
5417
5418 SmallVector<OpFoldResult> basis = op.getMixedBasis();
5419 for (auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5420 std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
5421 if (!basisEntry || *basisEntry != 1) {
5422 newIndices.push_back(index);
5423 newBasis.push_back(basisElem);
5424 continue;
5425 }
5426
5427 std::optional<int64_t> indexValue = getConstantIntValue(index);
5428 if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5429 newIndices.push_back(index);
5430 newBasis.push_back(basisElem);
5431 continue;
5432 }
5433 }
5434 if (newIndices.size() == numIndices)
5435 return rewriter.notifyMatchFailure(op,
5436 "no unit basis entries to replace");
5437
5438 if (newIndices.empty()) {
5439 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
5440 return success();
5441 }
5442 rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5443 op, newIndices, newBasis, op.getDisjoint());
5444 return success();
5445 }
5446};
5447
5448OpFoldResult computeProduct(Location loc, OpBuilder &builder,
5449 ArrayRef<OpFoldResult> terms) {
5450 int64_t nDynamic = 0;
5451 SmallVector<Value> dynamicPart;
5452 AffineExpr result = builder.getAffineConstantExpr(1);
5453 for (OpFoldResult term : terms) {
5454 if (!term)
5455 return term;
5456 std::optional<int64_t> maybeConst = getConstantIntValue(term);
5457 if (maybeConst) {
5458 result = result * builder.getAffineConstantExpr(*maybeConst);
5459 } else {
5460 dynamicPart.push_back(cast<Value>(term));
5461 result = result * builder.getAffineSymbolExpr(nDynamic++);
5462 }
5463 }
5464 if (auto constant = dyn_cast<AffineConstantExpr>(result))
5465 return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
5466 return AffineApplyOp::create(builder, loc, result, dynamicPart).getResult();
5467}
5468
5469/// If conseceutive outputs of a delinearize_index are linearized with the same
5470/// bounds, canonicalize away the redundant arithmetic.
5471///
5472/// That is, if we have
5473/// ```
5474/// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
5475/// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
5476/// by (...e, B1, B2, ..., BK, ...f)
5477/// ```
5478///
5479/// We can rewrite this to
5480/// ```
5481/// B = B1 * B2 ... BK
5482/// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
5483/// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
5484/// ```
5485/// where we replace all results of %s unaffected by the change with results
5486/// from %sMerged.
5487///
5488/// As a special case, if all results of the delinearize are merged in this way
5489/// we can replace those usages with %x, thus cancelling the delinearization
5490/// entirely, as in
5491/// ```
5492/// %s:3 = affine.delinearize_index %x into (2, 4, 8)
5493/// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
5494/// ```
5495/// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
5496struct CancelLinearizeOfDelinearizePortion final
5497 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5499
5500private:
5501 // Struct representing a case where the cancellation pattern
5502 // applies. A `Match` means that `length` inputs to the linearize operation
5503 // starting at `linStart` can be cancelled with `length` outputs of
5504 // `delinearize`, starting from `delinStart`.
5505 struct Match {
5506 AffineDelinearizeIndexOp delinearize;
5507 unsigned linStart = 0;
5508 unsigned delinStart = 0;
5509 unsigned length = 0;
5510 };
5511
5512public:
5513 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5514 PatternRewriter &rewriter) const override {
5515 SmallVector<Match> matches;
5516
5517 const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
5518 ArrayRef<OpFoldResult> linBasisRef = linBasis;
5519
5520 ValueRange multiIndex = linearizeOp.getMultiIndex();
5521 unsigned numLinArgs = multiIndex.size();
5522 unsigned linArgIdx = 0;
5523 // We only want to replace one run from the same delinearize op per
5524 // pattern invocation lest we run into invalidation issues.
5525 llvm::SmallPtrSet<Operation *, 2> alreadyMatchedDelinearize;
5526 while (linArgIdx < numLinArgs) {
5527 auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5528 if (!asResult) {
5529 linArgIdx++;
5530 continue;
5531 }
5532
5533 auto delinearizeOp =
5534 dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5535 if (!delinearizeOp) {
5536 linArgIdx++;
5537 continue;
5538 }
5539
5540 /// Result 0 of the delinearize and argument 0 of the linearize can
5541 /// leave their maximum value unspecified. However, even if this happens
5542 /// we can still sometimes start the match process. Specifically, if
5543 /// - The argument we're matching is result 0 and argument 0 (so the
5544 /// bounds don't matter). For example,
5545 ///
5546 /// %0:2 = affine.delinearize_index %x into (8) : index, index
5547 /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
5548 /// allows cancellation
5549 /// - The delinearization doesn't specify a bound, but the linearization
5550 /// is `disjoint`, which asserts that the bound on the linearization is
5551 /// correct.
5552 unsigned delinArgIdx = asResult.getResultNumber();
5553 SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
5554 OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5555 OpFoldResult firstLinBound = linBasis[linArgIdx];
5556 bool boundsMatch = firstDelinBound == firstLinBound;
5557 bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5558 bool knownByDisjoint =
5559 linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5560 if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5561 linArgIdx++;
5562 continue;
5563 }
5564
5565 unsigned j = 1;
5566 unsigned numDelinOuts = delinearizeOp.getNumResults();
5567 for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
5568 ++j) {
5569 if (multiIndex[linArgIdx + j] !=
5570 delinearizeOp.getResult(delinArgIdx + j))
5571 break;
5572 if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
5573 break;
5574 }
5575 // If there're multiple matches against the same delinearize_index,
5576 // only rewrite the first one we find to prevent invalidations. The next
5577 // ones will be taken care of by subsequent pattern invocations.
5578 if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5579 linArgIdx++;
5580 continue;
5581 }
5582 matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx, j});
5583 linArgIdx += j;
5584 }
5585
5586 if (matches.empty())
5587 return rewriter.notifyMatchFailure(
5588 linearizeOp, "no run of delinearize outputs to deal with");
5589
5590 // Record all the delinearize replacements so we can do them after creating
5591 // the new linearization operation, since the new operation might use
5592 // outputs of something we're replacing.
5593 SmallVector<SmallVector<Value>> delinearizeReplacements;
5594
5595 SmallVector<Value> newIndex;
5596 newIndex.reserve(numLinArgs);
5597 SmallVector<OpFoldResult> newBasis;
5598 newBasis.reserve(numLinArgs);
5599 unsigned prevMatchEnd = 0;
5600 for (Match m : matches) {
5601 unsigned gap = m.linStart - prevMatchEnd;
5602 llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5603 llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5604 // Update here so we don't forget this during early continues
5605 prevMatchEnd = m.linStart + m.length;
5606
5607 PatternRewriter::InsertionGuard g(rewriter);
5608 rewriter.setInsertionPoint(m.delinearize);
5609
5610 ArrayRef<OpFoldResult> basisToMerge =
5611 linBasisRef.slice(m.linStart, m.length);
5612 // We use the slice from the linearize's basis above because of the
5613 // "bounds inferred from `disjoint`" case above.
5614 OpFoldResult newSize =
5615 computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
5616
5617 // Trivial case where we can just skip past the delinearize all together
5618 if (m.length == m.delinearize.getNumResults()) {
5619 newIndex.push_back(m.delinearize.getLinearIndex());
5620 newBasis.push_back(newSize);
5621 // Pad out set of replacements so we don't do anything with this one.
5622 delinearizeReplacements.push_back(SmallVector<Value>());
5623 continue;
5624 }
5625
5626 SmallVector<Value> newDelinResults;
5627 SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
5628 newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5629 newDelinBasis.begin() + m.delinStart + m.length);
5630 newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5631 auto newDelinearize = AffineDelinearizeIndexOp::create(
5632 rewriter, m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5633 newDelinBasis);
5634
5635 // Since there may be other uses of the indices we just merged together,
5636 // create a residual affine.delinearize_index that delinearizes the
5637 // merged output into its component parts.
5638 Value combinedElem = newDelinearize.getResult(m.delinStart);
5639 auto residualDelinearize = AffineDelinearizeIndexOp::create(
5640 rewriter, m.delinearize.getLoc(), combinedElem, basisToMerge);
5641
5642 // Swap all the uses of the unaffected delinearize outputs to the new
5643 // delinearization so that the old code can be removed if this
5644 // linearize_index is the only user of the merged results.
5645 llvm::append_range(newDelinResults,
5646 newDelinearize.getResults().take_front(m.delinStart));
5647 llvm::append_range(newDelinResults, residualDelinearize.getResults());
5648 llvm::append_range(
5649 newDelinResults,
5650 newDelinearize.getResults().drop_front(m.delinStart + 1));
5651
5652 delinearizeReplacements.push_back(newDelinResults);
5653 newIndex.push_back(combinedElem);
5654 newBasis.push_back(newSize);
5655 }
5656 llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5657 llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5658 rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
5659 linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5660
5661 for (auto [m, newResults] :
5662 llvm::zip_equal(matches, delinearizeReplacements)) {
5663 if (newResults.empty())
5664 continue;
5665 rewriter.replaceOp(m.delinearize, newResults);
5666 }
5667
5668 return success();
5669 }
5670};
5671
5672/// Strip leading zero from affine.linearize_index.
5673///
5674/// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
5675/// to `affine.linearize_index [...a] by (...b)` in all cases.
5676struct DropLinearizeLeadingZero final
5677 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5679
5680 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5681 PatternRewriter &rewriter) const override {
5682 Value leadingIdx = op.getMultiIndex().front();
5683 if (!matchPattern(leadingIdx, m_Zero()))
5684 return failure();
5685
5686 if (op.getMultiIndex().size() == 1) {
5687 rewriter.replaceOp(op, leadingIdx);
5688 return success();
5689 }
5690
5691 SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
5692 ArrayRef<OpFoldResult> newMixedBasis = mixedBasis;
5693 if (op.hasOuterBound())
5694 newMixedBasis = newMixedBasis.drop_front();
5695
5696 rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5697 op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5698 return success();
5699 }
5700};
5701} // namespace
5702
5703void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5704 RewritePatternSet &patterns, MLIRContext *context) {
5705 patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5706 DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5707}
5708
5709//===----------------------------------------------------------------------===//
5710// TableGen'd op method definitions
5711//===----------------------------------------------------------------------===//
5712
5713#define GET_OP_CLASSES
5714#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
return success()
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
static LogicalResult verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
Definition AffineOps.cpp:62
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
static bool isTopLevelValueOrAbove(Value value, Region *region)
A utility function to check if a value is defined at the top level of region or is an argument of reg...
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
static bool isValidAffineIndexOperand(Value value, Region *region)
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
static LogicalResult replaceAffineDelinearizeIndexInverseExpression(AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
If this map contains of the expression x_1 + x_1 * C_1 + ... x_n * C_N + / ... (not necessarily in or...
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
static void simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
static void legalizeDemotedDims(MapOrSet &mapOrSet, SmallVectorImpl< Value > &operands)
A valid affine dimension may appear as a symbol in affine.apply operations.
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
return success()
static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, AffineExpr dimOrSym, AffineMap *map, ValueRange dims, ValueRange syms)
Assuming dimOrSym is a quantity in the apply op map map and defined by minOp = affine_min(x_1,...
static SmallVector< OpFoldResult > AffineForEmptyLoopFolder(AffineForOp forOp)
Fold the empty loop.
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr &quotientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms, bool replaceAffineMin)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static std::optional< uint64_t > getTrivialConstantTripCount(AffineForOp forOp)
Returns constant trip count in trivial cases.
static LogicalResult verifyAffineMinMaxOp(T op)
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
static void shortenAddChainsContainingAll(AffineExpr e, const llvm::SmallDenseSet< AffineExpr, 4 > &exprsToRemove, AffineExpr newVal, DenseMap< AffineExpr, AffineExpr > &replacementsMap)
Recursively traverse e.
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value getMemRef(Operation *memOp)
Returns the memref being read/written by a memref/affine load/store op.
Definition Utils.cpp:246
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getI64ArrayAttr(paddingDimensions)
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
Definition SCFToGPU.cpp:75
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
Definition SCFToGPU.cpp:80
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
#define div(a, b)
#define rem(a, b)
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
AffineExprKind getKind() const
Return the classification for this type.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
MLIRContext * getContext() const
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
MLIRContext * getContext() const
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
Definition AffineMap.h:221
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
Definition AffineMap.h:267
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
Definition AffineMap.h:228
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ... numSymbols) by symbols[offset + shift ... shift + numSymbols).
Definition AffineMap.h:280
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
AffineMap getDimIdentityMap()
Definition Builders.cpp:383
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:387
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:368
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:372
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
Definition Builders.cpp:179
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
NoneType getNoneType()
Definition Builders.cpp:88
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
Definition Builders.cpp:376
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition Builders.cpp:378
AffineMap getSymbolIdentityMap()
Definition Builders.cpp:396
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:281
IndexType getIndexType()
Definition Builders.cpp:51
An attribute that represents a reference to a dense integer vector or tensor object.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
Location getLoc() const
Accessors for the implied location.
Definition Builders.h:663
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition IntegerSet.h:44
unsigned getNumDims() const
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
unsigned getNumInputs() const
ArrayRef< AffineExpr > getConstraints() const
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
unsigned getNumSymbols() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition Builders.cpp:457
This class represents a single result from folding an operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperandRange operand_range
Definition Operation.h:371
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
operand_range::iterator operand_iterator
Definition Operation.h:372
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool isParent() const
Returns true if branching from the parent op.
Operation * getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
A variable that can be added to the constraint set as a "column".
static bool compare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Return "true" if "lhs cmp rhs" was proven to hold.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
AffineBound represents a lower or upper bound in the for operation.
Definition AffineOps.h:550
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
Definition AffineOps.h:106
OpOperand & getTagMemRefMutable()
Definition AffineOps.h:211
Value getTagMemRef()
Returns the Tag MemRef for this DMA operation.
Definition AffineOps.h:210
static void build(OpBuilder &builder, OperationState &result, Value srcMemRef, AffineMap srcMap, ValueRange srcIndices, Value destMemRef, AffineMap dstMap, ValueRange destIndices, Value tagMemRef, AffineMap tagMap, ValueRange tagIndices, Value numElements, Value stride=nullptr, Value elementsPerStride=nullptr)
operand_range getDstIndices()
Returns the destination memref indices for this DMA operation.
Definition AffineOps.h:198
Value getNumElementsPerStride()
Returns the number of elements to transfer per stride for this DMA op.
Definition AffineOps.h:307
AffineMapAttr getTagMapAttr()
Definition AffineOps.h:225
operand_range getSrcIndices()
Returns the source memref affine map indices for this DMA operation.
Definition AffineOps.h:155
AffineMapAttr getSrcMapAttr()
Definition AffineOps.h:149
bool isStrided()
Returns true if this DMA operation is strided, returns false otherwise.
Definition AffineOps.h:294
AffineMap getDstMap()
Returns the affine map used to access the destination memref.
Definition AffineOps.h:191
void print(OpAsmPrinter &p)
OpOperand & getDstMemRefMutable()
Definition AffineOps.h:173
Value getDstMemRef()
Returns the destination MemRefType for this DMA operation.
Definition AffineOps.h:172
static StringRef getSrcMapAttrStrName()
Definition AffineOps.h:281
AffineMapAttr getDstMapAttr()
Definition AffineOps.h:192
unsigned getSrcMemRefOperandIndex()
Returns the operand index of the source memref.
Definition AffineOps.h:133
unsigned getTagMemRefOperandIndex()
Returns the operand index of the tag memref.
Definition AffineOps.h:205
static StringRef getTagMapAttrStrName()
Definition AffineOps.h:283
LogicalResult verifyInvariantsImpl()
void getEffects(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects)
AffineMap getSrcMap()
Returns the affine map used to access the source memref.
Definition AffineOps.h:148
Value getNumElements()
Returns the number of elements being transferred by this DMA operation.
Definition AffineOps.h:238
static AffineDmaStartOp create(OpBuilder &builder, Location location, Value srcMemRef, AffineMap srcMap, ValueRange srcIndices, Value destMemRef, AffineMap dstMap, ValueRange destIndices, Value tagMemRef, AffineMap tagMap, ValueRange tagIndices, Value numElements, Value stride=nullptr, Value elementsPerStride=nullptr)
AffineMap getTagMap()
Returns the affine map used to access the tag memref.
Definition AffineOps.h:224
static ParseResult parse(OpAsmParser &parser, OperationState &result)
Value getStride()
Returns the stride value for this DMA operation.
Definition AffineOps.h:300
unsigned getDstMemRefOperandIndex()
Returns the operand index of the destination memref.
Definition AffineOps.h:167
static StringRef getDstMapAttrStrName()
Definition AffineOps.h:282
static StringRef getOperationName()
Definition AffineOps.h:285
Value getSrcMemRef()
Returns the source MemRefType for this DMA operation.
Definition AffineOps.h:136
OpOperand & getSrcMemRefMutable()
Definition AffineOps.h:137
operand_range getTagIndices()
Returns the tag memref indices for this DMA operation.
Definition AffineOps.h:231
LogicalResult fold(ArrayRef< Attribute > cstOperands, SmallVectorImpl< OpFoldResult > &results)
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
Definition AffineOps.h:330
Value getNumElements()
Returns the number of elements transferred by the associated DMA op.
Definition AffineOps.h:380
LogicalResult verifyInvariantsImpl()
static StringRef getOperationName()
Definition AffineOps.h:344
Value getTagMemRef()
Returns the Tag MemRef associated with the DMA operation being waited on.
Definition AffineOps.h:347
static ParseResult parse(OpAsmParser &parser, OperationState &result)
static StringRef getTagMapAttrStrName()
Definition AffineOps.h:382
void getEffects(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects)
LogicalResult fold(ArrayRef< Attribute > cstOperands, SmallVectorImpl< OpFoldResult > &results)
AffineMapAttr getTagMapAttr()
Definition AffineOps.h:355
void print(OpAsmPrinter &p)
static AffineDmaWaitOp create(OpBuilder &builder, Location location, Value tagMemRef, AffineMap tagMap, ValueRange tagIndices, Value numElements)
static void build(OpBuilder &builder, OperationState &result, Value tagMemRef, AffineMap tagMap, ValueRange tagIndices, Value numElements)
OpOperand & getTagMemRefMutable()
Definition AffineOps.h:348
operand_range getTagIndices()
Returns the tag memref index for this DMA operation.
Definition AffineOps.h:361
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
void reset(AffineMap map, ValueRange operands, ValueRange results={})
static void difference(const AffineValueMap &a, const AffineValueMap &b, AffineValueMap *res)
Return the value map that is the difference of value maps 'a' and 'b', represented as an affine map a...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult computeProduct(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > terms)
Return the product of terms, creating an affine.apply if any of them are non-constant values.
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:46
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t > > constLowerBounds, ArrayRef< std::optional< int64_t > > constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
AffineExprKind
Definition AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
Definition AffineExpr.h:50
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
Definition AffineExpr.h:46
@ DimId
Dimensional identifier.
Definition AffineExpr.h:59
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
Definition AffineExpr.h:48
@ SymbolId
Symbolic identifier.
Definition AffineExpr.h:61
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Canonicalize the affine map result expression order of an affine min/max operation.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Remove duplicated expressions in affine min/max ops.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition Builders.h:285
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.