MLIR 23.0.0git
AffineOps.cpp
Go to the documentation of this file.
1//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/AffineExpr.h"
16#include "mlir/IR/IRMapping.h"
17#include "mlir/IR/IntegerSet.h"
18#include "mlir/IR/Matchers.h"
21#include "mlir/IR/Value.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/SmallVectorExtras.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Support/DebugLog.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/MathExtras.h"
32#include <numeric>
33#include <optional>
34
35using namespace mlir;
36using namespace mlir::affine;
37
38using llvm::divideCeilSigned;
39using llvm::divideFloorSigned;
40using llvm::mod;
41
42#define DEBUG_TYPE "affine-ops"
43
44#include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
45
46/// A utility function to check if a value is defined at the top level of
47/// `region` or is an argument of `region`. A value of index type defined at the
48/// top level of a `AffineScope` region is always a valid symbol for all
49/// uses in that region.
51 if (auto arg = dyn_cast<BlockArgument>(value))
52 return arg.getParentRegion() == region;
53 return value.getDefiningOp()->getParentRegion() == region;
54}
55
56/// Checks if `value` known to be a legal affine dimension or symbol in `src`
57/// region remains legal if the operation that uses it is inlined into `dest`
58/// with the given value mapping. `legalityCheck` is either `isValidDim` or
59/// `isValidSymbol`, depending on the value being required to remain a valid
60/// dimension or symbol.
61static bool
63 const IRMapping &mapping,
64 function_ref<bool(Value, Region *)> legalityCheck) {
65 // If the value is a valid dimension for any other reason than being
66 // a top-level value, it will remain valid: constants get inlined
67 // with the function, transitive affine applies also get inlined and
68 // will be checked themselves, etc.
69 if (!isTopLevelValue(value, src))
70 return true;
71
72 // If it's a top-level value because it's a block operand, i.e. a
73 // function argument, check whether the value replacing it after
74 // inlining is a valid dimension in the new region.
75 if (llvm::isa<BlockArgument>(value))
76 return legalityCheck(mapping.lookup(value), dest);
77
78 // If it's a top-level value because it's defined in the region,
79 // it can only be inlined if the defining op is a constant or a
80 // `dim`, which can appear anywhere and be valid, since the defining
81 // op won't be top-level anymore after inlining.
82 Attribute operandCst;
83 bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp());
84 return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
85 isDimLikeOp;
86}
87
88/// Checks if all values known to be legal affine dimensions or symbols in `src`
89/// remain so if their respective users are inlined into `dest`.
90static bool
92 const IRMapping &mapping,
93 function_ref<bool(Value, Region *)> legalityCheck) {
94 return llvm::all_of(values, [&](Value v) {
95 return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
96 });
97}
98
99/// Checks if an affine read or write operation remains legal after inlining
100/// from `src` to `dest`.
101template <typename OpTy>
102static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
103 const IRMapping &mapping) {
104 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
105 AffineWriteOpInterface>::value,
106 "only ops with affine read/write interface are supported");
107
108 AffineMap map = op.getAffineMap();
109 ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
110 ValueRange symbolOperands =
111 op.getMapOperands().take_back(map.getNumSymbols());
113 dimOperands, src, dest, mapping,
114 static_cast<bool (*)(Value, Region *)>(isValidDim)))
115 return false;
117 symbolOperands, src, dest, mapping,
118 static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
119 return false;
120 return true;
121}
122
123/// Checks if an affine apply operation remains legal after inlining from `src`
124/// to `dest`.
125// Use "unused attribute" marker to silence clang-tidy warning stemming from
126// the inability to see through "llvm::TypeSwitch".
127template <>
128[[maybe_unused]] bool remainsLegalAfterInline(AffineApplyOp op, Region *src,
129 Region *dest,
130 const IRMapping &mapping) {
131 // If it's a valid dimension, we need to check that it remains so.
132 if (isValidDim(op.getResult(), src))
134 op.getMapOperands(), src, dest, mapping,
135 static_cast<bool (*)(Value, Region *)>(isValidDim));
136
137 // Otherwise it must be a valid symbol, check that it remains so.
139 op.getMapOperands(), src, dest, mapping,
140 static_cast<bool (*)(Value, Region *)>(isValidSymbol));
141}
142
143//===----------------------------------------------------------------------===//
144// AffineDialect Interfaces
145//===----------------------------------------------------------------------===//
146
147namespace {
148/// This class defines the interface for handling inlining with affine
149/// operations.
150struct AffineInlinerInterface : public DialectInlinerInterface {
151 using DialectInlinerInterface::DialectInlinerInterface;
152
153 //===--------------------------------------------------------------------===//
154 // Analysis Hooks
155 //===--------------------------------------------------------------------===//
156
157 /// Returns true if the given region 'src' can be inlined into the region
158 /// 'dest' that is attached to an operation registered to the current dialect.
159 /// 'wouldBeCloned' is set if the region is cloned into its new location
160 /// rather than moved, indicating there may be other users.
161 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
162 IRMapping &valueMapping) const final {
163 // We can inline into affine loops and conditionals if this doesn't break
164 // affine value categorization rules.
165 Operation *destOp = dest->getParentOp();
166 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
167 return false;
168
169 // Multi-block regions cannot be inlined into affine constructs, all of
170 // which require single-block regions.
171 if (!src->hasOneBlock())
172 return false;
173
174 // Side-effecting operations that the affine dialect cannot understand
175 // should not be inlined.
176 Block &srcBlock = src->front();
177 for (Operation &op : srcBlock) {
178 // Ops with no side effects are fine,
179 if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
180 if (iface.hasNoEffect())
181 continue;
182 }
183
184 // Assuming the inlined region is valid, we only need to check if the
185 // inlining would change it.
186 bool remainsValid =
187 llvm::TypeSwitch<Operation *, bool>(&op)
188 .Case<AffineApplyOp, AffineReadOpInterface,
189 AffineWriteOpInterface>([&](auto op) {
190 return remainsLegalAfterInline(op, src, dest, valueMapping);
191 })
192 .Default([](Operation *) {
193 // Conservatively disallow inlining ops we cannot reason about.
194 return false;
195 });
196
197 if (!remainsValid)
198 return false;
199 }
200
201 return true;
202 }
203
204 /// Returns true if the given operation 'op', that is registered to this
205 /// dialect, can be inlined into the given region, false otherwise.
206 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
207 IRMapping &valueMapping) const final {
208 // Always allow inlining affine operations into a region that is marked as
209 // affine scope, or into affine loops and conditionals. There are some edge
210 // cases when inlining *into* affine structures, but that is handled in the
211 // other 'isLegalToInline' hook above.
212 Operation *parentOp = region->getParentOp();
213 return parentOp->hasTrait<OpTrait::AffineScope>() ||
214 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
215 }
216
217 /// Affine regions should be analyzed recursively.
218 bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
219};
220} // namespace
221
222//===----------------------------------------------------------------------===//
223// AffineDialect
224//===----------------------------------------------------------------------===//
225
226void AffineDialect::initialize() {
227 addOperations<
228#define GET_OP_LIST
229#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
230 >();
231 addInterfaces<AffineInlinerInterface>();
232 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
233 AffineMinOp>();
234}
235
236/// Materialize a single constant operation from a given attribute value with
237/// the desired resultant type.
238Operation *AffineDialect::materializeConstant(OpBuilder &builder,
239 Attribute value, Type type,
240 Location loc) {
241 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
242 return ub::PoisonOp::create(builder, loc, type, poison);
243 return arith::ConstantOp::materialize(builder, value, type, loc);
244}
245
246/// A utility function to check if a value is defined at the top level of an
247/// op with trait `AffineScope`. If the value is defined in an unlinked region,
248/// conservatively assume it is not top-level. A value of index type defined at
249/// the top level is always a valid symbol.
251 if (auto arg = dyn_cast<BlockArgument>(value)) {
252 // The block owning the argument may be unlinked, e.g. when the surrounding
253 // region has not yet been attached to an Op, at which point the parent Op
254 // is null.
255 Operation *parentOp = arg.getOwner()->getParentOp();
256 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
257 }
258 // The defining Op may live in an unlinked block so its parent Op may be null.
259 Operation *parentOp = value.getDefiningOp()->getParentOp();
260 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
261}
262
263/// Returns the closest region enclosing `op` that is held by an operation with
264/// trait `AffineScope`; `nullptr` if there is no such region.
266 auto *curOp = op;
267 while (auto *parentOp = curOp->getParentOp()) {
268 if (parentOp->hasTrait<OpTrait::AffineScope>())
269 return curOp->getParentRegion();
270 curOp = parentOp;
271 }
272 return nullptr;
273}
274
276 Operation *curOp = op;
277 while (auto *parentOp = curOp->getParentOp()) {
278 if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))
279 return curOp->getParentRegion();
280 curOp = parentOp;
281 }
282 return nullptr;
283}
284
285// A Value can be used as a dimension id iff it meets one of the following
286// conditions:
287// *) It is valid as a symbol.
288// *) It is an induction variable.
289// *) It is the result of affine apply operation with dimension id arguments.
291 // The value must be an index type.
292 if (!value.getType().isIndex())
293 return false;
294
295 if (auto *defOp = value.getDefiningOp())
296 return isValidDim(value, getAffineScope(defOp));
297
298 // This value has to be a block argument for an op that has the
299 // `AffineScope` trait or an induction var of an affine.for or
300 // affine.parallel.
301 if (isAffineInductionVar(value))
302 return true;
303 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
304 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
305}
306
307// Value can be used as a dimension id iff it meets one of the following
308// conditions:
309// *) It is valid as a symbol.
310// *) It is an induction variable.
311// *) It is the result of an affine apply operation with dimension id operands.
312// *) It is the result of a more specialized index transformation (ex.
313// delinearize_index or linearize_index) with dimension id operands.
315 // The value must be an index type.
316 if (!value.getType().isIndex())
317 return false;
318
319 // All valid symbols are okay.
320 if (isValidSymbol(value, region))
321 return true;
322
323 auto *op = value.getDefiningOp();
324 if (!op) {
325 // This value has to be an induction var for an affine.for or an
326 // affine.parallel.
327 return isAffineInductionVar(value);
328 }
329
330 // Affine apply operation is ok if all of its operands are ok.
331 if (auto applyOp = dyn_cast<AffineApplyOp>(op))
332 return applyOp.isValidDim(region);
333 // delinearize_index and linearize_index are special forms of apply
334 // and so are valid dimensions if all their arguments are valid dimensions.
335 if (isa<AffineDelinearizeIndexOp, AffineLinearizeIndexOp>(op))
336 return llvm::all_of(op->getOperands(),
337 [&](Value arg) { return ::isValidDim(arg, region); });
338 // The dim op is okay if its operand memref/tensor is defined at the top
339 // level.
340 if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
341 return isTopLevelValue(dimOp.getShapedValue());
342 return false;
343}
344
345/// Returns true if the 'index' dimension of the `memref` defined by
346/// `memrefDefOp` is a statically shaped one or defined using a valid symbol
347/// for `region`.
348template <typename AnyMemRefDefOp>
349static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
350 Region *region) {
351 MemRefType memRefType = memrefDefOp.getType();
352
353 // Dimension index is out of bounds.
354 if (index >= memRefType.getRank()) {
355 return false;
356 }
357
358 // Statically shaped.
359 if (!memRefType.isDynamicDim(index))
360 return true;
361 // Get the position of the dimension among dynamic dimensions;
362 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
363 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
364 region);
365}
366
367/// Returns true if the result of the dim op is a valid symbol for `region`.
368static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
369 // The dim op is okay if its source is defined at the top level.
370 if (isTopLevelValue(dimOp.getShapedValue()))
371 return true;
372
373 // Conservatively handle remaining BlockArguments as non-valid symbols.
374 // E.g. scf.for iterArgs.
375 if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
376 return false;
377
378 // The dim op is also okay if its operand memref is a view/subview whose
379 // corresponding size is a valid symbol.
380 std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
381
382 // Be conservative if we can't understand the dimension.
383 if (!index.has_value())
384 return false;
385
386 // Skip over all memref.cast ops (if any).
387 Operation *op = dimOp.getShapedValue().getDefiningOp();
388 while (auto castOp = dyn_cast<memref::CastOp>(op)) {
389 // Bail on unranked memrefs.
390 if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
391 return false;
392 op = castOp.getSource().getDefiningOp();
393 if (!op)
394 return false;
395 }
396
397 int64_t i = index.value();
399 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
400 [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
401 .Default([](Operation *) { return false; });
402}
403
404// A value can be used as a symbol (at all its use sites) iff it meets one of
405// the following conditions:
406// *) It is a constant.
407// *) Its defining op or block arg appearance is immediately enclosed by an op
408// with `AffineScope` trait.
409// *) It is the result of an affine.apply operation with symbol operands.
410// *) It is a result of the dim op on a memref whose corresponding size is a
411// valid symbol.
413 if (!value)
414 return false;
415
416 // The value must be an index type.
417 if (!value.getType().isIndex())
418 return false;
419
420 // Check that the value is a top level value.
421 if (isTopLevelValue(value))
422 return true;
423
424 if (auto *defOp = value.getDefiningOp())
425 return isValidSymbol(value, getAffineScope(defOp));
426
427 return false;
428}
429
430/// A utility function to check if a value is defined at the top level of
431/// `region` or is an argument of `region` or is defined above the region.
432static bool isTopLevelValueOrAbove(Value value, Region *region) {
433 Region *parentRegion = value.getParentRegion();
434 do {
435 if (parentRegion == region)
436 return true;
437 Operation *regionOp = region->getParentOp();
438 if (regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
439 break;
440 region = region->getParentOp()->getParentRegion();
441 } while (region);
442 return false;
443}
444
445/// A value can be used as a symbol for `region` iff it meets one of the
446/// following conditions:
447/// *) It is a constant.
448/// *) It is a result of a `Pure` operation whose operands are valid symbolic
449/// *) identifiers.
450/// *) It is a result of the dim op on a memref whose corresponding size is
451/// a valid symbol.
452/// *) It is defined at the top level of 'region' or is its argument.
453/// *) It dominates `region`'s parent op.
454/// If `region` is null, conservatively assume the symbol definition scope does
455/// not exist and only accept the values that would be symbols regardless of
456/// the surrounding region structure, i.e. the first three cases above.
458 // The value must be an index type.
459 if (!value.getType().isIndex())
460 return false;
461
462 // A top-level value is a valid symbol.
463 if (region && isTopLevelValueOrAbove(value, region))
464 return true;
465
466 auto *defOp = value.getDefiningOp();
467 if (!defOp)
468 return false;
469
470 // Constant operation is ok.
471 Attribute operandCst;
472 if (matchPattern(defOp, m_Constant(&operandCst)))
473 return true;
474
475 // `Pure` operation that whose operands are valid symbolic identifiers.
476 if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
477 return affine::isValidSymbol(operand, region);
478 })) {
479 return true;
480 }
481
482 // Dim op results could be valid symbols at any level.
483 if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
484 return isDimOpValidSymbol(dimOp, region);
485
486 return false;
487}
488
489// Returns true if 'value' is a valid index to an affine operation (e.g.
490// affine.load, affine.store, affine.dma_start, affine.dma_wait) where
491// `region` provides the polyhedral symbol scope. Returns false otherwise.
492static bool isValidAffineIndexOperand(Value value, Region *region) {
493 return isValidDim(value, region) || isValidSymbol(value, region);
494}
495
496/// Prints dimension and symbol list.
499 unsigned numDims, OpAsmPrinter &printer) {
500 OperandRange operands(begin, end);
501 printer << '(' << operands.take_front(numDims) << ')';
502 if (operands.size() > numDims)
503 printer << '[' << operands.drop_front(numDims) << ']';
504}
505
506/// Parses dimension and symbol list and returns true if parsing failed.
508 OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
511 return failure();
512 // Store number of dimensions for validation by caller.
513 numDims = opInfos.size();
514
515 // Parse the optional symbol operands.
516 auto indexTy = parser.getBuilder().getIndexType();
517 return failure(parser.parseOperandList(
519 parser.resolveOperands(opInfos, indexTy, operands));
520}
521
522/// Utility function to verify that a set of operands are valid dimension and
523/// symbol identifiers. The operands should be laid out such that the dimension
524/// operands are before the symbol operands. This function returns failure if
525/// there was an invalid operand. An operation is provided to emit any necessary
526/// errors.
527template <typename OpTy>
528static LogicalResult
530 unsigned numDims) {
531 unsigned opIt = 0;
532 for (auto operand : operands) {
533 if (opIt++ < numDims) {
534 if (!isValidDim(operand, getAffineScope(op)))
535 return op.emitOpError("operand cannot be used as a dimension id");
536 } else if (!isValidSymbol(operand, getAffineScope(op))) {
537 return op.emitOpError("operand cannot be used as a symbol");
538 }
539 }
540 return success();
541}
542
543//===----------------------------------------------------------------------===//
544// AffineApplyOp
545//===----------------------------------------------------------------------===//
546
547AffineValueMap AffineApplyOp::getAffineValueMap() {
548 return AffineValueMap(getAffineMap(), getOperands(), getResult());
549}
550
551ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
552 auto &builder = parser.getBuilder();
553 auto indexTy = builder.getIndexType();
554
555 AffineMapAttr mapAttr;
556 unsigned numDims;
557 if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
558 parseDimAndSymbolList(parser, result.operands, numDims) ||
559 parser.parseOptionalAttrDict(result.attributes))
560 return failure();
561 auto map = mapAttr.getValue();
562
563 if (map.getNumDims() != numDims ||
564 numDims + map.getNumSymbols() != result.operands.size()) {
565 return parser.emitError(parser.getNameLoc(),
566 "dimension or symbol index mismatch");
567 }
568
569 result.types.append(map.getNumResults(), indexTy);
570 return success();
571}
572
573void AffineApplyOp::print(OpAsmPrinter &p) {
574 p << " " << getMapAttr();
575 printDimAndSymbolList(operand_begin(), operand_end(),
576 getAffineMap().getNumDims(), p);
577 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
578}
579
580LogicalResult AffineApplyOp::verify() {
581 // Check input and output dimensions match.
582 AffineMap affineMap = getMap();
583
584 // Verify that operand count matches affine map dimension and symbol count.
585 if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
586 return emitOpError(
587 "operand count and affine map dimension and symbol count must match");
588
589 // Verify that the map only produces one result.
590 if (affineMap.getNumResults() != 1)
591 return emitOpError("mapping must produce one value");
592
593 // Do not allow valid dims to be used in symbol positions. We do allow
594 // affine.apply to use operands for values that may neither qualify as affine
595 // dims or affine symbols due to usage outside of affine ops, analyses, etc.
596 Region *region = getAffineScope(*this);
597 for (Value operand : getMapOperands().drop_front(affineMap.getNumDims())) {
598 if (::isValidDim(operand, region) && !::isValidSymbol(operand, region))
599 return emitError("dimensional operand cannot be used as a symbol");
600 }
601
602 return success();
603}
604
605// The result of the affine apply operation can be used as a dimension id if all
606// its operands are valid dimension ids.
607bool AffineApplyOp::isValidDim() {
608 return llvm::all_of(getOperands(),
609 [](Value op) { return affine::isValidDim(op); });
610}
611
612// The result of the affine apply operation can be used as a dimension id if all
613// its operands are valid dimension ids with the parent operation of `region`
614// defining the polyhedral scope for symbols.
615bool AffineApplyOp::isValidDim(Region *region) {
616 return llvm::all_of(getOperands(),
617 [&](Value op) { return ::isValidDim(op, region); });
618}
619
620// The result of the affine apply operation can be used as a symbol if all its
621// operands are symbols.
622bool AffineApplyOp::isValidSymbol() {
623 return llvm::all_of(getOperands(),
624 [](Value op) { return affine::isValidSymbol(op); });
625}
626
627// The result of the affine apply operation can be used as a symbol in `region`
628// if all its operands are symbols in `region`.
629bool AffineApplyOp::isValidSymbol(Region *region) {
630 return llvm::all_of(getOperands(), [&](Value operand) {
631 return affine::isValidSymbol(operand, region);
632 });
633}
634
635OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
636 auto map = getAffineMap();
637
638 // Fold dims and symbols to existing values.
639 auto expr = map.getResult(0);
640 if (auto dim = dyn_cast<AffineDimExpr>(expr))
641 return getOperand(dim.getPosition());
642 if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
643 return getOperand(map.getNumDims() + sym.getPosition());
644
645 // Otherwise, default to folding the map.
647 bool hasPoison = false;
648 auto foldResult =
649 map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
650 if (hasPoison)
651 return ub::PoisonAttr::get(getContext());
652 if (failed(foldResult))
653 return {};
654 return result[0];
655}
656
657/// Returns the largest known divisor of `e`. Exploits information from the
658/// values in `operands`.
660 // This method isn't aware of `operands`.
662
663 // We now make use of operands for the case `e` is a dim expression.
664 // TODO: More powerful simplification would have to modify
665 // getLargestKnownDivisor to take `operands` and exploit that information as
666 // well for dim/sym expressions, but in that case, getLargestKnownDivisor
667 // can't be part of the IR library but of the `Analysis` library. The IR
668 // library can only really depend on simple O(1) checks.
669 auto dimExpr = dyn_cast<AffineDimExpr>(e);
670 // If it's not a dim expr, `div` is the best we have.
671 if (!dimExpr)
672 return div;
673
674 // We simply exploit information from loop IVs.
675 // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
676 // desired simplifications are expected to be part of other
677 // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
678 // LoopAnalysis library.
679 Value operand = operands[dimExpr.getPosition()];
680 int64_t operandDivisor = 1;
681 // TODO: With the right accessors, this can be extended to
682 // LoopLikeOpInterface.
683 if (AffineForOp forOp = getForInductionVarOwner(operand)) {
684 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
685 operandDivisor = forOp.getStepAsInt();
686 } else {
687 uint64_t lbLargestKnownDivisor =
688 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
689 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
690 }
691 }
692 return operandDivisor;
693}
694
695/// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
696/// being an affine dim expression or a constant.
698 int64_t k) {
699 if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
700 int64_t constVal = constExpr.getValue();
701 return constVal >= 0 && constVal < k;
702 }
703 auto dimExpr = dyn_cast<AffineDimExpr>(e);
704 if (!dimExpr)
705 return false;
706 Value operand = operands[dimExpr.getPosition()];
707 // TODO: With the right accessors, this can be extended to
708 // LoopLikeOpInterface.
709 if (AffineForOp forOp = getForInductionVarOwner(operand)) {
710 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
711 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
712 return true;
713 }
714 }
715
716 // We don't consider other cases like `operand` being defined by a constant or
717 // an affine.apply op since such cases will already be handled by other
718 // patterns and propagation of loop IVs or constant would happen.
719 return false;
720}
721
722/// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
723/// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
724/// expression is in that form.
726 AffineExpr &quotientTimesDiv, AffineExpr &rem) {
727 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
728 if (!bin || bin.getKind() != AffineExprKind::Add)
729 return false;
730
731 AffineExpr llhs = bin.getLHS();
732 AffineExpr rlhs = bin.getRHS();
733 div = getLargestKnownDivisor(llhs, operands);
734 if (isNonNegativeBoundedBy(rlhs, operands, div)) {
735 quotientTimesDiv = llhs;
736 rem = rlhs;
737 return true;
738 }
739 div = getLargestKnownDivisor(rlhs, operands);
740 if (isNonNegativeBoundedBy(llhs, operands, div)) {
741 quotientTimesDiv = rlhs;
742 rem = llhs;
743 return true;
744 }
745 return false;
746}
747
748/// Gets the constant lower bound on an `iv`.
749static std::optional<int64_t> getLowerBound(Value iv) {
750 AffineForOp forOp = getForInductionVarOwner(iv);
751 if (forOp && forOp.hasConstantLowerBound())
752 return forOp.getConstantLowerBound();
753 return std::nullopt;
754}
755
756/// Gets the constant upper bound on an affine.for `iv`.
757static std::optional<int64_t> getUpperBound(Value iv) {
758 AffineForOp forOp = getForInductionVarOwner(iv);
759 if (!forOp || !forOp.hasConstantUpperBound())
760 return std::nullopt;
761
762 // If its lower bound is also known, we can get a more precise bound
763 // whenever the step is not one.
764 if (forOp.hasConstantLowerBound()) {
765 return forOp.getConstantUpperBound() - 1 -
766 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
767 forOp.getStepAsInt();
768 }
769 return forOp.getConstantUpperBound() - 1;
770}
771
772/// Determine a constant upper bound for `expr` if one exists while exploiting
773/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
774/// is guaranteed to be less than or equal to it.
775static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
776 unsigned numSymbols,
777 ArrayRef<Value> operands) {
778 // Get the constant lower or upper bounds on the operands.
779 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
780 constLowerBounds.reserve(operands.size());
781 constUpperBounds.reserve(operands.size());
782 for (Value operand : operands) {
783 constLowerBounds.push_back(getLowerBound(operand));
784 constUpperBounds.push_back(getUpperBound(operand));
785 }
786
787 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
788 return constExpr.getValue();
789
790 return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
791 constUpperBounds,
792 /*isUpper=*/true);
793}
794
795/// Determine a constant lower bound for `expr` if one exists while exploiting
796/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
797/// is guaranteed to be less than or equal to it.
798static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
799 unsigned numSymbols,
800 ArrayRef<Value> operands) {
801 // Get the constant lower or upper bounds on the operands.
802 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
803 constLowerBounds.reserve(operands.size());
804 constUpperBounds.reserve(operands.size());
805 for (Value operand : operands) {
806 constLowerBounds.push_back(getLowerBound(operand));
807 constUpperBounds.push_back(getUpperBound(operand));
808 }
809
810 std::optional<int64_t> lowerBound;
811 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
812 lowerBound = constExpr.getValue();
813 } else {
814 lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
815 constLowerBounds, constUpperBounds,
816 /*isUpper=*/false);
817 }
818 return lowerBound;
819}
820
821/// Simplify `expr` while exploiting information from the values in `operands`.
822static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
823 unsigned numSymbols,
824 ArrayRef<Value> operands) {
825 // We do this only for certain floordiv/mod expressions.
826 auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
827 if (!binExpr)
828 return;
829
830 // Simplify the child expressions first.
831 AffineExpr lhs = binExpr.getLHS();
832 AffineExpr rhs = binExpr.getRHS();
833 simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
834 simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
835 expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
836
837 binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
838 if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
840 expr.getKind() != AffineExprKind::Mod)) {
841 return;
842 }
843
844 // The `lhs` and `rhs` may be different post construction of simplified expr.
845 lhs = binExpr.getLHS();
846 rhs = binExpr.getRHS();
847 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
848 if (!rhsConst)
849 return;
850
851 int64_t rhsConstVal = rhsConst.getValue();
852 // Undefined exprsessions aren't touched; IR can still be valid with them.
853 if (rhsConstVal <= 0)
854 return;
855
856 // Exploit constant lower/upper bounds to simplify a floordiv or mod.
857 MLIRContext *context = expr.getContext();
858 std::optional<int64_t> lhsLbConst =
859 getLowerBound(lhs, numDims, numSymbols, operands);
860 std::optional<int64_t> lhsUbConst =
861 getUpperBound(lhs, numDims, numSymbols, operands);
862 if (lhsLbConst && lhsUbConst) {
863 int64_t lhsLbConstVal = *lhsLbConst;
864 int64_t lhsUbConstVal = *lhsUbConst;
865 // lhs floordiv c is a single value lhs is bounded in a range `c` that has
866 // the same quotient.
867 if (binExpr.getKind() == AffineExprKind::FloorDiv &&
868 divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
869 divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
871 divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
872 return;
873 }
874 // lhs ceildiv c is a single value if the entire range has the same ceil
875 // quotient.
876 if (binExpr.getKind() == AffineExprKind::CeilDiv &&
877 divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
878 divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
879 expr = getAffineConstantExpr(divideCeilSigned(lhsLbConstVal, rhsConstVal),
880 context);
881 return;
882 }
883 // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
884 if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
885 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
886 expr = lhs;
887 return;
888 }
889 }
890
891 // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
892 // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
893 // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
894 // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
895 AffineExpr quotientTimesDiv, rem;
896 int64_t divisor;
897 if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
898 if (rhsConstVal % divisor == 0 &&
899 binExpr.getKind() == AffineExprKind::FloorDiv) {
900 expr = quotientTimesDiv.floorDiv(rhsConst);
901 } else if (divisor % rhsConstVal == 0 &&
902 binExpr.getKind() == AffineExprKind::Mod) {
903 expr = rem % rhsConst;
904 }
905 return;
906 }
907
908 // Handle the simple case when the LHS expression can be either upper
909 // bounded or is a known multiple of RHS constant.
910 // lhs floordiv c -> 0 if 0 <= lhs < c,
911 // lhs mod c -> 0 if lhs % c = 0.
912 if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) &&
913 binExpr.getKind() == AffineExprKind::FloorDiv) ||
914 (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 &&
915 binExpr.getKind() == AffineExprKind::Mod)) {
916 expr = getAffineConstantExpr(0, expr.getContext());
917 }
918}
919
920/// Simplify the expressions in `map` while making use of lower or upper bounds
921/// of its operands. If `isMax` is true, the map is to be treated as a max of
922/// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
923/// d1) can be simplified to (8) if the operands are respectively lower bounded
924/// by 2 and 0 (the second expression can't be lower than 8).
926 ArrayRef<Value> operands,
927 bool isMax) {
928 // Can't simplify.
929 if (operands.empty())
930 return;
931
932 // Get the upper or lower bound on an affine.for op IV using its range.
933 // Get the constant lower or upper bounds on the operands.
934 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
935 constLowerBounds.reserve(operands.size());
936 constUpperBounds.reserve(operands.size());
937 for (Value operand : operands) {
938 constLowerBounds.push_back(getLowerBound(operand));
939 constUpperBounds.push_back(getUpperBound(operand));
940 }
941
942 // We will compute the lower and upper bounds on each of the expressions
943 // Then, we will check (depending on max or min) as to whether a specific
944 // bound is redundant by checking if its highest (in case of max) and its
945 // lowest (in the case of min) value is already lower than (or higher than)
946 // the lower bound (or upper bound in the case of min) of another bound.
947 SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
948 lowerBounds.reserve(map.getNumResults());
949 upperBounds.reserve(map.getNumResults());
950 for (AffineExpr e : map.getResults()) {
951 if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
952 lowerBounds.push_back(constExpr.getValue());
953 upperBounds.push_back(constExpr.getValue());
954 } else {
955 lowerBounds.push_back(
957 constLowerBounds, constUpperBounds,
958 /*isUpper=*/false));
959 upperBounds.push_back(
961 constLowerBounds, constUpperBounds,
962 /*isUpper=*/true));
963 }
964 }
965
966 // Collect expressions that are not redundant.
967 SmallVector<AffineExpr, 4> irredundantExprs;
968 for (auto exprEn : llvm::enumerate(map.getResults())) {
969 AffineExpr e = exprEn.value();
970 unsigned i = exprEn.index();
971 // Some expressions can be turned into constants.
972 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
973 e = getAffineConstantExpr(*lowerBounds[i], e.getContext());
974
975 // Check if the expression is redundant.
976 if (isMax) {
977 if (!upperBounds[i]) {
978 irredundantExprs.push_back(e);
979 continue;
980 }
981 // If there exists another expression such that its lower bound is greater
982 // than this expression's upper bound, it's redundant.
983 if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
984 auto otherLowerBound = en.value();
985 unsigned pos = en.index();
986 if (pos == i || !otherLowerBound)
987 return false;
988 if (*otherLowerBound > *upperBounds[i])
989 return true;
990 if (*otherLowerBound < *upperBounds[i])
991 return false;
992 // Equality case. When both expressions are considered redundant, we
993 // don't want to get both of them. We keep the one that appears
994 // first.
995 if (upperBounds[pos] && lowerBounds[i] &&
996 lowerBounds[i] == upperBounds[i] &&
997 otherLowerBound == *upperBounds[pos] && i < pos)
998 return false;
999 return true;
1000 }))
1001 irredundantExprs.push_back(e);
1002 } else {
1003 if (!lowerBounds[i]) {
1004 irredundantExprs.push_back(e);
1005 continue;
1006 }
1007 // Likewise for the `min` case. Use the complement of the condition above.
1008 if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
1009 auto otherUpperBound = en.value();
1010 unsigned pos = en.index();
1011 if (pos == i || !otherUpperBound)
1012 return false;
1013 if (*otherUpperBound < *lowerBounds[i])
1014 return true;
1015 if (*otherUpperBound > *lowerBounds[i])
1016 return false;
1017 if (lowerBounds[pos] && upperBounds[i] &&
1018 lowerBounds[i] == upperBounds[i] &&
1019 otherUpperBound == lowerBounds[pos] && i < pos)
1020 return false;
1021 return true;
1022 }))
1023 irredundantExprs.push_back(e);
1024 }
1025 }
1026
1027 // Create the map without the redundant expressions.
1028 map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs,
1029 map.getContext());
1030}
1031
1032/// Simplify the map while exploiting information on the values in `operands`.
1033// Use "unused attribute" marker to silence warning stemming from the inability
1034// to see through the template expansion.
1035[[maybe_unused]] static void simplifyMapWithOperands(AffineMap &map,
1036 ArrayRef<Value> operands) {
1037 assert(map.getNumInputs() == operands.size() && "invalid operands for map");
1038 SmallVector<AffineExpr> newResults;
1039 newResults.reserve(map.getNumResults());
1040 for (AffineExpr expr : map.getResults()) {
1042 operands);
1043 newResults.push_back(expr);
1044 }
1045 map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
1046 map.getContext());
1047}
1048
1049/// Assuming `dimOrSym` is a quantity in the apply op map `map` and defined by
1050/// `minOp = affine_min(x_1, ..., x_n)`. This function checks that:
1051/// `0 < affine_min(x_1, ..., x_n)` and proceeds with replacing the patterns:
1052/// ```
1053/// dimOrSym.ceildiv(x_k)
1054/// (dimOrSym + x_k - 1).floordiv(x_k)
1055/// ```
1056/// by `1` for all `k` in `1, ..., n`. This is possible because `x / x_k <= 1`.
1057///
1058///
1059/// Warning: ValueBoundsConstraintSet::computeConstantBound is needed to check
1060/// `minOp` is positive.
1061static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
1062 AffineExpr dimOrSym,
1063 AffineMap *map,
1064 ValueRange dims,
1065 ValueRange syms) {
1066 LDBG() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`";
1067 AffineMap affineMinMap = minOp.getAffineMap();
1068
1069 // Check the value is positive.
1070 for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
1071 // Compare each expression in the minimum against 0.
1073 getAsIndexOpFoldResult(minOp.getContext(), 0),
1076 minOp.getOperands())))
1077 return failure();
1078 }
1079
1080 /// Convert affine symbols and dimensions in minOp to symbols or dimensions in
1081 /// the apply op affine map.
1082 DenseMap<AffineExpr, AffineExpr> dimSymConversionTable;
1083 SmallVector<unsigned> unmappedDims, unmappedSyms;
1084 for (auto [i, dim] : llvm::enumerate(minOp.getDimOperands())) {
1085 auto it = llvm::find(dims, dim);
1086 if (it == dims.end()) {
1087 unmappedDims.push_back(i);
1088 continue;
1089 }
1090 dimSymConversionTable[getAffineDimExpr(i, minOp.getContext())] =
1091 getAffineDimExpr(it.getIndex(), minOp.getContext());
1092 }
1093 for (auto [i, sym] : llvm::enumerate(minOp.getSymbolOperands())) {
1094 auto it = llvm::find(syms, sym);
1095 if (it == syms.end()) {
1096 unmappedSyms.push_back(i);
1097 continue;
1098 }
1099 dimSymConversionTable[getAffineSymbolExpr(i, minOp.getContext())] =
1100 getAffineSymbolExpr(it.getIndex(), minOp.getContext());
1101 }
1102
1103 // Create the replacement map.
1105 AffineExpr c1 = getAffineConstantExpr(1, minOp.getContext());
1106 for (AffineExpr expr : affineMinMap.getResults()) {
1107 // If we cannot express the result in terms of the apply map symbols and
1108 // sims then continue.
1109 if (llvm::any_of(unmappedDims,
1110 [&](unsigned i) { return expr.isFunctionOfDim(i); }) ||
1111 llvm::any_of(unmappedSyms,
1112 [&](unsigned i) { return expr.isFunctionOfSymbol(i); }))
1113 continue;
1114
1115 AffineExpr convertedExpr = expr.replace(dimSymConversionTable);
1116
1117 // dimOrSym.ceilDiv(expr) -> 1
1118 repl[dimOrSym.ceilDiv(convertedExpr)] = c1;
1119 // (dimOrSym + expr - 1).floorDiv(expr) -> 1
1120 repl[(dimOrSym + convertedExpr - 1).floorDiv(convertedExpr)] = c1;
1121 }
1122 AffineMap initialMap = *map;
1123 *map = initialMap.replace(repl, initialMap.getNumDims(),
1124 initialMap.getNumSymbols());
1125 return success(*map != initialMap);
1126}
1127
1128/// Recursively traverse `e`. If `e` or one of its sub-expressions has the form
1129/// e1 + e2 + ... + eK, where the e_i are a super(multi)set of `exprsToRemove`,
1130/// place a map between e and `newVal` + sum({e1, e2, .. eK} - exprsToRemove)
1131/// into `replacementsMap`. If no entries were added to `replacementsMap`,
1132/// nothing was found.
1134 AffineExpr e, const llvm::SmallDenseSet<AffineExpr, 4> &exprsToRemove,
1135 AffineExpr newVal, DenseMap<AffineExpr, AffineExpr> &replacementsMap) {
1136 auto binOp = dyn_cast<AffineBinaryOpExpr>(e);
1137 if (!binOp)
1138 return;
1139 AffineExpr lhs = binOp.getLHS();
1140 AffineExpr rhs = binOp.getRHS();
1141 if (binOp.getKind() != AffineExprKind::Add) {
1142 shortenAddChainsContainingAll(lhs, exprsToRemove, newVal, replacementsMap);
1143 shortenAddChainsContainingAll(rhs, exprsToRemove, newVal, replacementsMap);
1144 return;
1145 }
1146 SmallVector<AffineExpr> toPreserve;
1147 llvm::SmallDenseSet<AffineExpr, 4> ourTracker(exprsToRemove);
1148 AffineExpr thisTerm = rhs;
1149 AffineExpr nextTerm = lhs;
1150
1151 while (thisTerm) {
1152 if (!ourTracker.erase(thisTerm)) {
1153 toPreserve.push_back(thisTerm);
1154 shortenAddChainsContainingAll(thisTerm, exprsToRemove, newVal,
1155 replacementsMap);
1156 }
1157 auto nextBinOp = dyn_cast_if_present<AffineBinaryOpExpr>(nextTerm);
1158 if (!nextBinOp || nextBinOp.getKind() != AffineExprKind::Add) {
1159 thisTerm = nextTerm;
1160 nextTerm = AffineExpr();
1161 } else {
1162 thisTerm = nextBinOp.getRHS();
1163 nextTerm = nextBinOp.getLHS();
1164 }
1165 }
1166 if (!ourTracker.empty())
1167 return;
1168 // We reverse the terms to be preserved here in order to preserve
1169 // associativity between them.
1170 AffineExpr newExpr = newVal;
1171 for (AffineExpr preserved : llvm::reverse(toPreserve))
1172 newExpr = newExpr + preserved;
1173 replacementsMap.insert({e, newExpr});
1174}
1175
1176/// If this map contains of the expression `x_1 + x_1 * C_1 + ... x_n * C_N +
1177/// ...` (not necessarily in order) where the set of the `x_i` is the set of
1178/// outputs of an `affine.delinearize_index` whos inverse is that expression,
1179/// replace that expression with the input of that delinearize_index op.
1180///
1181/// `unitDimInput` is the input that was detected as the potential start to this
1182/// replacement chain - if it isn't the rightmost result of the delinearization,
1183/// this method fails. (This is intended to ensure we don't have redundant scans
1184/// over the same expression).
1185///
1186/// While this currently only handles delinearizations with a constant basis,
1187/// that isn't a fundamental limitation.
1188///
1189/// This is a utility function for `replaceDimOrSym` below.
1191 AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map,
1193 if (!delinOp.getDynamicBasis().empty())
1194 return failure();
1195 if (resultToReplace != delinOp.getMultiIndex().back())
1196 return failure();
1197
1198 MLIRContext *ctx = delinOp.getContext();
1199 SmallVector<AffineExpr> resToExpr(delinOp.getNumResults(), AffineExpr());
1200 for (auto [pos, dim] : llvm::enumerate(dims)) {
1201 auto asResult = dyn_cast_if_present<OpResult>(dim);
1202 if (!asResult)
1203 continue;
1204 if (asResult.getOwner() == delinOp.getOperation())
1205 resToExpr[asResult.getResultNumber()] = getAffineDimExpr(pos, ctx);
1206 }
1207 for (auto [pos, sym] : llvm::enumerate(syms)) {
1208 auto asResult = dyn_cast_if_present<OpResult>(sym);
1209 if (!asResult)
1210 continue;
1211 if (asResult.getOwner() == delinOp.getOperation())
1212 resToExpr[asResult.getResultNumber()] = getAffineSymbolExpr(pos, ctx);
1213 }
1214 if (llvm::is_contained(resToExpr, AffineExpr()))
1215 return failure();
1216
1217 bool isDimReplacement = llvm::all_of(resToExpr, llvm::IsaPred<AffineDimExpr>);
1218 int64_t stride = 1;
1219 llvm::SmallDenseSet<AffineExpr, 4> expectedExprs;
1220 // This isn't zip_equal since sometimes the delinearize basis is missing a
1221 // size for the first result.
1222 for (auto [binding, size] : llvm::zip(
1223 llvm::reverse(resToExpr), llvm::reverse(delinOp.getStaticBasis()))) {
1224 expectedExprs.insert(binding * getAffineConstantExpr(stride, ctx));
1225 stride *= size;
1226 }
1227 if (resToExpr.size() != delinOp.getStaticBasis().size())
1228 expectedExprs.insert(resToExpr[0] * stride);
1229
1231 AffineExpr delinInExpr = isDimReplacement
1232 ? getAffineDimExpr(dims.size(), ctx)
1233 : getAffineSymbolExpr(syms.size(), ctx);
1234
1235 for (AffineExpr e : map->getResults())
1236 shortenAddChainsContainingAll(e, expectedExprs, delinInExpr, replacements);
1237 if (replacements.empty())
1238 return failure();
1239
1240 AffineMap origMap = *map;
1241 if (isDimReplacement)
1242 dims.push_back(delinOp.getLinearIndex());
1243 else
1244 syms.push_back(delinOp.getLinearIndex());
1245 *map = origMap.replace(replacements, dims.size(), syms.size());
1246
1247 // Blank out dead dimensions and symbols
1248 for (AffineExpr e : resToExpr) {
1249 if (auto d = dyn_cast<AffineDimExpr>(e)) {
1250 unsigned pos = d.getPosition();
1251 if (!map->isFunctionOfDim(pos))
1252 dims[pos] = nullptr;
1253 }
1254 if (auto s = dyn_cast<AffineSymbolExpr>(e)) {
1255 unsigned pos = s.getPosition();
1256 if (!map->isFunctionOfSymbol(pos))
1257 syms[pos] = nullptr;
1258 }
1259 }
1260 return success();
1261}
1262
1263/// Replace all occurrences of AffineExpr at position `pos` in `map` by the
1264/// defining AffineApplyOp expression and operands.
1265/// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
1266/// When `dimOrSymbolPosition >= dims.size()`,
1267/// AffineSymbolExpr@[pos - dims.size()] is replaced.
1268/// Mutate `map`,`dims` and `syms` in place as follows:
1269/// 1. `dims` and `syms` are only appended to.
1270/// 2. `map` dim and symbols are gradually shifted to higher positions.
1271/// 3. Old `dim` and `sym` entries are replaced by nullptr
1272/// This avoids the need for any bookkeeping.
1273/// If `replaceAffineMin` is set to true, additionally triggers more expensive
1274/// replacements involving affine_min operations.
1275static LogicalResult replaceDimOrSym(AffineMap *map,
1276 unsigned dimOrSymbolPosition,
1279 bool replaceAffineMin) {
1280 MLIRContext *ctx = map->getContext();
1281 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1282 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1283 : dimOrSymbolPosition - dims.size();
1284 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1285 if (!v)
1286 return failure();
1287
1288 if (auto minOp = v.getDefiningOp<AffineMinOp>(); minOp && replaceAffineMin) {
1289 AffineExpr dimOrSym = isDimReplacement ? getAffineDimExpr(pos, ctx)
1290 : getAffineSymbolExpr(pos, ctx);
1291 return replaceAffineMinBoundingBoxExpression(minOp, dimOrSym, map, dims,
1292 syms);
1293 }
1294
1295 if (auto delinOp = v.getDefiningOp<affine::AffineDelinearizeIndexOp>()) {
1296 return replaceAffineDelinearizeIndexInverseExpression(delinOp, v, map, dims,
1297 syms);
1298 }
1299
1300 auto affineApply = v.getDefiningOp<AffineApplyOp>();
1301 if (!affineApply)
1302 return failure();
1303
1304 // At this point we will perform a replacement of `v`, set the entry in `dim`
1305 // or `sym` to nullptr immediately.
1306 v = nullptr;
1307
1308 // Compute the map, dims and symbols coming from the AffineApplyOp.
1309 AffineMap composeMap = affineApply.getAffineMap();
1310 assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1311 SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1312 affineApply.getMapOperands().end());
1313 // Canonicalize the map to promote dims to symbols when possible. This is to
1314 // avoid generating invalid maps.
1315 canonicalizeMapAndOperands(&composeMap, &composeOperands);
1316 AffineExpr replacementExpr =
1317 composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
1318 ValueRange composeDims =
1319 ArrayRef<Value>(composeOperands).take_front(composeMap.getNumDims());
1320 ValueRange composeSyms =
1321 ArrayRef<Value>(composeOperands).take_back(composeMap.getNumSymbols());
1322 AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
1323 : getAffineSymbolExpr(pos, ctx);
1324
1325 // Append the dims and symbols where relevant and perform the replacement.
1326 dims.append(composeDims.begin(), composeDims.end());
1327 syms.append(composeSyms.begin(), composeSyms.end());
1328 *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());
1329
1330 return success();
1331}
1332
1333/// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1334/// iteratively. Perform canonicalization of map and operands as well as
1335/// AffineMap simplification. `map` and `operands` are mutated in place.
1337 SmallVectorImpl<Value> *operands,
1338 bool composeAffineMin = false) {
1339 if (map->getNumResults() == 0) {
1340 canonicalizeMapAndOperands(map, operands);
1341 *map = simplifyAffineMap(*map);
1342 return;
1343 }
1344
1345 MLIRContext *ctx = map->getContext();
1346 SmallVector<Value, 4> dims(operands->begin(),
1347 operands->begin() + map->getNumDims());
1348 SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1349 operands->end());
1350
1351 // Iterate over dims and symbols coming from AffineApplyOp and replace until
1352 // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1353 // and `syms` can only increase by construction.
1354 // The implementation uses a `while` loop to support the case of symbols
1355 // that may be constructed from dims ;this may be overkill.
1356 while (true) {
1357 bool changed = false;
1358 for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1359 if ((changed |=
1360 succeeded(replaceDimOrSym(map, pos, dims, syms, composeAffineMin))))
1361 break;
1362 if (!changed)
1363 break;
1364 }
1365
1366 // Clear operands so we can fill them anew.
1367 operands->clear();
1368
1369 // At this point we may have introduced null operands, prune them out before
1370 // canonicalizing map and operands.
1371 unsigned nDims = 0, nSyms = 0;
1372 SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1373 dimReplacements.reserve(dims.size());
1374 symReplacements.reserve(syms.size());
1375 for (auto *container : {&dims, &syms}) {
1376 bool isDim = (container == &dims);
1377 auto &repls = isDim ? dimReplacements : symReplacements;
1378 for (const auto &en : llvm::enumerate(*container)) {
1379 Value v = en.value();
1380 if (!v) {
1381 assert(isDim ? !map->isFunctionOfDim(en.index())
1382 : !map->isFunctionOfSymbol(en.index()) &&
1383 "map is function of unexpected expr@pos");
1384 repls.push_back(getAffineConstantExpr(0, ctx));
1385 continue;
1386 }
1387 repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
1388 : getAffineSymbolExpr(nSyms++, ctx));
1389 operands->push_back(v);
1390 }
1391 }
1392 *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
1393 nSyms);
1394
1395 // Canonicalize and simplify before returning.
1396 canonicalizeMapAndOperands(map, operands);
1397 *map = simplifyAffineMap(*map);
1398}
1399
1401 AffineMap *map, SmallVectorImpl<Value> *operands, bool composeAffineMin) {
1402 while (llvm::any_of(*operands, [](Value v) {
1403 return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
1404 })) {
1405 composeAffineMapAndOperands(map, operands, composeAffineMin);
1406 }
1407 // Additional trailing step for AffineMinOps in case no chains of AffineApply.
1408 if (composeAffineMin && llvm::any_of(*operands, [](Value v) {
1409 return isa_and_nonnull<AffineMinOp>(v.getDefiningOp());
1410 })) {
1411 composeAffineMapAndOperands(map, operands, composeAffineMin);
1412 }
1413}
1414
1415AffineApplyOp
1417 ArrayRef<OpFoldResult> operands,
1418 bool composeAffineMin) {
1419 SmallVector<Value> valueOperands;
1420 map = foldAttributesIntoMap(b, map, operands, valueOperands);
1421 composeAffineMapAndOperands(&map, &valueOperands, composeAffineMin);
1422 assert(map);
1423 return AffineApplyOp::create(b, loc, map, valueOperands);
1424}
1425
1426AffineApplyOp
1428 ArrayRef<OpFoldResult> operands,
1429 bool composeAffineMin) {
1431 b, loc,
1433 .front(),
1434 operands, composeAffineMin);
1435}
1436
1437/// Composes the given affine map with the given list of operands, pulling in
1438/// the maps from any affine.apply operations that supply the operands.
1440 SmallVectorImpl<Value> &operands,
1441 bool composeAffineMin = false) {
1442 // Compose and canonicalize each expression in the map individually because
1443 // composition only applies to single-result maps, collecting potentially
1444 // duplicate operands in a single list with shifted dimensions and symbols.
1445 SmallVector<Value> dims, symbols;
1447 for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
1448 SmallVector<Value> submapOperands(operands.begin(), operands.end());
1449 AffineMap submap = map.getSubMap({i});
1450 fullyComposeAffineMapAndOperands(&submap, &submapOperands,
1451 composeAffineMin);
1452 canonicalizeMapAndOperands(&submap, &submapOperands);
1453 unsigned numNewDims = submap.getNumDims();
1454 submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
1455 llvm::append_range(dims,
1456 ArrayRef<Value>(submapOperands).take_front(numNewDims));
1457 llvm::append_range(symbols,
1458 ArrayRef<Value>(submapOperands).drop_front(numNewDims));
1459 exprs.push_back(submap.getResult(0));
1460 }
1461
1462 // Canonicalize the map created from composed expressions to deduplicate the
1463 // dimension and symbol operands.
1464 operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1465 map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
1466 canonicalizeMapAndOperands(&map, &operands);
1467}
1468
1471 bool composeAffineMin) {
1472 assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1473
1474 // Create new builder without a listener, so that no notification is
1475 // triggered if the op is folded.
1476 // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1477 // workaround is no longer needed.
1478 OpBuilder newBuilder(b.getContext());
1479 newBuilder.setInsertionPoint(b.getInsertionBlock(), b.getInsertionPoint());
1480
1481 // Create op.
1482 AffineApplyOp applyOp =
1483 makeComposedAffineApply(newBuilder, loc, map, operands, composeAffineMin);
1484
1485 // Get constant operands.
1486 SmallVector<Attribute> constOperands(applyOp->getNumOperands());
1487 for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1488 matchPattern(applyOp->getOperand(i), m_Constant(&constOperands[i]));
1489
1490 // Try to fold the operation.
1491 SmallVector<OpFoldResult> foldResults;
1492 if (failed(applyOp->fold(constOperands, foldResults)) ||
1493 foldResults.empty()) {
1494 if (OpBuilder::Listener *listener = b.getListener())
1495 listener->notifyOperationInserted(applyOp, /*previous=*/{});
1496 return applyOp.getResult();
1497 }
1498
1499 applyOp->erase();
1500 return llvm::getSingleElement(foldResults);
1501}
1502
1504 OpBuilder &b, Location loc, AffineExpr expr,
1505 ArrayRef<OpFoldResult> operands, bool composeAffineMin) {
1507 b, loc,
1509 .front(),
1510 operands, composeAffineMin);
1511}
1512
1516 bool composeAffineMin) {
1517 return llvm::map_to_vector(
1518 llvm::seq<unsigned>(0, map.getNumResults()), [&](unsigned i) {
1519 return makeComposedFoldedAffineApply(b, loc, map.getSubMap({i}),
1520 operands, composeAffineMin);
1521 });
1522}
1523
1524template <typename OpTy>
1526 ArrayRef<OpFoldResult> operands) {
1527 SmallVector<Value> valueOperands;
1528 map = foldAttributesIntoMap(b, map, operands, valueOperands);
1529 composeMultiResultAffineMap(map, valueOperands);
1530 return OpTy::create(b, loc, b.getIndexType(), map, valueOperands);
1531}
1532
1533AffineMinOp
1538
1539template <typename OpTy>
1541 AffineMap map,
1542 ArrayRef<OpFoldResult> operands) {
1543 // Create new builder without a listener, so that no notification is
1544 // triggered if the op is folded.
1545 // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1546 // workaround is no longer needed.
1547 OpBuilder newBuilder(b.getContext());
1548 newBuilder.setInsertionPoint(b.getInsertionBlock(), b.getInsertionPoint());
1549
1550 // Create op.
1551 auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1552
1553 // Get constant operands.
1554 SmallVector<Attribute> constOperands(minMaxOp->getNumOperands());
1555 for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1556 matchPattern(minMaxOp->getOperand(i), m_Constant(&constOperands[i]));
1557
1558 // Try to fold the operation.
1559 SmallVector<OpFoldResult> foldResults;
1560 if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1561 foldResults.empty()) {
1562 if (OpBuilder::Listener *listener = b.getListener())
1563 listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
1564 return minMaxOp.getResult();
1565 }
1566
1567 minMaxOp->erase();
1568 return llvm::getSingleElement(foldResults);
1569}
1570
1577
1584
1585// A symbol may appear as a dim in affine.apply operations. This function
1586// canonicalizes dims that are valid symbols into actual symbols.
1587template <class MapOrSet>
1588static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1589 SmallVectorImpl<Value> *operands) {
1590 if (!mapOrSet || operands->empty())
1591 return;
1592
1593 assert(mapOrSet->getNumInputs() == operands->size() &&
1594 "map/set inputs must match number of operands");
1595
1596 auto *context = mapOrSet->getContext();
1597 SmallVector<Value, 8> resultOperands;
1598 resultOperands.reserve(operands->size());
1599 SmallVector<Value, 8> remappedSymbols;
1600 remappedSymbols.reserve(operands->size());
1601 unsigned nextDim = 0;
1602 unsigned nextSym = 0;
1603 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1604 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1605 for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1606 if (i < mapOrSet->getNumDims()) {
1607 if (isValidSymbol((*operands)[i])) {
1608 // This is a valid symbol that appears as a dim, canonicalize it.
1609 dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1610 remappedSymbols.push_back((*operands)[i]);
1611 } else {
1612 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1613 resultOperands.push_back((*operands)[i]);
1614 }
1615 } else {
1616 resultOperands.push_back((*operands)[i]);
1617 }
1618 }
1619
1620 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1621 *operands = resultOperands;
1622 *mapOrSet = mapOrSet->replaceDimsAndSymbols(
1623 dimRemapping, /*symReplacements=*/{}, nextDim, oldNumSyms + nextSym);
1624
1625 assert(mapOrSet->getNumInputs() == operands->size() &&
1626 "map/set inputs must match number of operands");
1627}
1628
1629/// A valid affine dimension may appear as a symbol in affine.apply operations.
1630/// Given an application of `operands` to an affine map or integer set
1631/// `mapOrSet`, this function canonicalizes symbols of `mapOrSet` that are valid
1632/// dims, but not valid symbols into actual dims. Without such a legalization,
1633/// the affine.apply will be invalid. This method is the exact inverse of
1634/// canonicalizePromotedSymbols.
1635template <class MapOrSet>
1636static void legalizeDemotedDims(MapOrSet &mapOrSet,
1637 SmallVectorImpl<Value> &operands) {
1638 if (!mapOrSet || operands.empty())
1639 return;
1640
1641 unsigned numOperands = operands.size();
1642
1643 assert(mapOrSet.getNumInputs() == numOperands &&
1644 "map/set inputs must match number of operands");
1645
1646 auto *context = mapOrSet.getContext();
1647 SmallVector<Value, 8> resultOperands;
1648 resultOperands.reserve(numOperands);
1649 SmallVector<Value, 8> remappedDims;
1650 remappedDims.reserve(numOperands);
1651 SmallVector<Value, 8> symOperands;
1652 symOperands.reserve(mapOrSet.getNumSymbols());
1653 unsigned nextSym = 0;
1654 unsigned nextDim = 0;
1655 unsigned oldNumDims = mapOrSet.getNumDims();
1656 SmallVector<AffineExpr, 8> symRemapping(mapOrSet.getNumSymbols());
1657 resultOperands.assign(operands.begin(), operands.begin() + oldNumDims);
1658 for (unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
1659 if (operands[i] && isValidDim(operands[i]) && !isValidSymbol(operands[i])) {
1660 // This is a valid dim that appears as a symbol, legalize it.
1661 symRemapping[i - oldNumDims] =
1662 getAffineDimExpr(oldNumDims + nextDim++, context);
1663 remappedDims.push_back(operands[i]);
1664 } else {
1665 symRemapping[i - oldNumDims] = getAffineSymbolExpr(nextSym++, context);
1666 symOperands.push_back(operands[i]);
1667 }
1668 }
1669
1670 append_range(resultOperands, remappedDims);
1671 append_range(resultOperands, symOperands);
1672 operands = resultOperands;
1673 mapOrSet = mapOrSet.replaceDimsAndSymbols(
1674 /*dimReplacements=*/{}, symRemapping, oldNumDims + nextDim, nextSym);
1675
1676 assert(mapOrSet.getNumInputs() == operands.size() &&
1677 "map/set inputs must match number of operands");
1678}
1679
1680// Works for either an affine map or an integer set.
1681template <class MapOrSet>
1682static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1683 SmallVectorImpl<Value> *operands) {
1684 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1685 "Argument must be either of AffineMap or IntegerSet type");
1686
1687 if (!mapOrSet || operands->empty())
1688 return;
1689
1690 assert(mapOrSet->getNumInputs() == operands->size() &&
1691 "map/set inputs must match number of operands");
1692
1693 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1694 legalizeDemotedDims<MapOrSet>(*mapOrSet, *operands);
1695
1696 // Check to see what dims are used.
1697 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1698 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1699 mapOrSet->walkExprs([&](AffineExpr expr) {
1700 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1701 usedDims[dimExpr.getPosition()] = true;
1702 else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1703 usedSyms[symExpr.getPosition()] = true;
1704 });
1705
1706 auto *context = mapOrSet->getContext();
1707
1708 SmallVector<Value, 8> resultOperands;
1709 resultOperands.reserve(operands->size());
1710
1711 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1712 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1713 unsigned nextDim = 0;
1714 for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1715 if (usedDims[i]) {
1716 // Remap dim positions for duplicate operands.
1717 auto it = seenDims.find((*operands)[i]);
1718 if (it == seenDims.end()) {
1719 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1720 resultOperands.push_back((*operands)[i]);
1721 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1722 } else {
1723 dimRemapping[i] = it->second;
1724 }
1725 }
1726 }
1727 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1728 SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1729 unsigned nextSym = 0;
1730 for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1731 if (!usedSyms[i])
1732 continue;
1733 // Handle constant operands (only needed for symbolic operands since
1734 // constant operands in dimensional positions would have already been
1735 // promoted to symbolic positions above).
1736 IntegerAttr operandCst;
1737 if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1738 m_Constant(&operandCst))) {
1739 symRemapping[i] =
1740 getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1741 continue;
1742 }
1743 // Remap symbol positions for duplicate operands.
1744 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1745 if (it == seenSymbols.end()) {
1746 symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1747 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1748 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1749 symRemapping[i]));
1750 } else {
1751 symRemapping[i] = it->second;
1752 }
1753 }
1754 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1755 nextDim, nextSym);
1756 *operands = resultOperands;
1757}
1758
1763
1768
1769namespace {
1770/// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1771/// maps that supply results into them.
1772///
1773template <typename AffineOpTy>
1774struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1775 using OpRewritePattern<AffineOpTy>::OpRewritePattern;
1776
1777 /// Replace the affine op with another instance of it with the supplied
1778 /// map and mapOperands.
1779 void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1780 AffineMap map, ArrayRef<Value> mapOperands) const;
1781
1782 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1783 PatternRewriter &rewriter) const override {
1784 static_assert(
1785 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1786 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1787 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1788 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1789 "expected");
1790 auto map = affineOp.getAffineMap();
1791 AffineMap oldMap = map;
1792 auto oldOperands = affineOp.getMapOperands();
1793 SmallVector<Value, 8> resultOperands(oldOperands);
1794 composeAffineMapAndOperands(&map, &resultOperands);
1795 canonicalizeMapAndOperands(&map, &resultOperands);
1796 simplifyMapWithOperands(map, resultOperands);
1797 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1798 resultOperands.begin()))
1799 return failure();
1800
1801 replaceAffineOp(rewriter, affineOp, map, resultOperands);
1802 return success();
1803 }
1804};
1805
1806// Specialize the template to account for the different build signatures for
1807// affine load, store, and apply ops.
1808template <>
1809void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1810 PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1811 ArrayRef<Value> mapOperands) const {
1812 rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1813 mapOperands);
1814}
1815template <>
1816void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1817 PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1818 ArrayRef<Value> mapOperands) const {
1819 rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1820 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1821 prefetch.getLocalityHint(), prefetch.getIsDataCache());
1822}
1823template <>
1824void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1825 PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1826 ArrayRef<Value> mapOperands) const {
1827 rewriter.replaceOpWithNewOp<AffineStoreOp>(
1828 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1829}
1830template <>
1831void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1832 PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1833 ArrayRef<Value> mapOperands) const {
1834 rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1835 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1836 mapOperands);
1837}
1838template <>
1839void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1840 PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1841 ArrayRef<Value> mapOperands) const {
1842 rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1843 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1844 mapOperands);
1845}
1846
1847// Generic version for ops that don't have extra operands.
1848template <typename AffineOpTy>
1849void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1850 PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1851 ArrayRef<Value> mapOperands) const {
1852 rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1853}
1854} // namespace
1855
1856void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1857 MLIRContext *context) {
1858 results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1859}
1860
1861//===----------------------------------------------------------------------===//
1862// AffineDmaStartOp
1863//===----------------------------------------------------------------------===//
1864
1865// TODO: Check that map operands are loop IVs or symbols.
1866void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1867 Value srcMemRef, AffineMap srcMap,
1868 ValueRange srcIndices, Value destMemRef,
1869 AffineMap dstMap, ValueRange destIndices,
1870 Value tagMemRef, AffineMap tagMap,
1871 ValueRange tagIndices, Value numElements,
1872 Value stride, Value elementsPerStride) {
1873 result.addOperands(srcMemRef);
1874 result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1875 result.addOperands(srcIndices);
1876 result.addOperands(destMemRef);
1877 result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1878 result.addOperands(destIndices);
1879 result.addOperands(tagMemRef);
1880 result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1881 result.addOperands(tagIndices);
1882 result.addOperands(numElements);
1883 if (stride) {
1884 result.addOperands({stride, elementsPerStride});
1885 }
1886}
1887
1888void AffineDmaStartOp::print(OpAsmPrinter &p) {
1889 p << " " << getSrcMemRef() << '[';
1890 p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1891 p << "], " << getDstMemRef() << '[';
1892 p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1893 p << "], " << getTagMemRef() << '[';
1894 p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1895 p << "], " << getNumElements();
1896 if (isStrided()) {
1897 p << ", " << getStride();
1898 p << ", " << getNumElementsPerStride();
1899 }
1900 p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1901 << getTagMemRefType();
1902}
1903
1904// Parse AffineDmaStartOp.
1905// Ex:
1906// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1907// %stride, %num_elt_per_stride
1908// : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1909//
1910ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
1912 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1913 AffineMapAttr srcMapAttr;
1915 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1916 AffineMapAttr dstMapAttr;
1918 OpAsmParser::UnresolvedOperand tagMemRefInfo;
1919 AffineMapAttr tagMapAttr;
1921 OpAsmParser::UnresolvedOperand numElementsInfo;
1923
1925 auto indexType = parser.getBuilder().getIndexType();
1926
1927 // Parse and resolve the following list of operands:
1928 // *) dst memref followed by its affine maps operands (in square brackets).
1929 // *) src memref followed by its affine map operands (in square brackets).
1930 // *) tag memref followed by its affine map operands (in square brackets).
1931 // *) number of elements transferred by DMA operation.
1932 if (parser.parseOperand(srcMemRefInfo) ||
1933 parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1934 getSrcMapAttrStrName(),
1935 result.attributes) ||
1936 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1937 parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1938 getDstMapAttrStrName(),
1939 result.attributes) ||
1940 parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1941 parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1942 getTagMapAttrStrName(),
1943 result.attributes) ||
1944 parser.parseComma() || parser.parseOperand(numElementsInfo))
1945 return failure();
1946
1947 // Parse optional stride and elements per stride.
1948 if (parser.parseTrailingOperandList(strideInfo))
1949 return failure();
1950
1951 if (!strideInfo.empty() && strideInfo.size() != 2) {
1952 return parser.emitError(parser.getNameLoc(),
1953 "expected two stride related operands");
1954 }
1955 bool isStrided = strideInfo.size() == 2;
1956
1957 if (parser.parseColonTypeList(types))
1958 return failure();
1959
1960 if (types.size() != 3)
1961 return parser.emitError(parser.getNameLoc(), "expected three types");
1962
1963 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1964 parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1965 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1966 parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1967 parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1968 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1969 parser.resolveOperand(numElementsInfo, indexType, result.operands))
1970 return failure();
1971
1972 if (isStrided) {
1973 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1974 return failure();
1975 }
1976
1977 // Check that src/dst/tag operand counts match their map.numInputs.
1978 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1979 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1980 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1981 return parser.emitError(parser.getNameLoc(),
1982 "memref operand count not equal to map.numInputs");
1983 return success();
1984}
1985
1986LogicalResult AffineDmaStartOp::verify() {
1987 if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
1988 return emitOpError("expected DMA source to be of memref type");
1989 if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
1990 return emitOpError("expected DMA destination to be of memref type");
1991 if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
1992 return emitOpError("expected DMA tag to be of memref type");
1993
1994 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1995 getDstMap().getNumInputs() +
1996 getTagMap().getNumInputs();
1997 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1998 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1999 return emitOpError("incorrect number of operands");
2000 }
2001
2002 Region *scope = getAffineScope(*this);
2003 for (auto idx : getSrcIndices()) {
2004 if (!idx.getType().isIndex())
2005 return emitOpError("src index to dma_start must have 'index' type");
2006 if (!isValidAffineIndexOperand(idx, scope))
2007 return emitOpError(
2008 "src index must be a valid dimension or symbol identifier");
2009 }
2010 for (auto idx : getDstIndices()) {
2011 if (!idx.getType().isIndex())
2012 return emitOpError("dst index to dma_start must have 'index' type");
2013 if (!isValidAffineIndexOperand(idx, scope))
2014 return emitOpError(
2015 "dst index must be a valid dimension or symbol identifier");
2016 }
2017 for (auto idx : getTagIndices()) {
2018 if (!idx.getType().isIndex())
2019 return emitOpError("tag index to dma_start must have 'index' type");
2020 if (!isValidAffineIndexOperand(idx, scope))
2021 return emitOpError(
2022 "tag index must be a valid dimension or symbol identifier");
2023 }
2024 return success();
2025}
2026
2027LogicalResult AffineDmaStartOp::fold(FoldAdaptor adaptor,
2029 /// dma_start(memrefcast) -> dma_start
2030 return memref::foldMemRefCast(*this);
2031}
2032
2033void AffineDmaStartOp::getEffects(
2035 &effects) {
2036 effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
2038 effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
2040 effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
2042}
2043
2044//===----------------------------------------------------------------------===//
2045// AffineDmaWaitOp
2046//===----------------------------------------------------------------------===//
2047
2048// TODO: Check that map operands are loop IVs or symbols.
2049void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
2050 Value tagMemRef, AffineMap tagMap,
2051 ValueRange tagIndices, Value numElements) {
2052 result.addOperands(tagMemRef);
2053 result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
2054 result.addOperands(tagIndices);
2055 result.addOperands(numElements);
2056}
2057
2058void AffineDmaWaitOp::print(OpAsmPrinter &p) {
2059 p << " " << getTagMemRef() << '[';
2060 SmallVector<Value, 2> operands(getTagIndices());
2061 p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
2062 p << "], ";
2064 p << " : " << getTagMemRef().getType();
2065}
2066
2067// Parse AffineDmaWaitOp.
2068// Eg:
2069// affine.dma_wait %tag[%index], %num_elements
2070// : memref<1 x i32, (d0) -> (d0), 4>
2071//
2072ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
2074 OpAsmParser::UnresolvedOperand tagMemRefInfo;
2075 AffineMapAttr tagMapAttr;
2077 Type type;
2078 auto indexType = parser.getBuilder().getIndexType();
2079 OpAsmParser::UnresolvedOperand numElementsInfo;
2080
2081 // Parse tag memref, its map operands, and dma size.
2082 if (parser.parseOperand(tagMemRefInfo) ||
2083 parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
2084 getTagMapAttrStrName(),
2085 result.attributes) ||
2086 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
2087 parser.parseColonType(type) ||
2088 parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
2089 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
2090 parser.resolveOperand(numElementsInfo, indexType, result.operands))
2091 return failure();
2092
2093 if (!llvm::isa<MemRefType>(type))
2094 return parser.emitError(parser.getNameLoc(),
2095 "expected tag to be of memref type");
2096
2097 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
2098 return parser.emitError(parser.getNameLoc(),
2099 "tag memref operand count != to map.numInputs");
2100 return success();
2101}
2102
2103LogicalResult AffineDmaWaitOp::verify() {
2104 if (!llvm::isa<MemRefType>(getOperand(0).getType()))
2105 return emitOpError("expected DMA tag to be of memref type");
2106 Region *scope = getAffineScope(*this);
2107 for (auto idx : getTagIndices()) {
2108 if (!idx.getType().isIndex())
2109 return emitOpError("index to dma_wait must have 'index' type");
2110 if (!isValidAffineIndexOperand(idx, scope))
2111 return emitOpError(
2112 "index must be a valid dimension or symbol identifier");
2113 }
2114 return success();
2115}
2116
2117LogicalResult AffineDmaWaitOp::fold(FoldAdaptor adaptor,
2119 /// dma_wait(memrefcast) -> dma_wait
2120 return memref::foldMemRefCast(*this);
2121}
2122
2123void AffineDmaWaitOp::getEffects(
2125 &effects) {
2126 effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
2128}
2129
2130//===----------------------------------------------------------------------===//
2131// AffineForOp
2132//===----------------------------------------------------------------------===//
2133
2134/// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
2135/// bodyBuilder are empty/null, we include default terminator op.
2136void AffineForOp::build(OpBuilder &builder, OperationState &result,
2137 ValueRange lbOperands, AffineMap lbMap,
2138 ValueRange ubOperands, AffineMap ubMap, int64_t step,
2139 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
2140 assert(((!lbMap && lbOperands.empty()) ||
2141 lbOperands.size() == lbMap.getNumInputs()) &&
2142 "lower bound operand count does not match the affine map");
2143 assert(((!ubMap && ubOperands.empty()) ||
2144 ubOperands.size() == ubMap.getNumInputs()) &&
2145 "upper bound operand count does not match the affine map");
2146 assert(step > 0 && "step has to be a positive integer constant");
2147
2148 OpBuilder::InsertionGuard guard(builder);
2149
2150 // Set variadic segment sizes.
2151 result.addAttribute(
2152 getOperandSegmentSizeAttr(),
2153 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()),
2154 static_cast<int32_t>(ubOperands.size()),
2155 static_cast<int32_t>(iterArgs.size())}));
2156
2157 for (Value val : iterArgs)
2158 result.addTypes(val.getType());
2159
2160 // Add an attribute for the step.
2161 result.addAttribute(getStepAttrName(result.name),
2162 builder.getIntegerAttr(builder.getIndexType(), step));
2163
2164 // Add the lower bound.
2165 result.addAttribute(getLowerBoundMapAttrName(result.name),
2166 AffineMapAttr::get(lbMap));
2167 result.addOperands(lbOperands);
2168
2169 // Add the upper bound.
2170 result.addAttribute(getUpperBoundMapAttrName(result.name),
2171 AffineMapAttr::get(ubMap));
2172 result.addOperands(ubOperands);
2173
2174 result.addOperands(iterArgs);
2175 // Create a region and a block for the body. The argument of the region is
2176 // the loop induction variable.
2177 Region *bodyRegion = result.addRegion();
2178 Block *bodyBlock = builder.createBlock(bodyRegion);
2179 Value inductionVar =
2180 bodyBlock->addArgument(builder.getIndexType(), result.location);
2181 for (Value val : iterArgs)
2182 bodyBlock->addArgument(val.getType(), val.getLoc());
2183
2184 // Create the default terminator if the builder is not provided and if the
2185 // iteration arguments are not provided. Otherwise, leave this to the caller
2186 // because we don't know which values to return from the loop.
2187 if (iterArgs.empty() && !bodyBuilder) {
2188 ensureTerminator(*bodyRegion, builder, result.location);
2189 } else if (bodyBuilder) {
2190 OpBuilder::InsertionGuard guard(builder);
2191 builder.setInsertionPointToStart(bodyBlock);
2192 bodyBuilder(builder, result.location, inductionVar,
2193 bodyBlock->getArguments().drop_front());
2194 }
2195}
2196
2197void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
2198 int64_t ub, int64_t step, ValueRange iterArgs,
2199 BodyBuilderFn bodyBuilder) {
2200 auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
2201 auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
2202 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
2203 bodyBuilder);
2204}
2205
2206LogicalResult AffineForOp::verifyRegions() {
2207 // Step must be a strictly positive integer.
2208 if (getStepAsInt() <= 0)
2209 return emitOpError("expected step to be a positive integer, got ")
2210 << getStepAsInt();
2211
2212 // Check that the body defines as single block argument for the induction
2213 // variable.
2214 auto *body = getBody();
2215 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
2216 return emitOpError("expected body to have a single index argument for the "
2217 "induction variable");
2218
2219 // Verify that the bound operands are valid dimension/symbols.
2220 /// Lower bound.
2221 if (getLowerBoundMap().getNumInputs() > 0)
2223 getLowerBoundMap().getNumDims())))
2224 return failure();
2225 /// Upper bound.
2226 if (getUpperBoundMap().getNumInputs() > 0)
2228 getUpperBoundMap().getNumDims())))
2229 return failure();
2230 if (getLowerBoundMap().getNumResults() < 1)
2231 return emitOpError("expected lower bound map to have at least one result");
2232 if (getUpperBoundMap().getNumResults() < 1)
2233 return emitOpError("expected upper bound map to have at least one result");
2234
2235 unsigned opNumResults = getNumResults();
2236 if (opNumResults == 0)
2237 return success();
2238
2239 // If ForOp defines values, check that the number and types of the defined
2240 // values match ForOp initial iter operands and backedge basic block
2241 // arguments.
2242 if (getNumIterOperands() != opNumResults)
2243 return emitOpError(
2244 "mismatch between the number of loop-carried values and results");
2245 if (getNumRegionIterArgs() != opNumResults)
2246 return emitOpError(
2247 "mismatch between the number of basic block args and results");
2248
2249 return success();
2250}
2251
2252/// Parse a for operation loop bounds.
2253static ParseResult parseBound(bool isLower, OperationState &result,
2254 OpAsmParser &p) {
2255 // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
2256 // the map has multiple results.
2257 bool failedToParsedMinMax =
2258 failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
2259
2260 auto &builder = p.getBuilder();
2261 auto boundAttrStrName =
2262 isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)
2263 : AffineForOp::getUpperBoundMapAttrName(result.name);
2264
2265 // Parse ssa-id as identity map.
2267 if (p.parseOperandList(boundOpInfos))
2268 return failure();
2269
2270 if (!boundOpInfos.empty()) {
2271 // Check that only one operand was parsed.
2272 if (boundOpInfos.size() > 1)
2273 return p.emitError(p.getNameLoc(),
2274 "expected only one loop bound operand");
2275
2276 // TODO: improve error message when SSA value is not of index type.
2277 // Currently it is 'use of value ... expects different type than prior uses'
2278 if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
2279 result.operands))
2280 return failure();
2281
2282 // Create an identity map using symbol id. This representation is optimized
2283 // for storage. Analysis passes may expand it into a multi-dimensional map
2284 // if desired.
2285 AffineMap map = builder.getSymbolIdentityMap();
2286 result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
2287 return success();
2288 }
2289
2290 // Get the attribute location.
2291 SMLoc attrLoc = p.getCurrentLocation();
2292
2293 Attribute boundAttr;
2294 if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
2295 result.attributes))
2296 return failure();
2297
2298 // Parse full form - affine map followed by dim and symbol list.
2299 if (auto affineMapAttr = dyn_cast<AffineMapAttr>(boundAttr)) {
2300 unsigned currentNumOperands = result.operands.size();
2301 unsigned numDims;
2302 if (parseDimAndSymbolList(p, result.operands, numDims))
2303 return failure();
2304
2305 auto map = affineMapAttr.getValue();
2306 if (map.getNumDims() != numDims)
2307 return p.emitError(
2308 p.getNameLoc(),
2309 "dim operand count and affine map dim count must match");
2310
2311 unsigned numDimAndSymbolOperands =
2312 result.operands.size() - currentNumOperands;
2313 if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
2314 return p.emitError(
2315 p.getNameLoc(),
2316 "symbol operand count and affine map symbol count must match");
2317
2318 // If the map has multiple results, make sure that we parsed the min/max
2319 // prefix.
2320 if (map.getNumResults() > 1 && failedToParsedMinMax) {
2321 if (isLower) {
2322 return p.emitError(attrLoc, "lower loop bound affine map with "
2323 "multiple results requires 'max' prefix");
2324 }
2325 return p.emitError(attrLoc, "upper loop bound affine map with multiple "
2326 "results requires 'min' prefix");
2327 }
2328 return success();
2329 }
2330
2331 // Parse custom assembly form.
2332 if (auto integerAttr = dyn_cast<IntegerAttr>(boundAttr)) {
2333 result.attributes.pop_back();
2334 result.addAttribute(
2335 boundAttrStrName,
2336 AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
2337 return success();
2338 }
2339
2340 return p.emitError(
2341 p.getNameLoc(),
2342 "expected valid affine map representation for loop bounds");
2343}
2344
2345ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
2346 auto &builder = parser.getBuilder();
2347 OpAsmParser::Argument inductionVariable;
2348 inductionVariable.type = builder.getIndexType();
2349 // Parse the induction variable followed by '='.
2350 if (parser.parseArgument(inductionVariable) || parser.parseEqual())
2351 return failure();
2352
2353 // Parse loop bounds.
2354 int64_t numOperands = result.operands.size();
2355 if (parseBound(/*isLower=*/true, result, parser))
2356 return failure();
2357 int64_t numLbOperands = result.operands.size() - numOperands;
2358 if (parser.parseKeyword("to", " between bounds"))
2359 return failure();
2360 numOperands = result.operands.size();
2361 if (parseBound(/*isLower=*/false, result, parser))
2362 return failure();
2363 int64_t numUbOperands = result.operands.size() - numOperands;
2364
2365 // Parse the optional loop step, we default to 1 if one is not present.
2366 if (parser.parseOptionalKeyword("step")) {
2367 result.addAttribute(
2368 getStepAttrName(result.name),
2369 builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
2370 } else {
2371 SMLoc stepLoc = parser.getCurrentLocation();
2372 IntegerAttr stepAttr;
2373 if (parser.parseAttribute(stepAttr, builder.getIndexType(),
2374 getStepAttrName(result.name).data(),
2375 result.attributes))
2376 return failure();
2377
2378 if (!stepAttr.getValue().isStrictlyPositive())
2379 return parser.emitError(
2380 stepLoc,
2381 "expected step to be representable as a positive signed integer");
2382 }
2383
2384 // Parse the optional initial iteration arguments.
2385 SmallVector<OpAsmParser::Argument, 4> regionArgs;
2386 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2387
2388 // Induction variable.
2389 regionArgs.push_back(inductionVariable);
2390
2391 if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
2392 // Parse assignment list and results type list.
2393 if (parser.parseAssignmentList(regionArgs, operands) ||
2394 parser.parseArrowTypeList(result.types))
2395 return failure();
2396 // Resolve input operands.
2397 for (auto argOperandType :
2398 llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2399 Type type = std::get<2>(argOperandType);
2400 std::get<0>(argOperandType).type = type;
2401 if (parser.resolveOperand(std::get<1>(argOperandType), type,
2402 result.operands))
2403 return failure();
2404 }
2405 }
2406
2407 result.addAttribute(
2408 getOperandSegmentSizeAttr(),
2409 builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands),
2410 static_cast<int32_t>(numUbOperands),
2411 static_cast<int32_t>(operands.size())}));
2412
2413 // Parse the body region.
2414 Region *body = result.addRegion();
2415 if (regionArgs.size() != result.types.size() + 1)
2416 return parser.emitError(
2417 parser.getNameLoc(),
2418 "mismatch between the number of loop-carried values and results");
2419 if (parser.parseRegion(*body, regionArgs))
2420 return failure();
2421
2422 AffineForOp::ensureTerminator(*body, builder, result.location);
2423
2424 // Parse the optional attribute list.
2425 return parser.parseOptionalAttrDict(result.attributes);
2426}
2427
2428static void printBound(AffineMapAttr boundMap,
2429 Operation::operand_range boundOperands,
2430 const char *prefix, OpAsmPrinter &p) {
2431 AffineMap map = boundMap.getValue();
2432
2433 // Check if this bound should be printed using custom assembly form.
2434 // The decision to restrict printing custom assembly form to trivial cases
2435 // comes from the will to roundtrip MLIR binary -> text -> binary in a
2436 // lossless way.
2437 // Therefore, custom assembly form parsing and printing is only supported for
2438 // zero-operand constant maps and single symbol operand identity maps.
2439 if (map.getNumResults() == 1) {
2440 AffineExpr expr = map.getResult(0);
2441
2442 // Print constant bound.
2443 if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2444 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2445 p << constExpr.getValue();
2446 return;
2447 }
2448 }
2449
2450 // Print bound that consists of a single SSA symbol if the map is over a
2451 // single symbol.
2452 if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2453 if (isa<AffineSymbolExpr>(expr)) {
2454 p.printOperand(*boundOperands.begin());
2455 return;
2456 }
2457 }
2458 } else {
2459 // Map has multiple results. Print 'min' or 'max' prefix.
2460 p << prefix << ' ';
2461 }
2462
2463 // Print the map and its operands.
2464 p << boundMap;
2465 printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
2466 map.getNumDims(), p);
2467}
2468
2469unsigned AffineForOp::getNumIterOperands() {
2470 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2471 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2472
2473 return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2474}
2475
2476std::optional<MutableArrayRef<OpOperand>>
2477AffineForOp::getYieldedValuesMutable() {
2478 return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2479}
2480
2481void AffineForOp::print(OpAsmPrinter &p) {
2482 p << ' ';
2483 p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
2484 /*omitType=*/true);
2485 p << " = ";
2486 printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
2487 p << " to ";
2488 printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
2489
2490 if (getStepAsInt() != 1)
2491 p << " step " << getStepAsInt();
2492
2493 bool printBlockTerminators = false;
2494 if (getNumIterOperands() > 0) {
2495 p << " iter_args(";
2496 auto regionArgs = getRegionIterArgs();
2497 auto operands = getInits();
2498
2499 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2500 p << std::get<0>(it) << " = " << std::get<1>(it);
2501 });
2502 p << ") -> (" << getResultTypes() << ")";
2503 printBlockTerminators = true;
2504 }
2505
2506 p << ' ';
2507 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2508 printBlockTerminators);
2510 (*this)->getAttrs(),
2511 /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()),
2512 getUpperBoundMapAttrName(getOperation()->getName()),
2513 getStepAttrName(getOperation()->getName()),
2514 getOperandSegmentSizeAttr()});
2515}
2516
2517/// Fold the constant bounds of a loop.
2518static LogicalResult foldLoopBounds(AffineForOp forOp) {
2519 auto foldLowerOrUpperBound = [&forOp](bool lower) {
2520 // Check to see if each of the operands is the result of a constant. If
2521 // so, get the value. If not, ignore it.
2522 SmallVector<Attribute, 8> operandConstants;
2523 auto boundOperands =
2524 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2525 for (auto operand : boundOperands) {
2526 Attribute operandCst;
2527 matchPattern(operand, m_Constant(&operandCst));
2528 operandConstants.push_back(operandCst);
2529 }
2530
2531 AffineMap boundMap =
2532 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2533 assert(boundMap.getNumResults() >= 1 &&
2534 "bound maps should have at least one result");
2535 SmallVector<Attribute, 4> foldedResults;
2536 if (failed(boundMap.constantFold(operandConstants, foldedResults)))
2537 return failure();
2538
2539 // Compute the max or min as applicable over the results.
2540 assert(!foldedResults.empty() && "bounds should have at least one result");
2541 auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2542 for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2543 auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2544 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2545 : llvm::APIntOps::smin(maxOrMin, foldedResult);
2546 }
2547 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2548 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2549 return success();
2550 };
2551
2552 // Try to fold the lower bound.
2553 bool folded = false;
2554 if (!forOp.hasConstantLowerBound())
2555 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
2556
2557 // Try to fold the upper bound.
2558 if (!forOp.hasConstantUpperBound())
2559 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
2560 return success(folded);
2561}
2562
2563/// Returns constant trip count in trivial cases.
2564static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2565 int64_t step = forOp.getStepAsInt();
2566 if (!forOp.hasConstantBounds() || step <= 0)
2567 return std::nullopt;
2568 int64_t lb = forOp.getConstantLowerBound();
2569 int64_t ub = forOp.getConstantUpperBound();
2570 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2571}
2572
2573/// Fold the empty loop.
2575 if (!llvm::hasSingleElement(*forOp.getBody()))
2576 return {};
2577 if (forOp.getNumResults() == 0)
2578 return {};
2579 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2580 if (tripCount == 0) {
2581 // The initial values of the iteration arguments would be the op's
2582 // results.
2583 return forOp.getInits();
2584 }
2585 SmallVector<Value, 4> replacements;
2586 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2587 auto iterArgs = forOp.getRegionIterArgs();
2588 bool hasValDefinedOutsideLoop = false;
2589 bool iterArgsNotInOrder = false;
2590 for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2591 Value val = yieldOp.getOperand(i);
2592 BlockArgument *iterArgIt = llvm::find(iterArgs, val);
2593 // TODO: It should be possible to perform a replacement by computing the
2594 // last value of the IV based on the bounds and the step.
2595 if (val == forOp.getInductionVar())
2596 return {};
2597 if (iterArgIt == iterArgs.end()) {
2598 // `val` is defined outside of the loop.
2599 assert(forOp.isDefinedOutsideOfLoop(val) &&
2600 "must be defined outside of the loop");
2601 hasValDefinedOutsideLoop = true;
2602 replacements.push_back(val);
2603 } else {
2604 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2605 if (pos != i)
2606 iterArgsNotInOrder = true;
2607 replacements.push_back(forOp.getInits()[pos]);
2608 }
2609 }
2610 // Bail out when the trip count is unknown and the loop returns any value
2611 // defined outside of the loop or any iterArg out of order.
2612 if (!tripCount.has_value() &&
2613 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2614 return {};
2615 // Bail out when the loop iterates more than once and it returns any iterArg
2616 // out of order.
2617 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2618 return {};
2619 return llvm::to_vector_of<OpFoldResult>(replacements);
2620}
2621
2622/// Canonicalize the bounds of the given loop.
2623static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2624 SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2625 SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2626
2627 auto lbMap = forOp.getLowerBoundMap();
2628 auto ubMap = forOp.getUpperBoundMap();
2629 auto prevLbMap = lbMap;
2630 auto prevUbMap = ubMap;
2631
2632 composeAffineMapAndOperands(&lbMap, &lbOperands);
2633 canonicalizeMapAndOperands(&lbMap, &lbOperands);
2634 simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
2635 simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
2636 lbMap = removeDuplicateExprs(lbMap);
2637
2638 composeAffineMapAndOperands(&ubMap, &ubOperands);
2639 canonicalizeMapAndOperands(&ubMap, &ubOperands);
2640 ubMap = removeDuplicateExprs(ubMap);
2641
2642 // Any canonicalization change always leads to updated map(s).
2643 if (lbMap == prevLbMap && ubMap == prevUbMap)
2644 return failure();
2645
2646 if (lbMap != prevLbMap)
2647 forOp.setLowerBound(lbOperands, lbMap);
2648 if (ubMap != prevUbMap)
2649 forOp.setUpperBound(ubOperands, ubMap);
2650 return success();
2651}
2652
2653/// Returns true if the affine.for has zero iterations in trivial cases.
2654static bool hasTrivialZeroTripCount(AffineForOp op) {
2655 return getTrivialConstantTripCount(op) == 0;
2656}
2657
2658LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2660 bool folded = succeeded(foldLoopBounds(*this));
2661 folded |= succeeded(canonicalizeLoopBounds(*this));
2662 if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
2663 // The initial values of the loop-carried variables (iter_args) are the
2664 // results of the op. But this must be avoided for an affine.for op that
2665 // does not return any results. Since ops that do not return results cannot
2666 // be folded away, we would enter an infinite loop of folds on the same
2667 // affine.for op.
2668 results.assign(getInits().begin(), getInits().end());
2669 folded = true;
2670 }
2671 SmallVector<OpFoldResult> foldResults = AffineForEmptyLoopFolder(*this);
2672 if (!foldResults.empty()) {
2673 results.assign(foldResults);
2674 folded = true;
2675 }
2676 return success(folded);
2677}
2678
2679OperandRange AffineForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2680 assert((successor.isParent() || successor.getSuccessor() == &getRegion()) &&
2681 "invalid region point");
2682
2683 // The initial operands map to the loop arguments after the induction
2684 // variable or are forwarded to the results when the trip count is zero.
2685 return getInits();
2686}
2687
2688void AffineForOp::getSuccessorRegions(
2690 assert((point.isParent() ||
2691 point.getTerminatorPredecessorOrNull()->getParentRegion() ==
2692 &getRegion()) &&
2693 "expected loop region");
2694 // The loop may typically branch back to its body or to the parent operation.
2695 // If the predecessor is the parent op and the trip count is known to be at
2696 // least one, branch into the body using the iterator arguments. And in cases
2697 // we know the trip count is zero, it can only branch back to its parent.
2698 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2699 if (tripCount.has_value()) {
2700 if (!point.isParent()) {
2701 // From the loop body, if the trip count is one, we can only branch back
2702 // to the parent.
2703 if (tripCount == 1) {
2704 regions.push_back(RegionSuccessor::parent());
2705 return;
2706 }
2707 if (tripCount == 0)
2708 return;
2709 } else {
2710 if (tripCount.value() > 0) {
2711 regions.push_back(RegionSuccessor(&getRegion()));
2712 return;
2713 }
2714 if (tripCount.value() == 0) {
2715 regions.push_back(RegionSuccessor::parent());
2716 return;
2717 }
2718 }
2719 }
2720
2721 // In all other cases, the loop may branch back to itself or the parent
2722 // operation.
2723 regions.push_back(RegionSuccessor(&getRegion()));
2724 regions.push_back(RegionSuccessor::parent());
2725}
2726
2727ValueRange AffineForOp::getSuccessorInputs(RegionSuccessor successor) {
2728 if (successor.isParent())
2729 return getResults();
2730 return getRegionIterArgs();
2731}
2732
2733AffineBound AffineForOp::getLowerBound() {
2734 return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
2735}
2736
2737AffineBound AffineForOp::getUpperBound() {
2738 return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap());
2739}
2740
2741void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2742 assert(lbOperands.size() == map.getNumInputs());
2743 assert(map.getNumResults() >= 1 && "bound map has at least one result");
2744 getLowerBoundOperandsMutable().assign(lbOperands);
2745 setLowerBoundMap(map);
2746}
2747
2748void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2749 assert(ubOperands.size() == map.getNumInputs());
2750 assert(map.getNumResults() >= 1 && "bound map has at least one result");
2751 getUpperBoundOperandsMutable().assign(ubOperands);
2752 setUpperBoundMap(map);
2753}
2754
2755bool AffineForOp::hasConstantLowerBound() {
2756 return getLowerBoundMap().isSingleConstant();
2757}
2758
2759bool AffineForOp::hasConstantUpperBound() {
2760 return getUpperBoundMap().isSingleConstant();
2761}
2762
2763int64_t AffineForOp::getConstantLowerBound() {
2764 return getLowerBoundMap().getSingleConstantResult();
2765}
2766
2767int64_t AffineForOp::getConstantUpperBound() {
2768 return getUpperBoundMap().getSingleConstantResult();
2769}
2770
2771void AffineForOp::setConstantLowerBound(int64_t value) {
2772 setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2773}
2774
2775void AffineForOp::setConstantUpperBound(int64_t value) {
2776 setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2777}
2778
2779AffineForOp::operand_range AffineForOp::getControlOperands() {
2780 return {operand_begin(), operand_begin() + getLowerBoundOperands().size() +
2781 getUpperBoundOperands().size()};
2782}
2783
2784bool AffineForOp::matchingBoundOperandList() {
2785 auto lbMap = getLowerBoundMap();
2786 auto ubMap = getUpperBoundMap();
2787 if (lbMap.getNumDims() != ubMap.getNumDims() ||
2788 lbMap.getNumSymbols() != ubMap.getNumSymbols())
2789 return false;
2790
2791 unsigned numOperands = lbMap.getNumInputs();
2792 for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2793 // Compare Value 's.
2794 if (getOperand(i) != getOperand(numOperands + i))
2795 return false;
2796 }
2797 return true;
2798}
2799
2800SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
2801
2802std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2803 return SmallVector<Value>{getInductionVar()};
2804}
2805
2806std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2807 if (!hasConstantLowerBound())
2808 return std::nullopt;
2809 OpBuilder b(getContext());
2810 return SmallVector<OpFoldResult>{
2811 OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2812}
2813
2814std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2815 OpBuilder b(getContext());
2816 return SmallVector<OpFoldResult>{
2817 OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
2818}
2819
2820std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2821 if (!hasConstantUpperBound())
2822 return {};
2823 OpBuilder b(getContext());
2824 return SmallVector<OpFoldResult>{
2825 OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2826}
2827
2828FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2829 RewriterBase &rewriter, ValueRange newInitOperands,
2830 bool replaceInitOperandUsesInLoop,
2831 const NewYieldValuesFn &newYieldValuesFn) {
2832 // Create a new loop before the existing one, with the extra operands.
2833 OpBuilder::InsertionGuard g(rewriter);
2834 rewriter.setInsertionPoint(getOperation());
2835 auto inits = llvm::to_vector(getInits());
2836 inits.append(newInitOperands.begin(), newInitOperands.end());
2837 AffineForOp newLoop = AffineForOp::create(
2838 rewriter, getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2839 getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
2840
2841 // Generate the new yield values and append them to the scf.yield operation.
2842 auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2843 ArrayRef<BlockArgument> newIterArgs =
2844 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2845 {
2846 OpBuilder::InsertionGuard g(rewriter);
2847 rewriter.setInsertionPoint(yieldOp);
2848 SmallVector<Value> newYieldedValues =
2849 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2850 assert(newInitOperands.size() == newYieldedValues.size() &&
2851 "expected as many new yield values as new iter operands");
2852 rewriter.modifyOpInPlace(yieldOp, [&]() {
2853 yieldOp.getOperandsMutable().append(newYieldedValues);
2854 });
2855 }
2856
2857 // Move the loop body to the new op.
2858 rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2859 newLoop.getBody()->getArguments().take_front(
2860 getBody()->getNumArguments()));
2861
2862 if (replaceInitOperandUsesInLoop) {
2863 // Replace all uses of `newInitOperands` with the corresponding basic block
2864 // arguments.
2865 for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2866 rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
2867 [&](OpOperand &use) {
2868 Operation *user = use.getOwner();
2869 return newLoop->isProperAncestor(user);
2870 });
2871 }
2872 }
2873
2874 // Replace the old loop.
2875 rewriter.replaceOp(getOperation(),
2876 newLoop->getResults().take_front(getNumResults()));
2877 return cast<LoopLikeOpInterface>(newLoop.getOperation());
2878}
2879
2880Speculation::Speculatability AffineForOp::getSpeculatability() {
2881 // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2882 // Start and End.
2883 //
2884 // For Step != 1, the loop may not terminate. We can add more smarts here if
2885 // needed.
2886 return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable
2888}
2889
2890/// Returns true if the provided value is the induction variable of a
2891/// AffineForOp.
2893 return getForInductionVarOwner(val) != AffineForOp();
2894}
2895
2899
2903
2905 auto ivArg = dyn_cast<BlockArgument>(val);
2906 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2907 return AffineForOp();
2908 if (auto forOp =
2909 ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2910 // Check to make sure `val` is the induction variable, not an iter_arg.
2911 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2912 return AffineForOp();
2913}
2914
2916 auto ivArg = dyn_cast<BlockArgument>(val);
2917 if (!ivArg || !ivArg.getOwner())
2918 return nullptr;
2919 Operation *containingOp = ivArg.getOwner()->getParentOp();
2920 auto parallelOp = dyn_cast_if_present<AffineParallelOp>(containingOp);
2921 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2922 return parallelOp;
2923 return nullptr;
2924}
2925
2926/// Extracts the induction variables from a list of AffineForOps and returns
2927/// them.
2930 ivs->reserve(forInsts.size());
2931 for (auto forInst : forInsts)
2932 ivs->push_back(forInst.getInductionVar());
2933}
2934
2937 ivs.reserve(affineOps.size());
2938 for (Operation *op : affineOps) {
2939 // Add constraints from forOp's bounds.
2940 if (auto forOp = dyn_cast<AffineForOp>(op))
2941 ivs.push_back(forOp.getInductionVar());
2942 else if (auto parallelOp = dyn_cast<AffineParallelOp>(op))
2943 for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
2944 ivs.push_back(parallelOp.getBody()->getArgument(i));
2945 }
2946}
2947
2948/// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
2949/// operations.
2950template <typename BoundListTy, typename LoopCreatorTy>
2952 OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
2953 ArrayRef<int64_t> steps,
2954 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
2955 LoopCreatorTy &&loopCreatorFn) {
2956 assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
2957 assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
2958
2959 // If there are no loops to be constructed, construct the body anyway.
2960 OpBuilder::InsertionGuard guard(builder);
2961 if (lbs.empty()) {
2962 if (bodyBuilderFn)
2963 bodyBuilderFn(builder, loc, ValueRange());
2964 return;
2965 }
2966
2967 // Create the loops iteratively and store the induction variables.
2969 ivs.reserve(lbs.size());
2970 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
2971 // Callback for creating the loop body, always creates the terminator.
2972 auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
2973 ValueRange iterArgs) {
2974 ivs.push_back(iv);
2975 // In the innermost loop, call the body builder.
2976 if (i == e - 1 && bodyBuilderFn) {
2977 OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
2978 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
2979 }
2980 AffineYieldOp::create(nestedBuilder, nestedLoc);
2981 };
2982
2983 // Delegate actual loop creation to the callback in order to dispatch
2984 // between constant- and variable-bound loops.
2985 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
2986 builder.setInsertionPointToStart(loop.getBody());
2987 }
2988}
2989
2990/// Creates an affine loop from the bounds known to be constants.
2991static AffineForOp
2993 int64_t ub, int64_t step,
2994 AffineForOp::BodyBuilderFn bodyBuilderFn) {
2995 return AffineForOp::create(builder, loc, lb, ub, step,
2996 /*iterArgs=*/ValueRange(), bodyBuilderFn);
2997}
2998
2999/// Creates an affine loop from the bounds that may or may not be constants.
3000static AffineForOp
3002 int64_t step,
3003 AffineForOp::BodyBuilderFn bodyBuilderFn) {
3004 std::optional<int64_t> lbConst = getConstantIntValue(lb);
3005 std::optional<int64_t> ubConst = getConstantIntValue(ub);
3006 if (lbConst && ubConst)
3007 return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
3008 ubConst.value(), step, bodyBuilderFn);
3009 return AffineForOp::create(builder, loc, lb, builder.getDimIdentityMap(), ub,
3010 builder.getDimIdentityMap(), step,
3011 /*iterArgs=*/ValueRange(), bodyBuilderFn);
3012}
3013
3015 OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
3017 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
3018 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
3020}
3021
3023 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
3024 ArrayRef<int64_t> steps,
3025 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
3026 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
3028}
3029
3030//===----------------------------------------------------------------------===//
3031// AffineIfOp
3032//===----------------------------------------------------------------------===//
3033
3034namespace {
3035/// Remove else blocks that have nothing other than a zero value yield.
3036struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
3037 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
3038
3039 LogicalResult matchAndRewrite(AffineIfOp ifOp,
3040 PatternRewriter &rewriter) const override {
3041 if (ifOp.getElseRegion().empty() ||
3042 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
3043 return failure();
3044
3045 rewriter.startOpModification(ifOp);
3046 rewriter.eraseBlock(ifOp.getElseBlock());
3047 rewriter.finalizeOpModification(ifOp);
3048 return success();
3049 }
3050};
3051
3052/// Removes affine.if cond if the condition is always true or false in certain
3053/// trivial cases. Promotes the then/else block in the parent operation block.
3054struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
3055 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
3056
3057 LogicalResult matchAndRewrite(AffineIfOp op,
3058 PatternRewriter &rewriter) const override {
3059
3060 auto isTriviallyFalse = [](IntegerSet iSet) {
3061 return iSet.isEmptyIntegerSet();
3062 };
3063
3064 auto isTriviallyTrue = [](IntegerSet iSet) {
3065 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
3066 iSet.getConstraint(0) == 0);
3067 };
3068
3069 IntegerSet affineIfConditions = op.getIntegerSet();
3070 Block *blockToMove;
3071 if (isTriviallyFalse(affineIfConditions)) {
3072 // The absence, or equivalently, the emptiness of the else region need not
3073 // be checked when affine.if is returning results because if an affine.if
3074 // operation is returning results, it always has a non-empty else region.
3075 if (op.getNumResults() == 0 && !op.hasElse()) {
3076 // If the else region is absent, or equivalently, empty, remove the
3077 // affine.if operation (which is not returning any results).
3078 rewriter.eraseOp(op);
3079 return success();
3080 }
3081 blockToMove = op.getElseBlock();
3082 } else if (isTriviallyTrue(affineIfConditions)) {
3083 blockToMove = op.getThenBlock();
3084 } else {
3085 return failure();
3086 }
3087 Operation *blockToMoveTerminator = blockToMove->getTerminator();
3088 // Promote the "blockToMove" block to the parent operation block between the
3089 // prologue and epilogue of "op".
3090 rewriter.inlineBlockBefore(blockToMove, op);
3091 // Replace the "op" operation with the operands of the
3092 // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
3093 // the affine.yield operation present in the "blockToMove" block. It has no
3094 // operands when affine.if is not returning results and therefore, in that
3095 // case, replaceOp just erases "op". When affine.if is not returning
3096 // results, the affine.yield operation can be omitted. It gets inserted
3097 // implicitly.
3098 rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
3099 // Erase the "blockToMoveTerminator" operation since it is now in the parent
3100 // operation block, which already has its own terminator.
3101 rewriter.eraseOp(blockToMoveTerminator);
3102 return success();
3103 }
3104};
3105} // namespace
3106
3107/// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
3108/// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
3109void AffineIfOp::getSuccessorRegions(
3111 // If the predecessor is an AffineIfOp, then branching into both `then` and
3112 // `else` region is valid.
3113 if (point.isParent()) {
3114 regions.reserve(2);
3115 regions.push_back(RegionSuccessor(&getThenRegion()));
3116 // If the "else" region is empty, branch bach into parent.
3117 if (getElseRegion().empty()) {
3118 regions.push_back(RegionSuccessor::parent());
3119 } else {
3120 regions.push_back(RegionSuccessor(&getElseRegion()));
3121 }
3122 return;
3123 }
3124
3125 // If the predecessor is the `else`/`then` region, then branching into parent
3126 // op is valid.
3127 regions.push_back(RegionSuccessor::parent());
3128}
3129
3130ValueRange AffineIfOp::getSuccessorInputs(RegionSuccessor successor) {
3131 if (successor.isParent())
3132 return getResults();
3133 if (successor == &getThenRegion())
3134 return getThenRegion().getArguments();
3135 if (successor == &getElseRegion())
3136 return getElseRegion().getArguments();
3137 llvm_unreachable("invalid region successor");
3138}
3139
3140LogicalResult AffineIfOp::verify() {
3141 // Verify that we have a condition attribute.
3142 // FIXME: This should be specified in the arguments list in ODS.
3143 auto conditionAttr =
3144 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3145 if (!conditionAttr)
3146 return emitOpError("requires an integer set attribute named 'condition'");
3147
3148 // Verify that there are enough operands for the condition.
3149 IntegerSet condition = conditionAttr.getValue();
3150 if (getNumOperands() != condition.getNumInputs())
3151 return emitOpError("operand count and condition integer set dimension and "
3152 "symbol count must match");
3153
3154 // Verify that the operands are valid dimension/symbols.
3155 if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
3156 condition.getNumDims())))
3157 return failure();
3158
3159 return success();
3160}
3161
3162ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
3163 // Parse the condition attribute set.
3164 IntegerSetAttr conditionAttr;
3165 unsigned numDims;
3166 if (parser.parseAttribute(conditionAttr,
3167 AffineIfOp::getConditionAttrStrName(),
3168 result.attributes) ||
3169 parseDimAndSymbolList(parser, result.operands, numDims))
3170 return failure();
3171
3172 // Verify the condition operands.
3173 auto set = conditionAttr.getValue();
3174 if (set.getNumDims() != numDims)
3175 return parser.emitError(
3176 parser.getNameLoc(),
3177 "dim operand count and integer set dim count must match");
3178 if (numDims + set.getNumSymbols() != result.operands.size())
3179 return parser.emitError(
3180 parser.getNameLoc(),
3181 "symbol operand count and integer set symbol count must match");
3182
3183 if (parser.parseOptionalArrowTypeList(result.types))
3184 return failure();
3185
3186 // Create the regions for 'then' and 'else'. The latter must be created even
3187 // if it remains empty for the validity of the operation.
3188 result.regions.reserve(2);
3189 Region *thenRegion = result.addRegion();
3190 Region *elseRegion = result.addRegion();
3191
3192 // Parse the 'then' region.
3193 if (parser.parseRegion(*thenRegion, {}, {}))
3194 return failure();
3195 AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
3196 result.location);
3197
3198 // If we find an 'else' keyword then parse the 'else' region.
3199 if (!parser.parseOptionalKeyword("else")) {
3200 if (parser.parseRegion(*elseRegion, {}, {}))
3201 return failure();
3202 AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
3203 result.location);
3204 }
3205
3206 // Parse the optional attribute list.
3207 if (parser.parseOptionalAttrDict(result.attributes))
3208 return failure();
3209
3210 return success();
3211}
3212
3213void AffineIfOp::print(OpAsmPrinter &p) {
3214 auto conditionAttr =
3215 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3216 p << " " << conditionAttr;
3217 printDimAndSymbolList(operand_begin(), operand_end(),
3218 conditionAttr.getValue().getNumDims(), p);
3219 p.printOptionalArrowTypeList(getResultTypes());
3220 p << ' ';
3221 p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
3222 /*printBlockTerminators=*/getNumResults());
3223
3224 // Print the 'else' regions if it has any blocks.
3225 auto &elseRegion = this->getElseRegion();
3226 if (!elseRegion.empty()) {
3227 p << " else ";
3228 p.printRegion(elseRegion,
3229 /*printEntryBlockArgs=*/false,
3230 /*printBlockTerminators=*/getNumResults());
3231 }
3232
3233 // Print the attribute list.
3234 p.printOptionalAttrDict((*this)->getAttrs(),
3235 /*elidedAttrs=*/getConditionAttrStrName());
3236}
3237
3238IntegerSet AffineIfOp::getIntegerSet() {
3239 return (*this)
3240 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
3241 .getValue();
3242}
3243
3244void AffineIfOp::setIntegerSet(IntegerSet newSet) {
3245 (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
3246}
3247
3248void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
3249 setIntegerSet(set);
3250 (*this)->setOperands(operands);
3251}
3252
3253void AffineIfOp::build(OpBuilder &builder, OperationState &result,
3254 TypeRange resultTypes, IntegerSet set, ValueRange args,
3255 bool withElseRegion) {
3256 assert(resultTypes.empty() || withElseRegion);
3257 OpBuilder::InsertionGuard guard(builder);
3258
3259 result.addTypes(resultTypes);
3260 result.addOperands(args);
3261 result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
3262
3263 Region *thenRegion = result.addRegion();
3264 builder.createBlock(thenRegion);
3265 if (resultTypes.empty())
3266 AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
3267
3268 Region *elseRegion = result.addRegion();
3269 if (withElseRegion) {
3270 builder.createBlock(elseRegion);
3271 if (resultTypes.empty())
3272 AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
3273 }
3274}
3275
3276void AffineIfOp::build(OpBuilder &builder, OperationState &result,
3277 IntegerSet set, ValueRange args, bool withElseRegion) {
3278 AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
3279 withElseRegion);
3280}
3281
3282/// Compose any affine.apply ops feeding into `operands` of the integer set
3283/// `set` by composing the maps of such affine.apply ops with the integer
3284/// set constraints.
3286 SmallVectorImpl<Value> &operands,
3287 bool composeAffineMin = false) {
3288 // We will simply reuse the API of the map composition by viewing the LHSs of
3289 // the equalities and inequalities of `set` as the affine exprs of an affine
3290 // map. Convert to equivalent map, compose, and convert back to set.
3291 auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(),
3292 set.getConstraints(), set.getContext());
3293 // Check if any composition is possible.
3294 if (llvm::none_of(operands,
3295 [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
3296 return;
3297
3298 composeAffineMapAndOperands(&map, &operands, composeAffineMin);
3299 set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
3300 set.getEqFlags());
3301}
3302
3303/// Canonicalize an affine if op's conditional (integer set + operands).
3304LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3305 auto set = getIntegerSet();
3306 SmallVector<Value, 4> operands(getOperands());
3307 composeSetAndOperands(set, operands);
3308 canonicalizeSetAndOperands(&set, &operands);
3309
3310 // Check if the canonicalization or composition led to any change.
3311 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
3312 return failure();
3313
3314 setConditional(set, operands);
3315 return success();
3316}
3317
3318void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
3319 MLIRContext *context) {
3320 results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
3321}
3322
3323//===----------------------------------------------------------------------===//
3324// AffineLoadOp
3325//===----------------------------------------------------------------------===//
3326
3327void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3328 AffineMap map, ValueRange operands) {
3329 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
3330 result.addOperands(operands);
3331 if (map)
3332 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3333 auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
3334 result.types.push_back(memrefType.getElementType());
3335}
3336
3337void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3338 Value memref, AffineMap map, ValueRange mapOperands) {
3339 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3340 result.addOperands(memref);
3341 result.addOperands(mapOperands);
3342 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3343 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3344 result.types.push_back(memrefType.getElementType());
3345}
3346
3347void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3349 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3350 int64_t rank = memrefType.getRank();
3351 // Create identity map for memrefs with at least one dimension or () -> ()
3352 // for zero-dimensional memrefs.
3353 auto map =
3354 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3355 build(builder, result, memref, map, indices);
3356}
3357
3358ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
3359 auto &builder = parser.getBuilder();
3360 auto indexTy = builder.getIndexType();
3361
3362 MemRefType type;
3364 AffineMapAttr mapAttr;
3366 return failure(
3367 parser.parseOperand(memrefInfo) ||
3368 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3369 AffineLoadOp::getMapAttrStrName(),
3370 result.attributes) ||
3371 parser.parseOptionalAttrDict(result.attributes) ||
3372 parser.parseColonType(type) ||
3373 parser.resolveOperand(memrefInfo, type, result.operands) ||
3374 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3375 parser.addTypeToList(type.getElementType(), result.types));
3376}
3377
3378void AffineLoadOp::print(OpAsmPrinter &p) {
3379 p << " " << getMemRef() << '[';
3380 if (AffineMapAttr mapAttr =
3381 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3382 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3383 p << ']';
3384 p.printOptionalAttrDict((*this)->getAttrs(),
3385 /*elidedAttrs=*/{getMapAttrStrName()});
3386 p << " : " << getMemRefType();
3387}
3388
3389/// Verify common indexing invariants of affine.load, affine.store,
3390/// affine.vector_load and affine.vector_store.
3391template <typename AffineMemOpTy>
3392static LogicalResult
3393verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr,
3394 Operation::operand_range mapOperands,
3395 MemRefType memrefType, unsigned numIndexOperands) {
3396 AffineMap map = mapAttr.getValue();
3397 if (map.getNumResults() != memrefType.getRank())
3398 return op->emitOpError("affine map num results must equal memref rank");
3399 if (map.getNumInputs() != numIndexOperands)
3400 return op->emitOpError("expects as many subscripts as affine map inputs");
3401
3402 for (auto idx : mapOperands) {
3403 if (!idx.getType().isIndex())
3404 return op->emitOpError("index to load must have 'index' type");
3405 }
3406 if (failed(verifyDimAndSymbolIdentifiers(op, mapOperands, map.getNumDims())))
3407 return failure();
3408
3409 return success();
3410}
3411
3412LogicalResult AffineLoadOp::verify() {
3413 auto memrefType = getMemRefType();
3414 if (getType() != memrefType.getElementType())
3415 return emitOpError("result type must match element type of memref");
3416
3418 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3419 getMapOperands(), memrefType,
3420 /*numIndexOperands=*/getNumOperands() - 1)))
3421 return failure();
3422
3423 return success();
3424}
3425
3426void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3427 MLIRContext *context) {
3428 results.add<SimplifyAffineOp<AffineLoadOp>>(context);
3429}
3430
3431OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3432 /// load(memrefcast) -> load
3433 if (succeeded(memref::foldMemRefCast(*this)))
3434 return getResult();
3435
3436 // Fold load from a global constant memref.
3437 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3438 if (!getGlobalOp)
3439 return {};
3440 // Get to the memref.global defining the symbol.
3442 getGlobalOp, getGlobalOp.getNameAttr());
3443 if (!global)
3444 return {};
3445
3446 // Check if the global memref is a constant.
3447 auto cstAttr =
3448 dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3449 if (!cstAttr)
3450 return {};
3451 // If it's a splat constant, we can fold irrespective of indices.
3452 if (auto splatAttr = dyn_cast<SplatElementsAttr>(cstAttr))
3453 return splatAttr.getSplatValue<Attribute>();
3454 // Otherwise, we can fold only if we know the indices.
3455 if (!getAffineMap().isConstant())
3456 return {};
3457 auto indices =
3458 llvm::map_to_vector<4>(getAffineMap().getConstantResults(),
3459 [](int64_t v) -> uint64_t { return v; });
3460 return cstAttr.getValues<Attribute>()[indices];
3461}
3462
3463//===----------------------------------------------------------------------===//
3464// AffineStoreOp
3465//===----------------------------------------------------------------------===//
3466
3467void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3468 Value valueToStore, Value memref, AffineMap map,
3469 ValueRange mapOperands) {
3470 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3471 result.addOperands(valueToStore);
3472 result.addOperands(memref);
3473 result.addOperands(mapOperands);
3474 result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3475}
3476
3477// Use identity map.
3478void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3479 Value valueToStore, Value memref,
3481 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3482 int64_t rank = memrefType.getRank();
3483 // Create identity map for memrefs with at least one dimension or () -> ()
3484 // for zero-dimensional memrefs.
3485 auto map =
3486 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3487 build(builder, result, valueToStore, memref, map, indices);
3488}
3489
3490ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
3491 auto indexTy = parser.getBuilder().getIndexType();
3492
3493 MemRefType type;
3494 OpAsmParser::UnresolvedOperand storeValueInfo;
3496 AffineMapAttr mapAttr;
3498 return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3499 parser.parseOperand(memrefInfo) ||
3501 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3502 result.attributes) ||
3503 parser.parseOptionalAttrDict(result.attributes) ||
3504 parser.parseColonType(type) ||
3505 parser.resolveOperand(storeValueInfo, type.getElementType(),
3506 result.operands) ||
3507 parser.resolveOperand(memrefInfo, type, result.operands) ||
3508 parser.resolveOperands(mapOperands, indexTy, result.operands));
3509}
3510
3511void AffineStoreOp::print(OpAsmPrinter &p) {
3512 p << " " << getValueToStore();
3513 p << ", " << getMemRef() << '[';
3514 if (AffineMapAttr mapAttr =
3515 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3516 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3517 p << ']';
3518 p.printOptionalAttrDict((*this)->getAttrs(),
3519 /*elidedAttrs=*/{getMapAttrStrName()});
3520 p << " : " << getMemRefType();
3521}
3522
3523LogicalResult AffineStoreOp::verify() {
3524 // The value to store must have the same type as memref element type.
3525 auto memrefType = getMemRefType();
3526 if (getValueToStore().getType() != memrefType.getElementType())
3527 return emitOpError(
3528 "value to store must have the same type as memref element type");
3529
3531 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3532 getMapOperands(), memrefType,
3533 /*numIndexOperands=*/getNumOperands() - 2)))
3534 return failure();
3535
3536 return success();
3537}
3538
3539void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3540 MLIRContext *context) {
3541 results.add<SimplifyAffineOp<AffineStoreOp>>(context);
3542}
3543
3544LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3546 /// store(memrefcast) -> store
3547 return memref::foldMemRefCast(*this, getValueToStore());
3548}
3549
3550//===----------------------------------------------------------------------===//
3551// AffineMinMaxOpBase
3552//===----------------------------------------------------------------------===//
3553
3554template <typename T>
3555static LogicalResult verifyAffineMinMaxOp(T op) {
3556 // Verify that operand count matches affine map dimension and symbol count.
3557 if (op.getNumOperands() !=
3558 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3559 return op.emitOpError(
3560 "operand count and affine map dimension and symbol count must match");
3561
3562 if (op.getMap().getNumResults() == 0)
3563 return op.emitOpError("affine map expect at least one result");
3564 return success();
3565}
3566
3567template <typename T>
3568static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3569 p << ' ' << op->getAttr(T::getMapAttrStrName());
3570 auto operands = op.getOperands();
3571 unsigned numDims = op.getMap().getNumDims();
3572 p << '(' << operands.take_front(numDims) << ')';
3573
3574 if (operands.size() != numDims)
3575 p << '[' << operands.drop_front(numDims) << ']';
3576 p.printOptionalAttrDict(op->getAttrs(),
3577 /*elidedAttrs=*/{T::getMapAttrStrName()});
3578}
3579
3580template <typename T>
3581static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
3583 auto &builder = parser.getBuilder();
3584 auto indexType = builder.getIndexType();
3587 AffineMapAttr mapAttr;
3588 return failure(
3589 parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3590 result.attributes) ||
3592 parser.parseOperandList(symInfos,
3594 parser.parseOptionalAttrDict(result.attributes) ||
3595 parser.resolveOperands(dimInfos, indexType, result.operands) ||
3596 parser.resolveOperands(symInfos, indexType, result.operands) ||
3597 parser.addTypeToList(indexType, result.types));
3598}
3599
3600/// Fold an affine min or max operation with the given operands. The operand
3601/// list may contain nulls, which are interpreted as the operand not being a
3602/// constant.
3603template <typename T>
3605 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3606 "expected affine min or max op");
3607
3608 // Fold the affine map.
3609 // TODO: Fold more cases:
3610 // min(some_affine, some_affine + constant, ...), etc.
3612 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3613
3614 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3615 return op.getOperand(0);
3616
3617 // If some of the map results are not constant, try changing the map in-place.
3618 if (results.empty()) {
3619 // If the map is the same, report that folding did not happen.
3620 if (foldedMap == op.getMap())
3621 return {};
3622 op->setAttr("map", AffineMapAttr::get(foldedMap));
3623 return op.getResult();
3624 }
3625
3626 // Otherwise, completely fold the op into a constant.
3627 auto resultIt = std::is_same<T, AffineMinOp>::value
3628 ? llvm::min_element(results)
3629 : llvm::max_element(results);
3630 if (resultIt == results.end())
3631 return {};
3632 return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3633}
3634
3635/// Remove duplicated expressions in affine min/max ops.
3636template <typename T>
3639
3640 LogicalResult matchAndRewrite(T affineOp,
3641 PatternRewriter &rewriter) const override {
3642 AffineMap oldMap = affineOp.getAffineMap();
3643
3645 for (AffineExpr expr : oldMap.getResults()) {
3646 // This is a linear scan over newExprs, but it should be fine given that
3647 // we typically just have a few expressions per op.
3648 if (!llvm::is_contained(newExprs, expr))
3649 newExprs.push_back(expr);
3650 }
3651
3652 if (newExprs.size() == oldMap.getNumResults())
3653 return failure();
3654
3655 auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
3656 newExprs, rewriter.getContext());
3657 rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3658
3659 return success();
3660 }
3661};
3662
3663/// Merge an affine min/max op to its consumers if its consumer is also an
3664/// affine min/max op.
3665///
3666/// This pattern requires the producer affine min/max op is bound to a
3667/// dimension/symbol that is used as a standalone expression in the consumer
3668/// affine op's map.
3669///
3670/// For example, a pattern like the following:
3671///
3672/// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3673/// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3674///
3675/// Can be turned into:
3676///
3677/// %1 = affine.min affine_map<
3678/// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3679template <typename T>
3682
3683 LogicalResult matchAndRewrite(T affineOp,
3684 PatternRewriter &rewriter) const override {
3685 AffineMap oldMap = affineOp.getAffineMap();
3686 ValueRange dimOperands =
3687 affineOp.getMapOperands().take_front(oldMap.getNumDims());
3688 ValueRange symOperands =
3689 affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3690
3691 auto newDimOperands = llvm::to_vector<8>(dimOperands);
3692 auto newSymOperands = llvm::to_vector<8>(symOperands);
3694 SmallVector<T, 4> producerOps;
3695
3696 // Go over each expression to see whether it's a single dimension/symbol
3697 // with the corresponding operand which is the result of another affine
3698 // min/max op. If So it can be merged into this affine op.
3699 for (AffineExpr expr : oldMap.getResults()) {
3700 if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3701 Value symValue = symOperands[symExpr.getPosition()];
3702 if (auto producerOp = symValue.getDefiningOp<T>()) {
3703 producerOps.push_back(producerOp);
3704 continue;
3705 }
3706 } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3707 Value dimValue = dimOperands[dimExpr.getPosition()];
3708 if (auto producerOp = dimValue.getDefiningOp<T>()) {
3709 producerOps.push_back(producerOp);
3710 continue;
3711 }
3712 }
3713 // For the above cases we will remove the expression by merging the
3714 // producer affine min/max's affine expressions. Otherwise we need to
3715 // keep the existing expression.
3716 newExprs.push_back(expr);
3717 }
3718
3719 if (producerOps.empty())
3720 return failure();
3721
3722 unsigned numUsedDims = oldMap.getNumDims();
3723 unsigned numUsedSyms = oldMap.getNumSymbols();
3724
3725 // Now go over all producer affine ops and merge their expressions.
3726 for (T producerOp : producerOps) {
3727 AffineMap producerMap = producerOp.getAffineMap();
3728 unsigned numProducerDims = producerMap.getNumDims();
3729 unsigned numProducerSyms = producerMap.getNumSymbols();
3730
3731 // Collect all dimension/symbol values.
3732 ValueRange dimValues =
3733 producerOp.getMapOperands().take_front(numProducerDims);
3734 ValueRange symValues =
3735 producerOp.getMapOperands().take_back(numProducerSyms);
3736 newDimOperands.append(dimValues.begin(), dimValues.end());
3737 newSymOperands.append(symValues.begin(), symValues.end());
3738
3739 // For expressions we need to shift to avoid overlap.
3740 for (AffineExpr expr : producerMap.getResults()) {
3741 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3742 .shiftSymbols(numProducerSyms, numUsedSyms));
3743 }
3744
3745 numUsedDims += numProducerDims;
3746 numUsedSyms += numProducerSyms;
3747 }
3748
3749 auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
3750 rewriter.getContext());
3751 auto newOperands =
3752 llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3753 rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3754
3755 return success();
3756 }
3757};
3758
3759/// Canonicalize the result expression order of an affine map and return success
3760/// if the order changed.
3761///
3762/// The function flattens the map's affine expressions to coefficient arrays and
3763/// sorts them in lexicographic order. A coefficient array contains a multiplier
3764/// for every dimension/symbol and a constant term. The canonicalization fails
3765/// if a result expression is not pure or if the flattening requires local
3766/// variables that, unlike dimensions and symbols, have no global order.
3767static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
3768 SmallVector<SmallVector<int64_t>> flattenedExprs;
3769 for (const AffineExpr &resultExpr : map.getResults()) {
3770 // Fail if the expression is not pure.
3771 if (!resultExpr.isPureAffine())
3772 return failure();
3773
3774 SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3775 auto flattenResult = flattener.walkPostOrder(resultExpr);
3776 if (failed(flattenResult))
3777 return failure();
3778
3779 // Fail if the flattened expression has local variables.
3780 if (flattener.operandExprStack.back().size() !=
3781 map.getNumDims() + map.getNumSymbols() + 1)
3782 return failure();
3783
3784 flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3785 flattener.operandExprStack.back().end());
3786 }
3787
3788 // Fail if sorting is not necessary.
3789 if (llvm::is_sorted(flattenedExprs))
3790 return failure();
3791
3792 // Reorder the result expressions according to their flattened form.
3793 SmallVector<unsigned> resultPermutation =
3794 llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults()));
3795 llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3796 return flattenedExprs[lhs] < flattenedExprs[rhs];
3797 });
3798 SmallVector<AffineExpr> newExprs;
3799 for (unsigned idx : resultPermutation)
3800 newExprs.push_back(map.getResult(idx));
3801
3802 map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs,
3803 map.getContext());
3804 return success();
3805}
3806
3807/// Canonicalize the affine map result expression order of an affine min/max
3808/// operation.
3809///
3810/// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3811/// expressions and replaces the operation if the order changed.
3812///
3813/// For example, the following operation:
3814///
3815/// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3816///
3817/// Turns into:
3818///
3819/// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3820template <typename T>
3823
3824 LogicalResult matchAndRewrite(T affineOp,
3825 PatternRewriter &rewriter) const override {
3826 AffineMap map = affineOp.getAffineMap();
3827 if (failed(canonicalizeMapExprAndTermOrder(map)))
3828 return failure();
3829 rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3830 return success();
3831 }
3832};
3833
3834template <typename T>
3837
3838 LogicalResult matchAndRewrite(T affineOp,
3839 PatternRewriter &rewriter) const override {
3840 if (affineOp.getMap().getNumResults() != 1)
3841 return failure();
3842 rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3843 affineOp.getOperands());
3844 return success();
3845 }
3846};
3847
3848//===----------------------------------------------------------------------===//
3849// AffineMinOp
3850//===----------------------------------------------------------------------===//
3851//
3852// %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3853//
3854
3855OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3856 return foldMinMaxOp(*this, adaptor.getOperands());
3857}
3858
3859void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3860 MLIRContext *context) {
3863 MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3865 context);
3866}
3867
3868LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); }
3869
3870ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3872}
3873
3874void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3875
3876//===----------------------------------------------------------------------===//
3877// AffineMaxOp
3878//===----------------------------------------------------------------------===//
3879//
3880// %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3881//
3882
3883OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3884 return foldMinMaxOp(*this, adaptor.getOperands());
3885}
3886
3887void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3888 MLIRContext *context) {
3891 MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3893 context);
3894}
3895
3896LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); }
3897
3898ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3900}
3901
3902void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3903
3904//===----------------------------------------------------------------------===//
3905// AffinePrefetchOp
3906//===----------------------------------------------------------------------===//
3907
3908//
3909// affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3910//
3911ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3913 auto &builder = parser.getBuilder();
3914 auto indexTy = builder.getIndexType();
3915
3916 MemRefType type;
3918 IntegerAttr hintInfo;
3919 auto i32Type = parser.getBuilder().getIntegerType(32);
3920 StringRef readOrWrite, cacheType;
3921
3922 AffineMapAttr mapAttr;
3924 if (parser.parseOperand(memrefInfo) ||
3925 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3926 AffinePrefetchOp::getMapAttrStrName(),
3927 result.attributes) ||
3928 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3929 parser.parseComma() || parser.parseKeyword("locality") ||
3930 parser.parseLess() ||
3931 parser.parseAttribute(hintInfo, i32Type,
3932 AffinePrefetchOp::getLocalityHintAttrStrName(),
3933 result.attributes) ||
3934 parser.parseGreater() || parser.parseComma() ||
3935 parser.parseKeyword(&cacheType) ||
3936 parser.parseOptionalAttrDict(result.attributes) ||
3937 parser.parseColonType(type) ||
3938 parser.resolveOperand(memrefInfo, type, result.operands) ||
3939 parser.resolveOperands(mapOperands, indexTy, result.operands))
3940 return failure();
3941
3942 if (readOrWrite != "read" && readOrWrite != "write")
3943 return parser.emitError(parser.getNameLoc(),
3944 "rw specifier has to be 'read' or 'write'");
3945 result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
3946 parser.getBuilder().getBoolAttr(readOrWrite == "write"));
3947
3948 if (cacheType != "data" && cacheType != "instr")
3949 return parser.emitError(parser.getNameLoc(),
3950 "cache type has to be 'data' or 'instr'");
3951
3952 result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
3953 parser.getBuilder().getBoolAttr(cacheType == "data"));
3954
3955 return success();
3956}
3957
3958void AffinePrefetchOp::print(OpAsmPrinter &p) {
3959 p << " " << getMemref() << '[';
3960 AffineMapAttr mapAttr =
3961 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3962 if (mapAttr)
3963 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3964 p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", " << "locality<"
3965 << getLocalityHint() << ">, " << (getIsDataCache() ? "data" : "instr");
3967 (*this)->getAttrs(),
3968 /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
3969 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
3970 p << " : " << getMemRefType();
3971}
3972
3973LogicalResult AffinePrefetchOp::verify() {
3974 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
3975 if (mapAttr) {
3976 AffineMap map = mapAttr.getValue();
3977 if (map.getNumResults() != getMemRefType().getRank())
3978 return emitOpError("affine.prefetch affine map num results must equal"
3979 " memref rank");
3980 if (map.getNumInputs() + 1 != getNumOperands())
3981 return emitOpError("too few operands");
3982 } else {
3983 if (getNumOperands() != 1)
3984 return emitOpError("too few operands");
3985 }
3986
3987 Region *scope = getAffineScope(*this);
3988 for (auto idx : getMapOperands()) {
3989 if (!isValidAffineIndexOperand(idx, scope))
3990 return emitOpError(
3991 "index must be a valid dimension or symbol identifier");
3992 }
3993 return success();
3994}
3995
3996void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
3997 MLIRContext *context) {
3998 // prefetch(memrefcast) -> prefetch
3999 results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
4000}
4001
4002LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
4004 /// prefetch(memrefcast) -> prefetch
4005 return memref::foldMemRefCast(*this);
4006}
4007
4008//===----------------------------------------------------------------------===//
4009// AffineParallelOp
4010//===----------------------------------------------------------------------===//
4011
4012void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
4013 TypeRange resultTypes,
4015 ArrayRef<int64_t> ranges) {
4016 SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
4017 auto ubs = llvm::map_to_vector<4>(ranges, [&](int64_t value) {
4018 return builder.getConstantAffineMap(value);
4019 });
4020 SmallVector<int64_t> steps(ranges.size(), 1);
4021 build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
4022 /*ubArgs=*/{}, steps);
4023}
4024
4025void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
4026 TypeRange resultTypes,
4028 ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
4029 ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
4030 ArrayRef<int64_t> steps) {
4031 assert(llvm::all_of(lbMaps,
4032 [lbMaps](AffineMap m) {
4033 return m.getNumDims() == lbMaps[0].getNumDims() &&
4034 m.getNumSymbols() == lbMaps[0].getNumSymbols();
4035 }) &&
4036 "expected all lower bounds maps to have the same number of dimensions "
4037 "and symbols");
4038 assert(llvm::all_of(ubMaps,
4039 [ubMaps](AffineMap m) {
4040 return m.getNumDims() == ubMaps[0].getNumDims() &&
4041 m.getNumSymbols() == ubMaps[0].getNumSymbols();
4042 }) &&
4043 "expected all upper bounds maps to have the same number of dimensions "
4044 "and symbols");
4045 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
4046 "expected lower bound maps to have as many inputs as lower bound "
4047 "operands");
4048 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
4049 "expected upper bound maps to have as many inputs as upper bound "
4050 "operands");
4051
4052 OpBuilder::InsertionGuard guard(builder);
4053 result.addTypes(resultTypes);
4054
4055 // Convert the reductions to integer attributes.
4056 SmallVector<Attribute, 4> reductionAttrs;
4057 for (arith::AtomicRMWKind reduction : reductions)
4058 reductionAttrs.push_back(
4059 builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
4060 result.addAttribute(getReductionsAttrStrName(),
4061 builder.getArrayAttr(reductionAttrs));
4062
4063 // Concatenates maps defined in the same input space (same dimensions and
4064 // symbols), assumes there is at least one map.
4065 auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
4066 SmallVectorImpl<int32_t> &groups) {
4067 if (maps.empty())
4068 return AffineMap::get(builder.getContext());
4070 groups.reserve(groups.size() + maps.size());
4071 exprs.reserve(maps.size());
4072 for (AffineMap m : maps) {
4073 llvm::append_range(exprs, m.getResults());
4074 groups.push_back(m.getNumResults());
4075 }
4076 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
4077 maps[0].getContext());
4078 };
4079
4080 // Set up the bounds.
4081 SmallVector<int32_t> lbGroups, ubGroups;
4082 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
4083 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
4084 result.addAttribute(getLowerBoundsMapAttrStrName(),
4085 AffineMapAttr::get(lbMap));
4086 result.addAttribute(getLowerBoundsGroupsAttrStrName(),
4087 builder.getI32TensorAttr(lbGroups));
4088 result.addAttribute(getUpperBoundsMapAttrStrName(),
4089 AffineMapAttr::get(ubMap));
4090 result.addAttribute(getUpperBoundsGroupsAttrStrName(),
4091 builder.getI32TensorAttr(ubGroups));
4092 result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
4093 result.addOperands(lbArgs);
4094 result.addOperands(ubArgs);
4095
4096 // Create a region and a block for the body.
4097 auto *bodyRegion = result.addRegion();
4098 Block *body = builder.createBlock(bodyRegion);
4099
4100 // Add all the block arguments.
4101 for (unsigned i = 0, e = steps.size(); i < e; ++i)
4102 body->addArgument(IndexType::get(builder.getContext()), result.location);
4103 if (resultTypes.empty())
4104 ensureTerminator(*bodyRegion, builder, result.location);
4105}
4106
4107SmallVector<Region *> AffineParallelOp::getLoopRegions() {
4108 return {&getRegion()};
4109}
4110
4111unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
4112
4113AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
4114 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
4115}
4116
4117AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
4118 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
4119}
4120
4121AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
4122 auto values = getLowerBoundsGroups().getValues<int32_t>();
4123 unsigned start = 0;
4124 for (unsigned i = 0; i < pos; ++i)
4125 start += values[i];
4126 return getLowerBoundsMap().getSliceMap(start, values[pos]);
4127}
4128
4129AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
4130 auto values = getUpperBoundsGroups().getValues<int32_t>();
4131 unsigned start = 0;
4132 for (unsigned i = 0; i < pos; ++i)
4133 start += values[i];
4134 return getUpperBoundsMap().getSliceMap(start, values[pos]);
4135}
4136
4137AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
4138 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
4139}
4140
4141AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
4142 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
4143}
4144
4145std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
4146 if (hasMinMaxBounds())
4147 return std::nullopt;
4148
4149 // Try to convert all the ranges to constant expressions.
4151 AffineValueMap rangesValueMap;
4152 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
4153 &rangesValueMap);
4154 out.reserve(rangesValueMap.getNumResults());
4155 for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
4156 auto expr = rangesValueMap.getResult(i);
4157 auto cst = dyn_cast<AffineConstantExpr>(expr);
4158 if (!cst)
4159 return std::nullopt;
4160 out.push_back(cst.getValue());
4161 }
4162 return out;
4163}
4164
4165Block *AffineParallelOp::getBody() { return &getRegion().front(); }
4166
4167OpBuilder AffineParallelOp::getBodyBuilder() {
4168 return OpBuilder(getBody(), std::prev(getBody()->end()));
4169}
4170
4171void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
4172 assert(lbOperands.size() == map.getNumInputs() &&
4173 "operands to map must match number of inputs");
4174
4175 auto ubOperands = getUpperBoundsOperands();
4176
4177 SmallVector<Value, 4> newOperands(lbOperands);
4178 newOperands.append(ubOperands.begin(), ubOperands.end());
4179 (*this)->setOperands(newOperands);
4180
4181 setLowerBoundsMapAttr(AffineMapAttr::get(map));
4182}
4183
4184void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
4185 assert(ubOperands.size() == map.getNumInputs() &&
4186 "operands to map must match number of inputs");
4187
4188 SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
4189 newOperands.append(ubOperands.begin(), ubOperands.end());
4190 (*this)->setOperands(newOperands);
4191
4192 setUpperBoundsMapAttr(AffineMapAttr::get(map));
4193}
4194
4195void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
4196 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
4197}
4198
4199// check whether resultType match op or not in affine.parallel
4201 arith::AtomicRMWKind op) {
4202 switch (op) {
4203 case arith::AtomicRMWKind::addf:
4204 return isa<FloatType>(resultType);
4205 case arith::AtomicRMWKind::addi:
4206 return isa<IntegerType>(resultType);
4207 case arith::AtomicRMWKind::assign:
4208 return true;
4209 case arith::AtomicRMWKind::mulf:
4210 return isa<FloatType>(resultType);
4211 case arith::AtomicRMWKind::muli:
4212 return isa<IntegerType>(resultType);
4213 case arith::AtomicRMWKind::maximumf:
4214 case arith::AtomicRMWKind::maxnumf:
4215 case arith::AtomicRMWKind::minimumf:
4216 case arith::AtomicRMWKind::minnumf:
4217 return isa<FloatType>(resultType);
4218 case arith::AtomicRMWKind::maxs: {
4219 auto intType = dyn_cast<IntegerType>(resultType);
4220 return intType && intType.isSigned();
4221 }
4222 case arith::AtomicRMWKind::mins: {
4223 auto intType = dyn_cast<IntegerType>(resultType);
4224 return intType && intType.isSigned();
4225 }
4226 case arith::AtomicRMWKind::maxu: {
4227 auto intType = dyn_cast<IntegerType>(resultType);
4228 return intType && intType.isUnsigned();
4229 }
4230 case arith::AtomicRMWKind::minu: {
4231 auto intType = dyn_cast<IntegerType>(resultType);
4232 return intType && intType.isUnsigned();
4233 }
4234 case arith::AtomicRMWKind::ori:
4235 case arith::AtomicRMWKind::andi:
4236 case arith::AtomicRMWKind::xori:
4237 return isa<IntegerType>(resultType);
4238 }
4239 llvm_unreachable("Unhandled atomic rmw kind");
4240}
4241
4242LogicalResult AffineParallelOp::verify() {
4243 auto numDims = getNumDims();
4244 if (getLowerBoundsGroups().getNumElements() != numDims ||
4245 getUpperBoundsGroups().getNumElements() != numDims ||
4246 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
4247 return emitOpError() << "the number of region arguments ("
4248 << getBody()->getNumArguments()
4249 << ") and the number of map groups for lower ("
4250 << getLowerBoundsGroups().getNumElements()
4251 << ") and upper bound ("
4252 << getUpperBoundsGroups().getNumElements()
4253 << "), and the number of steps (" << getSteps().size()
4254 << ") must all match";
4255 }
4256
4257 unsigned expectedNumLBResults = 0;
4258 for (APInt v : getLowerBoundsGroups()) {
4259 unsigned results = v.getZExtValue();
4260 if (results == 0)
4261 return emitOpError()
4262 << "expected lower bound map to have at least one result";
4263 expectedNumLBResults += results;
4264 }
4265 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
4266 return emitOpError() << "expected lower bounds map to have "
4267 << expectedNumLBResults << " results";
4268 unsigned expectedNumUBResults = 0;
4269 for (APInt v : getUpperBoundsGroups()) {
4270 unsigned results = v.getZExtValue();
4271 if (results == 0)
4272 return emitOpError()
4273 << "expected upper bound map to have at least one result";
4274 expectedNumUBResults += results;
4275 }
4276 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
4277 return emitOpError() << "expected upper bounds map to have "
4278 << expectedNumUBResults << " results";
4279
4280 if (getReductions().size() != getNumResults())
4281 return emitOpError("a reduction must be specified for each output");
4282
4283 // Verify reduction ops are all valid and each result type matches reduction
4284 // ops
4285 for (auto it : llvm::enumerate((getReductions()))) {
4286 Attribute attr = it.value();
4287 auto intAttr = dyn_cast<IntegerAttr>(attr);
4288 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
4289 return emitOpError("invalid reduction attribute");
4290 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
4291 if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
4292 return emitOpError("result type cannot match reduction attribute");
4293 }
4294
4295 // Verify that the bound operands are valid dimension/symbols.
4296 /// Lower bounds.
4297 if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
4298 getLowerBoundsMap().getNumDims())))
4299 return failure();
4300 /// Upper bounds.
4301 if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
4302 getUpperBoundsMap().getNumDims())))
4303 return failure();
4304 return success();
4305}
4306
4308 SmallVector<Value, 4> newOperands{operands};
4309 auto newMap = getAffineMap();
4310 composeAffineMapAndOperands(&newMap, &newOperands);
4311 if (newMap == getAffineMap() && newOperands == operands)
4312 return failure();
4313 reset(newMap, newOperands);
4314 return success();
4315}
4316
4317/// Canonicalize the bounds of the given loop.
4318static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
4319 AffineValueMap lb = op.getLowerBoundsValueMap();
4320 bool lbCanonicalized = succeeded(lb.canonicalize());
4321
4322 AffineValueMap ub = op.getUpperBoundsValueMap();
4323 bool ubCanonicalized = succeeded(ub.canonicalize());
4324
4325 // Any canonicalization change always leads to updated map(s).
4326 if (!lbCanonicalized && !ubCanonicalized)
4327 return failure();
4328
4329 if (lbCanonicalized)
4330 op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
4331 if (ubCanonicalized)
4332 op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
4333
4334 return success();
4335}
4336
4337LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4338 SmallVectorImpl<OpFoldResult> &results) {
4339 return canonicalizeLoopBounds(*this);
4340}
4341
4342/// Prints a lower(upper) bound of an affine parallel loop with max(min)
4343/// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
4344/// identifies which of the those expressions form max/min groups. `operands`
4345/// are the SSA values of dimensions and symbols and `keyword` is either "min"
4346/// or "max".
4347static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
4348 DenseIntElementsAttr group, ValueRange operands,
4349 StringRef keyword) {
4350 AffineMap map = mapAttr.getValue();
4351 unsigned numDims = map.getNumDims();
4352 ValueRange dimOperands = operands.take_front(numDims);
4353 ValueRange symOperands = operands.drop_front(numDims);
4354 unsigned start = 0;
4355 for (llvm::APInt groupSize : group) {
4356 if (start != 0)
4357 p << ", ";
4358
4359 unsigned size = groupSize.getZExtValue();
4360 if (size == 1) {
4361 p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
4362 ++start;
4363 } else {
4364 p << keyword << '(';
4365 AffineMap submap = map.getSliceMap(start, size);
4366 p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
4367 p << ')';
4368 start += size;
4369 }
4370 }
4371}
4372
4373void AffineParallelOp::print(OpAsmPrinter &p) {
4374 p << " (" << getBody()->getArguments() << ") = (";
4375 printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
4376 getLowerBoundsOperands(), "max");
4377 p << ") to (";
4378 printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
4379 getUpperBoundsOperands(), "min");
4380 p << ')';
4381 SmallVector<int64_t, 8> steps = getSteps();
4382 bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4383 if (!elideSteps) {
4384 p << " step (";
4385 llvm::interleaveComma(steps, p);
4386 p << ')';
4387 }
4388 if (getNumResults()) {
4389 p << " reduce (";
4390 llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4391 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4392 llvm::cast<IntegerAttr>(attr).getInt());
4393 p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4394 });
4395 p << ") -> (" << getResultTypes() << ")";
4396 }
4397
4398 p << ' ';
4399 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4400 /*printBlockTerminators=*/getNumResults());
4402 (*this)->getAttrs(),
4403 /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4404 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4405 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4406 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4407 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4408 AffineParallelOp::getStepsAttrStrName()});
4409}
4410
4411/// Given a list of lists of parsed operands, populates `uniqueOperands` with
4412/// unique operands. Also populates `replacements with affine expressions of
4413/// `kind` that can be used to update affine maps previously accepting a
4414/// `operands` to accept `uniqueOperands` instead.
4415static ParseResult deduplicateAndResolveOperands(
4416 OpAsmParser &parser,
4417 ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> operands,
4418 SmallVectorImpl<Value> &uniqueOperands,
4419 SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
4420 assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
4421 "expected operands to be dim or symbol expression");
4422
4423 Type indexType = parser.getBuilder().getIndexType();
4424 for (const auto &list : operands) {
4425 SmallVector<Value> valueOperands;
4426 if (parser.resolveOperands(list, indexType, valueOperands))
4427 return failure();
4428 for (Value operand : valueOperands) {
4429 unsigned pos = std::distance(uniqueOperands.begin(),
4430 llvm::find(uniqueOperands, operand));
4431 if (pos == uniqueOperands.size())
4432 uniqueOperands.push_back(operand);
4433 replacements.push_back(
4434 kind == AffineExprKind::DimId
4435 ? getAffineDimExpr(pos, parser.getContext())
4436 : getAffineSymbolExpr(pos, parser.getContext()));
4437 }
4438 }
4439 return success();
4440}
4441
4442namespace {
4443enum class MinMaxKind { Min, Max };
4444} // namespace
4445
4446/// Parses an affine map that can contain a min/max for groups of its results,
4447/// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4448/// `result` attributes with the map (flat list of expressions) and the grouping
4449/// (list of integers that specify how many expressions to put into each
4450/// min/max) attributes. Deduplicates repeated operands.
4451///
4452/// parallel-bound ::= `(` parallel-group-list `)`
4453/// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
4454/// parallel-group ::= simple-group | min-max-group
4455/// simple-group ::= expr-of-ssa-ids
4456/// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4457/// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4458///
4459/// Examples:
4460/// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4461/// (%0, max(%1 - 2 * %2))
4462static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
4463 OperationState &result,
4464 MinMaxKind kind) {
4465 // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4466 // see: https://reviews.llvm.org/D134227#3821753
4467 const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4468
4469 StringRef mapName = kind == MinMaxKind::Min
4470 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4471 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4472 StringRef groupsName =
4473 kind == MinMaxKind::Min
4474 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4475 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4476
4477 if (failed(parser.parseLParen()))
4478 return failure();
4479
4480 if (succeeded(parser.parseOptionalRParen())) {
4481 result.addAttribute(
4482 mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4483 result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
4484 return success();
4485 }
4486
4487 SmallVector<AffineExpr> flatExprs;
4488 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatDimOperands;
4489 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatSymOperands;
4490 SmallVector<int32_t> numMapsPerGroup;
4491 SmallVector<OpAsmParser::UnresolvedOperand> mapOperands;
4492 auto parseOperands = [&]() {
4493 if (succeeded(parser.parseOptionalKeyword(
4494 kind == MinMaxKind::Min ? "min" : "max"))) {
4495 mapOperands.clear();
4496 AffineMapAttr map;
4497 if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName,
4498 result.attributes,
4500 return failure();
4501 result.attributes.erase(tmpAttrStrName);
4502 llvm::append_range(flatExprs, map.getValue().getResults());
4503 auto operandsRef = llvm::ArrayRef(mapOperands);
4504 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4505 SmallVector<OpAsmParser::UnresolvedOperand> dims(dimsRef);
4506 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4507 SmallVector<OpAsmParser::UnresolvedOperand> syms(symsRef);
4508 flatDimOperands.append(map.getValue().getNumResults(), dims);
4509 flatSymOperands.append(map.getValue().getNumResults(), syms);
4510 numMapsPerGroup.push_back(map.getValue().getNumResults());
4511 } else {
4512 if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
4513 flatSymOperands.emplace_back(),
4514 flatExprs.emplace_back())))
4515 return failure();
4516 numMapsPerGroup.push_back(1);
4517 }
4518 return success();
4519 };
4520 if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
4521 return failure();
4522
4523 unsigned totalNumDims = 0;
4524 unsigned totalNumSyms = 0;
4525 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4526 unsigned numDims = flatDimOperands[i].size();
4527 unsigned numSyms = flatSymOperands[i].size();
4528 flatExprs[i] = flatExprs[i]
4529 .shiftDims(numDims, totalNumDims)
4530 .shiftSymbols(numSyms, totalNumSyms);
4531 totalNumDims += numDims;
4532 totalNumSyms += numSyms;
4533 }
4534
4535 // Deduplicate map operands.
4536 SmallVector<Value> dimOperands, symOperands;
4537 SmallVector<AffineExpr> dimRplacements, symRepacements;
4538 if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
4539 dimRplacements, AffineExprKind::DimId) ||
4540 deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
4541 symRepacements, AffineExprKind::SymbolId))
4542 return failure();
4543
4544 result.operands.append(dimOperands.begin(), dimOperands.end());
4545 result.operands.append(symOperands.begin(), symOperands.end());
4546
4547 Builder &builder = parser.getBuilder();
4548 auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4549 parser.getContext());
4550 flatMap = flatMap.replaceDimsAndSymbols(
4551 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4552
4553 result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4554 result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
4555 return success();
4556}
4557
4558//
4559// operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4560// `to` parallel-bound steps? region attr-dict?
4561// steps ::= `steps` `(` integer-literals `)`
4562//
4563ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4564 OperationState &result) {
4565 auto &builder = parser.getBuilder();
4566 auto indexType = builder.getIndexType();
4567 SmallVector<OpAsmParser::Argument, 4> ivs;
4569 parser.parseEqual() ||
4570 parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
4571 parser.parseKeyword("to") ||
4572 parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
4573 return failure();
4574
4575 AffineMapAttr stepsMapAttr;
4576 NamedAttrList stepsAttrs;
4577 SmallVector<OpAsmParser::UnresolvedOperand, 4> stepsMapOperands;
4578 if (failed(parser.parseOptionalKeyword("step"))) {
4579 SmallVector<int64_t, 4> steps(ivs.size(), 1);
4580 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4581 builder.getI64ArrayAttr(steps));
4582 } else {
4583 if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
4584 AffineParallelOp::getStepsAttrStrName(),
4585 stepsAttrs,
4587 return failure();
4588
4589 // Convert steps from an AffineMap into an I64ArrayAttr.
4590 SmallVector<int64_t, 4> steps;
4591 auto stepsMap = stepsMapAttr.getValue();
4592 for (const auto &result : stepsMap.getResults()) {
4593 auto constExpr = dyn_cast<AffineConstantExpr>(result);
4594 if (!constExpr)
4595 return parser.emitError(parser.getNameLoc(),
4596 "steps must be constant integers");
4597 steps.push_back(constExpr.getValue());
4598 }
4599 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4600 builder.getI64ArrayAttr(steps));
4601 }
4602
4603 // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4604 // quoted strings are a member of the enum AtomicRMWKind.
4605 SmallVector<Attribute, 4> reductions;
4606 if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4607 if (parser.parseLParen())
4608 return failure();
4609 auto parseAttributes = [&]() -> ParseResult {
4610 // Parse a single quoted string via the attribute parsing, and then
4611 // verify it is a member of the enum and convert to it's integer
4612 // representation.
4613 StringAttr attrVal;
4614 NamedAttrList attrStorage;
4615 auto loc = parser.getCurrentLocation();
4616 if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
4617 attrStorage))
4618 return failure();
4619 std::optional<arith::AtomicRMWKind> reduction =
4620 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4621 if (!reduction)
4622 return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4623 reductions.push_back(
4624 builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4625 // While we keep getting commas, keep parsing.
4626 return success();
4627 };
4628 if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
4629 return failure();
4630 }
4631 result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4632 builder.getArrayAttr(reductions));
4633
4634 // Parse return types of reductions (if any)
4635 if (parser.parseOptionalArrowTypeList(result.types))
4636 return failure();
4637
4638 // Now parse the body.
4639 Region *body = result.addRegion();
4640 for (auto &iv : ivs)
4641 iv.type = indexType;
4642 if (parser.parseRegion(*body, ivs) ||
4643 parser.parseOptionalAttrDict(result.attributes))
4644 return failure();
4645
4646 // Add a terminator if none was parsed.
4647 AffineParallelOp::ensureTerminator(*body, builder, result.location);
4648 return success();
4649}
4650
4651//===----------------------------------------------------------------------===//
4652// AffineYieldOp
4653//===----------------------------------------------------------------------===//
4654
4655LogicalResult AffineYieldOp::verify() {
4656 auto *parentOp = (*this)->getParentOp();
4657 auto results = parentOp->getResults();
4658 auto operands = getOperands();
4659
4660 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4661 return emitOpError() << "only terminates affine.if/for/parallel regions";
4662 if (parentOp->getNumResults() != getNumOperands())
4663 return emitOpError() << "parent of yield must have same number of "
4664 "results as the yield operands";
4665 for (auto it : llvm::zip(results, operands)) {
4666 if (std::get<0>(it).getType() != std::get<1>(it).getType())
4667 return emitOpError() << "types mismatch between yield op and its parent";
4668 }
4669
4670 return success();
4671}
4672
4673//===----------------------------------------------------------------------===//
4674// AffineVectorLoadOp
4675//===----------------------------------------------------------------------===//
4676
4677void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4678 VectorType resultType, AffineMap map,
4679 ValueRange operands) {
4680 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4681 result.addOperands(operands);
4682 if (map)
4683 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4684 result.types.push_back(resultType);
4685}
4686
4687void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4688 VectorType resultType, Value memref,
4689 AffineMap map, ValueRange mapOperands) {
4690 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4691 result.addOperands(memref);
4692 result.addOperands(mapOperands);
4693 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4694 result.types.push_back(resultType);
4695}
4696
4697void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4698 VectorType resultType, Value memref,
4700 auto memrefType = llvm::cast<MemRefType>(memref.getType());
4701 int64_t rank = memrefType.getRank();
4702 // Create identity map for memrefs with at least one dimension or () -> ()
4703 // for zero-dimensional memrefs.
4704 auto map =
4705 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4706 build(builder, result, resultType, memref, map, indices);
4707}
4708
4709void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4710 MLIRContext *context) {
4711 results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4712}
4713
4714ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4715 OperationState &result) {
4716 auto &builder = parser.getBuilder();
4717 auto indexTy = builder.getIndexType();
4718
4719 MemRefType memrefType;
4720 VectorType resultType;
4721 OpAsmParser::UnresolvedOperand memrefInfo;
4722 AffineMapAttr mapAttr;
4723 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
4724 return failure(
4725 parser.parseOperand(memrefInfo) ||
4726 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4727 AffineVectorLoadOp::getMapAttrStrName(),
4728 result.attributes) ||
4729 parser.parseOptionalAttrDict(result.attributes) ||
4730 parser.parseColonType(memrefType) || parser.parseComma() ||
4731 parser.parseType(resultType) ||
4732 parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4733 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
4734 parser.addTypeToList(resultType, result.types));
4735}
4736
4737void AffineVectorLoadOp::print(OpAsmPrinter &p) {
4738 p << " " << getMemRef() << '[';
4739 if (AffineMapAttr mapAttr =
4740 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4741 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4742 p << ']';
4743 p.printOptionalAttrDict((*this)->getAttrs(),
4744 /*elidedAttrs=*/{getMapAttrStrName()});
4745 p << " : " << getMemRefType() << ", " << getType();
4746}
4747
4748/// Verify common invariants of affine.vector_load and affine.vector_store.
4749static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4750 VectorType vectorType) {
4751 // Check that memref and vector element types match.
4752 if (memrefType.getElementType() != vectorType.getElementType())
4753 return op->emitOpError(
4754 "requires memref and vector types of the same elemental type");
4755 return success();
4756}
4757
4758LogicalResult AffineVectorLoadOp::verify() {
4759 MemRefType memrefType = getMemRefType();
4761 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4762 getMapOperands(), memrefType,
4763 /*numIndexOperands=*/getNumOperands() - 1)))
4764 return failure();
4765
4766 if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
4767 return failure();
4768
4769 return success();
4770}
4771
4772//===----------------------------------------------------------------------===//
4773// AffineVectorStoreOp
4774//===----------------------------------------------------------------------===//
4775
4776void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4777 Value valueToStore, Value memref, AffineMap map,
4778 ValueRange mapOperands) {
4779 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4780 result.addOperands(valueToStore);
4781 result.addOperands(memref);
4782 result.addOperands(mapOperands);
4783 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4784}
4785
4786// Use identity map.
4787void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4788 Value valueToStore, Value memref,
4790 auto memrefType = llvm::cast<MemRefType>(memref.getType());
4791 int64_t rank = memrefType.getRank();
4792 // Create identity map for memrefs with at least one dimension or () -> ()
4793 // for zero-dimensional memrefs.
4794 auto map =
4795 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4796 build(builder, result, valueToStore, memref, map, indices);
4797}
4798void AffineVectorStoreOp::getCanonicalizationPatterns(
4799 RewritePatternSet &results, MLIRContext *context) {
4800 results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4801}
4802
4803ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4804 OperationState &result) {
4805 auto indexTy = parser.getBuilder().getIndexType();
4806
4807 MemRefType memrefType;
4808 VectorType resultType;
4809 OpAsmParser::UnresolvedOperand storeValueInfo;
4810 OpAsmParser::UnresolvedOperand memrefInfo;
4811 AffineMapAttr mapAttr;
4812 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
4813 return failure(
4814 parser.parseOperand(storeValueInfo) || parser.parseComma() ||
4815 parser.parseOperand(memrefInfo) ||
4816 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4817 AffineVectorStoreOp::getMapAttrStrName(),
4818 result.attributes) ||
4819 parser.parseOptionalAttrDict(result.attributes) ||
4820 parser.parseColonType(memrefType) || parser.parseComma() ||
4821 parser.parseType(resultType) ||
4822 parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
4823 parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4824 parser.resolveOperands(mapOperands, indexTy, result.operands));
4825}
4826
4827void AffineVectorStoreOp::print(OpAsmPrinter &p) {
4828 p << " " << getValueToStore();
4829 p << ", " << getMemRef() << '[';
4830 if (AffineMapAttr mapAttr =
4831 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4832 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4833 p << ']';
4834 p.printOptionalAttrDict((*this)->getAttrs(),
4835 /*elidedAttrs=*/{getMapAttrStrName()});
4836 p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4837}
4838
4839LogicalResult AffineVectorStoreOp::verify() {
4840 MemRefType memrefType = getMemRefType();
4842 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4843 getMapOperands(), memrefType,
4844 /*numIndexOperands=*/getNumOperands() - 2)))
4845 return failure();
4846
4847 if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4848 return failure();
4849
4850 return success();
4851}
4852
4853//===----------------------------------------------------------------------===//
4854// DelinearizeIndexOp
4855//===----------------------------------------------------------------------===//
4856
4857void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4858 OperationState &odsState,
4859 Value linearIndex, ValueRange dynamicBasis,
4860 ArrayRef<int64_t> staticBasis,
4861 bool hasOuterBound) {
4862 SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
4863 : staticBasis.size() + 1,
4864 linearIndex.getType());
4865 build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4866 staticBasis);
4867}
4868
4869void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4870 OperationState &odsState,
4871 Value linearIndex, ValueRange basis,
4872 bool hasOuterBound) {
4873 if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
4874 hasOuterBound = false;
4875 basis = basis.drop_front();
4876 }
4877 SmallVector<Value> dynamicBasis;
4878 SmallVector<int64_t> staticBasis;
4879 dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4880 staticBasis);
4881 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4882 hasOuterBound);
4883}
4884
4885void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4886 OperationState &odsState,
4887 Value linearIndex,
4888 ArrayRef<OpFoldResult> basis,
4889 bool hasOuterBound) {
4890 if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
4891 hasOuterBound = false;
4892 basis = basis.drop_front();
4893 }
4894 SmallVector<Value> dynamicBasis;
4895 SmallVector<int64_t> staticBasis;
4896 dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
4897 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4898 hasOuterBound);
4899}
4900
4901void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4902 OperationState &odsState,
4903 Value linearIndex, ArrayRef<int64_t> basis,
4904 bool hasOuterBound) {
4905 build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound);
4906}
4907
4908LogicalResult AffineDelinearizeIndexOp::verify() {
4909 ArrayRef<int64_t> staticBasis = getStaticBasis();
4910 if (getNumResults() != staticBasis.size() &&
4911 getNumResults() != staticBasis.size() + 1)
4912 return emitOpError("should return an index for each basis element and up "
4913 "to one extra index");
4914
4915 auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4916 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4917 return emitOpError(
4918 "mismatch between dynamic and static basis (kDynamic marker but no "
4919 "corresponding dynamic basis entry) -- this can only happen due to an "
4920 "incorrect fold/rewrite");
4921
4922 if (!llvm::all_of(staticBasis, [](int64_t v) {
4923 return v > 0 || ShapedType::isDynamic(v);
4924 }))
4925 return emitOpError("no basis element may be statically non-positive");
4926
4927 return success();
4928}
4929
4930/// Given mixed basis of affine.delinearize_index/linearize_index replace
4931/// constant SSA values with the constant integer value and return the new
4932/// static basis. In case no such candidate for replacement exists, this utility
4933/// returns std::nullopt.
4934static std::optional<SmallVector<int64_t>>
4936 MutableOperandRange mutableDynamicBasis,
4937 ArrayRef<Attribute> dynamicBasis) {
4938 uint64_t dynamicBasisIndex = 0;
4939 for (Attribute basis : dynamicBasis) {
4940 // Skip poison values: they don't have a concrete integer value, so erasing
4941 // them from the dynamic operands would create an inconsistency between
4942 // the static basis (which would still hold kDynamic) and the dynamic
4943 // operand list (which would be one element shorter).
4944 if (basis && isa<IntegerAttr>(basis)) {
4945 mutableDynamicBasis.erase(dynamicBasisIndex);
4946 } else {
4947 ++dynamicBasisIndex;
4948 }
4949 }
4950
4951 // No constant SSA value exists.
4952 if (dynamicBasisIndex == dynamicBasis.size())
4953 return std::nullopt;
4954
4955 SmallVector<int64_t> staticBasis;
4956 for (OpFoldResult basis : mixedBasis) {
4957 std::optional<int64_t> basisVal = getConstantIntValue(basis);
4958 if (!basisVal)
4959 staticBasis.push_back(ShapedType::kDynamic);
4960 else
4961 staticBasis.push_back(*basisVal);
4962 }
4963
4964 return staticBasis;
4965}
4966
4967LogicalResult
4968AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
4969 SmallVectorImpl<OpFoldResult> &result) {
4970 std::optional<SmallVector<int64_t>> maybeStaticBasis =
4971 foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
4972 adaptor.getDynamicBasis());
4973 if (maybeStaticBasis) {
4974 setStaticBasis(*maybeStaticBasis);
4975 return success();
4976 }
4977 // If we won't be doing any division or modulo (no basis or the one basis
4978 // element is purely advisory), simply return the input value.
4979 if (getNumResults() == 1) {
4980 result.push_back(getLinearIndex());
4981 return success();
4982 }
4983
4984 if (adaptor.getLinearIndex() == nullptr)
4985 return failure();
4986
4987 if (!adaptor.getDynamicBasis().empty())
4988 return failure();
4989
4990 int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
4991 Type attrType = getLinearIndex().getType();
4992
4993 ArrayRef<int64_t> staticBasis = getStaticBasis();
4994 if (hasOuterBound())
4995 staticBasis = staticBasis.drop_front();
4996 for (int64_t modulus : llvm::reverse(staticBasis)) {
4997 result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
4998 highPart = llvm::divideFloorSigned(highPart, modulus);
4999 }
5000 result.push_back(IntegerAttr::get(attrType, highPart));
5001 std::reverse(result.begin(), result.end());
5002 return success();
5003}
5004
5005SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
5006 OpBuilder builder(getContext());
5007 if (hasOuterBound()) {
5008 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5009 return getMixedValues(getStaticBasis().drop_front(),
5010 getDynamicBasis().drop_front(), builder);
5011
5012 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5013 builder);
5014 }
5015
5016 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5017}
5018
5019SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
5020 SmallVector<OpFoldResult> ret = getMixedBasis();
5021 if (!hasOuterBound())
5022 ret.insert(ret.begin(), OpFoldResult());
5023 return ret;
5024}
5025
5026namespace {
5027
5028// Drops delinearization indices that correspond to unit-extent basis
5029struct DropUnitExtentBasis
5030 : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5032
5033 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5034 PatternRewriter &rewriter) const override {
5035 SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
5036 std::optional<Value> zero = std::nullopt;
5037 Location loc = delinearizeOp->getLoc();
5038 Type indexType = delinearizeOp.getLinearIndex().getType();
5039 auto getZero = [&]() -> Value {
5040 if (!zero)
5041 zero = arith::ConstantOp::create(rewriter, loc,
5042 rewriter.getZeroAttr(indexType));
5043 return zero.value();
5044 };
5045
5046 // Replace all indices corresponding to unit-extent basis with 0.
5047 // Remaining basis can be used to get a new `affine.delinearize_index` op.
5048 SmallVector<OpFoldResult> newBasis;
5049 for (auto [index, basis] :
5050 llvm::enumerate(delinearizeOp.getPaddedBasis())) {
5051 std::optional<int64_t> basisVal =
5052 basis ? getConstantIntValue(basis) : std::nullopt;
5053 if (basisVal == 1)
5054 replacements[index] = getZero();
5055 else
5056 newBasis.push_back(basis);
5057 }
5058
5059 if (newBasis.size() == delinearizeOp.getNumResults())
5060 return rewriter.notifyMatchFailure(delinearizeOp,
5061 "no unit basis elements");
5062
5063 if (!newBasis.empty()) {
5064 // Will drop the leading nullptr from `basis` if there was no outer bound.
5065 auto newDelinearizeOp = affine::AffineDelinearizeIndexOp::create(
5066 rewriter, loc, delinearizeOp.getLinearIndex(), newBasis);
5067 int newIndex = 0;
5068 // Map back the new delinearized indices to the values they replace.
5069 for (auto &replacement : replacements) {
5070 if (replacement)
5071 continue;
5072 replacement = newDelinearizeOp->getResult(newIndex++);
5073 }
5074 }
5075
5076 rewriter.replaceOp(delinearizeOp, replacements);
5077 return success();
5078 }
5079};
5080
5081/// If a `affine.delinearize_index`'s input is a `affine.linearize_index
5082/// disjoint` and the two operations end with the same basis elements,
5083/// cancel those parts of the operations out because they are inverses
5084/// of each other.
5085///
5086/// If the operations have the same basis, cancel them entirely.
5087///
5088/// The `disjoint` flag is needed on the `affine.linearize_index` because
5089/// otherwise, there is no guarantee that the inputs to the linearization are
5090/// in-bounds the way the outputs of the delinearization would be.
5091struct CancelDelinearizeOfLinearizeDisjointExactTail
5092 : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5094
5095 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5096 PatternRewriter &rewriter) const override {
5097 auto linearizeOp = delinearizeOp.getLinearIndex()
5098 .getDefiningOp<affine::AffineLinearizeIndexOp>();
5099 if (!linearizeOp)
5100 return rewriter.notifyMatchFailure(delinearizeOp,
5101 "index doesn't come from linearize");
5102
5103 if (!linearizeOp.getDisjoint())
5104 return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
5105
5106 ValueRange linearizeIns = linearizeOp.getMultiIndex();
5107 // Note: we use the full basis so we don't lose outer bounds later.
5108 SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
5109 SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
5110 size_t numMatches = 0;
5111 for (auto [linSize, delinSize] : llvm::zip(
5112 llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
5113 if (linSize != delinSize)
5114 break;
5115 ++numMatches;
5116 }
5117
5118 if (numMatches == 0)
5119 return rewriter.notifyMatchFailure(
5120 delinearizeOp, "final basis element doesn't match linearize");
5121
5122 // The easy case: everything lines up and the basis match sup completely.
5123 if (numMatches == linearizeBasis.size() &&
5124 numMatches == delinearizeBasis.size() &&
5125 linearizeIns.size() == delinearizeOp.getNumResults()) {
5126 rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
5127 return success();
5128 }
5129
5130 Value newLinearize = affine::AffineLinearizeIndexOp::create(
5131 rewriter, linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
5132 ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
5133 linearizeOp.getDisjoint());
5134 auto newDelinearize = affine::AffineDelinearizeIndexOp::create(
5135 rewriter, delinearizeOp.getLoc(), newLinearize,
5136 ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
5137 delinearizeOp.hasOuterBound());
5138 SmallVector<Value> mergedResults(newDelinearize.getResults());
5139 mergedResults.append(linearizeIns.take_back(numMatches).begin(),
5140 linearizeIns.take_back(numMatches).end());
5141 rewriter.replaceOp(delinearizeOp, mergedResults);
5142 return success();
5143 }
5144};
5145
5146/// If the input to a delinearization is a disjoint linearization, and the
5147/// last k > 1 components of the delinearization basis multiply to the
5148/// last component of the linearization basis, break the linearization and
5149/// delinearization into two parts, peeling off the last input to linearization.
5150///
5151/// For example:
5152/// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
5153/// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
5154/// becomes
5155/// %0 = affine.linearize_index [%z, %y] by (3, 2) : index
5156/// %1:2 = affine.delinearize_index %0 by (2, 3) : index
5157/// %2:2 = affine.delinearize_index %x by (8, 4) : index
5158/// where the original %1:4 is replaced by %1:2 ++ %2:2
5159struct SplitDelinearizeSpanningLastLinearizeArg final
5160 : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5162
5163 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5164 PatternRewriter &rewriter) const override {
5165 auto linearizeOp = delinearizeOp.getLinearIndex()
5166 .getDefiningOp<affine::AffineLinearizeIndexOp>();
5167 if (!linearizeOp)
5168 return rewriter.notifyMatchFailure(delinearizeOp,
5169 "index doesn't come from linearize");
5170
5171 if (!linearizeOp.getDisjoint())
5172 return rewriter.notifyMatchFailure(linearizeOp,
5173 "linearize isn't disjoint");
5174
5175 int64_t target = linearizeOp.getStaticBasis().back();
5176 if (ShapedType::isDynamic(target))
5177 return rewriter.notifyMatchFailure(
5178 linearizeOp, "linearize ends with dynamic basis value");
5179
5180 int64_t sizeToSplit = 1;
5181 size_t elemsToSplit = 0;
5182 ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
5183 for (int64_t basisElem : llvm::reverse(basis)) {
5184 if (ShapedType::isDynamic(basisElem))
5185 return rewriter.notifyMatchFailure(
5186 delinearizeOp, "dynamic basis element while scanning for split");
5187 sizeToSplit *= basisElem;
5188 elemsToSplit += 1;
5189
5190 if (sizeToSplit > target)
5191 return rewriter.notifyMatchFailure(delinearizeOp,
5192 "overshot last argument size");
5193 if (sizeToSplit == target)
5194 break;
5195 }
5196
5197 if (sizeToSplit < target)
5198 return rewriter.notifyMatchFailure(
5199 delinearizeOp, "product of known basis elements doesn't exceed last "
5200 "linearize argument");
5201
5202 if (elemsToSplit < 2)
5203 return rewriter.notifyMatchFailure(
5204 delinearizeOp,
5205 "need at least two elements to form the basis product");
5206
5207 Value linearizeWithoutBack = affine::AffineLinearizeIndexOp::create(
5208 rewriter, linearizeOp.getLoc(), linearizeOp.getLinearIndex().getType(),
5209 linearizeOp.getMultiIndex().drop_back(), linearizeOp.getDynamicBasis(),
5210 linearizeOp.getStaticBasis().drop_back(), linearizeOp.getDisjoint());
5211 auto delinearizeWithoutSplitPart = affine::AffineDelinearizeIndexOp::create(
5212 rewriter, delinearizeOp.getLoc(), linearizeWithoutBack,
5213 delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
5214 delinearizeOp.hasOuterBound());
5215 auto delinearizeBack = affine::AffineDelinearizeIndexOp::create(
5216 rewriter, delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
5217 basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
5218 SmallVector<Value> results = llvm::to_vector(
5219 llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
5220 delinearizeBack.getResults()));
5221 rewriter.replaceOp(delinearizeOp, results);
5222
5223 return success();
5224 }
5225};
5226} // namespace
5227
5228void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
5229 RewritePatternSet &patterns, MLIRContext *context) {
5230 patterns
5231 .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
5232 DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
5233 context);
5234}
5235
5236//===----------------------------------------------------------------------===//
5237// LinearizeIndexOp
5238//===----------------------------------------------------------------------===//
5239
5240/// Infer the index type from a set of multi-index values. Returns the common
5241/// type (index or vector<...xindex>), or IndexType if the set is empty.
5242static Type inferIndexType(MLIRContext *ctx, ValueRange multiIndex) {
5243 if (multiIndex.empty())
5244 return IndexType::get(ctx);
5245 return multiIndex.front().getType();
5246}
5247
5248void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5249 OperationState &odsState,
5250 ValueRange multiIndex, ValueRange basis,
5251 bool disjoint) {
5252 if (!basis.empty() && basis.front() == Value())
5253 basis = basis.drop_front();
5254 SmallVector<Value> dynamicBasis;
5255 SmallVector<int64_t> staticBasis;
5256 dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
5257 staticBasis);
5258 Type resultType = inferIndexType(odsBuilder.getContext(), multiIndex);
5259 build(odsBuilder, odsState, resultType, multiIndex, dynamicBasis, staticBasis,
5260 disjoint);
5261}
5262
5263void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5264 OperationState &odsState,
5265 ValueRange multiIndex,
5266 ArrayRef<OpFoldResult> basis,
5267 bool disjoint) {
5268 if (!basis.empty() && basis.front() == OpFoldResult())
5269 basis = basis.drop_front();
5270 SmallVector<Value> dynamicBasis;
5271 SmallVector<int64_t> staticBasis;
5272 dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
5273 Type resultType = inferIndexType(odsBuilder.getContext(), multiIndex);
5274 build(odsBuilder, odsState, resultType, multiIndex, dynamicBasis, staticBasis,
5275 disjoint);
5276}
5277
5278void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5279 OperationState &odsState,
5280 ValueRange multiIndex,
5281 ArrayRef<int64_t> basis, bool disjoint) {
5282 Type resultType = inferIndexType(odsBuilder.getContext(), multiIndex);
5283 build(odsBuilder, odsState, resultType, multiIndex, ValueRange{}, basis,
5284 disjoint);
5285}
5286
5287LogicalResult AffineLinearizeIndexOp::verify() {
5288 size_t numIndexes = getMultiIndex().size();
5289 size_t numBasisElems = getStaticBasis().size();
5290 if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
5291 return emitOpError("should be passed a basis element for each index except "
5292 "possibly the first");
5293
5294 auto dynamicMarkersCount =
5295 llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
5296 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
5297 return emitOpError(
5298 "mismatch between dynamic and static basis (kDynamic marker but no "
5299 "corresponding dynamic basis entry) -- this can only happen due to an "
5300 "incorrect fold/rewrite");
5301
5302 return success();
5303}
5304
5305OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
5306 std::optional<SmallVector<int64_t>> maybeStaticBasis =
5307 foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
5308 adaptor.getDynamicBasis());
5309 if (maybeStaticBasis) {
5310 setStaticBasis(*maybeStaticBasis);
5311 return getResult();
5312 }
5313 // No indices linearizes to zero.
5314 if (getMultiIndex().empty())
5315 return IntegerAttr::get(getResult().getType(), 0);
5316
5317 // One single index linearizes to itself.
5318 if (getMultiIndex().size() == 1)
5319 return getMultiIndex().front();
5320
5321 // Return nullptr if any multi-index attribute has not been folded to a
5322 // concrete integer (e.g. it is still a runtime value or has folded to a
5323 // non-integer attribute such as #ub.poison).
5324 if (llvm::any_of(adaptor.getMultiIndex(), [](Attribute a) {
5325 return !isa_and_nonnull<IntegerAttr>(a);
5326 }))
5327 return nullptr;
5328
5329 if (!adaptor.getDynamicBasis().empty())
5330 return nullptr;
5331
5332 int64_t result = 0;
5333 int64_t stride = 1;
5334 for (auto [length, indexAttr] :
5335 llvm::zip_first(llvm::reverse(getStaticBasis()),
5336 llvm::reverse(adaptor.getMultiIndex()))) {
5337 result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
5338 stride = stride * length;
5339 }
5340 // Handle the index element with no basis element.
5341 if (!hasOuterBound())
5342 result =
5343 result +
5344 cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
5345
5346 return IntegerAttr::get(getResult().getType(), result);
5347}
5348
5349SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
5350 OpBuilder builder(getContext());
5351 if (hasOuterBound()) {
5352 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5353 return getMixedValues(getStaticBasis().drop_front(),
5354 getDynamicBasis().drop_front(), builder);
5355
5356 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5357 builder);
5358 }
5359
5360 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5361}
5362
5363SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
5364 SmallVector<OpFoldResult> ret = getMixedBasis();
5365 if (!hasOuterBound())
5366 ret.insert(ret.begin(), OpFoldResult());
5367 return ret;
5368}
5369
5370namespace {
5371/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
5372/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
5373/// %...d)`.
5374
5375/// Note that `disjoint` is required here, because, without it, we could have
5376/// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
5377/// is a valid operation where the `%c64` cannot be trivially dropped.
5378///
5379/// Alternatively, if `%x` in the above is a known constant 0, remove it even if
5380/// the operation isn't asserted to be `disjoint`.
5381struct DropLinearizeUnitComponentsIfDisjointOrZero final
5382 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5384
5385 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5386 PatternRewriter &rewriter) const override {
5387 ValueRange multiIndex = op.getMultiIndex();
5388 size_t numIndices = multiIndex.size();
5389 SmallVector<Value> newIndices;
5390 newIndices.reserve(numIndices);
5391 SmallVector<OpFoldResult> newBasis;
5392 newBasis.reserve(numIndices);
5393
5394 if (!op.hasOuterBound()) {
5395 newIndices.push_back(multiIndex.front());
5396 multiIndex = multiIndex.drop_front();
5397 }
5398
5399 SmallVector<OpFoldResult> basis = op.getMixedBasis();
5400 for (auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5401 std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
5402 if (!basisEntry || *basisEntry != 1) {
5403 newIndices.push_back(index);
5404 newBasis.push_back(basisElem);
5405 continue;
5406 }
5407
5408 std::optional<int64_t> indexValue = getConstantIntValue(index);
5409 if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5410 newIndices.push_back(index);
5411 newBasis.push_back(basisElem);
5412 continue;
5413 }
5414 }
5415 if (newIndices.size() == numIndices)
5416 return rewriter.notifyMatchFailure(op,
5417 "no unit basis entries to replace");
5418
5419 if (newIndices.empty()) {
5420 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
5421 op, rewriter.getZeroAttr(op.getLinearIndex().getType()));
5422 return success();
5423 }
5424 rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5425 op, newIndices, newBasis, op.getDisjoint());
5426 return success();
5427 }
5428};
5429
5430OpFoldResult computeProduct(Location loc, OpBuilder &builder,
5431 ArrayRef<OpFoldResult> terms) {
5432 int64_t nDynamic = 0;
5433 SmallVector<Value> dynamicPart;
5434 AffineExpr result = builder.getAffineConstantExpr(1);
5435 for (OpFoldResult term : terms) {
5436 if (!term)
5437 return term;
5438 std::optional<int64_t> maybeConst = getConstantIntValue(term);
5439 if (maybeConst) {
5440 result = result * builder.getAffineConstantExpr(*maybeConst);
5441 } else {
5442 dynamicPart.push_back(cast<Value>(term));
5443 result = result * builder.getAffineSymbolExpr(nDynamic++);
5444 }
5445 }
5446 if (auto constant = dyn_cast<AffineConstantExpr>(result))
5447 return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
5448 return AffineApplyOp::create(builder, loc, result, dynamicPart).getResult();
5449}
5450
5451/// If conseceutive outputs of a delinearize_index are linearized with the same
5452/// bounds, canonicalize away the redundant arithmetic.
5453///
5454/// That is, if we have
5455/// ```
5456/// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
5457/// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
5458/// by (...e, B1, B2, ..., BK, ...f)
5459/// ```
5460///
5461/// We can rewrite this to
5462/// ```
5463/// B = B1 * B2 ... BK
5464/// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
5465/// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
5466/// ```
5467/// where we replace all results of %s unaffected by the change with results
5468/// from %sMerged.
5469///
5470/// As a special case, if all results of the delinearize are merged in this way
5471/// we can replace those usages with %x, thus cancelling the delinearization
5472/// entirely, as in
5473/// ```
5474/// %s:3 = affine.delinearize_index %x into (2, 4, 8)
5475/// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
5476/// ```
5477/// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
5478struct CancelLinearizeOfDelinearizePortion final
5479 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5481
5482private:
5483 // Struct representing a case where the cancellation pattern
5484 // applies. A `Match` means that `length` inputs to the linearize operation
5485 // starting at `linStart` can be cancelled with `length` outputs of
5486 // `delinearize`, starting from `delinStart`.
5487 struct Match {
5488 AffineDelinearizeIndexOp delinearize;
5489 unsigned linStart = 0;
5490 unsigned delinStart = 0;
5491 unsigned length = 0;
5492 };
5493
5494public:
5495 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5496 PatternRewriter &rewriter) const override {
5497 SmallVector<Match> matches;
5498
5499 const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
5500 ArrayRef<OpFoldResult> linBasisRef = linBasis;
5501
5502 ValueRange multiIndex = linearizeOp.getMultiIndex();
5503 unsigned numLinArgs = multiIndex.size();
5504 unsigned linArgIdx = 0;
5505 // We only want to replace one run from the same delinearize op per
5506 // pattern invocation lest we run into invalidation issues.
5507 llvm::SmallPtrSet<Operation *, 2> alreadyMatchedDelinearize;
5508 while (linArgIdx < numLinArgs) {
5509 auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5510 if (!asResult) {
5511 linArgIdx++;
5512 continue;
5513 }
5514
5515 auto delinearizeOp =
5516 dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5517 if (!delinearizeOp) {
5518 linArgIdx++;
5519 continue;
5520 }
5521
5522 /// Result 0 of the delinearize and argument 0 of the linearize can
5523 /// leave their maximum value unspecified. However, even if this happens
5524 /// we can still sometimes start the match process. Specifically, if
5525 /// - The argument we're matching is result 0 and argument 0 (so the
5526 /// bounds don't matter). For example,
5527 ///
5528 /// %0:2 = affine.delinearize_index %x into (8) : index, index
5529 /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
5530 /// allows cancellation
5531 /// - The delinearization doesn't specify a bound, but the linearization
5532 /// is `disjoint`, which asserts that the bound on the linearization is
5533 /// correct.
5534 unsigned delinArgIdx = asResult.getResultNumber();
5535 SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
5536 OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5537 OpFoldResult firstLinBound = linBasis[linArgIdx];
5538 bool boundsMatch = firstDelinBound == firstLinBound;
5539 bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5540 bool knownByDisjoint =
5541 linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5542 if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5543 linArgIdx++;
5544 continue;
5545 }
5546
5547 unsigned j = 1;
5548 unsigned numDelinOuts = delinearizeOp.getNumResults();
5549 for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
5550 ++j) {
5551 if (multiIndex[linArgIdx + j] !=
5552 delinearizeOp.getResult(delinArgIdx + j))
5553 break;
5554 if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
5555 break;
5556 }
5557 // If there're multiple matches against the same delinearize_index,
5558 // only rewrite the first one we find to prevent invalidations. The next
5559 // ones will be taken care of by subsequent pattern invocations.
5560 if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5561 linArgIdx++;
5562 continue;
5563 }
5564 matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx, j});
5565 linArgIdx += j;
5566 }
5567
5568 if (matches.empty())
5569 return rewriter.notifyMatchFailure(
5570 linearizeOp, "no run of delinearize outputs to deal with");
5571
5572 // Record all the delinearize replacements so we can do them after creating
5573 // the new linearization operation, since the new operation might use
5574 // outputs of something we're replacing.
5575 SmallVector<SmallVector<Value>> delinearizeReplacements;
5576
5577 SmallVector<Value> newIndex;
5578 newIndex.reserve(numLinArgs);
5579 SmallVector<OpFoldResult> newBasis;
5580 newBasis.reserve(numLinArgs);
5581 unsigned prevMatchEnd = 0;
5582 for (Match m : matches) {
5583 unsigned gap = m.linStart - prevMatchEnd;
5584 llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5585 llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5586 // Update here so we don't forget this during early continues
5587 prevMatchEnd = m.linStart + m.length;
5588
5589 PatternRewriter::InsertionGuard g(rewriter);
5590 rewriter.setInsertionPoint(m.delinearize);
5591
5592 ArrayRef<OpFoldResult> basisToMerge =
5593 linBasisRef.slice(m.linStart, m.length);
5594 // We use the slice from the linearize's basis above because of the
5595 // "bounds inferred from `disjoint`" case above.
5596 OpFoldResult newSize =
5597 computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
5598
5599 // Trivial case where we can just skip past the delinearize all together
5600 if (m.length == m.delinearize.getNumResults()) {
5601 newIndex.push_back(m.delinearize.getLinearIndex());
5602 newBasis.push_back(newSize);
5603 // Pad out set of replacements so we don't do anything with this one.
5604 delinearizeReplacements.push_back(SmallVector<Value>());
5605 continue;
5606 }
5607
5608 SmallVector<Value> newDelinResults;
5609 SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
5610 newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5611 newDelinBasis.begin() + m.delinStart + m.length);
5612 newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5613 auto newDelinearize = AffineDelinearizeIndexOp::create(
5614 rewriter, m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5615 newDelinBasis);
5616
5617 // Since there may be other uses of the indices we just merged together,
5618 // create a residual affine.delinearize_index that delinearizes the
5619 // merged output into its component parts.
5620 Value combinedElem = newDelinearize.getResult(m.delinStart);
5621 auto residualDelinearize = AffineDelinearizeIndexOp::create(
5622 rewriter, m.delinearize.getLoc(), combinedElem, basisToMerge);
5623
5624 // Swap all the uses of the unaffected delinearize outputs to the new
5625 // delinearization so that the old code can be removed if this
5626 // linearize_index is the only user of the merged results.
5627 llvm::append_range(newDelinResults,
5628 newDelinearize.getResults().take_front(m.delinStart));
5629 llvm::append_range(newDelinResults, residualDelinearize.getResults());
5630 llvm::append_range(
5631 newDelinResults,
5632 newDelinearize.getResults().drop_front(m.delinStart + 1));
5633
5634 delinearizeReplacements.push_back(newDelinResults);
5635 newIndex.push_back(combinedElem);
5636 newBasis.push_back(newSize);
5637 }
5638 llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5639 llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5640 rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
5641 linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5642
5643 for (auto [m, newResults] :
5644 llvm::zip_equal(matches, delinearizeReplacements)) {
5645 if (newResults.empty())
5646 continue;
5647 rewriter.replaceOp(m.delinearize, newResults);
5648 }
5649
5650 return success();
5651 }
5652};
5653
5654/// Strip leading zero from affine.linearize_index.
5655///
5656/// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
5657/// to `affine.linearize_index [...a] by (...b)` in all cases.
5658struct DropLinearizeLeadingZero final
5659 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5661
5662 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5663 PatternRewriter &rewriter) const override {
5664 Value leadingIdx = op.getMultiIndex().front();
5665 if (!matchPattern(leadingIdx, m_Zero()))
5666 return failure();
5667
5668 if (op.getMultiIndex().size() == 1) {
5669 rewriter.replaceOp(op, leadingIdx);
5670 return success();
5671 }
5672
5673 SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
5674 ArrayRef<OpFoldResult> newMixedBasis = mixedBasis;
5675 if (op.hasOuterBound())
5676 newMixedBasis = newMixedBasis.drop_front();
5677
5678 rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5679 op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5680 return success();
5681 }
5682};
5683} // namespace
5684
5685void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5686 RewritePatternSet &patterns, MLIRContext *context) {
5687 patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5688 DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5689}
5690
5691//===----------------------------------------------------------------------===//
5692// TableGen'd op method definitions
5693//===----------------------------------------------------------------------===//
5694
5695#define GET_OP_CLASSES
5696#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
return success()
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
static Type inferIndexType(MLIRContext *ctx, ValueRange multiIndex)
Infer the index type from a set of multi-index values. Returns the common type (index or vector<....
static LogicalResult verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
Definition AffineOps.cpp:62
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
static bool isTopLevelValueOrAbove(Value value, Region *region)
A utility function to check if a value is defined at the top level of region or is an argument of reg...
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
static bool isValidAffineIndexOperand(Value value, Region *region)
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
static LogicalResult replaceAffineDelinearizeIndexInverseExpression(AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
If this map contains of the expression x_1 + x_1 * C_1 + ... x_n * C_N + / ... (not necessarily in or...
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
static void simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
static void legalizeDemotedDims(MapOrSet &mapOrSet, SmallVectorImpl< Value > &operands)
A valid affine dimension may appear as a symbol in affine.apply operations.
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
return success()
static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, AffineExpr dimOrSym, AffineMap *map, ValueRange dims, ValueRange syms)
Assuming dimOrSym is a quantity in the apply op map map and defined by minOp = affine_min(x_1,...
static SmallVector< OpFoldResult > AffineForEmptyLoopFolder(AffineForOp forOp)
Fold the empty loop.
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr &quotientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms, bool replaceAffineMin)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static std::optional< uint64_t > getTrivialConstantTripCount(AffineForOp forOp)
Returns constant trip count in trivial cases.
static LogicalResult verifyAffineMinMaxOp(T op)
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
static void shortenAddChainsContainingAll(AffineExpr e, const llvm::SmallDenseSet< AffineExpr, 4 > &exprsToRemove, AffineExpr newVal, DenseMap< AffineExpr, AffineExpr > &replacementsMap)
Recursively traverse e.
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value getMemRef(Operation *memOp)
Returns the memref being read/written by a memref/affine load/store op.
Definition Utils.cpp:247
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getI64ArrayAttr(paddingDimensions)
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
Definition SCFToGPU.cpp:75
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
Definition SCFToGPU.cpp:80
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
#define div(a, b)
#define rem(a, b)
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
AffineExprKind getKind() const
Return the classification for this type.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
MLIRContext * getContext() const
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
MLIRContext * getContext() const
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
Definition AffineMap.h:221
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
Definition AffineMap.h:267
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
Definition AffineMap.h:228
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ... numSymbols) by symbols[offset + shift ... shift + numSymbols).
Definition AffineMap.h:280
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
AffineMap getDimIdentityMap()
Definition Builders.cpp:387
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:391
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:372
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:376
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
Definition Builders.cpp:183
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
NoneType getNoneType()
Definition Builders.cpp:92
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
Definition Builders.cpp:380
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition Builders.cpp:382
AffineMap getSymbolIdentityMap()
Definition Builders.cpp:400
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:285
IndexType getIndexType()
Definition Builders.cpp:55
An attribute that represents a reference to a dense integer vector or tensor object.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition IntegerSet.h:44
unsigned getNumDims() const
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
unsigned getNumInputs() const
ArrayRef< AffineExpr > getConstraints() const
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
unsigned getNumSymbols() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:119
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
This class represents a single result from folding an operation.
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
OperandRange operand_range
Definition Operation.h:397
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:248
operand_range::iterator operand_iterator
Definition Operation.h:398
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
Region * getSuccessor() const
Return the given region successor.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
A variable that can be added to the constraint set as a "column".
static bool compare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Return "true" if "lhs cmp rhs" was proven to hold.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
AffineBound represents a lower or upper bound in the for operation.
Definition AffineOps.h:217
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
void reset(AffineMap map, ValueRange operands, ValueRange results={})
static void difference(const AffineValueMap &a, const AffineValueMap &b, AffineValueMap *res)
Return the value map that is the difference of value maps 'a' and 'b', represented as an affine map a...
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult computeProduct(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > terms)
Return the product of terms, creating an affine.apply if any of them are non-constant values.
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t > > constLowerBounds, ArrayRef< std::optional< int64_t > > constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
AffineExprKind
Definition AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
Definition AffineExpr.h:50
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
Definition AffineExpr.h:46
@ DimId
Dimensional identifier.
Definition AffineExpr.h:59
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
Definition AffineExpr.h:48
@ SymbolId
Symbolic identifier.
Definition AffineExpr.h:61
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Canonicalize the affine map result expression order of an affine min/max operation.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Remove duplicated expressions in affine min/max ops.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition Builders.h:287
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.