MLIR 23.0.0git
AffineOps.cpp
Go to the documentation of this file.
1//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/AffineExpr.h"
16#include "mlir/IR/IRMapping.h"
17#include "mlir/IR/IntegerSet.h"
18#include "mlir/IR/Matchers.h"
21#include "mlir/IR/Value.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/SmallVectorExtras.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Support/DebugLog.h"
30#include "llvm/Support/LogicalResult.h"
31#include "llvm/Support/MathExtras.h"
32#include <numeric>
33#include <optional>
34
35using namespace mlir;
36using namespace mlir::affine;
37
38using llvm::divideCeilSigned;
39using llvm::divideFloorSigned;
40using llvm::mod;
41
42#define DEBUG_TYPE "affine-ops"
43
44#include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
45
46/// A utility function to check if a value is defined at the top level of
47/// `region` or is an argument of `region`. A value of index type defined at the
48/// top level of a `AffineScope` region is always a valid symbol for all
49/// uses in that region.
51 if (auto arg = dyn_cast<BlockArgument>(value))
52 return arg.getParentRegion() == region;
53 return value.getDefiningOp()->getParentRegion() == region;
54}
55
56/// Checks if `value` known to be a legal affine dimension or symbol in `src`
57/// region remains legal if the operation that uses it is inlined into `dest`
58/// with the given value mapping. `legalityCheck` is either `isValidDim` or
59/// `isValidSymbol`, depending on the value being required to remain a valid
60/// dimension or symbol.
61static bool
63 const IRMapping &mapping,
64 function_ref<bool(Value, Region *)> legalityCheck) {
65 // If the value is a valid dimension for any other reason than being
66 // a top-level value, it will remain valid: constants get inlined
67 // with the function, transitive affine applies also get inlined and
68 // will be checked themselves, etc.
69 if (!isTopLevelValue(value, src))
70 return true;
71
72 // If it's a top-level value because it's a block operand, i.e. a
73 // function argument, check whether the value replacing it after
74 // inlining is a valid dimension in the new region.
75 if (llvm::isa<BlockArgument>(value))
76 return legalityCheck(mapping.lookup(value), dest);
77
78 // If it's a top-level value because it's defined in the region,
79 // it can only be inlined if the defining op is a constant or a
80 // `dim`, which can appear anywhere and be valid, since the defining
81 // op won't be top-level anymore after inlining.
82 Attribute operandCst;
83 bool isDimLikeOp = isa<ShapedDimOpInterface>(value.getDefiningOp());
84 return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
85 isDimLikeOp;
86}
87
88/// Checks if all values known to be legal affine dimensions or symbols in `src`
89/// remain so if their respective users are inlined into `dest`.
90static bool
92 const IRMapping &mapping,
93 function_ref<bool(Value, Region *)> legalityCheck) {
94 return llvm::all_of(values, [&](Value v) {
95 return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
96 });
97}
98
99/// Checks if an affine read or write operation remains legal after inlining
100/// from `src` to `dest`.
101template <typename OpTy>
102static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
103 const IRMapping &mapping) {
104 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
105 AffineWriteOpInterface>::value,
106 "only ops with affine read/write interface are supported");
107
108 AffineMap map = op.getAffineMap();
109 ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
110 ValueRange symbolOperands =
111 op.getMapOperands().take_back(map.getNumSymbols());
113 dimOperands, src, dest, mapping,
114 static_cast<bool (*)(Value, Region *)>(isValidDim)))
115 return false;
117 symbolOperands, src, dest, mapping,
118 static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
119 return false;
120 return true;
121}
122
123/// Checks if an affine apply operation remains legal after inlining from `src`
124/// to `dest`.
125// Use "unused attribute" marker to silence clang-tidy warning stemming from
126// the inability to see through "llvm::TypeSwitch".
127template <>
128[[maybe_unused]] bool remainsLegalAfterInline(AffineApplyOp op, Region *src,
129 Region *dest,
130 const IRMapping &mapping) {
131 // If it's a valid dimension, we need to check that it remains so.
132 if (isValidDim(op.getResult(), src))
134 op.getMapOperands(), src, dest, mapping,
135 static_cast<bool (*)(Value, Region *)>(isValidDim));
136
137 // Otherwise it must be a valid symbol, check that it remains so.
139 op.getMapOperands(), src, dest, mapping,
140 static_cast<bool (*)(Value, Region *)>(isValidSymbol));
141}
142
143//===----------------------------------------------------------------------===//
144// AffineDialect Interfaces
145//===----------------------------------------------------------------------===//
146
147namespace {
148/// This class defines the interface for handling inlining with affine
149/// operations.
150struct AffineInlinerInterface : public DialectInlinerInterface {
151 using DialectInlinerInterface::DialectInlinerInterface;
152
153 //===--------------------------------------------------------------------===//
154 // Analysis Hooks
155 //===--------------------------------------------------------------------===//
156
157 /// Returns true if the given region 'src' can be inlined into the region
158 /// 'dest' that is attached to an operation registered to the current dialect.
159 /// 'wouldBeCloned' is set if the region is cloned into its new location
160 /// rather than moved, indicating there may be other users.
161 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
162 IRMapping &valueMapping) const final {
163 // We can inline into affine loops and conditionals if this doesn't break
164 // affine value categorization rules.
165 Operation *destOp = dest->getParentOp();
166 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
167 return false;
168
169 // Multi-block regions cannot be inlined into affine constructs, all of
170 // which require single-block regions.
171 if (!src->hasOneBlock())
172 return false;
173
174 // Side-effecting operations that the affine dialect cannot understand
175 // should not be inlined.
176 Block &srcBlock = src->front();
177 for (Operation &op : srcBlock) {
178 // Ops with no side effects are fine,
179 if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
180 if (iface.hasNoEffect())
181 continue;
182 }
183
184 // Assuming the inlined region is valid, we only need to check if the
185 // inlining would change it.
186 bool remainsValid =
187 llvm::TypeSwitch<Operation *, bool>(&op)
188 .Case<AffineApplyOp, AffineReadOpInterface,
189 AffineWriteOpInterface>([&](auto op) {
190 return remainsLegalAfterInline(op, src, dest, valueMapping);
191 })
192 .Default([](Operation *) {
193 // Conservatively disallow inlining ops we cannot reason about.
194 return false;
195 });
196
197 if (!remainsValid)
198 return false;
199 }
200
201 return true;
202 }
203
204 /// Returns true if the given operation 'op', that is registered to this
205 /// dialect, can be inlined into the given region, false otherwise.
206 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
207 IRMapping &valueMapping) const final {
208 // Always allow inlining affine operations into a region that is marked as
209 // affine scope, or into affine loops and conditionals. There are some edge
210 // cases when inlining *into* affine structures, but that is handled in the
211 // other 'isLegalToInline' hook above.
212 Operation *parentOp = region->getParentOp();
213 return parentOp->hasTrait<OpTrait::AffineScope>() ||
214 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
215 }
216
217 /// Affine regions should be analyzed recursively.
218 bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
219};
220} // namespace
221
222//===----------------------------------------------------------------------===//
223// AffineDialect
224//===----------------------------------------------------------------------===//
225
226void AffineDialect::initialize() {
227 addOperations<
228#define GET_OP_LIST
229#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
230 >();
231 addInterfaces<AffineInlinerInterface>();
232 declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
233 AffineMinOp>();
234}
235
236/// Materialize a single constant operation from a given attribute value with
237/// the desired resultant type.
238Operation *AffineDialect::materializeConstant(OpBuilder &builder,
239 Attribute value, Type type,
240 Location loc) {
241 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
242 return ub::PoisonOp::create(builder, loc, type, poison);
243 return arith::ConstantOp::materialize(builder, value, type, loc);
244}
245
246/// A utility function to check if a value is defined at the top level of an
247/// op with trait `AffineScope`. If the value is defined in an unlinked region,
248/// conservatively assume it is not top-level. A value of index type defined at
249/// the top level is always a valid symbol.
251 if (auto arg = dyn_cast<BlockArgument>(value)) {
252 // The block owning the argument may be unlinked, e.g. when the surrounding
253 // region has not yet been attached to an Op, at which point the parent Op
254 // is null.
255 Operation *parentOp = arg.getOwner()->getParentOp();
256 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
257 }
258 // The defining Op may live in an unlinked block so its parent Op may be null.
259 Operation *parentOp = value.getDefiningOp()->getParentOp();
260 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
261}
262
263/// Returns the closest region enclosing `op` that is held by an operation with
264/// trait `AffineScope`; `nullptr` if there is no such region.
266 auto *curOp = op;
267 while (auto *parentOp = curOp->getParentOp()) {
268 if (parentOp->hasTrait<OpTrait::AffineScope>())
269 return curOp->getParentRegion();
270 curOp = parentOp;
271 }
272 return nullptr;
273}
274
276 Operation *curOp = op;
277 while (auto *parentOp = curOp->getParentOp()) {
278 if (!isa<AffineForOp, AffineIfOp, AffineParallelOp>(parentOp))
279 return curOp->getParentRegion();
280 curOp = parentOp;
281 }
282 return nullptr;
283}
284
285// A Value can be used as a dimension id iff it meets one of the following
286// conditions:
287// *) It is valid as a symbol.
288// *) It is an induction variable.
289// *) It is the result of affine apply operation with dimension id arguments.
291 // The value must be an index type.
292 if (!value.getType().isIndex())
293 return false;
294
295 if (auto *defOp = value.getDefiningOp())
296 return isValidDim(value, getAffineScope(defOp));
297
298 // This value has to be a block argument for an op that has the
299 // `AffineScope` trait or an induction var of an affine.for or
300 // affine.parallel.
301 if (isAffineInductionVar(value))
302 return true;
303 auto *parentOp = llvm::cast<BlockArgument>(value).getOwner()->getParentOp();
304 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
305}
306
307// Value can be used as a dimension id iff it meets one of the following
308// conditions:
309// *) It is valid as a symbol.
310// *) It is an induction variable.
311// *) It is the result of an affine apply operation with dimension id operands.
312// *) It is the result of a more specialized index transformation (ex.
313// delinearize_index or linearize_index) with dimension id operands.
315 // The value must be an index type.
316 if (!value.getType().isIndex())
317 return false;
318
319 // All valid symbols are okay.
320 if (isValidSymbol(value, region))
321 return true;
322
323 auto *op = value.getDefiningOp();
324 if (!op) {
325 // This value has to be an induction var for an affine.for or an
326 // affine.parallel.
327 return isAffineInductionVar(value);
328 }
329
330 // Affine apply operation is ok if all of its operands are ok.
331 if (auto applyOp = dyn_cast<AffineApplyOp>(op))
332 return applyOp.isValidDim(region);
333 // delinearize_index and linearize_index are special forms of apply
334 // and so are valid dimensions if all their arguments are valid dimensions.
335 if (isa<AffineDelinearizeIndexOp, AffineLinearizeIndexOp>(op))
336 return llvm::all_of(op->getOperands(),
337 [&](Value arg) { return ::isValidDim(arg, region); });
338 // The dim op is okay if its operand memref/tensor is defined at the top
339 // level.
340 if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))
341 return isTopLevelValue(dimOp.getShapedValue());
342 return false;
343}
344
345/// Returns true if the 'index' dimension of the `memref` defined by
346/// `memrefDefOp` is a statically shaped one or defined using a valid symbol
347/// for `region`.
348template <typename AnyMemRefDefOp>
349static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
350 Region *region) {
351 MemRefType memRefType = memrefDefOp.getType();
352
353 // Dimension index is out of bounds.
354 if (index >= memRefType.getRank()) {
355 return false;
356 }
357
358 // Statically shaped.
359 if (!memRefType.isDynamicDim(index))
360 return true;
361 // Get the position of the dimension among dynamic dimensions;
362 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
363 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
364 region);
365}
366
367/// Returns true if the result of the dim op is a valid symbol for `region`.
368static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) {
369 // The dim op is okay if its source is defined at the top level.
370 if (isTopLevelValue(dimOp.getShapedValue()))
371 return true;
372
373 // Conservatively handle remaining BlockArguments as non-valid symbols.
374 // E.g. scf.for iterArgs.
375 if (llvm::isa<BlockArgument>(dimOp.getShapedValue()))
376 return false;
377
378 // The dim op is also okay if its operand memref is a view/subview whose
379 // corresponding size is a valid symbol.
380 std::optional<int64_t> index = getConstantIntValue(dimOp.getDimension());
381
382 // Be conservative if we can't understand the dimension.
383 if (!index.has_value())
384 return false;
385
386 // Skip over all memref.cast ops (if any).
387 Operation *op = dimOp.getShapedValue().getDefiningOp();
388 while (auto castOp = dyn_cast<memref::CastOp>(op)) {
389 // Bail on unranked memrefs.
390 if (isa<UnrankedMemRefType>(castOp.getSource().getType()))
391 return false;
392 op = castOp.getSource().getDefiningOp();
393 if (!op)
394 return false;
395 }
396
397 int64_t i = index.value();
399 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
400 [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
401 .Default([](Operation *) { return false; });
402}
403
404// A value can be used as a symbol (at all its use sites) iff it meets one of
405// the following conditions:
406// *) It is a constant.
407// *) Its defining op or block arg appearance is immediately enclosed by an op
408// with `AffineScope` trait.
409// *) It is the result of an affine.apply operation with symbol operands.
410// *) It is a result of the dim op on a memref whose corresponding size is a
411// valid symbol.
413 if (!value)
414 return false;
415
416 // The value must be an index type.
417 if (!value.getType().isIndex())
418 return false;
419
420 // Check that the value is a top level value.
421 if (isTopLevelValue(value))
422 return true;
423
424 if (auto *defOp = value.getDefiningOp())
425 return isValidSymbol(value, getAffineScope(defOp));
426
427 return false;
428}
429
430/// A utility function to check if a value is defined at the top level of
431/// `region` or is an argument of `region` or is defined above the region.
432static bool isTopLevelValueOrAbove(Value value, Region *region) {
433 Region *parentRegion = value.getParentRegion();
434 do {
435 if (parentRegion == region)
436 return true;
437 Operation *regionOp = region->getParentOp();
438 if (regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
439 break;
440 region = region->getParentOp()->getParentRegion();
441 } while (region);
442 return false;
443}
444
445/// A value can be used as a symbol for `region` iff it meets one of the
446/// following conditions:
447/// *) It is a constant.
448/// *) It is a result of a `Pure` operation whose operands are valid symbolic
449/// *) identifiers.
450/// *) It is a result of the dim op on a memref whose corresponding size is
451/// a valid symbol.
452/// *) It is defined at the top level of 'region' or is its argument.
453/// *) It dominates `region`'s parent op.
454/// If `region` is null, conservatively assume the symbol definition scope does
455/// not exist and only accept the values that would be symbols regardless of
456/// the surrounding region structure, i.e. the first three cases above.
458 // The value must be an index type.
459 if (!value.getType().isIndex())
460 return false;
461
462 // A top-level value is a valid symbol.
463 if (region && isTopLevelValueOrAbove(value, region))
464 return true;
465
466 auto *defOp = value.getDefiningOp();
467 if (!defOp)
468 return false;
469
470 // Constant operation is ok.
471 Attribute operandCst;
472 if (matchPattern(defOp, m_Constant(&operandCst)))
473 return true;
474
475 // `Pure` operation that whose operands are valid symbolic identifiers.
476 if (isPure(defOp) && llvm::all_of(defOp->getOperands(), [&](Value operand) {
477 return affine::isValidSymbol(operand, region);
478 })) {
479 return true;
480 }
481
482 // Dim op results could be valid symbols at any level.
483 if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp))
484 return isDimOpValidSymbol(dimOp, region);
485
486 return false;
487}
488
489// Returns true if 'value' is a valid index to an affine operation (e.g.
490// affine.load, affine.store, affine.dma_start, affine.dma_wait) where
491// `region` provides the polyhedral symbol scope. Returns false otherwise.
492static bool isValidAffineIndexOperand(Value value, Region *region) {
493 return isValidDim(value, region) || isValidSymbol(value, region);
494}
495
496/// Prints dimension and symbol list.
499 unsigned numDims, OpAsmPrinter &printer) {
500 OperandRange operands(begin, end);
501 printer << '(' << operands.take_front(numDims) << ')';
502 if (operands.size() > numDims)
503 printer << '[' << operands.drop_front(numDims) << ']';
504}
505
506/// Parses dimension and symbol list and returns true if parsing failed.
508 OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
511 return failure();
512 // Store number of dimensions for validation by caller.
513 numDims = opInfos.size();
514
515 // Parse the optional symbol operands.
516 auto indexTy = parser.getBuilder().getIndexType();
517 return failure(parser.parseOperandList(
519 parser.resolveOperands(opInfos, indexTy, operands));
520}
521
522/// Utility function to verify that a set of operands are valid dimension and
523/// symbol identifiers. The operands should be laid out such that the dimension
524/// operands are before the symbol operands. This function returns failure if
525/// there was an invalid operand. An operation is provided to emit any necessary
526/// errors.
527template <typename OpTy>
528static LogicalResult
530 unsigned numDims) {
531 unsigned opIt = 0;
532 for (auto operand : operands) {
533 if (opIt++ < numDims) {
534 if (!isValidDim(operand, getAffineScope(op)))
535 return op.emitOpError("operand cannot be used as a dimension id");
536 } else if (!isValidSymbol(operand, getAffineScope(op))) {
537 return op.emitOpError("operand cannot be used as a symbol");
538 }
539 }
540 return success();
541}
542
543//===----------------------------------------------------------------------===//
544// AffineApplyOp
545//===----------------------------------------------------------------------===//
546
547AffineValueMap AffineApplyOp::getAffineValueMap() {
548 return AffineValueMap(getAffineMap(), getOperands(), getResult());
549}
550
551ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
552 auto &builder = parser.getBuilder();
553 auto indexTy = builder.getIndexType();
554
555 AffineMapAttr mapAttr;
556 unsigned numDims;
557 if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
558 parseDimAndSymbolList(parser, result.operands, numDims) ||
559 parser.parseOptionalAttrDict(result.attributes))
560 return failure();
561 auto map = mapAttr.getValue();
562
563 if (map.getNumDims() != numDims ||
564 numDims + map.getNumSymbols() != result.operands.size()) {
565 return parser.emitError(parser.getNameLoc(),
566 "dimension or symbol index mismatch");
567 }
568
569 result.types.append(map.getNumResults(), indexTy);
570 return success();
571}
572
573void AffineApplyOp::print(OpAsmPrinter &p) {
574 p << " " << getMapAttr();
575 printDimAndSymbolList(operand_begin(), operand_end(),
576 getAffineMap().getNumDims(), p);
577 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
578}
579
580LogicalResult AffineApplyOp::verify() {
581 // Check input and output dimensions match.
582 AffineMap affineMap = getMap();
583
584 // Verify that operand count matches affine map dimension and symbol count.
585 if (getNumOperands() != affineMap.getNumDims() + affineMap.getNumSymbols())
586 return emitOpError(
587 "operand count and affine map dimension and symbol count must match");
588
589 // Verify that the map only produces one result.
590 if (affineMap.getNumResults() != 1)
591 return emitOpError("mapping must produce one value");
592
593 // Do not allow valid dims to be used in symbol positions. We do allow
594 // affine.apply to use operands for values that may neither qualify as affine
595 // dims or affine symbols due to usage outside of affine ops, analyses, etc.
596 Region *region = getAffineScope(*this);
597 for (Value operand : getMapOperands().drop_front(affineMap.getNumDims())) {
598 if (::isValidDim(operand, region) && !::isValidSymbol(operand, region))
599 return emitError("dimensional operand cannot be used as a symbol");
600 }
601
602 return success();
603}
604
605// The result of the affine apply operation can be used as a dimension id if all
606// its operands are valid dimension ids.
607bool AffineApplyOp::isValidDim() {
608 return llvm::all_of(getOperands(),
609 [](Value op) { return affine::isValidDim(op); });
610}
611
612// The result of the affine apply operation can be used as a dimension id if all
613// its operands are valid dimension ids with the parent operation of `region`
614// defining the polyhedral scope for symbols.
615bool AffineApplyOp::isValidDim(Region *region) {
616 return llvm::all_of(getOperands(),
617 [&](Value op) { return ::isValidDim(op, region); });
618}
619
620// The result of the affine apply operation can be used as a symbol if all its
621// operands are symbols.
622bool AffineApplyOp::isValidSymbol() {
623 return llvm::all_of(getOperands(),
624 [](Value op) { return affine::isValidSymbol(op); });
625}
626
627// The result of the affine apply operation can be used as a symbol in `region`
628// if all its operands are symbols in `region`.
629bool AffineApplyOp::isValidSymbol(Region *region) {
630 return llvm::all_of(getOperands(), [&](Value operand) {
631 return affine::isValidSymbol(operand, region);
632 });
633}
634
635OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) {
636 auto map = getAffineMap();
637
638 // Fold dims and symbols to existing values.
639 auto expr = map.getResult(0);
640 if (auto dim = dyn_cast<AffineDimExpr>(expr))
641 return getOperand(dim.getPosition());
642 if (auto sym = dyn_cast<AffineSymbolExpr>(expr))
643 return getOperand(map.getNumDims() + sym.getPosition());
644
645 // Otherwise, default to folding the map.
647 bool hasPoison = false;
648 auto foldResult =
649 map.constantFold(adaptor.getMapOperands(), result, &hasPoison);
650 if (hasPoison)
651 return ub::PoisonAttr::get(getContext());
652 if (failed(foldResult))
653 return {};
654 return result[0];
655}
656
657/// Returns the largest known divisor of `e`. Exploits information from the
658/// values in `operands`.
660 // This method isn't aware of `operands`.
662
663 // We now make use of operands for the case `e` is a dim expression.
664 // TODO: More powerful simplification would have to modify
665 // getLargestKnownDivisor to take `operands` and exploit that information as
666 // well for dim/sym expressions, but in that case, getLargestKnownDivisor
667 // can't be part of the IR library but of the `Analysis` library. The IR
668 // library can only really depend on simple O(1) checks.
669 auto dimExpr = dyn_cast<AffineDimExpr>(e);
670 // If it's not a dim expr, `div` is the best we have.
671 if (!dimExpr)
672 return div;
673
674 // We simply exploit information from loop IVs.
675 // We don't need to use mlir::getLargestKnownDivisorOfValue since the other
676 // desired simplifications are expected to be part of other
677 // canonicalizations. Also, mlir::getLargestKnownDivisorOfValue is part of the
678 // LoopAnalysis library.
679 Value operand = operands[dimExpr.getPosition()];
680 int64_t operandDivisor = 1;
681 // TODO: With the right accessors, this can be extended to
682 // LoopLikeOpInterface.
683 if (AffineForOp forOp = getForInductionVarOwner(operand)) {
684 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() == 0) {
685 operandDivisor = forOp.getStepAsInt();
686 } else {
687 uint64_t lbLargestKnownDivisor =
688 forOp.getLowerBoundMap().getLargestKnownDivisorOfMapExprs();
689 operandDivisor = std::gcd(lbLargestKnownDivisor, forOp.getStepAsInt());
690 }
691 }
692 return operandDivisor;
693}
694
695/// Check if `e` is known to be: 0 <= `e` < `k`. Handles the simple cases of `e`
696/// being an affine dim expression or a constant.
698 int64_t k) {
699 if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
700 int64_t constVal = constExpr.getValue();
701 return constVal >= 0 && constVal < k;
702 }
703 auto dimExpr = dyn_cast<AffineDimExpr>(e);
704 if (!dimExpr)
705 return false;
706 Value operand = operands[dimExpr.getPosition()];
707 // TODO: With the right accessors, this can be extended to
708 // LoopLikeOpInterface.
709 if (AffineForOp forOp = getForInductionVarOwner(operand)) {
710 if (forOp.hasConstantLowerBound() && forOp.getConstantLowerBound() >= 0 &&
711 forOp.hasConstantUpperBound() && forOp.getConstantUpperBound() <= k) {
712 return true;
713 }
714 }
715
716 // We don't consider other cases like `operand` being defined by a constant or
717 // an affine.apply op since such cases will already be handled by other
718 // patterns and propagation of loop IVs or constant would happen.
719 return false;
720}
721
722/// Check if expression `e` is of the form d*e_1 + e_2 where 0 <= e_2 < d.
723/// Set `div` to `d`, `quotientTimesDiv` to e_1 and `rem` to e_2 if the
724/// expression is in that form.
726 AffineExpr &quotientTimesDiv, AffineExpr &rem) {
727 auto bin = dyn_cast<AffineBinaryOpExpr>(e);
728 if (!bin || bin.getKind() != AffineExprKind::Add)
729 return false;
730
731 AffineExpr llhs = bin.getLHS();
732 AffineExpr rlhs = bin.getRHS();
733 div = getLargestKnownDivisor(llhs, operands);
734 if (isNonNegativeBoundedBy(rlhs, operands, div)) {
735 quotientTimesDiv = llhs;
736 rem = rlhs;
737 return true;
738 }
739 div = getLargestKnownDivisor(rlhs, operands);
740 if (isNonNegativeBoundedBy(llhs, operands, div)) {
741 quotientTimesDiv = rlhs;
742 rem = llhs;
743 return true;
744 }
745 return false;
746}
747
748/// Gets the constant lower bound on an `iv`.
749static std::optional<int64_t> getLowerBound(Value iv) {
750 AffineForOp forOp = getForInductionVarOwner(iv);
751 if (forOp && forOp.hasConstantLowerBound())
752 return forOp.getConstantLowerBound();
753 return std::nullopt;
754}
755
756/// Gets the constant upper bound on an affine.for `iv`.
757static std::optional<int64_t> getUpperBound(Value iv) {
758 AffineForOp forOp = getForInductionVarOwner(iv);
759 if (!forOp || !forOp.hasConstantUpperBound())
760 return std::nullopt;
761
762 // If its lower bound is also known, we can get a more precise bound
763 // whenever the step is not one.
764 if (forOp.hasConstantLowerBound()) {
765 return forOp.getConstantUpperBound() - 1 -
766 (forOp.getConstantUpperBound() - forOp.getConstantLowerBound() - 1) %
767 forOp.getStepAsInt();
768 }
769 return forOp.getConstantUpperBound() - 1;
770}
771
772/// Determine a constant upper bound for `expr` if one exists while exploiting
773/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
774/// is guaranteed to be less than or equal to it.
775static std::optional<int64_t> getUpperBound(AffineExpr expr, unsigned numDims,
776 unsigned numSymbols,
777 ArrayRef<Value> operands) {
778 // Get the constant lower or upper bounds on the operands.
779 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
780 constLowerBounds.reserve(operands.size());
781 constUpperBounds.reserve(operands.size());
782 for (Value operand : operands) {
783 constLowerBounds.push_back(getLowerBound(operand));
784 constUpperBounds.push_back(getUpperBound(operand));
785 }
786
787 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr))
788 return constExpr.getValue();
789
790 return getBoundForAffineExpr(expr, numDims, numSymbols, constLowerBounds,
791 constUpperBounds,
792 /*isUpper=*/true);
793}
794
795/// Determine a constant lower bound for `expr` if one exists while exploiting
796/// values in `operands`. Note that the upper bound is an inclusive one. `expr`
797/// is guaranteed to be less than or equal to it.
798static std::optional<int64_t> getLowerBound(AffineExpr expr, unsigned numDims,
799 unsigned numSymbols,
800 ArrayRef<Value> operands) {
801 // Get the constant lower or upper bounds on the operands.
802 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
803 constLowerBounds.reserve(operands.size());
804 constUpperBounds.reserve(operands.size());
805 for (Value operand : operands) {
806 constLowerBounds.push_back(getLowerBound(operand));
807 constUpperBounds.push_back(getUpperBound(operand));
808 }
809
810 std::optional<int64_t> lowerBound;
811 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
812 lowerBound = constExpr.getValue();
813 } else {
814 lowerBound = getBoundForAffineExpr(expr, numDims, numSymbols,
815 constLowerBounds, constUpperBounds,
816 /*isUpper=*/false);
817 }
818 return lowerBound;
819}
820
821/// Simplify `expr` while exploiting information from the values in `operands`.
822static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims,
823 unsigned numSymbols,
824 ArrayRef<Value> operands) {
825 // We do this only for certain floordiv/mod expressions.
826 auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
827 if (!binExpr)
828 return;
829
830 // Simplify the child expressions first.
831 AffineExpr lhs = binExpr.getLHS();
832 AffineExpr rhs = binExpr.getRHS();
833 simplifyExprAndOperands(lhs, numDims, numSymbols, operands);
834 simplifyExprAndOperands(rhs, numDims, numSymbols, operands);
835 expr = getAffineBinaryOpExpr(binExpr.getKind(), lhs, rhs);
836
837 binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
838 if (!binExpr || (expr.getKind() != AffineExprKind::FloorDiv &&
840 expr.getKind() != AffineExprKind::Mod)) {
841 return;
842 }
843
844 // The `lhs` and `rhs` may be different post construction of simplified expr.
845 lhs = binExpr.getLHS();
846 rhs = binExpr.getRHS();
847 auto rhsConst = dyn_cast<AffineConstantExpr>(rhs);
848 if (!rhsConst)
849 return;
850
851 int64_t rhsConstVal = rhsConst.getValue();
852 // Undefined exprsessions aren't touched; IR can still be valid with them.
853 if (rhsConstVal <= 0)
854 return;
855
856 // Exploit constant lower/upper bounds to simplify a floordiv or mod.
857 MLIRContext *context = expr.getContext();
858 std::optional<int64_t> lhsLbConst =
859 getLowerBound(lhs, numDims, numSymbols, operands);
860 std::optional<int64_t> lhsUbConst =
861 getUpperBound(lhs, numDims, numSymbols, operands);
862 if (lhsLbConst && lhsUbConst) {
863 int64_t lhsLbConstVal = *lhsLbConst;
864 int64_t lhsUbConstVal = *lhsUbConst;
865 // lhs floordiv c is a single value lhs is bounded in a range `c` that has
866 // the same quotient.
867 if (binExpr.getKind() == AffineExprKind::FloorDiv &&
868 divideFloorSigned(lhsLbConstVal, rhsConstVal) ==
869 divideFloorSigned(lhsUbConstVal, rhsConstVal)) {
871 divideFloorSigned(lhsLbConstVal, rhsConstVal), context);
872 return;
873 }
874 // lhs ceildiv c is a single value if the entire range has the same ceil
875 // quotient.
876 if (binExpr.getKind() == AffineExprKind::CeilDiv &&
877 divideCeilSigned(lhsLbConstVal, rhsConstVal) ==
878 divideCeilSigned(lhsUbConstVal, rhsConstVal)) {
879 expr = getAffineConstantExpr(divideCeilSigned(lhsLbConstVal, rhsConstVal),
880 context);
881 return;
882 }
883 // lhs mod c is lhs if the entire range has quotient 0 w.r.t the rhs.
884 if (binExpr.getKind() == AffineExprKind::Mod && lhsLbConstVal >= 0 &&
885 lhsLbConstVal < rhsConstVal && lhsUbConstVal < rhsConstVal) {
886 expr = lhs;
887 return;
888 }
889 }
890
891 // Simplify expressions of the form e = (e_1 + e_2) floordiv c or (e_1 + e_2)
892 // mod c, where e_1 is a multiple of `k` and 0 <= e_2 < k. In such cases, if
893 // `c` % `k` == 0, (e_1 + e_2) floordiv c can be simplified to e_1 floordiv c.
894 // And when k % c == 0, (e_1 + e_2) mod c can be simplified to e_2 mod c.
895 AffineExpr quotientTimesDiv, rem;
896 int64_t divisor;
897 if (isQTimesDPlusR(lhs, operands, divisor, quotientTimesDiv, rem)) {
898 if (rhsConstVal % divisor == 0 &&
899 binExpr.getKind() == AffineExprKind::FloorDiv) {
900 expr = quotientTimesDiv.floorDiv(rhsConst);
901 } else if (divisor % rhsConstVal == 0 &&
902 binExpr.getKind() == AffineExprKind::Mod) {
903 expr = rem % rhsConst;
904 }
905 return;
906 }
907
908 // Handle the simple case when the LHS expression can be either upper
909 // bounded or is a known multiple of RHS constant.
910 // lhs floordiv c -> 0 if 0 <= lhs < c,
911 // lhs mod c -> 0 if lhs % c = 0.
912 if ((isNonNegativeBoundedBy(lhs, operands, rhsConstVal) &&
913 binExpr.getKind() == AffineExprKind::FloorDiv) ||
914 (getLargestKnownDivisor(lhs, operands) % rhsConstVal == 0 &&
915 binExpr.getKind() == AffineExprKind::Mod)) {
916 expr = getAffineConstantExpr(0, expr.getContext());
917 }
918}
919
920/// Simplify the expressions in `map` while making use of lower or upper bounds
921/// of its operands. If `isMax` is true, the map is to be treated as a max of
922/// its result expressions, and min otherwise. Eg: min (d0, d1) -> (8, 4 * d0 +
923/// d1) can be simplified to (8) if the operands are respectively lower bounded
924/// by 2 and 0 (the second expression can't be lower than 8).
926 ArrayRef<Value> operands,
927 bool isMax) {
928 // Can't simplify.
929 if (operands.empty())
930 return;
931
932 // Get the upper or lower bound on an affine.for op IV using its range.
933 // Get the constant lower or upper bounds on the operands.
934 SmallVector<std::optional<int64_t>> constLowerBounds, constUpperBounds;
935 constLowerBounds.reserve(operands.size());
936 constUpperBounds.reserve(operands.size());
937 for (Value operand : operands) {
938 constLowerBounds.push_back(getLowerBound(operand));
939 constUpperBounds.push_back(getUpperBound(operand));
940 }
941
942 // We will compute the lower and upper bounds on each of the expressions
943 // Then, we will check (depending on max or min) as to whether a specific
944 // bound is redundant by checking if its highest (in case of max) and its
945 // lowest (in the case of min) value is already lower than (or higher than)
946 // the lower bound (or upper bound in the case of min) of another bound.
947 SmallVector<std::optional<int64_t>, 4> lowerBounds, upperBounds;
948 lowerBounds.reserve(map.getNumResults());
949 upperBounds.reserve(map.getNumResults());
950 for (AffineExpr e : map.getResults()) {
951 if (auto constExpr = dyn_cast<AffineConstantExpr>(e)) {
952 lowerBounds.push_back(constExpr.getValue());
953 upperBounds.push_back(constExpr.getValue());
954 } else {
955 lowerBounds.push_back(
957 constLowerBounds, constUpperBounds,
958 /*isUpper=*/false));
959 upperBounds.push_back(
961 constLowerBounds, constUpperBounds,
962 /*isUpper=*/true));
963 }
964 }
965
966 // Collect expressions that are not redundant.
967 SmallVector<AffineExpr, 4> irredundantExprs;
968 for (auto exprEn : llvm::enumerate(map.getResults())) {
969 AffineExpr e = exprEn.value();
970 unsigned i = exprEn.index();
971 // Some expressions can be turned into constants.
972 if (lowerBounds[i] && upperBounds[i] && *lowerBounds[i] == *upperBounds[i])
973 e = getAffineConstantExpr(*lowerBounds[i], e.getContext());
974
975 // Check if the expression is redundant.
976 if (isMax) {
977 if (!upperBounds[i]) {
978 irredundantExprs.push_back(e);
979 continue;
980 }
981 // If there exists another expression such that its lower bound is greater
982 // than this expression's upper bound, it's redundant.
983 if (!llvm::any_of(llvm::enumerate(lowerBounds), [&](const auto &en) {
984 auto otherLowerBound = en.value();
985 unsigned pos = en.index();
986 if (pos == i || !otherLowerBound)
987 return false;
988 if (*otherLowerBound > *upperBounds[i])
989 return true;
990 if (*otherLowerBound < *upperBounds[i])
991 return false;
992 // Equality case. When both expressions are considered redundant, we
993 // don't want to get both of them. We keep the one that appears
994 // first.
995 if (upperBounds[pos] && lowerBounds[i] &&
996 lowerBounds[i] == upperBounds[i] &&
997 otherLowerBound == *upperBounds[pos] && i < pos)
998 return false;
999 return true;
1000 }))
1001 irredundantExprs.push_back(e);
1002 } else {
1003 if (!lowerBounds[i]) {
1004 irredundantExprs.push_back(e);
1005 continue;
1006 }
1007 // Likewise for the `min` case. Use the complement of the condition above.
1008 if (!llvm::any_of(llvm::enumerate(upperBounds), [&](const auto &en) {
1009 auto otherUpperBound = en.value();
1010 unsigned pos = en.index();
1011 if (pos == i || !otherUpperBound)
1012 return false;
1013 if (*otherUpperBound < *lowerBounds[i])
1014 return true;
1015 if (*otherUpperBound > *lowerBounds[i])
1016 return false;
1017 if (lowerBounds[pos] && upperBounds[i] &&
1018 lowerBounds[i] == upperBounds[i] &&
1019 otherUpperBound == lowerBounds[pos] && i < pos)
1020 return false;
1021 return true;
1022 }))
1023 irredundantExprs.push_back(e);
1024 }
1025 }
1026
1027 // Create the map without the redundant expressions.
1028 map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), irredundantExprs,
1029 map.getContext());
1030}
1031
1032/// Simplify the map while exploiting information on the values in `operands`.
1033// Use "unused attribute" marker to silence warning stemming from the inability
1034// to see through the template expansion.
1035[[maybe_unused]] static void simplifyMapWithOperands(AffineMap &map,
1036 ArrayRef<Value> operands) {
1037 assert(map.getNumInputs() == operands.size() && "invalid operands for map");
1038 SmallVector<AffineExpr> newResults;
1039 newResults.reserve(map.getNumResults());
1040 for (AffineExpr expr : map.getResults()) {
1042 operands);
1043 newResults.push_back(expr);
1044 }
1045 map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newResults,
1046 map.getContext());
1047}
1048
1049/// Assuming `dimOrSym` is a quantity in the apply op map `map` and defined by
1050/// `minOp = affine_min(x_1, ..., x_n)`. This function checks that:
1051/// `0 < affine_min(x_1, ..., x_n)` and proceeds with replacing the patterns:
1052/// ```
1053/// dimOrSym.ceildiv(x_k)
1054/// (dimOrSym + x_k - 1).floordiv(x_k)
1055/// ```
1056/// by `1` for all `k` in `1, ..., n`. This is possible because `x / x_k <= 1`.
1057///
1058///
1059/// Warning: ValueBoundsConstraintSet::computeConstantBound is needed to check
1060/// `minOp` is positive.
1061static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
1062 AffineExpr dimOrSym,
1063 AffineMap *map,
1064 ValueRange dims,
1065 ValueRange syms) {
1066 LDBG() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`";
1067 AffineMap affineMinMap = minOp.getAffineMap();
1068
1069 // Check the value is positive.
1070 for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
1071 // Compare each expression in the minimum against 0.
1073 getAsIndexOpFoldResult(minOp.getContext(), 0),
1076 minOp.getOperands())))
1077 return failure();
1078 }
1079
1080 /// Convert affine symbols and dimensions in minOp to symbols or dimensions in
1081 /// the apply op affine map.
1082 DenseMap<AffineExpr, AffineExpr> dimSymConversionTable;
1083 SmallVector<unsigned> unmappedDims, unmappedSyms;
1084 for (auto [i, dim] : llvm::enumerate(minOp.getDimOperands())) {
1085 auto it = llvm::find(dims, dim);
1086 if (it == dims.end()) {
1087 unmappedDims.push_back(i);
1088 continue;
1089 }
1090 dimSymConversionTable[getAffineDimExpr(i, minOp.getContext())] =
1091 getAffineDimExpr(it.getIndex(), minOp.getContext());
1092 }
1093 for (auto [i, sym] : llvm::enumerate(minOp.getSymbolOperands())) {
1094 auto it = llvm::find(syms, sym);
1095 if (it == syms.end()) {
1096 unmappedSyms.push_back(i);
1097 continue;
1098 }
1099 dimSymConversionTable[getAffineSymbolExpr(i, minOp.getContext())] =
1100 getAffineSymbolExpr(it.getIndex(), minOp.getContext());
1101 }
1102
1103 // Create the replacement map.
1105 AffineExpr c1 = getAffineConstantExpr(1, minOp.getContext());
1106 for (AffineExpr expr : affineMinMap.getResults()) {
1107 // If we cannot express the result in terms of the apply map symbols and
1108 // sims then continue.
1109 if (llvm::any_of(unmappedDims,
1110 [&](unsigned i) { return expr.isFunctionOfDim(i); }) ||
1111 llvm::any_of(unmappedSyms,
1112 [&](unsigned i) { return expr.isFunctionOfSymbol(i); }))
1113 continue;
1114
1115 AffineExpr convertedExpr = expr.replace(dimSymConversionTable);
1116
1117 // dimOrSym.ceilDiv(expr) -> 1
1118 repl[dimOrSym.ceilDiv(convertedExpr)] = c1;
1119 // (dimOrSym + expr - 1).floorDiv(expr) -> 1
1120 repl[(dimOrSym + convertedExpr - 1).floorDiv(convertedExpr)] = c1;
1121 }
1122 AffineMap initialMap = *map;
1123 *map = initialMap.replace(repl, initialMap.getNumDims(),
1124 initialMap.getNumSymbols());
1125 return success(*map != initialMap);
1126}
1127
1128/// Recursively traverse `e`. If `e` or one of its sub-expressions has the form
1129/// e1 + e2 + ... + eK, where the e_i are a super(multi)set of `exprsToRemove`,
1130/// place a map between e and `newVal` + sum({e1, e2, .. eK} - exprsToRemove)
1131/// into `replacementsMap`. If no entries were added to `replacementsMap`,
1132/// nothing was found.
1134 AffineExpr e, const llvm::SmallDenseSet<AffineExpr, 4> &exprsToRemove,
1135 AffineExpr newVal, DenseMap<AffineExpr, AffineExpr> &replacementsMap) {
1136 auto binOp = dyn_cast<AffineBinaryOpExpr>(e);
1137 if (!binOp)
1138 return;
1139 AffineExpr lhs = binOp.getLHS();
1140 AffineExpr rhs = binOp.getRHS();
1141 if (binOp.getKind() != AffineExprKind::Add) {
1142 shortenAddChainsContainingAll(lhs, exprsToRemove, newVal, replacementsMap);
1143 shortenAddChainsContainingAll(rhs, exprsToRemove, newVal, replacementsMap);
1144 return;
1145 }
1146 SmallVector<AffineExpr> toPreserve;
1147 llvm::SmallDenseSet<AffineExpr, 4> ourTracker(exprsToRemove);
1148 AffineExpr thisTerm = rhs;
1149 AffineExpr nextTerm = lhs;
1150
1151 while (thisTerm) {
1152 if (!ourTracker.erase(thisTerm)) {
1153 toPreserve.push_back(thisTerm);
1154 shortenAddChainsContainingAll(thisTerm, exprsToRemove, newVal,
1155 replacementsMap);
1156 }
1157 auto nextBinOp = dyn_cast_if_present<AffineBinaryOpExpr>(nextTerm);
1158 if (!nextBinOp || nextBinOp.getKind() != AffineExprKind::Add) {
1159 thisTerm = nextTerm;
1160 nextTerm = AffineExpr();
1161 } else {
1162 thisTerm = nextBinOp.getRHS();
1163 nextTerm = nextBinOp.getLHS();
1164 }
1165 }
1166 if (!ourTracker.empty())
1167 return;
1168 // We reverse the terms to be preserved here in order to preserve
1169 // associativity between them.
1170 AffineExpr newExpr = newVal;
1171 for (AffineExpr preserved : llvm::reverse(toPreserve))
1172 newExpr = newExpr + preserved;
1173 replacementsMap.insert({e, newExpr});
1174}
1175
1176/// If this map contains of the expression `x_1 + x_1 * C_1 + ... x_n * C_N +
1177/// ...` (not necessarily in order) where the set of the `x_i` is the set of
1178/// outputs of an `affine.delinearize_index` whos inverse is that expression,
1179/// replace that expression with the input of that delinearize_index op.
1180///
1181/// `unitDimInput` is the input that was detected as the potential start to this
1182/// replacement chain - if it isn't the rightmost result of the delinearization,
1183/// this method fails. (This is intended to ensure we don't have redundant scans
1184/// over the same expression).
1185///
1186/// While this currently only handles delinearizations with a constant basis,
1187/// that isn't a fundamental limitation.
1188///
1189/// This is a utility function for `replaceDimOrSym` below.
1191 AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map,
1193 if (!delinOp.getDynamicBasis().empty())
1194 return failure();
1195 if (resultToReplace != delinOp.getMultiIndex().back())
1196 return failure();
1197
1198 MLIRContext *ctx = delinOp.getContext();
1199 SmallVector<AffineExpr> resToExpr(delinOp.getNumResults(), AffineExpr());
1200 for (auto [pos, dim] : llvm::enumerate(dims)) {
1201 auto asResult = dyn_cast_if_present<OpResult>(dim);
1202 if (!asResult)
1203 continue;
1204 if (asResult.getOwner() == delinOp.getOperation())
1205 resToExpr[asResult.getResultNumber()] = getAffineDimExpr(pos, ctx);
1206 }
1207 for (auto [pos, sym] : llvm::enumerate(syms)) {
1208 auto asResult = dyn_cast_if_present<OpResult>(sym);
1209 if (!asResult)
1210 continue;
1211 if (asResult.getOwner() == delinOp.getOperation())
1212 resToExpr[asResult.getResultNumber()] = getAffineSymbolExpr(pos, ctx);
1213 }
1214 if (llvm::is_contained(resToExpr, AffineExpr()))
1215 return failure();
1216
1217 bool isDimReplacement = llvm::all_of(resToExpr, llvm::IsaPred<AffineDimExpr>);
1218 int64_t stride = 1;
1219 llvm::SmallDenseSet<AffineExpr, 4> expectedExprs;
1220 // This isn't zip_equal since sometimes the delinearize basis is missing a
1221 // size for the first result.
1222 for (auto [binding, size] : llvm::zip(
1223 llvm::reverse(resToExpr), llvm::reverse(delinOp.getStaticBasis()))) {
1224 expectedExprs.insert(binding * getAffineConstantExpr(stride, ctx));
1225 stride *= size;
1226 }
1227 if (resToExpr.size() != delinOp.getStaticBasis().size())
1228 expectedExprs.insert(resToExpr[0] * stride);
1229
1231 AffineExpr delinInExpr = isDimReplacement
1232 ? getAffineDimExpr(dims.size(), ctx)
1233 : getAffineSymbolExpr(syms.size(), ctx);
1234
1235 for (AffineExpr e : map->getResults())
1236 shortenAddChainsContainingAll(e, expectedExprs, delinInExpr, replacements);
1237 if (replacements.empty())
1238 return failure();
1239
1240 AffineMap origMap = *map;
1241 if (isDimReplacement)
1242 dims.push_back(delinOp.getLinearIndex());
1243 else
1244 syms.push_back(delinOp.getLinearIndex());
1245 *map = origMap.replace(replacements, dims.size(), syms.size());
1246
1247 // Blank out dead dimensions and symbols
1248 for (AffineExpr e : resToExpr) {
1249 if (auto d = dyn_cast<AffineDimExpr>(e)) {
1250 unsigned pos = d.getPosition();
1251 if (!map->isFunctionOfDim(pos))
1252 dims[pos] = nullptr;
1253 }
1254 if (auto s = dyn_cast<AffineSymbolExpr>(e)) {
1255 unsigned pos = s.getPosition();
1256 if (!map->isFunctionOfSymbol(pos))
1257 syms[pos] = nullptr;
1258 }
1259 }
1260 return success();
1261}
1262
1263/// Replace all occurrences of AffineExpr at position `pos` in `map` by the
1264/// defining AffineApplyOp expression and operands.
1265/// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
1266/// When `dimOrSymbolPosition >= dims.size()`,
1267/// AffineSymbolExpr@[pos - dims.size()] is replaced.
1268/// Mutate `map`,`dims` and `syms` in place as follows:
1269/// 1. `dims` and `syms` are only appended to.
1270/// 2. `map` dim and symbols are gradually shifted to higher positions.
1271/// 3. Old `dim` and `sym` entries are replaced by nullptr
1272/// This avoids the need for any bookkeeping.
1273/// If `replaceAffineMin` is set to true, additionally triggers more expensive
1274/// replacements involving affine_min operations.
1275static LogicalResult replaceDimOrSym(AffineMap *map,
1276 unsigned dimOrSymbolPosition,
1279 bool replaceAffineMin) {
1280 MLIRContext *ctx = map->getContext();
1281 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
1282 unsigned pos = isDimReplacement ? dimOrSymbolPosition
1283 : dimOrSymbolPosition - dims.size();
1284 Value &v = isDimReplacement ? dims[pos] : syms[pos];
1285 if (!v)
1286 return failure();
1287
1288 if (auto minOp = v.getDefiningOp<AffineMinOp>(); minOp && replaceAffineMin) {
1289 AffineExpr dimOrSym = isDimReplacement ? getAffineDimExpr(pos, ctx)
1290 : getAffineSymbolExpr(pos, ctx);
1291 return replaceAffineMinBoundingBoxExpression(minOp, dimOrSym, map, dims,
1292 syms);
1293 }
1294
1295 if (auto delinOp = v.getDefiningOp<affine::AffineDelinearizeIndexOp>()) {
1296 return replaceAffineDelinearizeIndexInverseExpression(delinOp, v, map, dims,
1297 syms);
1298 }
1299
1300 auto affineApply = v.getDefiningOp<AffineApplyOp>();
1301 if (!affineApply)
1302 return failure();
1303
1304 // At this point we will perform a replacement of `v`, set the entry in `dim`
1305 // or `sym` to nullptr immediately.
1306 v = nullptr;
1307
1308 // Compute the map, dims and symbols coming from the AffineApplyOp.
1309 AffineMap composeMap = affineApply.getAffineMap();
1310 assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
1311 SmallVector<Value> composeOperands(affineApply.getMapOperands().begin(),
1312 affineApply.getMapOperands().end());
1313 // Canonicalize the map to promote dims to symbols when possible. This is to
1314 // avoid generating invalid maps.
1315 canonicalizeMapAndOperands(&composeMap, &composeOperands);
1316 AffineExpr replacementExpr =
1317 composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
1318 ValueRange composeDims =
1319 ArrayRef<Value>(composeOperands).take_front(composeMap.getNumDims());
1320 ValueRange composeSyms =
1321 ArrayRef<Value>(composeOperands).take_back(composeMap.getNumSymbols());
1322 AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
1323 : getAffineSymbolExpr(pos, ctx);
1324
1325 // Append the dims and symbols where relevant and perform the replacement.
1326 dims.append(composeDims.begin(), composeDims.end());
1327 syms.append(composeSyms.begin(), composeSyms.end());
1328 *map = map->replace(toReplace, replacementExpr, dims.size(), syms.size());
1329
1330 return success();
1331}
1332
1333/// Iterate over `operands` and fold away all those produced by an AffineApplyOp
1334/// iteratively. Perform canonicalization of map and operands as well as
1335/// AffineMap simplification. `map` and `operands` are mutated in place.
1337 SmallVectorImpl<Value> *operands,
1338 bool composeAffineMin = false) {
1339 if (map->getNumResults() == 0) {
1340 canonicalizeMapAndOperands(map, operands);
1341 *map = simplifyAffineMap(*map);
1342 return;
1343 }
1344
1345 MLIRContext *ctx = map->getContext();
1346 SmallVector<Value, 4> dims(operands->begin(),
1347 operands->begin() + map->getNumDims());
1348 SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
1349 operands->end());
1350
1351 // Iterate over dims and symbols coming from AffineApplyOp and replace until
1352 // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
1353 // and `syms` can only increase by construction.
1354 // The implementation uses a `while` loop to support the case of symbols
1355 // that may be constructed from dims ;this may be overkill.
1356 while (true) {
1357 bool changed = false;
1358 for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
1359 if ((changed |=
1360 succeeded(replaceDimOrSym(map, pos, dims, syms, composeAffineMin))))
1361 break;
1362 if (!changed)
1363 break;
1364 }
1365
1366 // Clear operands so we can fill them anew.
1367 operands->clear();
1368
1369 // At this point we may have introduced null operands, prune them out before
1370 // canonicalizing map and operands.
1371 unsigned nDims = 0, nSyms = 0;
1372 SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
1373 dimReplacements.reserve(dims.size());
1374 symReplacements.reserve(syms.size());
1375 for (auto *container : {&dims, &syms}) {
1376 bool isDim = (container == &dims);
1377 auto &repls = isDim ? dimReplacements : symReplacements;
1378 for (const auto &en : llvm::enumerate(*container)) {
1379 Value v = en.value();
1380 if (!v) {
1381 assert(isDim ? !map->isFunctionOfDim(en.index())
1382 : !map->isFunctionOfSymbol(en.index()) &&
1383 "map is function of unexpected expr@pos");
1384 repls.push_back(getAffineConstantExpr(0, ctx));
1385 continue;
1386 }
1387 repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
1388 : getAffineSymbolExpr(nSyms++, ctx));
1389 operands->push_back(v);
1390 }
1391 }
1392 *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
1393 nSyms);
1394
1395 // Canonicalize and simplify before returning.
1396 canonicalizeMapAndOperands(map, operands);
1397 *map = simplifyAffineMap(*map);
1398}
1399
1401 AffineMap *map, SmallVectorImpl<Value> *operands, bool composeAffineMin) {
1402 while (llvm::any_of(*operands, [](Value v) {
1403 return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
1404 })) {
1405 composeAffineMapAndOperands(map, operands, composeAffineMin);
1406 }
1407 // Additional trailing step for AffineMinOps in case no chains of AffineApply.
1408 if (composeAffineMin && llvm::any_of(*operands, [](Value v) {
1409 return isa_and_nonnull<AffineMinOp>(v.getDefiningOp());
1410 })) {
1411 composeAffineMapAndOperands(map, operands, composeAffineMin);
1412 }
1413}
1414
1415AffineApplyOp
1417 ArrayRef<OpFoldResult> operands,
1418 bool composeAffineMin) {
1419 SmallVector<Value> valueOperands;
1420 map = foldAttributesIntoMap(b, map, operands, valueOperands);
1421 composeAffineMapAndOperands(&map, &valueOperands, composeAffineMin);
1422 assert(map);
1423 return AffineApplyOp::create(b, loc, map, valueOperands);
1424}
1425
1426AffineApplyOp
1428 ArrayRef<OpFoldResult> operands,
1429 bool composeAffineMin) {
1431 b, loc,
1433 .front(),
1434 operands, composeAffineMin);
1435}
1436
1437/// Composes the given affine map with the given list of operands, pulling in
1438/// the maps from any affine.apply operations that supply the operands.
1440 SmallVectorImpl<Value> &operands,
1441 bool composeAffineMin = false) {
1442 // Compose and canonicalize each expression in the map individually because
1443 // composition only applies to single-result maps, collecting potentially
1444 // duplicate operands in a single list with shifted dimensions and symbols.
1445 SmallVector<Value> dims, symbols;
1447 for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
1448 SmallVector<Value> submapOperands(operands.begin(), operands.end());
1449 AffineMap submap = map.getSubMap({i});
1450 fullyComposeAffineMapAndOperands(&submap, &submapOperands,
1451 composeAffineMin);
1452 canonicalizeMapAndOperands(&submap, &submapOperands);
1453 unsigned numNewDims = submap.getNumDims();
1454 submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
1455 llvm::append_range(dims,
1456 ArrayRef<Value>(submapOperands).take_front(numNewDims));
1457 llvm::append_range(symbols,
1458 ArrayRef<Value>(submapOperands).drop_front(numNewDims));
1459 exprs.push_back(submap.getResult(0));
1460 }
1461
1462 // Canonicalize the map created from composed expressions to deduplicate the
1463 // dimension and symbol operands.
1464 operands = llvm::to_vector(llvm::concat<Value>(dims, symbols));
1465 map = AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
1466 canonicalizeMapAndOperands(&map, &operands);
1467}
1468
1471 bool composeAffineMin) {
1472 assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");
1473
1474 // Create new builder without a listener, so that no notification is
1475 // triggered if the op is folded.
1476 // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1477 // workaround is no longer needed.
1478 OpBuilder newBuilder(b.getContext());
1479 newBuilder.setInsertionPoint(b.getInsertionBlock(), b.getInsertionPoint());
1480
1481 // Create op.
1482 AffineApplyOp applyOp =
1483 makeComposedAffineApply(newBuilder, loc, map, operands, composeAffineMin);
1484
1485 // Get constant operands.
1486 SmallVector<Attribute> constOperands(applyOp->getNumOperands());
1487 for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1488 matchPattern(applyOp->getOperand(i), m_Constant(&constOperands[i]));
1489
1490 // Try to fold the operation.
1491 SmallVector<OpFoldResult> foldResults;
1492 if (failed(applyOp->fold(constOperands, foldResults)) ||
1493 foldResults.empty()) {
1494 if (OpBuilder::Listener *listener = b.getListener())
1495 listener->notifyOperationInserted(applyOp, /*previous=*/{});
1496 return applyOp.getResult();
1497 }
1498
1499 applyOp->erase();
1500 return llvm::getSingleElement(foldResults);
1501}
1502
1504 OpBuilder &b, Location loc, AffineExpr expr,
1505 ArrayRef<OpFoldResult> operands, bool composeAffineMin) {
1507 b, loc,
1509 .front(),
1510 operands, composeAffineMin);
1511}
1512
1516 bool composeAffineMin) {
1517 return llvm::map_to_vector(
1518 llvm::seq<unsigned>(0, map.getNumResults()), [&](unsigned i) {
1519 return makeComposedFoldedAffineApply(b, loc, map.getSubMap({i}),
1520 operands, composeAffineMin);
1521 });
1522}
1523
1524template <typename OpTy>
1526 ArrayRef<OpFoldResult> operands) {
1527 SmallVector<Value> valueOperands;
1528 map = foldAttributesIntoMap(b, map, operands, valueOperands);
1529 composeMultiResultAffineMap(map, valueOperands);
1530 return OpTy::create(b, loc, b.getIndexType(), map, valueOperands);
1531}
1532
1533AffineMinOp
1538
1539template <typename OpTy>
1541 AffineMap map,
1542 ArrayRef<OpFoldResult> operands) {
1543 // Create new builder without a listener, so that no notification is
1544 // triggered if the op is folded.
1545 // TODO: OpBuilder::createOrFold should return OpFoldResults, then this
1546 // workaround is no longer needed.
1547 OpBuilder newBuilder(b.getContext());
1548 newBuilder.setInsertionPoint(b.getInsertionBlock(), b.getInsertionPoint());
1549
1550 // Create op.
1551 auto minMaxOp = makeComposedMinMax<OpTy>(newBuilder, loc, map, operands);
1552
1553 // Get constant operands.
1554 SmallVector<Attribute> constOperands(minMaxOp->getNumOperands());
1555 for (unsigned i = 0, e = constOperands.size(); i != e; ++i)
1556 matchPattern(minMaxOp->getOperand(i), m_Constant(&constOperands[i]));
1557
1558 // Try to fold the operation.
1559 SmallVector<OpFoldResult> foldResults;
1560 if (failed(minMaxOp->fold(constOperands, foldResults)) ||
1561 foldResults.empty()) {
1562 if (OpBuilder::Listener *listener = b.getListener())
1563 listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
1564 return minMaxOp.getResult();
1565 }
1566
1567 minMaxOp->erase();
1568 return llvm::getSingleElement(foldResults);
1569}
1570
1577
1584
1585// A symbol may appear as a dim in affine.apply operations. This function
1586// canonicalizes dims that are valid symbols into actual symbols.
1587template <class MapOrSet>
1588static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
1589 SmallVectorImpl<Value> *operands) {
1590 if (!mapOrSet || operands->empty())
1591 return;
1592
1593 assert(mapOrSet->getNumInputs() == operands->size() &&
1594 "map/set inputs must match number of operands");
1595
1596 auto *context = mapOrSet->getContext();
1597 SmallVector<Value, 8> resultOperands;
1598 resultOperands.reserve(operands->size());
1599 SmallVector<Value, 8> remappedSymbols;
1600 remappedSymbols.reserve(operands->size());
1601 unsigned nextDim = 0;
1602 unsigned nextSym = 0;
1603 unsigned oldNumSyms = mapOrSet->getNumSymbols();
1604 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1605 for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
1606 if (i < mapOrSet->getNumDims()) {
1607 if (isValidSymbol((*operands)[i])) {
1608 // This is a valid symbol that appears as a dim, canonicalize it.
1609 dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
1610 remappedSymbols.push_back((*operands)[i]);
1611 } else {
1612 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1613 resultOperands.push_back((*operands)[i]);
1614 }
1615 } else {
1616 resultOperands.push_back((*operands)[i]);
1617 }
1618 }
1619
1620 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
1621 *operands = resultOperands;
1622 *mapOrSet = mapOrSet->replaceDimsAndSymbols(
1623 dimRemapping, /*symReplacements=*/{}, nextDim, oldNumSyms + nextSym);
1624
1625 assert(mapOrSet->getNumInputs() == operands->size() &&
1626 "map/set inputs must match number of operands");
1627}
1628
1629/// A valid affine dimension may appear as a symbol in affine.apply operations.
1630/// Given an application of `operands` to an affine map or integer set
1631/// `mapOrSet`, this function canonicalizes symbols of `mapOrSet` that are valid
1632/// dims, but not valid symbols into actual dims. Without such a legalization,
1633/// the affine.apply will be invalid. This method is the exact inverse of
1634/// canonicalizePromotedSymbols.
1635template <class MapOrSet>
1636static void legalizeDemotedDims(MapOrSet &mapOrSet,
1637 SmallVectorImpl<Value> &operands) {
1638 if (!mapOrSet || operands.empty())
1639 return;
1640
1641 unsigned numOperands = operands.size();
1642
1643 assert(mapOrSet.getNumInputs() == numOperands &&
1644 "map/set inputs must match number of operands");
1645
1646 auto *context = mapOrSet.getContext();
1647 SmallVector<Value, 8> resultOperands;
1648 resultOperands.reserve(numOperands);
1649 SmallVector<Value, 8> remappedDims;
1650 remappedDims.reserve(numOperands);
1651 SmallVector<Value, 8> symOperands;
1652 symOperands.reserve(mapOrSet.getNumSymbols());
1653 unsigned nextSym = 0;
1654 unsigned nextDim = 0;
1655 unsigned oldNumDims = mapOrSet.getNumDims();
1656 SmallVector<AffineExpr, 8> symRemapping(mapOrSet.getNumSymbols());
1657 resultOperands.assign(operands.begin(), operands.begin() + oldNumDims);
1658 for (unsigned i = oldNumDims, e = mapOrSet.getNumInputs(); i != e; ++i) {
1659 if (operands[i] && isValidDim(operands[i]) && !isValidSymbol(operands[i])) {
1660 // This is a valid dim that appears as a symbol, legalize it.
1661 symRemapping[i - oldNumDims] =
1662 getAffineDimExpr(oldNumDims + nextDim++, context);
1663 remappedDims.push_back(operands[i]);
1664 } else {
1665 symRemapping[i - oldNumDims] = getAffineSymbolExpr(nextSym++, context);
1666 symOperands.push_back(operands[i]);
1667 }
1668 }
1669
1670 append_range(resultOperands, remappedDims);
1671 append_range(resultOperands, symOperands);
1672 operands = resultOperands;
1673 mapOrSet = mapOrSet.replaceDimsAndSymbols(
1674 /*dimReplacements=*/{}, symRemapping, oldNumDims + nextDim, nextSym);
1675
1676 assert(mapOrSet.getNumInputs() == operands.size() &&
1677 "map/set inputs must match number of operands");
1678}
1679
1680// Works for either an affine map or an integer set.
1681template <class MapOrSet>
1682static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
1683 SmallVectorImpl<Value> *operands) {
1684 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
1685 "Argument must be either of AffineMap or IntegerSet type");
1686
1687 if (!mapOrSet || operands->empty())
1688 return;
1689
1690 assert(mapOrSet->getNumInputs() == operands->size() &&
1691 "map/set inputs must match number of operands");
1692
1693 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
1694 legalizeDemotedDims<MapOrSet>(*mapOrSet, *operands);
1695
1696 // Check to see what dims are used.
1697 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
1698 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
1699 mapOrSet->walkExprs([&](AffineExpr expr) {
1700 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
1701 usedDims[dimExpr.getPosition()] = true;
1702 else if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr))
1703 usedSyms[symExpr.getPosition()] = true;
1704 });
1705
1706 auto *context = mapOrSet->getContext();
1707
1708 SmallVector<Value, 8> resultOperands;
1709 resultOperands.reserve(operands->size());
1710
1711 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
1712 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
1713 unsigned nextDim = 0;
1714 for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
1715 if (usedDims[i]) {
1716 // Remap dim positions for duplicate operands.
1717 auto it = seenDims.find((*operands)[i]);
1718 if (it == seenDims.end()) {
1719 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
1720 resultOperands.push_back((*operands)[i]);
1721 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
1722 } else {
1723 dimRemapping[i] = it->second;
1724 }
1725 }
1726 }
1727 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
1728 SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
1729 unsigned nextSym = 0;
1730 for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
1731 if (!usedSyms[i])
1732 continue;
1733 // Handle constant operands (only needed for symbolic operands since
1734 // constant operands in dimensional positions would have already been
1735 // promoted to symbolic positions above).
1736 IntegerAttr operandCst;
1737 if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
1738 m_Constant(&operandCst))) {
1739 symRemapping[i] =
1740 getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
1741 continue;
1742 }
1743 // Remap symbol positions for duplicate operands.
1744 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
1745 if (it == seenSymbols.end()) {
1746 symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
1747 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
1748 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
1749 symRemapping[i]));
1750 } else {
1751 symRemapping[i] = it->second;
1752 }
1753 }
1754 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
1755 nextDim, nextSym);
1756 *operands = resultOperands;
1757}
1758
1763
1768
1769namespace {
1770/// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
1771/// maps that supply results into them.
1772///
1773template <typename AffineOpTy>
1774struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
1775 using OpRewritePattern<AffineOpTy>::OpRewritePattern;
1776
1777 /// Replace the affine op with another instance of it with the supplied
1778 /// map and mapOperands.
1779 void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
1780 AffineMap map, ArrayRef<Value> mapOperands) const;
1781
1782 LogicalResult matchAndRewrite(AffineOpTy affineOp,
1783 PatternRewriter &rewriter) const override {
1784 static_assert(
1785 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
1786 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
1787 AffineVectorStoreOp, AffineVectorLoadOp>::value,
1788 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
1789 "expected");
1790 auto map = affineOp.getAffineMap();
1791 AffineMap oldMap = map;
1792 auto oldOperands = affineOp.getMapOperands();
1793 SmallVector<Value, 8> resultOperands(oldOperands);
1794 composeAffineMapAndOperands(&map, &resultOperands);
1795 canonicalizeMapAndOperands(&map, &resultOperands);
1796 simplifyMapWithOperands(map, resultOperands);
1797 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
1798 resultOperands.begin()))
1799 return failure();
1800
1801 replaceAffineOp(rewriter, affineOp, map, resultOperands);
1802 return success();
1803 }
1804};
1805
1806// Specialize the template to account for the different build signatures for
1807// affine load, store, and apply ops.
1808template <>
1809void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
1810 PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
1811 ArrayRef<Value> mapOperands) const {
1812 rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
1813 mapOperands);
1814}
1815template <>
1816void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
1817 PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
1818 ArrayRef<Value> mapOperands) const {
1819 rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1820 prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1821 prefetch.getLocalityHint(), prefetch.getIsDataCache());
1822}
1823template <>
1824void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
1825 PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
1826 ArrayRef<Value> mapOperands) const {
1827 rewriter.replaceOpWithNewOp<AffineStoreOp>(
1828 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
1829}
1830template <>
1831void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
1832 PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
1833 ArrayRef<Value> mapOperands) const {
1834 rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
1835 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
1836 mapOperands);
1837}
1838template <>
1839void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
1840 PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
1841 ArrayRef<Value> mapOperands) const {
1842 rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
1843 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
1844 mapOperands);
1845}
1846
1847// Generic version for ops that don't have extra operands.
1848template <typename AffineOpTy>
1849void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
1850 PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
1851 ArrayRef<Value> mapOperands) const {
1852 rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
1853}
1854} // namespace
1855
1856void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1857 MLIRContext *context) {
1858 results.add<SimplifyAffineOp<AffineApplyOp>>(context);
1859}
1860
1861//===----------------------------------------------------------------------===//
1862// AffineDmaStartOp
1863//===----------------------------------------------------------------------===//
1864
1865// TODO: Check that map operands are loop IVs or symbols.
1866void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
1867 Value srcMemRef, AffineMap srcMap,
1868 ValueRange srcIndices, Value destMemRef,
1869 AffineMap dstMap, ValueRange destIndices,
1870 Value tagMemRef, AffineMap tagMap,
1871 ValueRange tagIndices, Value numElements,
1872 Value stride, Value elementsPerStride) {
1873 result.addOperands(srcMemRef);
1874 result.addAttribute(getSrcMapAttrStrName(), AffineMapAttr::get(srcMap));
1875 result.addOperands(srcIndices);
1876 result.addOperands(destMemRef);
1877 result.addAttribute(getDstMapAttrStrName(), AffineMapAttr::get(dstMap));
1878 result.addOperands(destIndices);
1879 result.addOperands(tagMemRef);
1880 result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
1881 result.addOperands(tagIndices);
1882 result.addOperands(numElements);
1883 if (stride) {
1884 result.addOperands({stride, elementsPerStride});
1885 }
1886}
1887
1888void AffineDmaStartOp::print(OpAsmPrinter &p) {
1889 p << " " << getSrcMemRef() << '[';
1890 p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
1891 p << "], " << getDstMemRef() << '[';
1892 p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
1893 p << "], " << getTagMemRef() << '[';
1894 p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1895 p << "], " << getNumElements();
1896 if (isStrided()) {
1897 p << ", " << getStride();
1898 p << ", " << getNumElementsPerStride();
1899 }
1900 p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1901 << getTagMemRefType();
1902}
1903
1904// Parse AffineDmaStartOp.
1905// Ex:
1906// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1907// %stride, %num_elt_per_stride
1908// : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1909//
1910ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
1912 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1913 AffineMapAttr srcMapAttr;
1915 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1916 AffineMapAttr dstMapAttr;
1918 OpAsmParser::UnresolvedOperand tagMemRefInfo;
1919 AffineMapAttr tagMapAttr;
1921 OpAsmParser::UnresolvedOperand numElementsInfo;
1923
1925 auto indexType = parser.getBuilder().getIndexType();
1926
1927 // Parse and resolve the following list of operands:
1928 // *) dst memref followed by its affine maps operands (in square brackets).
1929 // *) src memref followed by its affine map operands (in square brackets).
1930 // *) tag memref followed by its affine map operands (in square brackets).
1931 // *) number of elements transferred by DMA operation.
1932 if (parser.parseOperand(srcMemRefInfo) ||
1933 parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1934 getSrcMapAttrStrName(),
1935 result.attributes) ||
1936 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1937 parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1938 getDstMapAttrStrName(),
1939 result.attributes) ||
1940 parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1941 parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1942 getTagMapAttrStrName(),
1943 result.attributes) ||
1944 parser.parseComma() || parser.parseOperand(numElementsInfo))
1945 return failure();
1946
1947 // Parse optional stride and elements per stride.
1948 if (parser.parseTrailingOperandList(strideInfo))
1949 return failure();
1950
1951 if (!strideInfo.empty() && strideInfo.size() != 2) {
1952 return parser.emitError(parser.getNameLoc(),
1953 "expected two stride related operands");
1954 }
1955 bool isStrided = strideInfo.size() == 2;
1956
1957 if (parser.parseColonTypeList(types))
1958 return failure();
1959
1960 if (types.size() != 3)
1961 return parser.emitError(parser.getNameLoc(), "expected three types");
1962
1963 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1964 parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1965 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1966 parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1967 parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1968 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1969 parser.resolveOperand(numElementsInfo, indexType, result.operands))
1970 return failure();
1971
1972 if (isStrided) {
1973 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1974 return failure();
1975 }
1976
1977 // Check that src/dst/tag operand counts match their map.numInputs.
1978 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1979 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1980 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1981 return parser.emitError(parser.getNameLoc(),
1982 "memref operand count not equal to map.numInputs");
1983 return success();
1984}
1985
1986LogicalResult AffineDmaStartOp::verify() {
1987 if (!llvm::isa<MemRefType>(getOperand(getSrcMemRefOperandIndex()).getType()))
1988 return emitOpError("expected DMA source to be of memref type");
1989 if (!llvm::isa<MemRefType>(getOperand(getDstMemRefOperandIndex()).getType()))
1990 return emitOpError("expected DMA destination to be of memref type");
1991 if (!llvm::isa<MemRefType>(getOperand(getTagMemRefOperandIndex()).getType()))
1992 return emitOpError("expected DMA tag to be of memref type");
1993
1994 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1995 getDstMap().getNumInputs() +
1996 getTagMap().getNumInputs();
1997 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1998 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1999 return emitOpError("incorrect number of operands");
2000 }
2001
2002 Region *scope = getAffineScope(*this);
2003 for (auto idx : getSrcIndices()) {
2004 if (!idx.getType().isIndex())
2005 return emitOpError("src index to dma_start must have 'index' type");
2006 if (!isValidAffineIndexOperand(idx, scope))
2007 return emitOpError(
2008 "src index must be a valid dimension or symbol identifier");
2009 }
2010 for (auto idx : getDstIndices()) {
2011 if (!idx.getType().isIndex())
2012 return emitOpError("dst index to dma_start must have 'index' type");
2013 if (!isValidAffineIndexOperand(idx, scope))
2014 return emitOpError(
2015 "dst index must be a valid dimension or symbol identifier");
2016 }
2017 for (auto idx : getTagIndices()) {
2018 if (!idx.getType().isIndex())
2019 return emitOpError("tag index to dma_start must have 'index' type");
2020 if (!isValidAffineIndexOperand(idx, scope))
2021 return emitOpError(
2022 "tag index must be a valid dimension or symbol identifier");
2023 }
2024 return success();
2025}
2026
2027LogicalResult AffineDmaStartOp::fold(FoldAdaptor adaptor,
2029 /// dma_start(memrefcast) -> dma_start
2030 return memref::foldMemRefCast(*this);
2031}
2032
2033void AffineDmaStartOp::getEffects(
2035 &effects) {
2036 effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
2038 effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
2040 effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
2042}
2043
2044//===----------------------------------------------------------------------===//
2045// AffineDmaWaitOp
2046//===----------------------------------------------------------------------===//
2047
2048// TODO: Check that map operands are loop IVs or symbols.
2049void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
2050 Value tagMemRef, AffineMap tagMap,
2051 ValueRange tagIndices, Value numElements) {
2052 result.addOperands(tagMemRef);
2053 result.addAttribute(getTagMapAttrStrName(), AffineMapAttr::get(tagMap));
2054 result.addOperands(tagIndices);
2055 result.addOperands(numElements);
2056}
2057
2058void AffineDmaWaitOp::print(OpAsmPrinter &p) {
2059 p << " " << getTagMemRef() << '[';
2060 SmallVector<Value, 2> operands(getTagIndices());
2061 p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
2062 p << "], ";
2064 p << " : " << getTagMemRef().getType();
2065}
2066
2067// Parse AffineDmaWaitOp.
2068// Eg:
2069// affine.dma_wait %tag[%index], %num_elements
2070// : memref<1 x i32, (d0) -> (d0), 4>
2071//
2072ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
2074 OpAsmParser::UnresolvedOperand tagMemRefInfo;
2075 AffineMapAttr tagMapAttr;
2077 Type type;
2078 auto indexType = parser.getBuilder().getIndexType();
2079 OpAsmParser::UnresolvedOperand numElementsInfo;
2080
2081 // Parse tag memref, its map operands, and dma size.
2082 if (parser.parseOperand(tagMemRefInfo) ||
2083 parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
2084 getTagMapAttrStrName(),
2085 result.attributes) ||
2086 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
2087 parser.parseColonType(type) ||
2088 parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
2089 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
2090 parser.resolveOperand(numElementsInfo, indexType, result.operands))
2091 return failure();
2092
2093 if (!llvm::isa<MemRefType>(type))
2094 return parser.emitError(parser.getNameLoc(),
2095 "expected tag to be of memref type");
2096
2097 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
2098 return parser.emitError(parser.getNameLoc(),
2099 "tag memref operand count != to map.numInputs");
2100 return success();
2101}
2102
2103LogicalResult AffineDmaWaitOp::verify() {
2104 if (!llvm::isa<MemRefType>(getOperand(0).getType()))
2105 return emitOpError("expected DMA tag to be of memref type");
2106 Region *scope = getAffineScope(*this);
2107 for (auto idx : getTagIndices()) {
2108 if (!idx.getType().isIndex())
2109 return emitOpError("index to dma_wait must have 'index' type");
2110 if (!isValidAffineIndexOperand(idx, scope))
2111 return emitOpError(
2112 "index must be a valid dimension or symbol identifier");
2113 }
2114 return success();
2115}
2116
2117LogicalResult AffineDmaWaitOp::fold(FoldAdaptor adaptor,
2119 /// dma_wait(memrefcast) -> dma_wait
2120 return memref::foldMemRefCast(*this);
2121}
2122
2123void AffineDmaWaitOp::getEffects(
2125 &effects) {
2126 effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
2128}
2129
2130//===----------------------------------------------------------------------===//
2131// AffineForOp
2132//===----------------------------------------------------------------------===//
2133
2134/// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
2135/// bodyBuilder are empty/null, we include default terminator op.
2136void AffineForOp::build(OpBuilder &builder, OperationState &result,
2137 ValueRange lbOperands, AffineMap lbMap,
2138 ValueRange ubOperands, AffineMap ubMap, int64_t step,
2139 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
2140 assert(((!lbMap && lbOperands.empty()) ||
2141 lbOperands.size() == lbMap.getNumInputs()) &&
2142 "lower bound operand count does not match the affine map");
2143 assert(((!ubMap && ubOperands.empty()) ||
2144 ubOperands.size() == ubMap.getNumInputs()) &&
2145 "upper bound operand count does not match the affine map");
2146 assert(step > 0 && "step has to be a positive integer constant");
2147
2148 OpBuilder::InsertionGuard guard(builder);
2149
2150 // Set variadic segment sizes.
2151 result.addAttribute(
2152 getOperandSegmentSizeAttr(),
2153 builder.getDenseI32ArrayAttr({static_cast<int32_t>(lbOperands.size()),
2154 static_cast<int32_t>(ubOperands.size()),
2155 static_cast<int32_t>(iterArgs.size())}));
2156
2157 for (Value val : iterArgs)
2158 result.addTypes(val.getType());
2159
2160 // Add an attribute for the step.
2161 result.addAttribute(getStepAttrName(result.name),
2162 builder.getIntegerAttr(builder.getIndexType(), step));
2163
2164 // Add the lower bound.
2165 result.addAttribute(getLowerBoundMapAttrName(result.name),
2166 AffineMapAttr::get(lbMap));
2167 result.addOperands(lbOperands);
2168
2169 // Add the upper bound.
2170 result.addAttribute(getUpperBoundMapAttrName(result.name),
2171 AffineMapAttr::get(ubMap));
2172 result.addOperands(ubOperands);
2173
2174 result.addOperands(iterArgs);
2175 // Create a region and a block for the body. The argument of the region is
2176 // the loop induction variable.
2177 Region *bodyRegion = result.addRegion();
2178 Block *bodyBlock = builder.createBlock(bodyRegion);
2179 Value inductionVar =
2180 bodyBlock->addArgument(builder.getIndexType(), result.location);
2181 for (Value val : iterArgs)
2182 bodyBlock->addArgument(val.getType(), val.getLoc());
2183
2184 // Create the default terminator if the builder is not provided and if the
2185 // iteration arguments are not provided. Otherwise, leave this to the caller
2186 // because we don't know which values to return from the loop.
2187 if (iterArgs.empty() && !bodyBuilder) {
2188 ensureTerminator(*bodyRegion, builder, result.location);
2189 } else if (bodyBuilder) {
2190 OpBuilder::InsertionGuard guard(builder);
2191 builder.setInsertionPointToStart(bodyBlock);
2192 bodyBuilder(builder, result.location, inductionVar,
2193 bodyBlock->getArguments().drop_front());
2194 }
2195}
2196
2197void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
2198 int64_t ub, int64_t step, ValueRange iterArgs,
2199 BodyBuilderFn bodyBuilder) {
2200 auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
2201 auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
2202 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
2203 bodyBuilder);
2204}
2205
2206LogicalResult AffineForOp::verifyRegions() {
2207 // Step must be a strictly positive integer.
2208 if (getStepAsInt() <= 0)
2209 return emitOpError("expected step to be a positive integer, got ")
2210 << getStepAsInt();
2211
2212 // Check that the body defines as single block argument for the induction
2213 // variable.
2214 auto *body = getBody();
2215 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
2216 return emitOpError("expected body to have a single index argument for the "
2217 "induction variable");
2218
2219 // Verify that the bound operands are valid dimension/symbols.
2220 /// Lower bound.
2221 if (getLowerBoundMap().getNumInputs() > 0)
2223 getLowerBoundMap().getNumDims())))
2224 return failure();
2225 /// Upper bound.
2226 if (getUpperBoundMap().getNumInputs() > 0)
2228 getUpperBoundMap().getNumDims())))
2229 return failure();
2230 if (getLowerBoundMap().getNumResults() < 1)
2231 return emitOpError("expected lower bound map to have at least one result");
2232 if (getUpperBoundMap().getNumResults() < 1)
2233 return emitOpError("expected upper bound map to have at least one result");
2234
2235 unsigned opNumResults = getNumResults();
2236 if (opNumResults == 0)
2237 return success();
2238
2239 // If ForOp defines values, check that the number and types of the defined
2240 // values match ForOp initial iter operands and backedge basic block
2241 // arguments.
2242 if (getNumIterOperands() != opNumResults)
2243 return emitOpError(
2244 "mismatch between the number of loop-carried values and results");
2245 if (getNumRegionIterArgs() != opNumResults)
2246 return emitOpError(
2247 "mismatch between the number of basic block args and results");
2248
2249 return success();
2250}
2251
2252/// Parse a for operation loop bounds.
2253static ParseResult parseBound(bool isLower, OperationState &result,
2254 OpAsmParser &p) {
2255 // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
2256 // the map has multiple results.
2257 bool failedToParsedMinMax =
2258 failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
2259
2260 auto &builder = p.getBuilder();
2261 auto boundAttrStrName =
2262 isLower ? AffineForOp::getLowerBoundMapAttrName(result.name)
2263 : AffineForOp::getUpperBoundMapAttrName(result.name);
2264
2265 // Parse ssa-id as identity map.
2267 if (p.parseOperandList(boundOpInfos))
2268 return failure();
2269
2270 if (!boundOpInfos.empty()) {
2271 // Check that only one operand was parsed.
2272 if (boundOpInfos.size() > 1)
2273 return p.emitError(p.getNameLoc(),
2274 "expected only one loop bound operand");
2275
2276 // TODO: improve error message when SSA value is not of index type.
2277 // Currently it is 'use of value ... expects different type than prior uses'
2278 if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
2279 result.operands))
2280 return failure();
2281
2282 // Create an identity map using symbol id. This representation is optimized
2283 // for storage. Analysis passes may expand it into a multi-dimensional map
2284 // if desired.
2285 AffineMap map = builder.getSymbolIdentityMap();
2286 result.addAttribute(boundAttrStrName, AffineMapAttr::get(map));
2287 return success();
2288 }
2289
2290 // Get the attribute location.
2291 SMLoc attrLoc = p.getCurrentLocation();
2292
2293 Attribute boundAttr;
2294 if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrStrName,
2295 result.attributes))
2296 return failure();
2297
2298 // Parse full form - affine map followed by dim and symbol list.
2299 if (auto affineMapAttr = dyn_cast<AffineMapAttr>(boundAttr)) {
2300 unsigned currentNumOperands = result.operands.size();
2301 unsigned numDims;
2302 if (parseDimAndSymbolList(p, result.operands, numDims))
2303 return failure();
2304
2305 auto map = affineMapAttr.getValue();
2306 if (map.getNumDims() != numDims)
2307 return p.emitError(
2308 p.getNameLoc(),
2309 "dim operand count and affine map dim count must match");
2310
2311 unsigned numDimAndSymbolOperands =
2312 result.operands.size() - currentNumOperands;
2313 if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
2314 return p.emitError(
2315 p.getNameLoc(),
2316 "symbol operand count and affine map symbol count must match");
2317
2318 // If the map has multiple results, make sure that we parsed the min/max
2319 // prefix.
2320 if (map.getNumResults() > 1 && failedToParsedMinMax) {
2321 if (isLower) {
2322 return p.emitError(attrLoc, "lower loop bound affine map with "
2323 "multiple results requires 'max' prefix");
2324 }
2325 return p.emitError(attrLoc, "upper loop bound affine map with multiple "
2326 "results requires 'min' prefix");
2327 }
2328 return success();
2329 }
2330
2331 // Parse custom assembly form.
2332 if (auto integerAttr = dyn_cast<IntegerAttr>(boundAttr)) {
2333 result.attributes.pop_back();
2334 result.addAttribute(
2335 boundAttrStrName,
2336 AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
2337 return success();
2338 }
2339
2340 return p.emitError(
2341 p.getNameLoc(),
2342 "expected valid affine map representation for loop bounds");
2343}
2344
2345ParseResult AffineForOp::parse(OpAsmParser &parser, OperationState &result) {
2346 auto &builder = parser.getBuilder();
2347 OpAsmParser::Argument inductionVariable;
2348 inductionVariable.type = builder.getIndexType();
2349 // Parse the induction variable followed by '='.
2350 if (parser.parseArgument(inductionVariable) || parser.parseEqual())
2351 return failure();
2352
2353 // Parse loop bounds.
2354 int64_t numOperands = result.operands.size();
2355 if (parseBound(/*isLower=*/true, result, parser))
2356 return failure();
2357 int64_t numLbOperands = result.operands.size() - numOperands;
2358 if (parser.parseKeyword("to", " between bounds"))
2359 return failure();
2360 numOperands = result.operands.size();
2361 if (parseBound(/*isLower=*/false, result, parser))
2362 return failure();
2363 int64_t numUbOperands = result.operands.size() - numOperands;
2364
2365 // Parse the optional loop step, we default to 1 if one is not present.
2366 if (parser.parseOptionalKeyword("step")) {
2367 result.addAttribute(
2368 getStepAttrName(result.name),
2369 builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
2370 } else {
2371 SMLoc stepLoc = parser.getCurrentLocation();
2372 IntegerAttr stepAttr;
2373 if (parser.parseAttribute(stepAttr, builder.getIndexType(),
2374 getStepAttrName(result.name).data(),
2375 result.attributes))
2376 return failure();
2377
2378 if (!stepAttr.getValue().isStrictlyPositive())
2379 return parser.emitError(
2380 stepLoc,
2381 "expected step to be representable as a positive signed integer");
2382 }
2383
2384 // Parse the optional initial iteration arguments.
2385 SmallVector<OpAsmParser::Argument, 4> regionArgs;
2386 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2387
2388 // Induction variable.
2389 regionArgs.push_back(inductionVariable);
2390
2391 if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
2392 // Parse assignment list and results type list.
2393 if (parser.parseAssignmentList(regionArgs, operands) ||
2394 parser.parseArrowTypeList(result.types))
2395 return failure();
2396 // Resolve input operands.
2397 for (auto argOperandType :
2398 llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
2399 Type type = std::get<2>(argOperandType);
2400 std::get<0>(argOperandType).type = type;
2401 if (parser.resolveOperand(std::get<1>(argOperandType), type,
2402 result.operands))
2403 return failure();
2404 }
2405 }
2406
2407 result.addAttribute(
2408 getOperandSegmentSizeAttr(),
2409 builder.getDenseI32ArrayAttr({static_cast<int32_t>(numLbOperands),
2410 static_cast<int32_t>(numUbOperands),
2411 static_cast<int32_t>(operands.size())}));
2412
2413 // Parse the body region.
2414 Region *body = result.addRegion();
2415 if (regionArgs.size() != result.types.size() + 1)
2416 return parser.emitError(
2417 parser.getNameLoc(),
2418 "mismatch between the number of loop-carried values and results");
2419 if (parser.parseRegion(*body, regionArgs))
2420 return failure();
2421
2422 AffineForOp::ensureTerminator(*body, builder, result.location);
2423
2424 // Parse the optional attribute list.
2425 return parser.parseOptionalAttrDict(result.attributes);
2426}
2427
2428static void printBound(AffineMapAttr boundMap,
2429 Operation::operand_range boundOperands,
2430 const char *prefix, OpAsmPrinter &p) {
2431 AffineMap map = boundMap.getValue();
2432
2433 // Check if this bound should be printed using custom assembly form.
2434 // The decision to restrict printing custom assembly form to trivial cases
2435 // comes from the will to roundtrip MLIR binary -> text -> binary in a
2436 // lossless way.
2437 // Therefore, custom assembly form parsing and printing is only supported for
2438 // zero-operand constant maps and single symbol operand identity maps.
2439 if (map.getNumResults() == 1) {
2440 AffineExpr expr = map.getResult(0);
2441
2442 // Print constant bound.
2443 if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
2444 if (auto constExpr = dyn_cast<AffineConstantExpr>(expr)) {
2445 p << constExpr.getValue();
2446 return;
2447 }
2448 }
2449
2450 // Print bound that consists of a single SSA symbol if the map is over a
2451 // single symbol.
2452 if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
2453 if (isa<AffineSymbolExpr>(expr)) {
2454 p.printOperand(*boundOperands.begin());
2455 return;
2456 }
2457 }
2458 } else {
2459 // Map has multiple results. Print 'min' or 'max' prefix.
2460 p << prefix << ' ';
2461 }
2462
2463 // Print the map and its operands.
2464 p << boundMap;
2465 printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
2466 map.getNumDims(), p);
2467}
2468
2469unsigned AffineForOp::getNumIterOperands() {
2470 AffineMap lbMap = getLowerBoundMapAttr().getValue();
2471 AffineMap ubMap = getUpperBoundMapAttr().getValue();
2472
2473 return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
2474}
2475
2476std::optional<MutableArrayRef<OpOperand>>
2477AffineForOp::getYieldedValuesMutable() {
2478 return cast<AffineYieldOp>(getBody()->getTerminator()).getOperandsMutable();
2479}
2480
2481void AffineForOp::print(OpAsmPrinter &p) {
2482 p << ' ';
2483 p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
2484 /*omitType=*/true);
2485 p << " = ";
2486 printBound(getLowerBoundMapAttr(), getLowerBoundOperands(), "max", p);
2487 p << " to ";
2488 printBound(getUpperBoundMapAttr(), getUpperBoundOperands(), "min", p);
2489
2490 if (getStepAsInt() != 1)
2491 p << " step " << getStepAsInt();
2492
2493 bool printBlockTerminators = false;
2494 if (getNumIterOperands() > 0) {
2495 p << " iter_args(";
2496 auto regionArgs = getRegionIterArgs();
2497 auto operands = getInits();
2498
2499 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
2500 p << std::get<0>(it) << " = " << std::get<1>(it);
2501 });
2502 p << ") -> (" << getResultTypes() << ")";
2503 printBlockTerminators = true;
2504 }
2505
2506 p << ' ';
2507 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2508 printBlockTerminators);
2510 (*this)->getAttrs(),
2511 /*elidedAttrs=*/{getLowerBoundMapAttrName(getOperation()->getName()),
2512 getUpperBoundMapAttrName(getOperation()->getName()),
2513 getStepAttrName(getOperation()->getName()),
2514 getOperandSegmentSizeAttr()});
2515}
2516
2517/// Fold the constant bounds of a loop.
2518static LogicalResult foldLoopBounds(AffineForOp forOp) {
2519 auto foldLowerOrUpperBound = [&forOp](bool lower) {
2520 // Check to see if each of the operands is the result of a constant. If
2521 // so, get the value. If not, ignore it.
2522 SmallVector<Attribute, 8> operandConstants;
2523 auto boundOperands =
2524 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
2525 for (auto operand : boundOperands) {
2526 Attribute operandCst;
2527 matchPattern(operand, m_Constant(&operandCst));
2528 operandConstants.push_back(operandCst);
2529 }
2530
2531 AffineMap boundMap =
2532 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
2533 assert(boundMap.getNumResults() >= 1 &&
2534 "bound maps should have at least one result");
2535 SmallVector<Attribute, 4> foldedResults;
2536 if (failed(boundMap.constantFold(operandConstants, foldedResults)))
2537 return failure();
2538
2539 // Compute the max or min as applicable over the results.
2540 assert(!foldedResults.empty() && "bounds should have at least one result");
2541 auto maxOrMin = llvm::cast<IntegerAttr>(foldedResults[0]).getValue();
2542 for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
2543 auto foldedResult = llvm::cast<IntegerAttr>(foldedResults[i]).getValue();
2544 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
2545 : llvm::APIntOps::smin(maxOrMin, foldedResult);
2546 }
2547 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
2548 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
2549 return success();
2550 };
2551
2552 // Try to fold the lower bound.
2553 bool folded = false;
2554 if (!forOp.hasConstantLowerBound())
2555 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
2556
2557 // Try to fold the upper bound.
2558 if (!forOp.hasConstantUpperBound())
2559 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
2560 return success(folded);
2561}
2562
2563/// Returns constant trip count in trivial cases.
2564static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
2565 int64_t step = forOp.getStepAsInt();
2566 if (!forOp.hasConstantBounds() || step <= 0)
2567 return std::nullopt;
2568 int64_t lb = forOp.getConstantLowerBound();
2569 int64_t ub = forOp.getConstantUpperBound();
2570 return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
2571}
2572
2573/// Fold the empty loop.
2575 if (!llvm::hasSingleElement(*forOp.getBody()))
2576 return {};
2577 if (forOp.getNumResults() == 0)
2578 return {};
2579 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2580 if (tripCount == 0) {
2581 // The initial values of the iteration arguments would be the op's
2582 // results.
2583 return forOp.getInits();
2584 }
2585 SmallVector<Value, 4> replacements;
2586 auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
2587 auto iterArgs = forOp.getRegionIterArgs();
2588 bool hasValDefinedOutsideLoop = false;
2589 bool iterArgsNotInOrder = false;
2590 for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
2591 Value val = yieldOp.getOperand(i);
2592 BlockArgument *iterArgIt = llvm::find(iterArgs, val);
2593 // TODO: It should be possible to perform a replacement by computing the
2594 // last value of the IV based on the bounds and the step.
2595 if (val == forOp.getInductionVar())
2596 return {};
2597 if (iterArgIt == iterArgs.end()) {
2598 // `val` is defined outside of the loop.
2599 assert(forOp.isDefinedOutsideOfLoop(val) &&
2600 "must be defined outside of the loop");
2601 hasValDefinedOutsideLoop = true;
2602 replacements.push_back(val);
2603 } else {
2604 unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
2605 if (pos != i)
2606 iterArgsNotInOrder = true;
2607 replacements.push_back(forOp.getInits()[pos]);
2608 }
2609 }
2610 // Bail out when the trip count is unknown and the loop returns any value
2611 // defined outside of the loop or any iterArg out of order.
2612 if (!tripCount.has_value() &&
2613 (hasValDefinedOutsideLoop || iterArgsNotInOrder))
2614 return {};
2615 // Bail out when the loop iterates more than once and it returns any iterArg
2616 // out of order.
2617 if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
2618 return {};
2619 return llvm::to_vector_of<OpFoldResult>(replacements);
2620}
2621
2622/// Canonicalize the bounds of the given loop.
2623static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
2624 SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
2625 SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
2626
2627 auto lbMap = forOp.getLowerBoundMap();
2628 auto ubMap = forOp.getUpperBoundMap();
2629 auto prevLbMap = lbMap;
2630 auto prevUbMap = ubMap;
2631
2632 composeAffineMapAndOperands(&lbMap, &lbOperands);
2633 canonicalizeMapAndOperands(&lbMap, &lbOperands);
2634 simplifyMinOrMaxExprWithOperands(lbMap, lbOperands, /*isMax=*/true);
2635 simplifyMinOrMaxExprWithOperands(ubMap, ubOperands, /*isMax=*/false);
2636 lbMap = removeDuplicateExprs(lbMap);
2637
2638 composeAffineMapAndOperands(&ubMap, &ubOperands);
2639 canonicalizeMapAndOperands(&ubMap, &ubOperands);
2640 ubMap = removeDuplicateExprs(ubMap);
2641
2642 // Any canonicalization change always leads to updated map(s).
2643 if (lbMap == prevLbMap && ubMap == prevUbMap)
2644 return failure();
2645
2646 if (lbMap != prevLbMap)
2647 forOp.setLowerBound(lbOperands, lbMap);
2648 if (ubMap != prevUbMap)
2649 forOp.setUpperBound(ubOperands, ubMap);
2650 return success();
2651}
2652
2653/// Returns true if the affine.for has zero iterations in trivial cases.
2654static bool hasTrivialZeroTripCount(AffineForOp op) {
2655 return getTrivialConstantTripCount(op) == 0;
2656}
2657
2658LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
2660 bool folded = succeeded(foldLoopBounds(*this));
2661 folded |= succeeded(canonicalizeLoopBounds(*this));
2662 if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
2663 // The initial values of the loop-carried variables (iter_args) are the
2664 // results of the op. But this must be avoided for an affine.for op that
2665 // does not return any results. Since ops that do not return results cannot
2666 // be folded away, we would enter an infinite loop of folds on the same
2667 // affine.for op.
2668 results.assign(getInits().begin(), getInits().end());
2669 folded = true;
2670 }
2671 SmallVector<OpFoldResult> foldResults = AffineForEmptyLoopFolder(*this);
2672 if (!foldResults.empty()) {
2673 results.assign(foldResults);
2674 folded = true;
2675 }
2676 return success(folded);
2677}
2678
2679OperandRange AffineForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
2680 assert(
2681 (successor.isOperation() || successor.getSuccessor() == &getRegion()) &&
2682 "invalid region point");
2683
2684 // The initial operands map to the loop arguments after the induction
2685 // variable or are forwarded to the results when the trip count is zero.
2686 return getInits();
2687}
2688
2689void AffineForOp::getSuccessorRegions(
2691 assert((point.isParent() ||
2692 point.getTerminatorPredecessorOrNull()->getParentRegion() ==
2693 &getRegion()) &&
2694 "expected loop region");
2695 // The loop may typically branch back to its body or to the parent operation.
2696 // If the predecessor is the parent op and the trip count is known to be at
2697 // least one, branch into the body using the iterator arguments. And in cases
2698 // we know the trip count is zero, it can only branch back to its parent.
2699 std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
2700 if (tripCount.has_value()) {
2701 if (!point.isParent()) {
2702 // From the loop body, if the trip count is one, we can only branch back
2703 // to the parent.
2704 if (tripCount == 1) {
2705 regions.push_back(RegionSuccessor(getOperation()));
2706 return;
2707 }
2708 if (tripCount == 0)
2709 return;
2710 } else {
2711 if (tripCount.value() > 0) {
2712 regions.push_back(RegionSuccessor(&getRegion()));
2713 return;
2714 }
2715 if (tripCount.value() == 0) {
2716 regions.push_back(RegionSuccessor(getOperation()));
2717 return;
2718 }
2719 }
2720 }
2721
2722 // In all other cases, the loop may branch back to itself or the parent
2723 // operation.
2724 regions.push_back(RegionSuccessor(&getRegion()));
2725 regions.push_back(RegionSuccessor(getOperation()));
2726}
2727
2728ValueRange AffineForOp::getSuccessorInputs(RegionSuccessor successor) {
2729 if (successor.isOperation())
2730 return getResults();
2731 return getRegionIterArgs();
2732}
2733
2734AffineBound AffineForOp::getLowerBound() {
2735 return AffineBound(*this, getLowerBoundOperands(), getLowerBoundMap());
2736}
2737
2738AffineBound AffineForOp::getUpperBound() {
2739 return AffineBound(*this, getUpperBoundOperands(), getUpperBoundMap());
2740}
2741
2742void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
2743 assert(lbOperands.size() == map.getNumInputs());
2744 assert(map.getNumResults() >= 1 && "bound map has at least one result");
2745 getLowerBoundOperandsMutable().assign(lbOperands);
2746 setLowerBoundMap(map);
2747}
2748
2749void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
2750 assert(ubOperands.size() == map.getNumInputs());
2751 assert(map.getNumResults() >= 1 && "bound map has at least one result");
2752 getUpperBoundOperandsMutable().assign(ubOperands);
2753 setUpperBoundMap(map);
2754}
2755
2756bool AffineForOp::hasConstantLowerBound() {
2757 return getLowerBoundMap().isSingleConstant();
2758}
2759
2760bool AffineForOp::hasConstantUpperBound() {
2761 return getUpperBoundMap().isSingleConstant();
2762}
2763
2764int64_t AffineForOp::getConstantLowerBound() {
2765 return getLowerBoundMap().getSingleConstantResult();
2766}
2767
2768int64_t AffineForOp::getConstantUpperBound() {
2769 return getUpperBoundMap().getSingleConstantResult();
2770}
2771
2772void AffineForOp::setConstantLowerBound(int64_t value) {
2773 setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
2774}
2775
2776void AffineForOp::setConstantUpperBound(int64_t value) {
2777 setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
2778}
2779
2780AffineForOp::operand_range AffineForOp::getControlOperands() {
2781 return {operand_begin(), operand_begin() + getLowerBoundOperands().size() +
2782 getUpperBoundOperands().size()};
2783}
2784
2785bool AffineForOp::matchingBoundOperandList() {
2786 auto lbMap = getLowerBoundMap();
2787 auto ubMap = getUpperBoundMap();
2788 if (lbMap.getNumDims() != ubMap.getNumDims() ||
2789 lbMap.getNumSymbols() != ubMap.getNumSymbols())
2790 return false;
2791
2792 unsigned numOperands = lbMap.getNumInputs();
2793 for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
2794 // Compare Value 's.
2795 if (getOperand(i) != getOperand(numOperands + i))
2796 return false;
2797 }
2798 return true;
2799}
2800
2801SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
2802
2803std::optional<SmallVector<Value>> AffineForOp::getLoopInductionVars() {
2804 return SmallVector<Value>{getInductionVar()};
2805}
2806
2807std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopLowerBounds() {
2808 if (!hasConstantLowerBound())
2809 return std::nullopt;
2810 OpBuilder b(getContext());
2811 return SmallVector<OpFoldResult>{
2812 OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()))};
2813}
2814
2815std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopSteps() {
2816 OpBuilder b(getContext());
2817 return SmallVector<OpFoldResult>{
2818 OpFoldResult(b.getI64IntegerAttr(getStepAsInt()))};
2819}
2820
2821std::optional<SmallVector<OpFoldResult>> AffineForOp::getLoopUpperBounds() {
2822 if (!hasConstantUpperBound())
2823 return {};
2824 OpBuilder b(getContext());
2825 return SmallVector<OpFoldResult>{
2826 OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()))};
2827}
2828
2829std::optional<APInt> AffineForOp::getStaticTripCount() {
2830 MLIRContext *context = getContext();
2831 int64_t step = getStepAsInt();
2832 if (step <= 0)
2833 return std::nullopt;
2834
2835 if (hasConstantBounds()) {
2836 int64_t lb = getConstantLowerBound();
2837 int64_t ub = getConstantUpperBound();
2838 int64_t loopSpan = ub - lb;
2839 if (loopSpan < 0)
2840 loopSpan = 0;
2841 return APInt(64, llvm::divideCeilSigned(loopSpan, step));
2842 }
2843
2844 auto lbMap = getLowerBoundMap();
2845 auto ubMap = getUpperBoundMap();
2846 if (lbMap.getNumResults() != 1)
2847 return std::nullopt;
2848
2849 // Difference of each upper bound expression from the single lower bound
2850 // expression (divided by the step) provides the expressions for the trip
2851 // count map.
2852 AffineValueMap ubValueMap(ubMap, getUpperBoundOperands());
2853
2854 SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
2855 lbMap.getResult(0));
2856 auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
2857 lbSplatExpr, context);
2858 AffineValueMap lbSplatValueMap(lbMapSplat, getLowerBoundOperands());
2859
2860 AffineValueMap tripCountValueMap;
2861 AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap);
2862
2863 // Take the min if all trip counts are constant.
2864 std::optional<uint64_t> tripCount;
2865 for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i) {
2866 AffineExpr expr = tripCountValueMap.getResult(i).ceilDiv(step);
2867 if (auto constExpr = llvm::dyn_cast<AffineConstantExpr>(expr)) {
2868 uint64_t value = constExpr.getValue();
2869 if (tripCount.has_value())
2870 tripCount = std::min(*tripCount, value);
2871 else
2872 tripCount = value;
2873 } else {
2874 return std::nullopt;
2875 }
2876 }
2877
2878 if (tripCount.has_value())
2879 return APInt(64, *tripCount);
2880
2881 return std::nullopt;
2882}
2883
2884FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
2885 RewriterBase &rewriter, ValueRange newInitOperands,
2886 bool replaceInitOperandUsesInLoop,
2887 const NewYieldValuesFn &newYieldValuesFn) {
2888 // Create a new loop before the existing one, with the extra operands.
2889 OpBuilder::InsertionGuard g(rewriter);
2890 rewriter.setInsertionPoint(getOperation());
2891 auto inits = llvm::to_vector(getInits());
2892 inits.append(newInitOperands.begin(), newInitOperands.end());
2893 AffineForOp newLoop = AffineForOp::create(
2894 rewriter, getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2895 getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
2896
2897 // Generate the new yield values and append them to the scf.yield operation.
2898 auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
2899 ArrayRef<BlockArgument> newIterArgs =
2900 newLoop.getBody()->getArguments().take_back(newInitOperands.size());
2901 {
2902 OpBuilder::InsertionGuard g(rewriter);
2903 rewriter.setInsertionPoint(yieldOp);
2904 SmallVector<Value> newYieldedValues =
2905 newYieldValuesFn(rewriter, getLoc(), newIterArgs);
2906 assert(newInitOperands.size() == newYieldedValues.size() &&
2907 "expected as many new yield values as new iter operands");
2908 rewriter.modifyOpInPlace(yieldOp, [&]() {
2909 yieldOp.getOperandsMutable().append(newYieldedValues);
2910 });
2911 }
2912
2913 // Move the loop body to the new op.
2914 rewriter.mergeBlocks(getBody(), newLoop.getBody(),
2915 newLoop.getBody()->getArguments().take_front(
2916 getBody()->getNumArguments()));
2917
2918 if (replaceInitOperandUsesInLoop) {
2919 // Replace all uses of `newInitOperands` with the corresponding basic block
2920 // arguments.
2921 for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
2922 rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
2923 [&](OpOperand &use) {
2924 Operation *user = use.getOwner();
2925 return newLoop->isProperAncestor(user);
2926 });
2927 }
2928 }
2929
2930 // Replace the old loop.
2931 rewriter.replaceOp(getOperation(),
2932 newLoop->getResults().take_front(getNumResults()));
2933 return cast<LoopLikeOpInterface>(newLoop.getOperation());
2934}
2935
2936Speculation::Speculatability AffineForOp::getSpeculatability() {
2937 // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
2938 // Start and End.
2939 //
2940 // For Step != 1, the loop may not terminate. We can add more smarts here if
2941 // needed.
2942 return getStepAsInt() == 1 ? Speculation::RecursivelySpeculatable
2944}
2945
2946/// Returns true if the provided value is the induction variable of a
2947/// AffineForOp.
2949 return getForInductionVarOwner(val) != AffineForOp();
2950}
2951
2955
2959
2961 auto ivArg = dyn_cast<BlockArgument>(val);
2962 if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent())
2963 return AffineForOp();
2964 if (auto forOp =
2965 ivArg.getOwner()->getParent()->getParentOfType<AffineForOp>())
2966 // Check to make sure `val` is the induction variable, not an iter_arg.
2967 return forOp.getInductionVar() == val ? forOp : AffineForOp();
2968 return AffineForOp();
2969}
2970
2972 auto ivArg = dyn_cast<BlockArgument>(val);
2973 if (!ivArg || !ivArg.getOwner())
2974 return nullptr;
2975 Operation *containingOp = ivArg.getOwner()->getParentOp();
2976 auto parallelOp = dyn_cast_if_present<AffineParallelOp>(containingOp);
2977 if (parallelOp && llvm::is_contained(parallelOp.getIVs(), val))
2978 return parallelOp;
2979 return nullptr;
2980}
2981
2982/// Extracts the induction variables from a list of AffineForOps and returns
2983/// them.
2986 ivs->reserve(forInsts.size());
2987 for (auto forInst : forInsts)
2988 ivs->push_back(forInst.getInductionVar());
2989}
2990
2993 ivs.reserve(affineOps.size());
2994 for (Operation *op : affineOps) {
2995 // Add constraints from forOp's bounds.
2996 if (auto forOp = dyn_cast<AffineForOp>(op))
2997 ivs.push_back(forOp.getInductionVar());
2998 else if (auto parallelOp = dyn_cast<AffineParallelOp>(op))
2999 for (size_t i = 0; i < parallelOp.getBody()->getNumArguments(); i++)
3000 ivs.push_back(parallelOp.getBody()->getArgument(i));
3001 }
3002}
3003
3004/// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
3005/// operations.
3006template <typename BoundListTy, typename LoopCreatorTy>
3008 OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
3009 ArrayRef<int64_t> steps,
3010 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
3011 LoopCreatorTy &&loopCreatorFn) {
3012 assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
3013 assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
3014
3015 // If there are no loops to be constructed, construct the body anyway.
3016 OpBuilder::InsertionGuard guard(builder);
3017 if (lbs.empty()) {
3018 if (bodyBuilderFn)
3019 bodyBuilderFn(builder, loc, ValueRange());
3020 return;
3021 }
3022
3023 // Create the loops iteratively and store the induction variables.
3025 ivs.reserve(lbs.size());
3026 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
3027 // Callback for creating the loop body, always creates the terminator.
3028 auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
3029 ValueRange iterArgs) {
3030 ivs.push_back(iv);
3031 // In the innermost loop, call the body builder.
3032 if (i == e - 1 && bodyBuilderFn) {
3033 OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
3034 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
3035 }
3036 AffineYieldOp::create(nestedBuilder, nestedLoc);
3037 };
3038
3039 // Delegate actual loop creation to the callback in order to dispatch
3040 // between constant- and variable-bound loops.
3041 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
3042 builder.setInsertionPointToStart(loop.getBody());
3043 }
3044}
3045
3046/// Creates an affine loop from the bounds known to be constants.
3047static AffineForOp
3049 int64_t ub, int64_t step,
3050 AffineForOp::BodyBuilderFn bodyBuilderFn) {
3051 return AffineForOp::create(builder, loc, lb, ub, step,
3052 /*iterArgs=*/ValueRange(), bodyBuilderFn);
3053}
3054
3055/// Creates an affine loop from the bounds that may or may not be constants.
3056static AffineForOp
3058 int64_t step,
3059 AffineForOp::BodyBuilderFn bodyBuilderFn) {
3060 std::optional<int64_t> lbConst = getConstantIntValue(lb);
3061 std::optional<int64_t> ubConst = getConstantIntValue(ub);
3062 if (lbConst && ubConst)
3063 return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
3064 ubConst.value(), step, bodyBuilderFn);
3065 return AffineForOp::create(builder, loc, lb, builder.getDimIdentityMap(), ub,
3066 builder.getDimIdentityMap(), step,
3067 /*iterArgs=*/ValueRange(), bodyBuilderFn);
3068}
3069
3071 OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
3073 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
3074 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
3076}
3077
3079 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
3080 ArrayRef<int64_t> steps,
3081 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
3082 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
3084}
3085
3086//===----------------------------------------------------------------------===//
3087// AffineIfOp
3088//===----------------------------------------------------------------------===//
3089
3090namespace {
3091/// Remove else blocks that have nothing other than a zero value yield.
3092struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
3093 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
3094
3095 LogicalResult matchAndRewrite(AffineIfOp ifOp,
3096 PatternRewriter &rewriter) const override {
3097 if (ifOp.getElseRegion().empty() ||
3098 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
3099 return failure();
3100
3101 rewriter.startOpModification(ifOp);
3102 rewriter.eraseBlock(ifOp.getElseBlock());
3103 rewriter.finalizeOpModification(ifOp);
3104 return success();
3105 }
3106};
3107
3108/// Removes affine.if cond if the condition is always true or false in certain
3109/// trivial cases. Promotes the then/else block in the parent operation block.
3110struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
3111 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
3112
3113 LogicalResult matchAndRewrite(AffineIfOp op,
3114 PatternRewriter &rewriter) const override {
3115
3116 auto isTriviallyFalse = [](IntegerSet iSet) {
3117 return iSet.isEmptyIntegerSet();
3118 };
3119
3120 auto isTriviallyTrue = [](IntegerSet iSet) {
3121 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
3122 iSet.getConstraint(0) == 0);
3123 };
3124
3125 IntegerSet affineIfConditions = op.getIntegerSet();
3126 Block *blockToMove;
3127 if (isTriviallyFalse(affineIfConditions)) {
3128 // The absence, or equivalently, the emptiness of the else region need not
3129 // be checked when affine.if is returning results because if an affine.if
3130 // operation is returning results, it always has a non-empty else region.
3131 if (op.getNumResults() == 0 && !op.hasElse()) {
3132 // If the else region is absent, or equivalently, empty, remove the
3133 // affine.if operation (which is not returning any results).
3134 rewriter.eraseOp(op);
3135 return success();
3136 }
3137 blockToMove = op.getElseBlock();
3138 } else if (isTriviallyTrue(affineIfConditions)) {
3139 blockToMove = op.getThenBlock();
3140 } else {
3141 return failure();
3142 }
3143 Operation *blockToMoveTerminator = blockToMove->getTerminator();
3144 // Promote the "blockToMove" block to the parent operation block between the
3145 // prologue and epilogue of "op".
3146 rewriter.inlineBlockBefore(blockToMove, op);
3147 // Replace the "op" operation with the operands of the
3148 // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
3149 // the affine.yield operation present in the "blockToMove" block. It has no
3150 // operands when affine.if is not returning results and therefore, in that
3151 // case, replaceOp just erases "op". When affine.if is not returning
3152 // results, the affine.yield operation can be omitted. It gets inserted
3153 // implicitly.
3154 rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
3155 // Erase the "blockToMoveTerminator" operation since it is now in the parent
3156 // operation block, which already has its own terminator.
3157 rewriter.eraseOp(blockToMoveTerminator);
3158 return success();
3159 }
3160};
3161} // namespace
3162
3163/// AffineIfOp has two regions -- `then` and `else`. The flow of data should be
3164/// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp
3165void AffineIfOp::getSuccessorRegions(
3167 // If the predecessor is an AffineIfOp, then branching into both `then` and
3168 // `else` region is valid.
3169 if (point.isParent()) {
3170 regions.reserve(2);
3171 regions.push_back(RegionSuccessor(&getThenRegion()));
3172 // If the "else" region is empty, branch bach into parent.
3173 if (getElseRegion().empty()) {
3174 regions.push_back(RegionSuccessor(getOperation()));
3175 } else {
3176 regions.push_back(RegionSuccessor(&getElseRegion()));
3177 }
3178 return;
3179 }
3180
3181 // If the predecessor is the `else`/`then` region, then branching into parent
3182 // op is valid.
3183 regions.push_back(RegionSuccessor(getOperation()));
3184}
3185
3186ValueRange AffineIfOp::getSuccessorInputs(RegionSuccessor successor) {
3187 if (successor.isOperation())
3188 return getResults();
3189 if (successor == &getThenRegion())
3190 return getThenRegion().getArguments();
3191 if (successor == &getElseRegion())
3192 return getElseRegion().getArguments();
3193 llvm_unreachable("invalid region successor");
3194}
3195
3196LogicalResult AffineIfOp::verify() {
3197 // Verify that we have a condition attribute.
3198 // FIXME: This should be specified in the arguments list in ODS.
3199 auto conditionAttr =
3200 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3201 if (!conditionAttr)
3202 return emitOpError("requires an integer set attribute named 'condition'");
3203
3204 // Verify that there are enough operands for the condition.
3205 IntegerSet condition = conditionAttr.getValue();
3206 if (getNumOperands() != condition.getNumInputs())
3207 return emitOpError("operand count and condition integer set dimension and "
3208 "symbol count must match");
3209
3210 // Verify that the operands are valid dimension/symbols.
3211 if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(),
3212 condition.getNumDims())))
3213 return failure();
3214
3215 return success();
3216}
3217
3218ParseResult AffineIfOp::parse(OpAsmParser &parser, OperationState &result) {
3219 // Parse the condition attribute set.
3220 IntegerSetAttr conditionAttr;
3221 unsigned numDims;
3222 if (parser.parseAttribute(conditionAttr,
3223 AffineIfOp::getConditionAttrStrName(),
3224 result.attributes) ||
3225 parseDimAndSymbolList(parser, result.operands, numDims))
3226 return failure();
3227
3228 // Verify the condition operands.
3229 auto set = conditionAttr.getValue();
3230 if (set.getNumDims() != numDims)
3231 return parser.emitError(
3232 parser.getNameLoc(),
3233 "dim operand count and integer set dim count must match");
3234 if (numDims + set.getNumSymbols() != result.operands.size())
3235 return parser.emitError(
3236 parser.getNameLoc(),
3237 "symbol operand count and integer set symbol count must match");
3238
3239 if (parser.parseOptionalArrowTypeList(result.types))
3240 return failure();
3241
3242 // Create the regions for 'then' and 'else'. The latter must be created even
3243 // if it remains empty for the validity of the operation.
3244 result.regions.reserve(2);
3245 Region *thenRegion = result.addRegion();
3246 Region *elseRegion = result.addRegion();
3247
3248 // Parse the 'then' region.
3249 if (parser.parseRegion(*thenRegion, {}, {}))
3250 return failure();
3251 AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
3252 result.location);
3253
3254 // If we find an 'else' keyword then parse the 'else' region.
3255 if (!parser.parseOptionalKeyword("else")) {
3256 if (parser.parseRegion(*elseRegion, {}, {}))
3257 return failure();
3258 AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
3259 result.location);
3260 }
3261
3262 // Parse the optional attribute list.
3263 if (parser.parseOptionalAttrDict(result.attributes))
3264 return failure();
3265
3266 return success();
3267}
3268
3269void AffineIfOp::print(OpAsmPrinter &p) {
3270 auto conditionAttr =
3271 (*this)->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName());
3272 p << " " << conditionAttr;
3273 printDimAndSymbolList(operand_begin(), operand_end(),
3274 conditionAttr.getValue().getNumDims(), p);
3275 p.printOptionalArrowTypeList(getResultTypes());
3276 p << ' ';
3277 p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/false,
3278 /*printBlockTerminators=*/getNumResults());
3279
3280 // Print the 'else' regions if it has any blocks.
3281 auto &elseRegion = this->getElseRegion();
3282 if (!elseRegion.empty()) {
3283 p << " else ";
3284 p.printRegion(elseRegion,
3285 /*printEntryBlockArgs=*/false,
3286 /*printBlockTerminators=*/getNumResults());
3287 }
3288
3289 // Print the attribute list.
3290 p.printOptionalAttrDict((*this)->getAttrs(),
3291 /*elidedAttrs=*/getConditionAttrStrName());
3292}
3293
3294IntegerSet AffineIfOp::getIntegerSet() {
3295 return (*this)
3296 ->getAttrOfType<IntegerSetAttr>(getConditionAttrStrName())
3297 .getValue();
3298}
3299
3300void AffineIfOp::setIntegerSet(IntegerSet newSet) {
3301 (*this)->setAttr(getConditionAttrStrName(), IntegerSetAttr::get(newSet));
3302}
3303
3304void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
3305 setIntegerSet(set);
3306 (*this)->setOperands(operands);
3307}
3308
3309void AffineIfOp::build(OpBuilder &builder, OperationState &result,
3310 TypeRange resultTypes, IntegerSet set, ValueRange args,
3311 bool withElseRegion) {
3312 assert(resultTypes.empty() || withElseRegion);
3313 OpBuilder::InsertionGuard guard(builder);
3314
3315 result.addTypes(resultTypes);
3316 result.addOperands(args);
3317 result.addAttribute(getConditionAttrStrName(), IntegerSetAttr::get(set));
3318
3319 Region *thenRegion = result.addRegion();
3320 builder.createBlock(thenRegion);
3321 if (resultTypes.empty())
3322 AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
3323
3324 Region *elseRegion = result.addRegion();
3325 if (withElseRegion) {
3326 builder.createBlock(elseRegion);
3327 if (resultTypes.empty())
3328 AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
3329 }
3330}
3331
3332void AffineIfOp::build(OpBuilder &builder, OperationState &result,
3333 IntegerSet set, ValueRange args, bool withElseRegion) {
3334 AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
3335 withElseRegion);
3336}
3337
3338/// Compose any affine.apply ops feeding into `operands` of the integer set
3339/// `set` by composing the maps of such affine.apply ops with the integer
3340/// set constraints.
3342 SmallVectorImpl<Value> &operands,
3343 bool composeAffineMin = false) {
3344 // We will simply reuse the API of the map composition by viewing the LHSs of
3345 // the equalities and inequalities of `set` as the affine exprs of an affine
3346 // map. Convert to equivalent map, compose, and convert back to set.
3347 auto map = AffineMap::get(set.getNumDims(), set.getNumSymbols(),
3348 set.getConstraints(), set.getContext());
3349 // Check if any composition is possible.
3350 if (llvm::none_of(operands,
3351 [](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
3352 return;
3353
3354 composeAffineMapAndOperands(&map, &operands, composeAffineMin);
3355 set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
3356 set.getEqFlags());
3357}
3358
3359/// Canonicalize an affine if op's conditional (integer set + operands).
3360LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3361 auto set = getIntegerSet();
3362 SmallVector<Value, 4> operands(getOperands());
3363 composeSetAndOperands(set, operands);
3364 canonicalizeSetAndOperands(&set, &operands);
3365
3366 // Check if the canonicalization or composition led to any change.
3367 if (getIntegerSet() == set && llvm::equal(operands, getOperands()))
3368 return failure();
3369
3370 setConditional(set, operands);
3371 return success();
3372}
3373
3374void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
3375 MLIRContext *context) {
3376 results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
3377}
3378
3379//===----------------------------------------------------------------------===//
3380// AffineLoadOp
3381//===----------------------------------------------------------------------===//
3382
3383void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3384 AffineMap map, ValueRange operands) {
3385 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
3386 result.addOperands(operands);
3387 if (map)
3388 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3389 auto memrefType = llvm::cast<MemRefType>(operands[0].getType());
3390 result.types.push_back(memrefType.getElementType());
3391}
3392
3393void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3394 Value memref, AffineMap map, ValueRange mapOperands) {
3395 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3396 result.addOperands(memref);
3397 result.addOperands(mapOperands);
3398 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3399 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
3400 result.types.push_back(memrefType.getElementType());
3401}
3402
3403void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
3405 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3406 int64_t rank = memrefType.getRank();
3407 // Create identity map for memrefs with at least one dimension or () -> ()
3408 // for zero-dimensional memrefs.
3409 auto map =
3410 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3411 build(builder, result, memref, map, indices);
3412}
3413
3414ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
3415 auto &builder = parser.getBuilder();
3416 auto indexTy = builder.getIndexType();
3417
3418 MemRefType type;
3420 AffineMapAttr mapAttr;
3422 return failure(
3423 parser.parseOperand(memrefInfo) ||
3424 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3425 AffineLoadOp::getMapAttrStrName(),
3426 result.attributes) ||
3427 parser.parseOptionalAttrDict(result.attributes) ||
3428 parser.parseColonType(type) ||
3429 parser.resolveOperand(memrefInfo, type, result.operands) ||
3430 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3431 parser.addTypeToList(type.getElementType(), result.types));
3432}
3433
3434void AffineLoadOp::print(OpAsmPrinter &p) {
3435 p << " " << getMemRef() << '[';
3436 if (AffineMapAttr mapAttr =
3437 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3438 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3439 p << ']';
3440 p.printOptionalAttrDict((*this)->getAttrs(),
3441 /*elidedAttrs=*/{getMapAttrStrName()});
3442 p << " : " << getMemRefType();
3443}
3444
3445/// Verify common indexing invariants of affine.load, affine.store,
3446/// affine.vector_load and affine.vector_store.
3447template <typename AffineMemOpTy>
3448static LogicalResult
3449verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr,
3450 Operation::operand_range mapOperands,
3451 MemRefType memrefType, unsigned numIndexOperands) {
3452 AffineMap map = mapAttr.getValue();
3453 if (map.getNumResults() != memrefType.getRank())
3454 return op->emitOpError("affine map num results must equal memref rank");
3455 if (map.getNumInputs() != numIndexOperands)
3456 return op->emitOpError("expects as many subscripts as affine map inputs");
3457
3458 for (auto idx : mapOperands) {
3459 if (!idx.getType().isIndex())
3460 return op->emitOpError("index to load must have 'index' type");
3461 }
3462 if (failed(verifyDimAndSymbolIdentifiers(op, mapOperands, map.getNumDims())))
3463 return failure();
3464
3465 return success();
3466}
3467
3468LogicalResult AffineLoadOp::verify() {
3469 auto memrefType = getMemRefType();
3470 if (getType() != memrefType.getElementType())
3471 return emitOpError("result type must match element type of memref");
3472
3474 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3475 getMapOperands(), memrefType,
3476 /*numIndexOperands=*/getNumOperands() - 1)))
3477 return failure();
3478
3479 return success();
3480}
3481
3482void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3483 MLIRContext *context) {
3484 results.add<SimplifyAffineOp<AffineLoadOp>>(context);
3485}
3486
3487OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) {
3488 /// load(memrefcast) -> load
3489 if (succeeded(memref::foldMemRefCast(*this)))
3490 return getResult();
3491
3492 // Fold load from a global constant memref.
3493 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
3494 if (!getGlobalOp)
3495 return {};
3496 // Get to the memref.global defining the symbol.
3498 getGlobalOp, getGlobalOp.getNameAttr());
3499 if (!global)
3500 return {};
3501
3502 // Check if the global memref is a constant.
3503 auto cstAttr =
3504 dyn_cast_or_null<DenseElementsAttr>(global.getConstantInitValue());
3505 if (!cstAttr)
3506 return {};
3507 // If it's a splat constant, we can fold irrespective of indices.
3508 if (auto splatAttr = dyn_cast<SplatElementsAttr>(cstAttr))
3509 return splatAttr.getSplatValue<Attribute>();
3510 // Otherwise, we can fold only if we know the indices.
3511 if (!getAffineMap().isConstant())
3512 return {};
3513 auto indices =
3514 llvm::map_to_vector<4>(getAffineMap().getConstantResults(),
3515 [](int64_t v) -> uint64_t { return v; });
3516 return cstAttr.getValues<Attribute>()[indices];
3517}
3518
3519//===----------------------------------------------------------------------===//
3520// AffineStoreOp
3521//===----------------------------------------------------------------------===//
3522
3523void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3524 Value valueToStore, Value memref, AffineMap map,
3525 ValueRange mapOperands) {
3526 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3527 result.addOperands(valueToStore);
3528 result.addOperands(memref);
3529 result.addOperands(mapOperands);
3530 result.getOrAddProperties<Properties>().map = AffineMapAttr::get(map);
3531}
3532
3533// Use identity map.
3534void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
3535 Value valueToStore, Value memref,
3537 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3538 int64_t rank = memrefType.getRank();
3539 // Create identity map for memrefs with at least one dimension or () -> ()
3540 // for zero-dimensional memrefs.
3541 auto map =
3542 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3543 build(builder, result, valueToStore, memref, map, indices);
3544}
3545
3546ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
3547 auto indexTy = parser.getBuilder().getIndexType();
3548
3549 MemRefType type;
3550 OpAsmParser::UnresolvedOperand storeValueInfo;
3552 AffineMapAttr mapAttr;
3554 return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3555 parser.parseOperand(memrefInfo) ||
3557 mapOperands, mapAttr, AffineStoreOp::getMapAttrStrName(),
3558 result.attributes) ||
3559 parser.parseOptionalAttrDict(result.attributes) ||
3560 parser.parseColonType(type) ||
3561 parser.resolveOperand(storeValueInfo, type.getElementType(),
3562 result.operands) ||
3563 parser.resolveOperand(memrefInfo, type, result.operands) ||
3564 parser.resolveOperands(mapOperands, indexTy, result.operands));
3565}
3566
3567void AffineStoreOp::print(OpAsmPrinter &p) {
3568 p << " " << getValueToStore();
3569 p << ", " << getMemRef() << '[';
3570 if (AffineMapAttr mapAttr =
3571 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
3572 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
3573 p << ']';
3574 p.printOptionalAttrDict((*this)->getAttrs(),
3575 /*elidedAttrs=*/{getMapAttrStrName()});
3576 p << " : " << getMemRefType();
3577}
3578
3579LogicalResult AffineStoreOp::verify() {
3580 // The value to store must have the same type as memref element type.
3581 auto memrefType = getMemRefType();
3582 if (getValueToStore().getType() != memrefType.getElementType())
3583 return emitOpError(
3584 "value to store must have the same type as memref element type");
3585
3587 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
3588 getMapOperands(), memrefType,
3589 /*numIndexOperands=*/getNumOperands() - 2)))
3590 return failure();
3591
3592 return success();
3593}
3594
3595void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
3596 MLIRContext *context) {
3597 results.add<SimplifyAffineOp<AffineStoreOp>>(context);
3598}
3599
3600LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor,
3602 /// store(memrefcast) -> store
3603 return memref::foldMemRefCast(*this, getValueToStore());
3604}
3605
3606//===----------------------------------------------------------------------===//
3607// AffineMinMaxOpBase
3608//===----------------------------------------------------------------------===//
3609
3610template <typename T>
3611static LogicalResult verifyAffineMinMaxOp(T op) {
3612 // Verify that operand count matches affine map dimension and symbol count.
3613 if (op.getNumOperands() !=
3614 op.getMap().getNumDims() + op.getMap().getNumSymbols())
3615 return op.emitOpError(
3616 "operand count and affine map dimension and symbol count must match");
3617
3618 if (op.getMap().getNumResults() == 0)
3619 return op.emitOpError("affine map expect at least one result");
3620 return success();
3621}
3622
3623template <typename T>
3624static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
3625 p << ' ' << op->getAttr(T::getMapAttrStrName());
3626 auto operands = op.getOperands();
3627 unsigned numDims = op.getMap().getNumDims();
3628 p << '(' << operands.take_front(numDims) << ')';
3629
3630 if (operands.size() != numDims)
3631 p << '[' << operands.drop_front(numDims) << ']';
3632 p.printOptionalAttrDict(op->getAttrs(),
3633 /*elidedAttrs=*/{T::getMapAttrStrName()});
3634}
3635
3636template <typename T>
3637static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
3639 auto &builder = parser.getBuilder();
3640 auto indexType = builder.getIndexType();
3643 AffineMapAttr mapAttr;
3644 return failure(
3645 parser.parseAttribute(mapAttr, T::getMapAttrStrName(),
3646 result.attributes) ||
3648 parser.parseOperandList(symInfos,
3650 parser.parseOptionalAttrDict(result.attributes) ||
3651 parser.resolveOperands(dimInfos, indexType, result.operands) ||
3652 parser.resolveOperands(symInfos, indexType, result.operands) ||
3653 parser.addTypeToList(indexType, result.types));
3654}
3655
3656/// Fold an affine min or max operation with the given operands. The operand
3657/// list may contain nulls, which are interpreted as the operand not being a
3658/// constant.
3659template <typename T>
3661 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
3662 "expected affine min or max op");
3663
3664 // Fold the affine map.
3665 // TODO: Fold more cases:
3666 // min(some_affine, some_affine + constant, ...), etc.
3668 auto foldedMap = op.getMap().partialConstantFold(operands, &results);
3669
3670 if (foldedMap.getNumSymbols() == 1 && foldedMap.isSymbolIdentity())
3671 return op.getOperand(0);
3672
3673 // If some of the map results are not constant, try changing the map in-place.
3674 if (results.empty()) {
3675 // If the map is the same, report that folding did not happen.
3676 if (foldedMap == op.getMap())
3677 return {};
3678 op->setAttr("map", AffineMapAttr::get(foldedMap));
3679 return op.getResult();
3680 }
3681
3682 // Otherwise, completely fold the op into a constant.
3683 auto resultIt = std::is_same<T, AffineMinOp>::value
3684 ? llvm::min_element(results)
3685 : llvm::max_element(results);
3686 if (resultIt == results.end())
3687 return {};
3688 return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
3689}
3690
3691/// Remove duplicated expressions in affine min/max ops.
3692template <typename T>
3695
3696 LogicalResult matchAndRewrite(T affineOp,
3697 PatternRewriter &rewriter) const override {
3698 AffineMap oldMap = affineOp.getAffineMap();
3699
3701 for (AffineExpr expr : oldMap.getResults()) {
3702 // This is a linear scan over newExprs, but it should be fine given that
3703 // we typically just have a few expressions per op.
3704 if (!llvm::is_contained(newExprs, expr))
3705 newExprs.push_back(expr);
3706 }
3707
3708 if (newExprs.size() == oldMap.getNumResults())
3709 return failure();
3710
3711 auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
3712 newExprs, rewriter.getContext());
3713 rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
3714
3715 return success();
3716 }
3717};
3718
3719/// Merge an affine min/max op to its consumers if its consumer is also an
3720/// affine min/max op.
3721///
3722/// This pattern requires the producer affine min/max op is bound to a
3723/// dimension/symbol that is used as a standalone expression in the consumer
3724/// affine op's map.
3725///
3726/// For example, a pattern like the following:
3727///
3728/// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
3729/// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
3730///
3731/// Can be turned into:
3732///
3733/// %1 = affine.min affine_map<
3734/// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
3735template <typename T>
3738
3739 LogicalResult matchAndRewrite(T affineOp,
3740 PatternRewriter &rewriter) const override {
3741 AffineMap oldMap = affineOp.getAffineMap();
3742 ValueRange dimOperands =
3743 affineOp.getMapOperands().take_front(oldMap.getNumDims());
3744 ValueRange symOperands =
3745 affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
3746
3747 auto newDimOperands = llvm::to_vector<8>(dimOperands);
3748 auto newSymOperands = llvm::to_vector<8>(symOperands);
3750 SmallVector<T, 4> producerOps;
3751
3752 // Go over each expression to see whether it's a single dimension/symbol
3753 // with the corresponding operand which is the result of another affine
3754 // min/max op. If So it can be merged into this affine op.
3755 for (AffineExpr expr : oldMap.getResults()) {
3756 if (auto symExpr = dyn_cast<AffineSymbolExpr>(expr)) {
3757 Value symValue = symOperands[symExpr.getPosition()];
3758 if (auto producerOp = symValue.getDefiningOp<T>()) {
3759 producerOps.push_back(producerOp);
3760 continue;
3761 }
3762 } else if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
3763 Value dimValue = dimOperands[dimExpr.getPosition()];
3764 if (auto producerOp = dimValue.getDefiningOp<T>()) {
3765 producerOps.push_back(producerOp);
3766 continue;
3767 }
3768 }
3769 // For the above cases we will remove the expression by merging the
3770 // producer affine min/max's affine expressions. Otherwise we need to
3771 // keep the existing expression.
3772 newExprs.push_back(expr);
3773 }
3774
3775 if (producerOps.empty())
3776 return failure();
3777
3778 unsigned numUsedDims = oldMap.getNumDims();
3779 unsigned numUsedSyms = oldMap.getNumSymbols();
3780
3781 // Now go over all producer affine ops and merge their expressions.
3782 for (T producerOp : producerOps) {
3783 AffineMap producerMap = producerOp.getAffineMap();
3784 unsigned numProducerDims = producerMap.getNumDims();
3785 unsigned numProducerSyms = producerMap.getNumSymbols();
3786
3787 // Collect all dimension/symbol values.
3788 ValueRange dimValues =
3789 producerOp.getMapOperands().take_front(numProducerDims);
3790 ValueRange symValues =
3791 producerOp.getMapOperands().take_back(numProducerSyms);
3792 newDimOperands.append(dimValues.begin(), dimValues.end());
3793 newSymOperands.append(symValues.begin(), symValues.end());
3794
3795 // For expressions we need to shift to avoid overlap.
3796 for (AffineExpr expr : producerMap.getResults()) {
3797 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
3798 .shiftSymbols(numProducerSyms, numUsedSyms));
3799 }
3800
3801 numUsedDims += numProducerDims;
3802 numUsedSyms += numProducerSyms;
3803 }
3804
3805 auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
3806 rewriter.getContext());
3807 auto newOperands =
3808 llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
3809 rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
3810
3811 return success();
3812 }
3813};
3814
3815/// Canonicalize the result expression order of an affine map and return success
3816/// if the order changed.
3817///
3818/// The function flattens the map's affine expressions to coefficient arrays and
3819/// sorts them in lexicographic order. A coefficient array contains a multiplier
3820/// for every dimension/symbol and a constant term. The canonicalization fails
3821/// if a result expression is not pure or if the flattening requires local
3822/// variables that, unlike dimensions and symbols, have no global order.
3823static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) {
3824 SmallVector<SmallVector<int64_t>> flattenedExprs;
3825 for (const AffineExpr &resultExpr : map.getResults()) {
3826 // Fail if the expression is not pure.
3827 if (!resultExpr.isPureAffine())
3828 return failure();
3829
3830 SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols());
3831 auto flattenResult = flattener.walkPostOrder(resultExpr);
3832 if (failed(flattenResult))
3833 return failure();
3834
3835 // Fail if the flattened expression has local variables.
3836 if (flattener.operandExprStack.back().size() !=
3837 map.getNumDims() + map.getNumSymbols() + 1)
3838 return failure();
3839
3840 flattenedExprs.emplace_back(flattener.operandExprStack.back().begin(),
3841 flattener.operandExprStack.back().end());
3842 }
3843
3844 // Fail if sorting is not necessary.
3845 if (llvm::is_sorted(flattenedExprs))
3846 return failure();
3847
3848 // Reorder the result expressions according to their flattened form.
3849 SmallVector<unsigned> resultPermutation =
3850 llvm::to_vector(llvm::seq<unsigned>(0, map.getNumResults()));
3851 llvm::sort(resultPermutation, [&](unsigned lhs, unsigned rhs) {
3852 return flattenedExprs[lhs] < flattenedExprs[rhs];
3853 });
3854 SmallVector<AffineExpr> newExprs;
3855 for (unsigned idx : resultPermutation)
3856 newExprs.push_back(map.getResult(idx));
3857
3858 map = AffineMap::get(map.getNumDims(), map.getNumSymbols(), newExprs,
3859 map.getContext());
3860 return success();
3861}
3862
3863/// Canonicalize the affine map result expression order of an affine min/max
3864/// operation.
3865///
3866/// The pattern calls `canonicalizeMapExprAndTermOrder` to order the result
3867/// expressions and replaces the operation if the order changed.
3868///
3869/// For example, the following operation:
3870///
3871/// %0 = affine.min affine_map<(d0, d1) -> (d0 + d1, d1 + 16, 32)> (%i0, %i1)
3872///
3873/// Turns into:
3874///
3875/// %0 = affine.min affine_map<(d0, d1) -> (32, d1 + 16, d0 + d1)> (%i0, %i1)
3876template <typename T>
3879
3880 LogicalResult matchAndRewrite(T affineOp,
3881 PatternRewriter &rewriter) const override {
3882 AffineMap map = affineOp.getAffineMap();
3883 if (failed(canonicalizeMapExprAndTermOrder(map)))
3884 return failure();
3885 rewriter.replaceOpWithNewOp<T>(affineOp, map, affineOp.getMapOperands());
3886 return success();
3887 }
3888};
3889
3890template <typename T>
3893
3894 LogicalResult matchAndRewrite(T affineOp,
3895 PatternRewriter &rewriter) const override {
3896 if (affineOp.getMap().getNumResults() != 1)
3897 return failure();
3898 rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.getMap(),
3899 affineOp.getOperands());
3900 return success();
3901 }
3902};
3903
3904//===----------------------------------------------------------------------===//
3905// AffineMinOp
3906//===----------------------------------------------------------------------===//
3907//
3908// %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
3909//
3910
3911OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) {
3912 return foldMinMaxOp(*this, adaptor.getOperands());
3913}
3914
3915void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3916 MLIRContext *context) {
3919 MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>,
3921 context);
3922}
3923
3924LogicalResult AffineMinOp::verify() { return verifyAffineMinMaxOp(*this); }
3925
3926ParseResult AffineMinOp::parse(OpAsmParser &parser, OperationState &result) {
3928}
3929
3930void AffineMinOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3931
3932//===----------------------------------------------------------------------===//
3933// AffineMaxOp
3934//===----------------------------------------------------------------------===//
3935//
3936// %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
3937//
3938
3939OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) {
3940 return foldMinMaxOp(*this, adaptor.getOperands());
3941}
3942
3943void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
3944 MLIRContext *context) {
3947 MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>,
3949 context);
3950}
3951
3952LogicalResult AffineMaxOp::verify() { return verifyAffineMinMaxOp(*this); }
3953
3954ParseResult AffineMaxOp::parse(OpAsmParser &parser, OperationState &result) {
3956}
3957
3958void AffineMaxOp::print(OpAsmPrinter &p) { printAffineMinMaxOp(p, *this); }
3959
3960//===----------------------------------------------------------------------===//
3961// AffinePrefetchOp
3962//===----------------------------------------------------------------------===//
3963
3964//
3965// affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
3966//
3967ParseResult AffinePrefetchOp::parse(OpAsmParser &parser,
3969 auto &builder = parser.getBuilder();
3970 auto indexTy = builder.getIndexType();
3971
3972 MemRefType type;
3974 IntegerAttr hintInfo;
3975 auto i32Type = parser.getBuilder().getIntegerType(32);
3976 StringRef readOrWrite, cacheType;
3977
3978 AffineMapAttr mapAttr;
3980 if (parser.parseOperand(memrefInfo) ||
3981 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3982 AffinePrefetchOp::getMapAttrStrName(),
3983 result.attributes) ||
3984 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
3985 parser.parseComma() || parser.parseKeyword("locality") ||
3986 parser.parseLess() ||
3987 parser.parseAttribute(hintInfo, i32Type,
3988 AffinePrefetchOp::getLocalityHintAttrStrName(),
3989 result.attributes) ||
3990 parser.parseGreater() || parser.parseComma() ||
3991 parser.parseKeyword(&cacheType) ||
3992 parser.parseOptionalAttrDict(result.attributes) ||
3993 parser.parseColonType(type) ||
3994 parser.resolveOperand(memrefInfo, type, result.operands) ||
3995 parser.resolveOperands(mapOperands, indexTy, result.operands))
3996 return failure();
3997
3998 if (readOrWrite != "read" && readOrWrite != "write")
3999 return parser.emitError(parser.getNameLoc(),
4000 "rw specifier has to be 'read' or 'write'");
4001 result.addAttribute(AffinePrefetchOp::getIsWriteAttrStrName(),
4002 parser.getBuilder().getBoolAttr(readOrWrite == "write"));
4003
4004 if (cacheType != "data" && cacheType != "instr")
4005 return parser.emitError(parser.getNameLoc(),
4006 "cache type has to be 'data' or 'instr'");
4007
4008 result.addAttribute(AffinePrefetchOp::getIsDataCacheAttrStrName(),
4009 parser.getBuilder().getBoolAttr(cacheType == "data"));
4010
4011 return success();
4012}
4013
4014void AffinePrefetchOp::print(OpAsmPrinter &p) {
4015 p << " " << getMemref() << '[';
4016 AffineMapAttr mapAttr =
4017 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
4018 if (mapAttr)
4019 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4020 p << ']' << ", " << (getIsWrite() ? "write" : "read") << ", " << "locality<"
4021 << getLocalityHint() << ">, " << (getIsDataCache() ? "data" : "instr");
4023 (*this)->getAttrs(),
4024 /*elidedAttrs=*/{getMapAttrStrName(), getLocalityHintAttrStrName(),
4025 getIsDataCacheAttrStrName(), getIsWriteAttrStrName()});
4026 p << " : " << getMemRefType();
4027}
4028
4029LogicalResult AffinePrefetchOp::verify() {
4030 auto mapAttr = (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName());
4031 if (mapAttr) {
4032 AffineMap map = mapAttr.getValue();
4033 if (map.getNumResults() != getMemRefType().getRank())
4034 return emitOpError("affine.prefetch affine map num results must equal"
4035 " memref rank");
4036 if (map.getNumInputs() + 1 != getNumOperands())
4037 return emitOpError("too few operands");
4038 } else {
4039 if (getNumOperands() != 1)
4040 return emitOpError("too few operands");
4041 }
4042
4043 Region *scope = getAffineScope(*this);
4044 for (auto idx : getMapOperands()) {
4045 if (!isValidAffineIndexOperand(idx, scope))
4046 return emitOpError(
4047 "index must be a valid dimension or symbol identifier");
4048 }
4049 return success();
4050}
4051
4052void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
4053 MLIRContext *context) {
4054 // prefetch(memrefcast) -> prefetch
4055 results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
4056}
4057
4058LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor,
4060 /// prefetch(memrefcast) -> prefetch
4061 return memref::foldMemRefCast(*this);
4062}
4063
4064//===----------------------------------------------------------------------===//
4065// AffineParallelOp
4066//===----------------------------------------------------------------------===//
4067
4068void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
4069 TypeRange resultTypes,
4071 ArrayRef<int64_t> ranges) {
4072 SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
4073 auto ubs = llvm::map_to_vector<4>(ranges, [&](int64_t value) {
4074 return builder.getConstantAffineMap(value);
4075 });
4076 SmallVector<int64_t> steps(ranges.size(), 1);
4077 build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
4078 /*ubArgs=*/{}, steps);
4079}
4080
4081void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
4082 TypeRange resultTypes,
4084 ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
4085 ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
4086 ArrayRef<int64_t> steps) {
4087 assert(llvm::all_of(lbMaps,
4088 [lbMaps](AffineMap m) {
4089 return m.getNumDims() == lbMaps[0].getNumDims() &&
4090 m.getNumSymbols() == lbMaps[0].getNumSymbols();
4091 }) &&
4092 "expected all lower bounds maps to have the same number of dimensions "
4093 "and symbols");
4094 assert(llvm::all_of(ubMaps,
4095 [ubMaps](AffineMap m) {
4096 return m.getNumDims() == ubMaps[0].getNumDims() &&
4097 m.getNumSymbols() == ubMaps[0].getNumSymbols();
4098 }) &&
4099 "expected all upper bounds maps to have the same number of dimensions "
4100 "and symbols");
4101 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
4102 "expected lower bound maps to have as many inputs as lower bound "
4103 "operands");
4104 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
4105 "expected upper bound maps to have as many inputs as upper bound "
4106 "operands");
4107
4108 OpBuilder::InsertionGuard guard(builder);
4109 result.addTypes(resultTypes);
4110
4111 // Convert the reductions to integer attributes.
4112 SmallVector<Attribute, 4> reductionAttrs;
4113 for (arith::AtomicRMWKind reduction : reductions)
4114 reductionAttrs.push_back(
4115 builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
4116 result.addAttribute(getReductionsAttrStrName(),
4117 builder.getArrayAttr(reductionAttrs));
4118
4119 // Concatenates maps defined in the same input space (same dimensions and
4120 // symbols), assumes there is at least one map.
4121 auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
4122 SmallVectorImpl<int32_t> &groups) {
4123 if (maps.empty())
4124 return AffineMap::get(builder.getContext());
4126 groups.reserve(groups.size() + maps.size());
4127 exprs.reserve(maps.size());
4128 for (AffineMap m : maps) {
4129 llvm::append_range(exprs, m.getResults());
4130 groups.push_back(m.getNumResults());
4131 }
4132 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
4133 maps[0].getContext());
4134 };
4135
4136 // Set up the bounds.
4137 SmallVector<int32_t> lbGroups, ubGroups;
4138 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
4139 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
4140 result.addAttribute(getLowerBoundsMapAttrStrName(),
4141 AffineMapAttr::get(lbMap));
4142 result.addAttribute(getLowerBoundsGroupsAttrStrName(),
4143 builder.getI32TensorAttr(lbGroups));
4144 result.addAttribute(getUpperBoundsMapAttrStrName(),
4145 AffineMapAttr::get(ubMap));
4146 result.addAttribute(getUpperBoundsGroupsAttrStrName(),
4147 builder.getI32TensorAttr(ubGroups));
4148 result.addAttribute(getStepsAttrStrName(), builder.getI64ArrayAttr(steps));
4149 result.addOperands(lbArgs);
4150 result.addOperands(ubArgs);
4151
4152 // Create a region and a block for the body.
4153 auto *bodyRegion = result.addRegion();
4154 Block *body = builder.createBlock(bodyRegion);
4155
4156 // Add all the block arguments.
4157 for (unsigned i = 0, e = steps.size(); i < e; ++i)
4158 body->addArgument(IndexType::get(builder.getContext()), result.location);
4159 if (resultTypes.empty())
4160 ensureTerminator(*bodyRegion, builder, result.location);
4161}
4162
4163SmallVector<Region *> AffineParallelOp::getLoopRegions() {
4164 return {&getRegion()};
4165}
4166
4167unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
4168
4169AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
4170 return getOperands().take_front(getLowerBoundsMap().getNumInputs());
4171}
4172
4173AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
4174 return getOperands().drop_front(getLowerBoundsMap().getNumInputs());
4175}
4176
4177AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
4178 auto values = getLowerBoundsGroups().getValues<int32_t>();
4179 unsigned start = 0;
4180 for (unsigned i = 0; i < pos; ++i)
4181 start += values[i];
4182 return getLowerBoundsMap().getSliceMap(start, values[pos]);
4183}
4184
4185AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
4186 auto values = getUpperBoundsGroups().getValues<int32_t>();
4187 unsigned start = 0;
4188 for (unsigned i = 0; i < pos; ++i)
4189 start += values[i];
4190 return getUpperBoundsMap().getSliceMap(start, values[pos]);
4191}
4192
4193AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
4194 return AffineValueMap(getLowerBoundsMap(), getLowerBoundsOperands());
4195}
4196
4197AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
4198 return AffineValueMap(getUpperBoundsMap(), getUpperBoundsOperands());
4199}
4200
4201std::optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
4202 if (hasMinMaxBounds())
4203 return std::nullopt;
4204
4205 // Try to convert all the ranges to constant expressions.
4207 AffineValueMap rangesValueMap;
4208 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
4209 &rangesValueMap);
4210 out.reserve(rangesValueMap.getNumResults());
4211 for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
4212 auto expr = rangesValueMap.getResult(i);
4213 auto cst = dyn_cast<AffineConstantExpr>(expr);
4214 if (!cst)
4215 return std::nullopt;
4216 out.push_back(cst.getValue());
4217 }
4218 return out;
4219}
4220
4221Block *AffineParallelOp::getBody() { return &getRegion().front(); }
4222
4223OpBuilder AffineParallelOp::getBodyBuilder() {
4224 return OpBuilder(getBody(), std::prev(getBody()->end()));
4225}
4226
4227void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
4228 assert(lbOperands.size() == map.getNumInputs() &&
4229 "operands to map must match number of inputs");
4230
4231 auto ubOperands = getUpperBoundsOperands();
4232
4233 SmallVector<Value, 4> newOperands(lbOperands);
4234 newOperands.append(ubOperands.begin(), ubOperands.end());
4235 (*this)->setOperands(newOperands);
4236
4237 setLowerBoundsMapAttr(AffineMapAttr::get(map));
4238}
4239
4240void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
4241 assert(ubOperands.size() == map.getNumInputs() &&
4242 "operands to map must match number of inputs");
4243
4244 SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
4245 newOperands.append(ubOperands.begin(), ubOperands.end());
4246 (*this)->setOperands(newOperands);
4247
4248 setUpperBoundsMapAttr(AffineMapAttr::get(map));
4249}
4250
4251void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
4252 setStepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
4253}
4254
4255// check whether resultType match op or not in affine.parallel
4257 arith::AtomicRMWKind op) {
4258 switch (op) {
4259 case arith::AtomicRMWKind::addf:
4260 return isa<FloatType>(resultType);
4261 case arith::AtomicRMWKind::addi:
4262 return isa<IntegerType>(resultType);
4263 case arith::AtomicRMWKind::assign:
4264 return true;
4265 case arith::AtomicRMWKind::mulf:
4266 return isa<FloatType>(resultType);
4267 case arith::AtomicRMWKind::muli:
4268 return isa<IntegerType>(resultType);
4269 case arith::AtomicRMWKind::maximumf:
4270 case arith::AtomicRMWKind::maxnumf:
4271 case arith::AtomicRMWKind::minimumf:
4272 case arith::AtomicRMWKind::minnumf:
4273 return isa<FloatType>(resultType);
4274 case arith::AtomicRMWKind::maxs: {
4275 auto intType = dyn_cast<IntegerType>(resultType);
4276 return intType && intType.isSigned();
4277 }
4278 case arith::AtomicRMWKind::mins: {
4279 auto intType = dyn_cast<IntegerType>(resultType);
4280 return intType && intType.isSigned();
4281 }
4282 case arith::AtomicRMWKind::maxu: {
4283 auto intType = dyn_cast<IntegerType>(resultType);
4284 return intType && intType.isUnsigned();
4285 }
4286 case arith::AtomicRMWKind::minu: {
4287 auto intType = dyn_cast<IntegerType>(resultType);
4288 return intType && intType.isUnsigned();
4289 }
4290 case arith::AtomicRMWKind::ori:
4291 case arith::AtomicRMWKind::andi:
4292 case arith::AtomicRMWKind::xori:
4293 return isa<IntegerType>(resultType);
4294 }
4295 llvm_unreachable("Unhandled atomic rmw kind");
4296}
4297
4298LogicalResult AffineParallelOp::verify() {
4299 auto numDims = getNumDims();
4300 if (getLowerBoundsGroups().getNumElements() != numDims ||
4301 getUpperBoundsGroups().getNumElements() != numDims ||
4302 getSteps().size() != numDims || getBody()->getNumArguments() != numDims) {
4303 return emitOpError() << "the number of region arguments ("
4304 << getBody()->getNumArguments()
4305 << ") and the number of map groups for lower ("
4306 << getLowerBoundsGroups().getNumElements()
4307 << ") and upper bound ("
4308 << getUpperBoundsGroups().getNumElements()
4309 << "), and the number of steps (" << getSteps().size()
4310 << ") must all match";
4311 }
4312
4313 unsigned expectedNumLBResults = 0;
4314 for (APInt v : getLowerBoundsGroups()) {
4315 unsigned results = v.getZExtValue();
4316 if (results == 0)
4317 return emitOpError()
4318 << "expected lower bound map to have at least one result";
4319 expectedNumLBResults += results;
4320 }
4321 if (expectedNumLBResults != getLowerBoundsMap().getNumResults())
4322 return emitOpError() << "expected lower bounds map to have "
4323 << expectedNumLBResults << " results";
4324 unsigned expectedNumUBResults = 0;
4325 for (APInt v : getUpperBoundsGroups()) {
4326 unsigned results = v.getZExtValue();
4327 if (results == 0)
4328 return emitOpError()
4329 << "expected upper bound map to have at least one result";
4330 expectedNumUBResults += results;
4331 }
4332 if (expectedNumUBResults != getUpperBoundsMap().getNumResults())
4333 return emitOpError() << "expected upper bounds map to have "
4334 << expectedNumUBResults << " results";
4335
4336 if (getReductions().size() != getNumResults())
4337 return emitOpError("a reduction must be specified for each output");
4338
4339 // Verify reduction ops are all valid and each result type matches reduction
4340 // ops
4341 for (auto it : llvm::enumerate((getReductions()))) {
4342 Attribute attr = it.value();
4343 auto intAttr = dyn_cast<IntegerAttr>(attr);
4344 if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt()))
4345 return emitOpError("invalid reduction attribute");
4346 auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value();
4347 if (!isResultTypeMatchAtomicRMWKind(getResult(it.index()).getType(), kind))
4348 return emitOpError("result type cannot match reduction attribute");
4349 }
4350
4351 // Verify that the bound operands are valid dimension/symbols.
4352 /// Lower bounds.
4353 if (failed(verifyDimAndSymbolIdentifiers(*this, getLowerBoundsOperands(),
4354 getLowerBoundsMap().getNumDims())))
4355 return failure();
4356 /// Upper bounds.
4357 if (failed(verifyDimAndSymbolIdentifiers(*this, getUpperBoundsOperands(),
4358 getUpperBoundsMap().getNumDims())))
4359 return failure();
4360 return success();
4361}
4362
4364 SmallVector<Value, 4> newOperands{operands};
4365 auto newMap = getAffineMap();
4366 composeAffineMapAndOperands(&newMap, &newOperands);
4367 if (newMap == getAffineMap() && newOperands == operands)
4368 return failure();
4369 reset(newMap, newOperands);
4370 return success();
4371}
4372
4373/// Canonicalize the bounds of the given loop.
4374static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
4375 AffineValueMap lb = op.getLowerBoundsValueMap();
4376 bool lbCanonicalized = succeeded(lb.canonicalize());
4377
4378 AffineValueMap ub = op.getUpperBoundsValueMap();
4379 bool ubCanonicalized = succeeded(ub.canonicalize());
4380
4381 // Any canonicalization change always leads to updated map(s).
4382 if (!lbCanonicalized && !ubCanonicalized)
4383 return failure();
4384
4385 if (lbCanonicalized)
4386 op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
4387 if (ubCanonicalized)
4388 op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
4389
4390 return success();
4391}
4392
4393LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor,
4394 SmallVectorImpl<OpFoldResult> &results) {
4395 return canonicalizeLoopBounds(*this);
4396}
4397
4398/// Prints a lower(upper) bound of an affine parallel loop with max(min)
4399/// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
4400/// identifies which of the those expressions form max/min groups. `operands`
4401/// are the SSA values of dimensions and symbols and `keyword` is either "min"
4402/// or "max".
4403static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
4404 DenseIntElementsAttr group, ValueRange operands,
4405 StringRef keyword) {
4406 AffineMap map = mapAttr.getValue();
4407 unsigned numDims = map.getNumDims();
4408 ValueRange dimOperands = operands.take_front(numDims);
4409 ValueRange symOperands = operands.drop_front(numDims);
4410 unsigned start = 0;
4411 for (llvm::APInt groupSize : group) {
4412 if (start != 0)
4413 p << ", ";
4414
4415 unsigned size = groupSize.getZExtValue();
4416 if (size == 1) {
4417 p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
4418 ++start;
4419 } else {
4420 p << keyword << '(';
4421 AffineMap submap = map.getSliceMap(start, size);
4422 p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
4423 p << ')';
4424 start += size;
4425 }
4426 }
4427}
4428
4429void AffineParallelOp::print(OpAsmPrinter &p) {
4430 p << " (" << getBody()->getArguments() << ") = (";
4431 printMinMaxBound(p, getLowerBoundsMapAttr(), getLowerBoundsGroupsAttr(),
4432 getLowerBoundsOperands(), "max");
4433 p << ") to (";
4434 printMinMaxBound(p, getUpperBoundsMapAttr(), getUpperBoundsGroupsAttr(),
4435 getUpperBoundsOperands(), "min");
4436 p << ')';
4437 SmallVector<int64_t, 8> steps = getSteps();
4438 bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
4439 if (!elideSteps) {
4440 p << " step (";
4441 llvm::interleaveComma(steps, p);
4442 p << ')';
4443 }
4444 if (getNumResults()) {
4445 p << " reduce (";
4446 llvm::interleaveComma(getReductions(), p, [&](auto &attr) {
4447 arith::AtomicRMWKind sym = *arith::symbolizeAtomicRMWKind(
4448 llvm::cast<IntegerAttr>(attr).getInt());
4449 p << "\"" << arith::stringifyAtomicRMWKind(sym) << "\"";
4450 });
4451 p << ") -> (" << getResultTypes() << ")";
4452 }
4453
4454 p << ' ';
4455 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4456 /*printBlockTerminators=*/getNumResults());
4458 (*this)->getAttrs(),
4459 /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrStrName(),
4460 AffineParallelOp::getLowerBoundsMapAttrStrName(),
4461 AffineParallelOp::getLowerBoundsGroupsAttrStrName(),
4462 AffineParallelOp::getUpperBoundsMapAttrStrName(),
4463 AffineParallelOp::getUpperBoundsGroupsAttrStrName(),
4464 AffineParallelOp::getStepsAttrStrName()});
4465}
4466
4467/// Given a list of lists of parsed operands, populates `uniqueOperands` with
4468/// unique operands. Also populates `replacements with affine expressions of
4469/// `kind` that can be used to update affine maps previously accepting a
4470/// `operands` to accept `uniqueOperands` instead.
4471static ParseResult deduplicateAndResolveOperands(
4472 OpAsmParser &parser,
4473 ArrayRef<SmallVector<OpAsmParser::UnresolvedOperand>> operands,
4474 SmallVectorImpl<Value> &uniqueOperands,
4475 SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
4476 assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
4477 "expected operands to be dim or symbol expression");
4478
4479 Type indexType = parser.getBuilder().getIndexType();
4480 for (const auto &list : operands) {
4481 SmallVector<Value> valueOperands;
4482 if (parser.resolveOperands(list, indexType, valueOperands))
4483 return failure();
4484 for (Value operand : valueOperands) {
4485 unsigned pos = std::distance(uniqueOperands.begin(),
4486 llvm::find(uniqueOperands, operand));
4487 if (pos == uniqueOperands.size())
4488 uniqueOperands.push_back(operand);
4489 replacements.push_back(
4490 kind == AffineExprKind::DimId
4491 ? getAffineDimExpr(pos, parser.getContext())
4492 : getAffineSymbolExpr(pos, parser.getContext()));
4493 }
4494 }
4495 return success();
4496}
4497
4498namespace {
4499enum class MinMaxKind { Min, Max };
4500} // namespace
4501
4502/// Parses an affine map that can contain a min/max for groups of its results,
4503/// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
4504/// `result` attributes with the map (flat list of expressions) and the grouping
4505/// (list of integers that specify how many expressions to put into each
4506/// min/max) attributes. Deduplicates repeated operands.
4507///
4508/// parallel-bound ::= `(` parallel-group-list `)`
4509/// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
4510/// parallel-group ::= simple-group | min-max-group
4511/// simple-group ::= expr-of-ssa-ids
4512/// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
4513/// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
4514///
4515/// Examples:
4516/// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
4517/// (%0, max(%1 - 2 * %2))
4518static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
4519 OperationState &result,
4520 MinMaxKind kind) {
4521 // Using `const` not `constexpr` below to workaround a MSVC optimizer bug,
4522 // see: https://reviews.llvm.org/D134227#3821753
4523 const llvm::StringLiteral tmpAttrStrName = "__pseudo_bound_map";
4524
4525 StringRef mapName = kind == MinMaxKind::Min
4526 ? AffineParallelOp::getUpperBoundsMapAttrStrName()
4527 : AffineParallelOp::getLowerBoundsMapAttrStrName();
4528 StringRef groupsName =
4529 kind == MinMaxKind::Min
4530 ? AffineParallelOp::getUpperBoundsGroupsAttrStrName()
4531 : AffineParallelOp::getLowerBoundsGroupsAttrStrName();
4532
4533 if (failed(parser.parseLParen()))
4534 return failure();
4535
4536 if (succeeded(parser.parseOptionalRParen())) {
4537 result.addAttribute(
4538 mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
4539 result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
4540 return success();
4541 }
4542
4543 SmallVector<AffineExpr> flatExprs;
4544 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatDimOperands;
4545 SmallVector<SmallVector<OpAsmParser::UnresolvedOperand>> flatSymOperands;
4546 SmallVector<int32_t> numMapsPerGroup;
4547 SmallVector<OpAsmParser::UnresolvedOperand> mapOperands;
4548 auto parseOperands = [&]() {
4549 if (succeeded(parser.parseOptionalKeyword(
4550 kind == MinMaxKind::Min ? "min" : "max"))) {
4551 mapOperands.clear();
4552 AffineMapAttr map;
4553 if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrStrName,
4554 result.attributes,
4556 return failure();
4557 result.attributes.erase(tmpAttrStrName);
4558 llvm::append_range(flatExprs, map.getValue().getResults());
4559 auto operandsRef = llvm::ArrayRef(mapOperands);
4560 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
4561 SmallVector<OpAsmParser::UnresolvedOperand> dims(dimsRef);
4562 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
4563 SmallVector<OpAsmParser::UnresolvedOperand> syms(symsRef);
4564 flatDimOperands.append(map.getValue().getNumResults(), dims);
4565 flatSymOperands.append(map.getValue().getNumResults(), syms);
4566 numMapsPerGroup.push_back(map.getValue().getNumResults());
4567 } else {
4568 if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
4569 flatSymOperands.emplace_back(),
4570 flatExprs.emplace_back())))
4571 return failure();
4572 numMapsPerGroup.push_back(1);
4573 }
4574 return success();
4575 };
4576 if (parser.parseCommaSeparatedList(parseOperands) || parser.parseRParen())
4577 return failure();
4578
4579 unsigned totalNumDims = 0;
4580 unsigned totalNumSyms = 0;
4581 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
4582 unsigned numDims = flatDimOperands[i].size();
4583 unsigned numSyms = flatSymOperands[i].size();
4584 flatExprs[i] = flatExprs[i]
4585 .shiftDims(numDims, totalNumDims)
4586 .shiftSymbols(numSyms, totalNumSyms);
4587 totalNumDims += numDims;
4588 totalNumSyms += numSyms;
4589 }
4590
4591 // Deduplicate map operands.
4592 SmallVector<Value> dimOperands, symOperands;
4593 SmallVector<AffineExpr> dimRplacements, symRepacements;
4594 if (deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
4595 dimRplacements, AffineExprKind::DimId) ||
4596 deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
4597 symRepacements, AffineExprKind::SymbolId))
4598 return failure();
4599
4600 result.operands.append(dimOperands.begin(), dimOperands.end());
4601 result.operands.append(symOperands.begin(), symOperands.end());
4602
4603 Builder &builder = parser.getBuilder();
4604 auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
4605 parser.getContext());
4606 flatMap = flatMap.replaceDimsAndSymbols(
4607 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
4608
4609 result.addAttribute(mapName, AffineMapAttr::get(flatMap));
4610 result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
4611 return success();
4612}
4613
4614//
4615// operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
4616// `to` parallel-bound steps? region attr-dict?
4617// steps ::= `steps` `(` integer-literals `)`
4618//
4619ParseResult AffineParallelOp::parse(OpAsmParser &parser,
4620 OperationState &result) {
4621 auto &builder = parser.getBuilder();
4622 auto indexType = builder.getIndexType();
4623 SmallVector<OpAsmParser::Argument, 4> ivs;
4625 parser.parseEqual() ||
4626 parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
4627 parser.parseKeyword("to") ||
4628 parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
4629 return failure();
4630
4631 AffineMapAttr stepsMapAttr;
4632 NamedAttrList stepsAttrs;
4633 SmallVector<OpAsmParser::UnresolvedOperand, 4> stepsMapOperands;
4634 if (failed(parser.parseOptionalKeyword("step"))) {
4635 SmallVector<int64_t, 4> steps(ivs.size(), 1);
4636 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4637 builder.getI64ArrayAttr(steps));
4638 } else {
4639 if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
4640 AffineParallelOp::getStepsAttrStrName(),
4641 stepsAttrs,
4643 return failure();
4644
4645 // Convert steps from an AffineMap into an I64ArrayAttr.
4646 SmallVector<int64_t, 4> steps;
4647 auto stepsMap = stepsMapAttr.getValue();
4648 for (const auto &result : stepsMap.getResults()) {
4649 auto constExpr = dyn_cast<AffineConstantExpr>(result);
4650 if (!constExpr)
4651 return parser.emitError(parser.getNameLoc(),
4652 "steps must be constant integers");
4653 steps.push_back(constExpr.getValue());
4654 }
4655 result.addAttribute(AffineParallelOp::getStepsAttrStrName(),
4656 builder.getI64ArrayAttr(steps));
4657 }
4658
4659 // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
4660 // quoted strings are a member of the enum AtomicRMWKind.
4661 SmallVector<Attribute, 4> reductions;
4662 if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4663 if (parser.parseLParen())
4664 return failure();
4665 auto parseAttributes = [&]() -> ParseResult {
4666 // Parse a single quoted string via the attribute parsing, and then
4667 // verify it is a member of the enum and convert to it's integer
4668 // representation.
4669 StringAttr attrVal;
4670 NamedAttrList attrStorage;
4671 auto loc = parser.getCurrentLocation();
4672 if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
4673 attrStorage))
4674 return failure();
4675 std::optional<arith::AtomicRMWKind> reduction =
4676 arith::symbolizeAtomicRMWKind(attrVal.getValue());
4677 if (!reduction)
4678 return parser.emitError(loc, "invalid reduction value: ") << attrVal;
4679 reductions.push_back(
4680 builder.getI64IntegerAttr(static_cast<int64_t>(reduction.value())));
4681 // While we keep getting commas, keep parsing.
4682 return success();
4683 };
4684 if (parser.parseCommaSeparatedList(parseAttributes) || parser.parseRParen())
4685 return failure();
4686 }
4687 result.addAttribute(AffineParallelOp::getReductionsAttrStrName(),
4688 builder.getArrayAttr(reductions));
4689
4690 // Parse return types of reductions (if any)
4691 if (parser.parseOptionalArrowTypeList(result.types))
4692 return failure();
4693
4694 // Now parse the body.
4695 Region *body = result.addRegion();
4696 for (auto &iv : ivs)
4697 iv.type = indexType;
4698 if (parser.parseRegion(*body, ivs) ||
4699 parser.parseOptionalAttrDict(result.attributes))
4700 return failure();
4701
4702 // Add a terminator if none was parsed.
4703 AffineParallelOp::ensureTerminator(*body, builder, result.location);
4704 return success();
4705}
4706
4707//===----------------------------------------------------------------------===//
4708// AffineYieldOp
4709//===----------------------------------------------------------------------===//
4710
4711LogicalResult AffineYieldOp::verify() {
4712 auto *parentOp = (*this)->getParentOp();
4713 auto results = parentOp->getResults();
4714 auto operands = getOperands();
4715
4716 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
4717 return emitOpError() << "only terminates affine.if/for/parallel regions";
4718 if (parentOp->getNumResults() != getNumOperands())
4719 return emitOpError() << "parent of yield must have same number of "
4720 "results as the yield operands";
4721 for (auto it : llvm::zip(results, operands)) {
4722 if (std::get<0>(it).getType() != std::get<1>(it).getType())
4723 return emitOpError() << "types mismatch between yield op and its parent";
4724 }
4725
4726 return success();
4727}
4728
4729//===----------------------------------------------------------------------===//
4730// AffineVectorLoadOp
4731//===----------------------------------------------------------------------===//
4732
4733void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4734 VectorType resultType, AffineMap map,
4735 ValueRange operands) {
4736 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
4737 result.addOperands(operands);
4738 if (map)
4739 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4740 result.types.push_back(resultType);
4741}
4742
4743void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4744 VectorType resultType, Value memref,
4745 AffineMap map, ValueRange mapOperands) {
4746 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4747 result.addOperands(memref);
4748 result.addOperands(mapOperands);
4749 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4750 result.types.push_back(resultType);
4751}
4752
4753void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
4754 VectorType resultType, Value memref,
4756 auto memrefType = llvm::cast<MemRefType>(memref.getType());
4757 int64_t rank = memrefType.getRank();
4758 // Create identity map for memrefs with at least one dimension or () -> ()
4759 // for zero-dimensional memrefs.
4760 auto map =
4761 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4762 build(builder, result, resultType, memref, map, indices);
4763}
4764
4765void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4766 MLIRContext *context) {
4767 results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
4768}
4769
4770ParseResult AffineVectorLoadOp::parse(OpAsmParser &parser,
4771 OperationState &result) {
4772 auto &builder = parser.getBuilder();
4773 auto indexTy = builder.getIndexType();
4774
4775 MemRefType memrefType;
4776 VectorType resultType;
4777 OpAsmParser::UnresolvedOperand memrefInfo;
4778 AffineMapAttr mapAttr;
4779 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
4780 return failure(
4781 parser.parseOperand(memrefInfo) ||
4782 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4783 AffineVectorLoadOp::getMapAttrStrName(),
4784 result.attributes) ||
4785 parser.parseOptionalAttrDict(result.attributes) ||
4786 parser.parseColonType(memrefType) || parser.parseComma() ||
4787 parser.parseType(resultType) ||
4788 parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4789 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
4790 parser.addTypeToList(resultType, result.types));
4791}
4792
4793void AffineVectorLoadOp::print(OpAsmPrinter &p) {
4794 p << " " << getMemRef() << '[';
4795 if (AffineMapAttr mapAttr =
4796 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4797 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4798 p << ']';
4799 p.printOptionalAttrDict((*this)->getAttrs(),
4800 /*elidedAttrs=*/{getMapAttrStrName()});
4801 p << " : " << getMemRefType() << ", " << getType();
4802}
4803
4804/// Verify common invariants of affine.vector_load and affine.vector_store.
4805static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
4806 VectorType vectorType) {
4807 // Check that memref and vector element types match.
4808 if (memrefType.getElementType() != vectorType.getElementType())
4809 return op->emitOpError(
4810 "requires memref and vector types of the same elemental type");
4811 return success();
4812}
4813
4814LogicalResult AffineVectorLoadOp::verify() {
4815 MemRefType memrefType = getMemRefType();
4817 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4818 getMapOperands(), memrefType,
4819 /*numIndexOperands=*/getNumOperands() - 1)))
4820 return failure();
4821
4822 if (failed(verifyVectorMemoryOp(getOperation(), memrefType, getVectorType())))
4823 return failure();
4824
4825 return success();
4826}
4827
4828//===----------------------------------------------------------------------===//
4829// AffineVectorStoreOp
4830//===----------------------------------------------------------------------===//
4831
4832void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4833 Value valueToStore, Value memref, AffineMap map,
4834 ValueRange mapOperands) {
4835 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
4836 result.addOperands(valueToStore);
4837 result.addOperands(memref);
4838 result.addOperands(mapOperands);
4839 result.addAttribute(getMapAttrStrName(), AffineMapAttr::get(map));
4840}
4841
4842// Use identity map.
4843void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
4844 Value valueToStore, Value memref,
4846 auto memrefType = llvm::cast<MemRefType>(memref.getType());
4847 int64_t rank = memrefType.getRank();
4848 // Create identity map for memrefs with at least one dimension or () -> ()
4849 // for zero-dimensional memrefs.
4850 auto map =
4851 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
4852 build(builder, result, valueToStore, memref, map, indices);
4853}
4854void AffineVectorStoreOp::getCanonicalizationPatterns(
4855 RewritePatternSet &results, MLIRContext *context) {
4856 results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
4857}
4858
4859ParseResult AffineVectorStoreOp::parse(OpAsmParser &parser,
4860 OperationState &result) {
4861 auto indexTy = parser.getBuilder().getIndexType();
4862
4863 MemRefType memrefType;
4864 VectorType resultType;
4865 OpAsmParser::UnresolvedOperand storeValueInfo;
4866 OpAsmParser::UnresolvedOperand memrefInfo;
4867 AffineMapAttr mapAttr;
4868 SmallVector<OpAsmParser::UnresolvedOperand, 1> mapOperands;
4869 return failure(
4870 parser.parseOperand(storeValueInfo) || parser.parseComma() ||
4871 parser.parseOperand(memrefInfo) ||
4872 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
4873 AffineVectorStoreOp::getMapAttrStrName(),
4874 result.attributes) ||
4875 parser.parseOptionalAttrDict(result.attributes) ||
4876 parser.parseColonType(memrefType) || parser.parseComma() ||
4877 parser.parseType(resultType) ||
4878 parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
4879 parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
4880 parser.resolveOperands(mapOperands, indexTy, result.operands));
4881}
4882
4883void AffineVectorStoreOp::print(OpAsmPrinter &p) {
4884 p << " " << getValueToStore();
4885 p << ", " << getMemRef() << '[';
4886 if (AffineMapAttr mapAttr =
4887 (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()))
4888 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
4889 p << ']';
4890 p.printOptionalAttrDict((*this)->getAttrs(),
4891 /*elidedAttrs=*/{getMapAttrStrName()});
4892 p << " : " << getMemRefType() << ", " << getValueToStore().getType();
4893}
4894
4895LogicalResult AffineVectorStoreOp::verify() {
4896 MemRefType memrefType = getMemRefType();
4898 *this, (*this)->getAttrOfType<AffineMapAttr>(getMapAttrStrName()),
4899 getMapOperands(), memrefType,
4900 /*numIndexOperands=*/getNumOperands() - 2)))
4901 return failure();
4902
4903 if (failed(verifyVectorMemoryOp(*this, memrefType, getVectorType())))
4904 return failure();
4905
4906 return success();
4907}
4908
4909//===----------------------------------------------------------------------===//
4910// DelinearizeIndexOp
4911//===----------------------------------------------------------------------===//
4912
4913void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4914 OperationState &odsState,
4915 Value linearIndex, ValueRange dynamicBasis,
4916 ArrayRef<int64_t> staticBasis,
4917 bool hasOuterBound) {
4918 SmallVector<Type> returnTypes(hasOuterBound ? staticBasis.size()
4919 : staticBasis.size() + 1,
4920 linearIndex.getType());
4921 build(odsBuilder, odsState, returnTypes, linearIndex, dynamicBasis,
4922 staticBasis);
4923}
4924
4925void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4926 OperationState &odsState,
4927 Value linearIndex, ValueRange basis,
4928 bool hasOuterBound) {
4929 if (hasOuterBound && !basis.empty() && basis.front() == nullptr) {
4930 hasOuterBound = false;
4931 basis = basis.drop_front();
4932 }
4933 SmallVector<Value> dynamicBasis;
4934 SmallVector<int64_t> staticBasis;
4935 dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
4936 staticBasis);
4937 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4938 hasOuterBound);
4939}
4940
4941void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4942 OperationState &odsState,
4943 Value linearIndex,
4944 ArrayRef<OpFoldResult> basis,
4945 bool hasOuterBound) {
4946 if (hasOuterBound && !basis.empty() && basis.front() == OpFoldResult()) {
4947 hasOuterBound = false;
4948 basis = basis.drop_front();
4949 }
4950 SmallVector<Value> dynamicBasis;
4951 SmallVector<int64_t> staticBasis;
4952 dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
4953 build(odsBuilder, odsState, linearIndex, dynamicBasis, staticBasis,
4954 hasOuterBound);
4955}
4956
4957void AffineDelinearizeIndexOp::build(OpBuilder &odsBuilder,
4958 OperationState &odsState,
4959 Value linearIndex, ArrayRef<int64_t> basis,
4960 bool hasOuterBound) {
4961 build(odsBuilder, odsState, linearIndex, ValueRange{}, basis, hasOuterBound);
4962}
4963
4964LogicalResult AffineDelinearizeIndexOp::verify() {
4965 ArrayRef<int64_t> staticBasis = getStaticBasis();
4966 if (getNumResults() != staticBasis.size() &&
4967 getNumResults() != staticBasis.size() + 1)
4968 return emitOpError("should return an index for each basis element and up "
4969 "to one extra index");
4970
4971 auto dynamicMarkersCount = llvm::count_if(staticBasis, ShapedType::isDynamic);
4972 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
4973 return emitOpError(
4974 "mismatch between dynamic and static basis (kDynamic marker but no "
4975 "corresponding dynamic basis entry) -- this can only happen due to an "
4976 "incorrect fold/rewrite");
4977
4978 if (!llvm::all_of(staticBasis, [](int64_t v) {
4979 return v > 0 || ShapedType::isDynamic(v);
4980 }))
4981 return emitOpError("no basis element may be statically non-positive");
4982
4983 return success();
4984}
4985
4986/// Given mixed basis of affine.delinearize_index/linearize_index replace
4987/// constant SSA values with the constant integer value and return the new
4988/// static basis. In case no such candidate for replacement exists, this utility
4989/// returns std::nullopt.
4990static std::optional<SmallVector<int64_t>>
4992 MutableOperandRange mutableDynamicBasis,
4993 ArrayRef<Attribute> dynamicBasis) {
4994 uint64_t dynamicBasisIndex = 0;
4995 for (Attribute basis : dynamicBasis) {
4996 // Skip poison values: they don't have a concrete integer value, so erasing
4997 // them from the dynamic operands would create an inconsistency between
4998 // the static basis (which would still hold kDynamic) and the dynamic
4999 // operand list (which would be one element shorter).
5000 if (basis && isa<IntegerAttr>(basis)) {
5001 mutableDynamicBasis.erase(dynamicBasisIndex);
5002 } else {
5003 ++dynamicBasisIndex;
5004 }
5005 }
5006
5007 // No constant SSA value exists.
5008 if (dynamicBasisIndex == dynamicBasis.size())
5009 return std::nullopt;
5010
5011 SmallVector<int64_t> staticBasis;
5012 for (OpFoldResult basis : mixedBasis) {
5013 std::optional<int64_t> basisVal = getConstantIntValue(basis);
5014 if (!basisVal)
5015 staticBasis.push_back(ShapedType::kDynamic);
5016 else
5017 staticBasis.push_back(*basisVal);
5018 }
5019
5020 return staticBasis;
5021}
5022
5023LogicalResult
5024AffineDelinearizeIndexOp::fold(FoldAdaptor adaptor,
5025 SmallVectorImpl<OpFoldResult> &result) {
5026 std::optional<SmallVector<int64_t>> maybeStaticBasis =
5027 foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
5028 adaptor.getDynamicBasis());
5029 if (maybeStaticBasis) {
5030 setStaticBasis(*maybeStaticBasis);
5031 return success();
5032 }
5033 // If we won't be doing any division or modulo (no basis or the one basis
5034 // element is purely advisory), simply return the input value.
5035 if (getNumResults() == 1) {
5036 result.push_back(getLinearIndex());
5037 return success();
5038 }
5039
5040 if (adaptor.getLinearIndex() == nullptr)
5041 return failure();
5042
5043 if (!adaptor.getDynamicBasis().empty())
5044 return failure();
5045
5046 int64_t highPart = cast<IntegerAttr>(adaptor.getLinearIndex()).getInt();
5047 Type attrType = getLinearIndex().getType();
5048
5049 ArrayRef<int64_t> staticBasis = getStaticBasis();
5050 if (hasOuterBound())
5051 staticBasis = staticBasis.drop_front();
5052 for (int64_t modulus : llvm::reverse(staticBasis)) {
5053 result.push_back(IntegerAttr::get(attrType, llvm::mod(highPart, modulus)));
5054 highPart = llvm::divideFloorSigned(highPart, modulus);
5055 }
5056 result.push_back(IntegerAttr::get(attrType, highPart));
5057 std::reverse(result.begin(), result.end());
5058 return success();
5059}
5060
5061SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getEffectiveBasis() {
5062 OpBuilder builder(getContext());
5063 if (hasOuterBound()) {
5064 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5065 return getMixedValues(getStaticBasis().drop_front(),
5066 getDynamicBasis().drop_front(), builder);
5067
5068 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5069 builder);
5070 }
5071
5072 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5073}
5074
5075SmallVector<OpFoldResult> AffineDelinearizeIndexOp::getPaddedBasis() {
5076 SmallVector<OpFoldResult> ret = getMixedBasis();
5077 if (!hasOuterBound())
5078 ret.insert(ret.begin(), OpFoldResult());
5079 return ret;
5080}
5081
5082namespace {
5083
5084// Drops delinearization indices that correspond to unit-extent basis
5085struct DropUnitExtentBasis
5086 : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5088
5089 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5090 PatternRewriter &rewriter) const override {
5091 SmallVector<Value> replacements(delinearizeOp->getNumResults(), nullptr);
5092 std::optional<Value> zero = std::nullopt;
5093 Location loc = delinearizeOp->getLoc();
5094 Type indexType = delinearizeOp.getLinearIndex().getType();
5095 auto getZero = [&]() -> Value {
5096 if (!zero)
5097 zero = arith::ConstantOp::create(rewriter, loc,
5098 rewriter.getZeroAttr(indexType));
5099 return zero.value();
5100 };
5101
5102 // Replace all indices corresponding to unit-extent basis with 0.
5103 // Remaining basis can be used to get a new `affine.delinearize_index` op.
5104 SmallVector<OpFoldResult> newBasis;
5105 for (auto [index, basis] :
5106 llvm::enumerate(delinearizeOp.getPaddedBasis())) {
5107 std::optional<int64_t> basisVal =
5108 basis ? getConstantIntValue(basis) : std::nullopt;
5109 if (basisVal == 1)
5110 replacements[index] = getZero();
5111 else
5112 newBasis.push_back(basis);
5113 }
5114
5115 if (newBasis.size() == delinearizeOp.getNumResults())
5116 return rewriter.notifyMatchFailure(delinearizeOp,
5117 "no unit basis elements");
5118
5119 if (!newBasis.empty()) {
5120 // Will drop the leading nullptr from `basis` if there was no outer bound.
5121 auto newDelinearizeOp = affine::AffineDelinearizeIndexOp::create(
5122 rewriter, loc, delinearizeOp.getLinearIndex(), newBasis);
5123 int newIndex = 0;
5124 // Map back the new delinearized indices to the values they replace.
5125 for (auto &replacement : replacements) {
5126 if (replacement)
5127 continue;
5128 replacement = newDelinearizeOp->getResult(newIndex++);
5129 }
5130 }
5131
5132 rewriter.replaceOp(delinearizeOp, replacements);
5133 return success();
5134 }
5135};
5136
5137/// If a `affine.delinearize_index`'s input is a `affine.linearize_index
5138/// disjoint` and the two operations end with the same basis elements,
5139/// cancel those parts of the operations out because they are inverses
5140/// of each other.
5141///
5142/// If the operations have the same basis, cancel them entirely.
5143///
5144/// The `disjoint` flag is needed on the `affine.linearize_index` because
5145/// otherwise, there is no guarantee that the inputs to the linearization are
5146/// in-bounds the way the outputs of the delinearization would be.
5147struct CancelDelinearizeOfLinearizeDisjointExactTail
5148 : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5150
5151 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5152 PatternRewriter &rewriter) const override {
5153 auto linearizeOp = delinearizeOp.getLinearIndex()
5154 .getDefiningOp<affine::AffineLinearizeIndexOp>();
5155 if (!linearizeOp)
5156 return rewriter.notifyMatchFailure(delinearizeOp,
5157 "index doesn't come from linearize");
5158
5159 if (!linearizeOp.getDisjoint())
5160 return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
5161
5162 ValueRange linearizeIns = linearizeOp.getMultiIndex();
5163 // Note: we use the full basis so we don't lose outer bounds later.
5164 SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
5165 SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
5166 size_t numMatches = 0;
5167 for (auto [linSize, delinSize] : llvm::zip(
5168 llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
5169 if (linSize != delinSize)
5170 break;
5171 ++numMatches;
5172 }
5173
5174 if (numMatches == 0)
5175 return rewriter.notifyMatchFailure(
5176 delinearizeOp, "final basis element doesn't match linearize");
5177
5178 // The easy case: everything lines up and the basis match sup completely.
5179 if (numMatches == linearizeBasis.size() &&
5180 numMatches == delinearizeBasis.size() &&
5181 linearizeIns.size() == delinearizeOp.getNumResults()) {
5182 rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
5183 return success();
5184 }
5185
5186 Value newLinearize = affine::AffineLinearizeIndexOp::create(
5187 rewriter, linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
5188 ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
5189 linearizeOp.getDisjoint());
5190 auto newDelinearize = affine::AffineDelinearizeIndexOp::create(
5191 rewriter, delinearizeOp.getLoc(), newLinearize,
5192 ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
5193 delinearizeOp.hasOuterBound());
5194 SmallVector<Value> mergedResults(newDelinearize.getResults());
5195 mergedResults.append(linearizeIns.take_back(numMatches).begin(),
5196 linearizeIns.take_back(numMatches).end());
5197 rewriter.replaceOp(delinearizeOp, mergedResults);
5198 return success();
5199 }
5200};
5201
5202/// If the input to a delinearization is a disjoint linearization, and the
5203/// last k > 1 components of the delinearization basis multiply to the
5204/// last component of the linearization basis, break the linearization and
5205/// delinearization into two parts, peeling off the last input to linearization.
5206///
5207/// For example:
5208/// %0 = affine.linearize_index [%z, %y, %x] by (3, 2, 32) : index
5209/// %1:4 = affine.delinearize_index %0 by (2, 3, 8, 4) : index, ...
5210/// becomes
5211/// %0 = affine.linearize_index [%z, %y] by (3, 2) : index
5212/// %1:2 = affine.delinearize_index %0 by (2, 3) : index
5213/// %2:2 = affine.delinearize_index %x by (8, 4) : index
5214/// where the original %1:4 is replaced by %1:2 ++ %2:2
5215struct SplitDelinearizeSpanningLastLinearizeArg final
5216 : OpRewritePattern<affine::AffineDelinearizeIndexOp> {
5218
5219 LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp delinearizeOp,
5220 PatternRewriter &rewriter) const override {
5221 auto linearizeOp = delinearizeOp.getLinearIndex()
5222 .getDefiningOp<affine::AffineLinearizeIndexOp>();
5223 if (!linearizeOp)
5224 return rewriter.notifyMatchFailure(delinearizeOp,
5225 "index doesn't come from linearize");
5226
5227 if (!linearizeOp.getDisjoint())
5228 return rewriter.notifyMatchFailure(linearizeOp,
5229 "linearize isn't disjoint");
5230
5231 // A linearize with no inputs has an empty basis and folds to a constant
5232 // zero; there is nothing to split, and reading its last basis element
5233 // below would be out of bounds.
5234 if (linearizeOp.getStaticBasis().empty())
5235 return rewriter.notifyMatchFailure(
5236 linearizeOp, "linearize has no basis elements (no inputs)");
5237
5238 int64_t target = linearizeOp.getStaticBasis().back();
5239 if (ShapedType::isDynamic(target))
5240 return rewriter.notifyMatchFailure(
5241 linearizeOp, "linearize ends with dynamic basis value");
5242
5243 int64_t sizeToSplit = 1;
5244 size_t elemsToSplit = 0;
5245 ArrayRef<int64_t> basis = delinearizeOp.getStaticBasis();
5246 for (int64_t basisElem : llvm::reverse(basis)) {
5247 if (ShapedType::isDynamic(basisElem))
5248 return rewriter.notifyMatchFailure(
5249 delinearizeOp, "dynamic basis element while scanning for split");
5250 sizeToSplit *= basisElem;
5251 elemsToSplit += 1;
5252
5253 if (sizeToSplit > target)
5254 return rewriter.notifyMatchFailure(delinearizeOp,
5255 "overshot last argument size");
5256 if (sizeToSplit == target)
5257 break;
5258 }
5259
5260 if (sizeToSplit < target)
5261 return rewriter.notifyMatchFailure(
5262 delinearizeOp, "product of known basis elements doesn't exceed last "
5263 "linearize argument");
5264
5265 if (elemsToSplit < 2)
5266 return rewriter.notifyMatchFailure(
5267 delinearizeOp,
5268 "need at least two elements to form the basis product");
5269
5270 Value linearizeWithoutBack = affine::AffineLinearizeIndexOp::create(
5271 rewriter, linearizeOp.getLoc(), linearizeOp.getLinearIndex().getType(),
5272 linearizeOp.getMultiIndex().drop_back(), linearizeOp.getDynamicBasis(),
5273 linearizeOp.getStaticBasis().drop_back(), linearizeOp.getDisjoint());
5274 auto delinearizeWithoutSplitPart = affine::AffineDelinearizeIndexOp::create(
5275 rewriter, delinearizeOp.getLoc(), linearizeWithoutBack,
5276 delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
5277 delinearizeOp.hasOuterBound());
5278 auto delinearizeBack = affine::AffineDelinearizeIndexOp::create(
5279 rewriter, delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
5280 basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
5281 SmallVector<Value> results = llvm::to_vector(
5282 llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
5283 delinearizeBack.getResults()));
5284 rewriter.replaceOp(delinearizeOp, results);
5285
5286 return success();
5287 }
5288};
5289} // namespace
5290
5291void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
5292 RewritePatternSet &patterns, MLIRContext *context) {
5293 patterns
5294 .insert<CancelDelinearizeOfLinearizeDisjointExactTail,
5295 DropUnitExtentBasis, SplitDelinearizeSpanningLastLinearizeArg>(
5296 context);
5297}
5298
5299//===----------------------------------------------------------------------===//
5300// LinearizeIndexOp
5301//===----------------------------------------------------------------------===//
5302
5303/// Infer the index type from a set of multi-index values. Returns the common
5304/// type (index or vector<...xindex>), or IndexType if the set is empty.
5305static Type inferIndexType(MLIRContext *ctx, ValueRange multiIndex) {
5306 if (multiIndex.empty())
5307 return IndexType::get(ctx);
5308 return multiIndex.front().getType();
5309}
5310
5311void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5312 OperationState &odsState,
5313 ValueRange multiIndex, ValueRange basis,
5314 bool disjoint) {
5315 if (!basis.empty() && basis.front() == Value())
5316 basis = basis.drop_front();
5317 SmallVector<Value> dynamicBasis;
5318 SmallVector<int64_t> staticBasis;
5319 dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis,
5320 staticBasis);
5321 Type resultType = inferIndexType(odsBuilder.getContext(), multiIndex);
5322 build(odsBuilder, odsState, resultType, multiIndex, dynamicBasis, staticBasis,
5323 disjoint);
5324}
5325
5326void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5327 OperationState &odsState,
5328 ValueRange multiIndex,
5329 ArrayRef<OpFoldResult> basis,
5330 bool disjoint) {
5331 if (!basis.empty() && basis.front() == OpFoldResult())
5332 basis = basis.drop_front();
5333 SmallVector<Value> dynamicBasis;
5334 SmallVector<int64_t> staticBasis;
5335 dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis);
5336 Type resultType = inferIndexType(odsBuilder.getContext(), multiIndex);
5337 build(odsBuilder, odsState, resultType, multiIndex, dynamicBasis, staticBasis,
5338 disjoint);
5339}
5340
5341void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder,
5342 OperationState &odsState,
5343 ValueRange multiIndex,
5344 ArrayRef<int64_t> basis, bool disjoint) {
5345 Type resultType = inferIndexType(odsBuilder.getContext(), multiIndex);
5346 build(odsBuilder, odsState, resultType, multiIndex, ValueRange{}, basis,
5347 disjoint);
5348}
5349
5350LogicalResult AffineLinearizeIndexOp::verify() {
5351 size_t numIndexes = getMultiIndex().size();
5352 size_t numBasisElems = getStaticBasis().size();
5353 if (numIndexes != numBasisElems && numIndexes != numBasisElems + 1)
5354 return emitOpError("should be passed a basis element for each index except "
5355 "possibly the first");
5356
5357 auto dynamicMarkersCount =
5358 llvm::count_if(getStaticBasis(), ShapedType::isDynamic);
5359 if (static_cast<size_t>(dynamicMarkersCount) != getDynamicBasis().size())
5360 return emitOpError(
5361 "mismatch between dynamic and static basis (kDynamic marker but no "
5362 "corresponding dynamic basis entry) -- this can only happen due to an "
5363 "incorrect fold/rewrite");
5364
5365 return success();
5366}
5367
5368OpFoldResult AffineLinearizeIndexOp::fold(FoldAdaptor adaptor) {
5369 std::optional<SmallVector<int64_t>> maybeStaticBasis =
5370 foldCstValueToCstAttrBasis(getMixedBasis(), getDynamicBasisMutable(),
5371 adaptor.getDynamicBasis());
5372 if (maybeStaticBasis) {
5373 setStaticBasis(*maybeStaticBasis);
5374 return getResult();
5375 }
5376 // No indices linearizes to zero.
5377 if (getMultiIndex().empty())
5378 return IntegerAttr::get(getResult().getType(), 0);
5379
5380 // One single index linearizes to itself.
5381 if (getMultiIndex().size() == 1)
5382 return getMultiIndex().front();
5383
5384 // Return nullptr if any multi-index attribute has not been folded to a
5385 // concrete integer (e.g. it is still a runtime value or has folded to a
5386 // non-integer attribute such as #ub.poison).
5387 if (llvm::any_of(adaptor.getMultiIndex(), [](Attribute a) {
5388 return !isa_and_nonnull<IntegerAttr>(a);
5389 }))
5390 return nullptr;
5391
5392 if (!adaptor.getDynamicBasis().empty())
5393 return nullptr;
5394
5395 int64_t result = 0;
5396 int64_t stride = 1;
5397 for (auto [length, indexAttr] :
5398 llvm::zip_first(llvm::reverse(getStaticBasis()),
5399 llvm::reverse(adaptor.getMultiIndex()))) {
5400 result = result + cast<IntegerAttr>(indexAttr).getInt() * stride;
5401 stride = stride * length;
5402 }
5403 // Handle the index element with no basis element.
5404 if (!hasOuterBound())
5405 result =
5406 result +
5407 cast<IntegerAttr>(adaptor.getMultiIndex().front()).getInt() * stride;
5408
5409 return IntegerAttr::get(getResult().getType(), result);
5410}
5411
5412SmallVector<OpFoldResult> AffineLinearizeIndexOp::getEffectiveBasis() {
5413 OpBuilder builder(getContext());
5414 if (hasOuterBound()) {
5415 if (getStaticBasis().front() == ::mlir::ShapedType::kDynamic)
5416 return getMixedValues(getStaticBasis().drop_front(),
5417 getDynamicBasis().drop_front(), builder);
5418
5419 return getMixedValues(getStaticBasis().drop_front(), getDynamicBasis(),
5420 builder);
5421 }
5422
5423 return getMixedValues(getStaticBasis(), getDynamicBasis(), builder);
5424}
5425
5426SmallVector<OpFoldResult> AffineLinearizeIndexOp::getPaddedBasis() {
5427 SmallVector<OpFoldResult> ret = getMixedBasis();
5428 if (!hasOuterBound())
5429 ret.insert(ret.begin(), OpFoldResult());
5430 return ret;
5431}
5432
5433namespace {
5434/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1,
5435/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c,
5436/// %...d)`.
5437
5438/// Note that `disjoint` is required here, because, without it, we could have
5439/// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)`
5440/// is a valid operation where the `%c64` cannot be trivially dropped.
5441///
5442/// Alternatively, if `%x` in the above is a known constant 0, remove it even if
5443/// the operation isn't asserted to be `disjoint`.
5444struct DropLinearizeUnitComponentsIfDisjointOrZero final
5445 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5447
5448 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5449 PatternRewriter &rewriter) const override {
5450 ValueRange multiIndex = op.getMultiIndex();
5451 size_t numIndices = multiIndex.size();
5452 SmallVector<Value> newIndices;
5453 newIndices.reserve(numIndices);
5454 SmallVector<OpFoldResult> newBasis;
5455 newBasis.reserve(numIndices);
5456
5457 if (!op.hasOuterBound()) {
5458 newIndices.push_back(multiIndex.front());
5459 multiIndex = multiIndex.drop_front();
5460 }
5461
5462 SmallVector<OpFoldResult> basis = op.getMixedBasis();
5463 for (auto [index, basisElem] : llvm::zip_equal(multiIndex, basis)) {
5464 std::optional<int64_t> basisEntry = getConstantIntValue(basisElem);
5465 if (!basisEntry || *basisEntry != 1) {
5466 newIndices.push_back(index);
5467 newBasis.push_back(basisElem);
5468 continue;
5469 }
5470
5471 std::optional<int64_t> indexValue = getConstantIntValue(index);
5472 if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) {
5473 newIndices.push_back(index);
5474 newBasis.push_back(basisElem);
5475 continue;
5476 }
5477 }
5478 if (newIndices.size() == numIndices)
5479 return rewriter.notifyMatchFailure(op,
5480 "no unit basis entries to replace");
5481
5482 if (newIndices.empty()) {
5483 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
5484 op, rewriter.getZeroAttr(op.getLinearIndex().getType()));
5485 return success();
5486 }
5487 rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5488 op, newIndices, newBasis, op.getDisjoint());
5489 return success();
5490 }
5491};
5492
5493OpFoldResult computeProduct(Location loc, OpBuilder &builder,
5494 ArrayRef<OpFoldResult> terms) {
5495 int64_t nDynamic = 0;
5496 SmallVector<Value> dynamicPart;
5497 AffineExpr result = builder.getAffineConstantExpr(1);
5498 for (OpFoldResult term : terms) {
5499 if (!term)
5500 return term;
5501 std::optional<int64_t> maybeConst = getConstantIntValue(term);
5502 if (maybeConst) {
5503 result = result * builder.getAffineConstantExpr(*maybeConst);
5504 } else {
5505 dynamicPart.push_back(cast<Value>(term));
5506 result = result * builder.getAffineSymbolExpr(nDynamic++);
5507 }
5508 }
5509 if (auto constant = dyn_cast<AffineConstantExpr>(result))
5510 return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
5511 return AffineApplyOp::create(builder, loc, result, dynamicPart).getResult();
5512}
5513
5514/// If conseceutive outputs of a delinearize_index are linearized with the same
5515/// bounds, canonicalize away the redundant arithmetic.
5516///
5517/// That is, if we have
5518/// ```
5519/// %s:N = affine.delinearize_index %x into (...a, B1, B2, ... BK, ...b)
5520/// %t = affine.linearize_index [...c, %s#I, %s#(I + 1), ... %s#(I+K-1), ...d]
5521/// by (...e, B1, B2, ..., BK, ...f)
5522/// ```
5523///
5524/// We can rewrite this to
5525/// ```
5526/// B = B1 * B2 ... BK
5527/// %sMerged:(N-K+1) affine.delinearize_index %x into (...a, B, ...b)
5528/// %t = affine.linearize_index [...c, %s#I, ...d] by (...e, B, ...f)
5529/// ```
5530/// where we replace all results of %s unaffected by the change with results
5531/// from %sMerged.
5532///
5533/// As a special case, if all results of the delinearize are merged in this way
5534/// we can replace those usages with %x, thus cancelling the delinearization
5535/// entirely, as in
5536/// ```
5537/// %s:3 = affine.delinearize_index %x into (2, 4, 8)
5538/// %t = affine.linearize_index [%s#0, %s#1, %s#2, %c0] by (2, 4, 8, 16)
5539/// ```
5540/// becoming `%t = affine.linearize_index [%x, %c0] by (64, 16)`
5541struct CancelLinearizeOfDelinearizePortion final
5542 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5544
5545private:
5546 // Struct representing a case where the cancellation pattern
5547 // applies. A `Match` means that `length` inputs to the linearize operation
5548 // starting at `linStart` can be cancelled with `length` outputs of
5549 // `delinearize`, starting from `delinStart`.
5550 struct Match {
5551 AffineDelinearizeIndexOp delinearize;
5552 unsigned linStart = 0;
5553 unsigned delinStart = 0;
5554 unsigned length = 0;
5555 };
5556
5557public:
5558 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp linearizeOp,
5559 PatternRewriter &rewriter) const override {
5560 SmallVector<Match> matches;
5561
5562 const SmallVector<OpFoldResult> linBasis = linearizeOp.getPaddedBasis();
5563 ArrayRef<OpFoldResult> linBasisRef = linBasis;
5564
5565 ValueRange multiIndex = linearizeOp.getMultiIndex();
5566 unsigned numLinArgs = multiIndex.size();
5567 unsigned linArgIdx = 0;
5568 // We only want to replace one run from the same delinearize op per
5569 // pattern invocation lest we run into invalidation issues.
5570 llvm::SmallPtrSet<Operation *, 2> alreadyMatchedDelinearize;
5571 while (linArgIdx < numLinArgs) {
5572 auto asResult = dyn_cast<OpResult>(multiIndex[linArgIdx]);
5573 if (!asResult) {
5574 linArgIdx++;
5575 continue;
5576 }
5577
5578 auto delinearizeOp =
5579 dyn_cast<AffineDelinearizeIndexOp>(asResult.getOwner());
5580 if (!delinearizeOp) {
5581 linArgIdx++;
5582 continue;
5583 }
5584
5585 /// Result 0 of the delinearize and argument 0 of the linearize can
5586 /// leave their maximum value unspecified. However, even if this happens
5587 /// we can still sometimes start the match process. Specifically, if
5588 /// - The argument we're matching is result 0 and argument 0 (so the
5589 /// bounds don't matter). For example,
5590 ///
5591 /// %0:2 = affine.delinearize_index %x into (8) : index, index
5592 /// %1 = affine.linearize_index [%s#0, %s#1, ...] (8, ...)
5593 /// allows cancellation
5594 /// - The delinearization doesn't specify a bound, but the linearization
5595 /// is `disjoint`, which asserts that the bound on the linearization is
5596 /// correct.
5597 unsigned delinArgIdx = asResult.getResultNumber();
5598 SmallVector<OpFoldResult> delinBasis = delinearizeOp.getPaddedBasis();
5599 OpFoldResult firstDelinBound = delinBasis[delinArgIdx];
5600 OpFoldResult firstLinBound = linBasis[linArgIdx];
5601 bool boundsMatch = firstDelinBound == firstLinBound;
5602 bool bothAtFront = linArgIdx == 0 && delinArgIdx == 0;
5603 bool knownByDisjoint =
5604 linearizeOp.getDisjoint() && delinArgIdx == 0 && !firstDelinBound;
5605 if (!boundsMatch && !bothAtFront && !knownByDisjoint) {
5606 linArgIdx++;
5607 continue;
5608 }
5609
5610 unsigned j = 1;
5611 unsigned numDelinOuts = delinearizeOp.getNumResults();
5612 for (; j + linArgIdx < numLinArgs && j + delinArgIdx < numDelinOuts;
5613 ++j) {
5614 if (multiIndex[linArgIdx + j] !=
5615 delinearizeOp.getResult(delinArgIdx + j))
5616 break;
5617 if (linBasis[linArgIdx + j] != delinBasis[delinArgIdx + j])
5618 break;
5619 }
5620 // If there're multiple matches against the same delinearize_index,
5621 // only rewrite the first one we find to prevent invalidations. The next
5622 // ones will be taken care of by subsequent pattern invocations.
5623 if (j <= 1 || !alreadyMatchedDelinearize.insert(delinearizeOp).second) {
5624 linArgIdx++;
5625 continue;
5626 }
5627 matches.push_back(Match{delinearizeOp, linArgIdx, delinArgIdx, j});
5628 linArgIdx += j;
5629 }
5630
5631 if (matches.empty())
5632 return rewriter.notifyMatchFailure(
5633 linearizeOp, "no run of delinearize outputs to deal with");
5634
5635 // Record all the delinearize replacements so we can do them after creating
5636 // the new linearization operation, since the new operation might use
5637 // outputs of something we're replacing.
5638 SmallVector<SmallVector<Value>> delinearizeReplacements;
5639
5640 SmallVector<Value> newIndex;
5641 newIndex.reserve(numLinArgs);
5642 SmallVector<OpFoldResult> newBasis;
5643 newBasis.reserve(numLinArgs);
5644 unsigned prevMatchEnd = 0;
5645 for (Match m : matches) {
5646 unsigned gap = m.linStart - prevMatchEnd;
5647 llvm::append_range(newIndex, multiIndex.slice(prevMatchEnd, gap));
5648 llvm::append_range(newBasis, linBasisRef.slice(prevMatchEnd, gap));
5649 // Update here so we don't forget this during early continues
5650 prevMatchEnd = m.linStart + m.length;
5651
5652 PatternRewriter::InsertionGuard g(rewriter);
5653 rewriter.setInsertionPoint(m.delinearize);
5654
5655 ArrayRef<OpFoldResult> basisToMerge =
5656 linBasisRef.slice(m.linStart, m.length);
5657 // We use the slice from the linearize's basis above because of the
5658 // "bounds inferred from `disjoint`" case above.
5659 OpFoldResult newSize =
5660 computeProduct(linearizeOp.getLoc(), rewriter, basisToMerge);
5661
5662 // Trivial case where we can just skip past the delinearize all together
5663 if (m.length == m.delinearize.getNumResults()) {
5664 newIndex.push_back(m.delinearize.getLinearIndex());
5665 newBasis.push_back(newSize);
5666 // Pad out set of replacements so we don't do anything with this one.
5667 delinearizeReplacements.push_back(SmallVector<Value>());
5668 continue;
5669 }
5670
5671 SmallVector<Value> newDelinResults;
5672 SmallVector<OpFoldResult> newDelinBasis = m.delinearize.getPaddedBasis();
5673 newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
5674 newDelinBasis.begin() + m.delinStart + m.length);
5675 newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5676 auto newDelinearize = AffineDelinearizeIndexOp::create(
5677 rewriter, m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5678 newDelinBasis);
5679
5680 // Since there may be other uses of the indices we just merged together,
5681 // create a residual affine.delinearize_index that delinearizes the
5682 // merged output into its component parts.
5683 Value combinedElem = newDelinearize.getResult(m.delinStart);
5684 auto residualDelinearize = AffineDelinearizeIndexOp::create(
5685 rewriter, m.delinearize.getLoc(), combinedElem, basisToMerge);
5686
5687 // Swap all the uses of the unaffected delinearize outputs to the new
5688 // delinearization so that the old code can be removed if this
5689 // linearize_index is the only user of the merged results.
5690 llvm::append_range(newDelinResults,
5691 newDelinearize.getResults().take_front(m.delinStart));
5692 llvm::append_range(newDelinResults, residualDelinearize.getResults());
5693 llvm::append_range(
5694 newDelinResults,
5695 newDelinearize.getResults().drop_front(m.delinStart + 1));
5696
5697 delinearizeReplacements.push_back(newDelinResults);
5698 newIndex.push_back(combinedElem);
5699 newBasis.push_back(newSize);
5700 }
5701 llvm::append_range(newIndex, multiIndex.drop_front(prevMatchEnd));
5702 llvm::append_range(newBasis, linBasisRef.drop_front(prevMatchEnd));
5703 rewriter.replaceOpWithNewOp<AffineLinearizeIndexOp>(
5704 linearizeOp, newIndex, newBasis, linearizeOp.getDisjoint());
5705
5706 for (auto [m, newResults] :
5707 llvm::zip_equal(matches, delinearizeReplacements)) {
5708 if (newResults.empty())
5709 continue;
5710 rewriter.replaceOp(m.delinearize, newResults);
5711 }
5712
5713 return success();
5714 }
5715};
5716
5717/// Strip leading zero from affine.linearize_index.
5718///
5719/// `affine.linearize_index [%c0, ...a] by (%x, ...b)` can be rewritten
5720/// to `affine.linearize_index [...a] by (...b)` in all cases.
5721struct DropLinearizeLeadingZero final
5722 : OpRewritePattern<affine::AffineLinearizeIndexOp> {
5724
5725 LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
5726 PatternRewriter &rewriter) const override {
5727 Value leadingIdx = op.getMultiIndex().front();
5728 if (!matchPattern(leadingIdx, m_Zero()))
5729 return failure();
5730
5731 if (op.getMultiIndex().size() == 1) {
5732 rewriter.replaceOp(op, leadingIdx);
5733 return success();
5734 }
5735
5736 SmallVector<OpFoldResult> mixedBasis = op.getMixedBasis();
5737 ArrayRef<OpFoldResult> newMixedBasis = mixedBasis;
5738 if (op.hasOuterBound())
5739 newMixedBasis = newMixedBasis.drop_front();
5740
5741 rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
5742 op, op.getMultiIndex().drop_front(), newMixedBasis, op.getDisjoint());
5743 return success();
5744 }
5745};
5746} // namespace
5747
5748void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
5749 RewritePatternSet &patterns, MLIRContext *context) {
5750 patterns.add<CancelLinearizeOfDelinearizePortion, DropLinearizeLeadingZero,
5751 DropLinearizeUnitComponentsIfDisjointOrZero>(context);
5752}
5753
5754//===----------------------------------------------------------------------===//
5755// TableGen'd op method definitions
5756//===----------------------------------------------------------------------===//
5757
5758#define GET_OP_CLASSES
5759#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
return success()
static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds known to be constants.
static bool hasTrivialZeroTripCount(AffineForOp op)
Returns true if the affine.for has zero iterations in trivial cases.
static Type inferIndexType(MLIRContext *ctx, ValueRange multiIndex)
Infer the index type from a set of multi-index values. Returns the common type (index or vector<....
static LogicalResult verifyMemoryOpIndexing(AffineMemOpTy op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands)
Verify common indexing invariants of affine.load, affine.store, affine.vector_load and affine....
static void printAffineMinMaxOp(OpAsmPrinter &p, T op)
static bool isResultTypeMatchAtomicRMWKind(Type resultType, arith::AtomicRMWKind op)
static bool remainsLegalAfterInline(Value value, Region *src, Region *dest, const IRMapping &mapping, function_ref< bool(Value, Region *)> legalityCheck)
Checks if value known to be a legal affine dimension or symbol in src region remains legal if the ope...
Definition AffineOps.cpp:62
static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr, DenseIntElementsAttr group, ValueRange operands, StringRef keyword)
Prints a lower(upper) bound of an affine parallel loop with max(min) conditions in it.
static OpFoldResult foldMinMaxOp(T op, ArrayRef< Attribute > operands)
Fold an affine min or max operation with the given operands.
static bool isTopLevelValueOrAbove(Value value, Region *region)
A utility function to check if a value is defined at the top level of region or is an argument of reg...
static LogicalResult canonicalizeLoopBounds(AffineForOp forOp)
Canonicalize the bounds of the given loop.
static void simplifyExprAndOperands(AffineExpr &expr, unsigned numDims, unsigned numSymbols, ArrayRef< Value > operands)
Simplify expr while exploiting information from the values in operands.
static bool isValidAffineIndexOperand(Value value, Region *region)
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p)
Parse a for operation loop bounds.
static std::optional< SmallVector< int64_t > > foldCstValueToCstAttrBasis(ArrayRef< OpFoldResult > mixedBasis, MutableOperandRange mutableDynamicBasis, ArrayRef< Attribute > dynamicBasis)
Given mixed basis of affine.delinearize_index/linearize_index replace constant SSA values with the co...
static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl< Value > *operands)
static void simplifyMinOrMaxExprWithOperands(AffineMap &map, ArrayRef< Value > operands, bool isMax)
Simplify the expressions in map while making use of lower or upper bounds of its operands.
static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result)
static LogicalResult replaceAffineDelinearizeIndexInverseExpression(AffineDelinearizeIndexOp delinOp, Value resultToReplace, AffineMap *map, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms)
If this map contains of the expression x_1 + x_1 * C_1 + ... x_n * C_N + / ... (not necessarily in or...
static void composeSetAndOperands(IntegerSet &set, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Compose any affine.apply ops feeding into operands of the integer set set by composing the maps of su...
static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region)
Returns true if the 'index' dimension of the memref defined by memrefDefOp is a statically shaped one...
static bool isNonNegativeBoundedBy(AffineExpr e, ArrayRef< Value > operands, int64_t k)
Check if e is known to be: 0 <= e < k.
static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn)
Creates an affine loop from the bounds that may or may not be constants.
static void simplifyMapWithOperands(AffineMap &map, ArrayRef< Value > operands)
Simplify the map while exploiting information on the values in operands.
static void printDimAndSymbolList(Operation::operand_iterator begin, Operation::operand_iterator end, unsigned numDims, OpAsmPrinter &printer)
Prints dimension and symbol list.
static int64_t getLargestKnownDivisor(AffineExpr e, ArrayRef< Value > operands)
Returns the largest known divisor of e.
static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Iterate over operands and fold away all those produced by an AffineApplyOp iteratively.
static void legalizeDemotedDims(MapOrSet &mapOrSet, SmallVectorImpl< Value > &operands)
A valid affine dimension may appear as a symbol in affine.apply operations.
static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static void buildAffineLoopNestImpl(OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn, LoopCreatorTy &&loopCreatorFn)
Builds an affine loop nest, using "loopCreatorFn" to create individual loop operations.
static LogicalResult foldLoopBounds(AffineForOp forOp)
Fold the constant bounds of a loop.
return success()
static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, AffineExpr dimOrSym, AffineMap *map, ValueRange dims, ValueRange syms)
Assuming dimOrSym is a quantity in the apply op map map and defined by minOp = affine_min(x_1,...
static SmallVector< OpFoldResult > AffineForEmptyLoopFolder(AffineForOp forOp)
Fold the empty loop.
static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims)
Utility function to verify that a set of operands are valid dimension and symbol identifiers.
static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region)
Returns true if the result of the dim op is a valid symbol for region.
static bool isQTimesDPlusR(AffineExpr e, ArrayRef< Value > operands, int64_t &div, AffineExpr &quotientTimesDiv, AffineExpr &rem)
Check if expression e is of the form d*e_1 + e_2 where 0 <= e_2 < d.
static LogicalResult replaceDimOrSym(AffineMap *map, unsigned dimOrSymbolPosition, SmallVectorImpl< Value > &dims, SmallVectorImpl< Value > &syms, bool replaceAffineMin)
Replace all occurrences of AffineExpr at position pos in map by the defining AffineApplyOp expression...
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static std::optional< uint64_t > getTrivialConstantTripCount(AffineForOp forOp)
Returns constant trip count in trivial cases.
static LogicalResult verifyAffineMinMaxOp(T op)
static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p)
static void shortenAddChainsContainingAll(AffineExpr e, const llvm::SmallDenseSet< AffineExpr, 4 > &exprsToRemove, AffineExpr newVal, DenseMap< AffineExpr, AffineExpr > &replacementsMap)
Recursively traverse e.
static void composeMultiResultAffineMap(AffineMap &map, SmallVectorImpl< Value > &operands, bool composeAffineMin=false)
Composes the given affine map with the given list of operands, pulling in the maps from any affine....
static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map)
Canonicalize the result expression order of an affine map and return success if the order changed.
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value getMemRef(Operation *memOp)
Returns the memref being read/written by a memref/affine load/store op.
Definition Utils.cpp:247
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getI64ArrayAttr(paddingDimensions)
b getContext())
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Operation::operand_range getLowerBoundOperands(AffineForOp forOp)
Definition SCFToGPU.cpp:75
static Operation::operand_range getUpperBoundOperands(AffineForOp forOp)
Definition SCFToGPU.cpp:80
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
#define div(a, b)
#define rem(a, b)
RetTy walkPostOrder(AffineExpr expr)
Base type for affine expression.
Definition AffineExpr.h:68
AffineExpr shiftDims(unsigned numDims, unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
AffineExpr shiftSymbols(unsigned numSymbols, unsigned shift, unsigned offset=0) const
Replace symbols[offset ... numSymbols) by symbols[offset + shift ... shift + numSymbols).
AffineExpr floorDiv(uint64_t v) const
AffineExprKind getKind() const
Return the classification for this type.
int64_t getLargestKnownDivisor() const
Returns the greatest known integral divisor of this affine expression.
MLIRContext * getContext() const
AffineExpr replace(AffineExpr expr, AffineExpr replacement) const
Sparse replace method.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
AffineMap getSliceMap(unsigned start, unsigned length) const
Returns the map consisting of length expressions starting from start.
MLIRContext * getContext() const
bool isFunctionOfDim(unsigned position) const
Return true if any affine expression involves AffineDimExpr position.
Definition AffineMap.h:221
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims).
Definition AffineMap.h:267
unsigned getNumSymbols() const
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
bool isFunctionOfSymbol(unsigned position) const
Return true if any affine expression involves AffineSymbolExpr position.
Definition AffineMap.h:228
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
unsigned getNumInputs() const
AffineMap shiftSymbols(unsigned shift, unsigned offset=0) const
Replace symbols[offset ... numSymbols) by symbols[offset + shift ... shift + numSymbols).
Definition AffineMap.h:280
AffineExpr getResult(unsigned idx) const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
static AffineMap getConstantMap(int64_t val, MLIRContext *context)
Returns a single constant result affine map.
AffineMap getSubMap(ArrayRef< unsigned > resultPos) const
Returns the map consisting of the resultPos subset.
LogicalResult constantFold(ArrayRef< Attribute > operandConstants, SmallVectorImpl< Attribute > &results, bool *hasPoison=nullptr) const
Folds the results of the application of an affine map on the provided operands to a constant if possi...
@ Paren
Parens surrounding zero or more operands.
@ OptionalSquare
Square brackets supporting zero or more ops, or nothing.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
BlockArgListType getArguments()
Definition Block.h:97
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:233
AffineMap getDimIdentityMap()
Definition Builders.cpp:388
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:392
AffineExpr getAffineSymbolExpr(unsigned position)
Definition Builders.cpp:373
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:377
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
Definition Builders.cpp:183
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
NoneType getNoneType()
Definition Builders.cpp:92
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
AffineMap getEmptyAffineMap()
Returns a zero result affine map with no dimensions or symbols: () -> ().
Definition Builders.cpp:381
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:329
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
Definition Builders.cpp:383
AffineMap getSymbolIdentityMap()
Definition Builders.cpp:401
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:271
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:286
IndexType getIndexType()
Definition Builders.cpp:55
An attribute that represents a reference to a dense integer vector or tensor object.
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
auto lookup(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:72
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition IntegerSet.h:44
unsigned getNumDims() const
static IntegerSet get(unsigned dimCount, unsigned symbolCount, ArrayRef< AffineExpr > constraints, ArrayRef< bool > eqFlags)
MLIRContext * getContext() const
unsigned getNumInputs() const
ArrayRef< AffineExpr > getConstraints() const
ArrayRef< bool > getEqFlags() const
Returns the equality bits, which specify whether each of the constraints is an equality or inequality...
unsigned getNumSymbols() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:119
void erase(unsigned subStart, unsigned subLen=1)
Erase the operands within the given sub-range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl< UnresolvedOperand > &operands, Attribute &map, StringRef attrName, NamedAttrList &attrs, Delimiter delimiter=Delimiter::Square)=0
Parses an affine map attribute where dims and symbols are SSA operands.
ParseResult parseAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)
Parse a list of assignments of the form (x1 = y1, x2 = y2, ...)
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseAffineExprOfSSAIds(SmallVectorImpl< UnresolvedOperand > &dimOperands, SmallVectorImpl< UnresolvedOperand > &symbOperands, AffineExpr &expr)=0
Parses an affine expression where dims and symbols are SSA operands.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, ValueRange symOperands)=0
Prints an affine expression of SSA ids with SSA id names used instead of dims and symbols.
virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, ValueRange operands)=0
Prints an affine map of SSA ids, where SSA id names are used in place of dims/symbols.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:435
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
This class represents a single result from folding an operation.
A trait of region holding operations that defines a new scope for polyhedral optimization purposes.
This class provides the API for ops that are known to be isolated from above.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:774
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:251
OperandRange operand_range
Definition Operation.h:396
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:247
operand_range::iterator operand_iterator
Definition Operation.h:397
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
RegionBranchTerminatorOpInterface getTerminatorPredecessorOrNull() const
Returns the terminator if branching from a region.
This class represents a successor of a region.
Region * getSuccessor() const
Return the given region successor.
bool isOperation() const
Return true if the successor is an operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
bool empty()
Definition Region.h:60
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceUsesWithIf(Value from, Value to, function_ref< bool(OpOperand &)> functor, bool *allUsesReplaced=nullptr)
Find uses of from and replace them with to if the functor returns true.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
std::vector< SmallVector< int64_t, 8 > > operandExprStack
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
A variable that can be added to the constraint set as a "column".
static bool compare(const Variable &lhs, ComparisonOperator cmp, const Variable &rhs)
Return "true" if "lhs cmp rhs" was proven to hold.
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
Region * getParentRegion()
Return the Region in which this Value is defined.
Definition Value.cpp:39
AffineBound represents a lower or upper bound in the for operation.
Definition AffineOps.h:217
An AffineValueMap is an affine map plus its ML value operands and results for analysis purposes.
LogicalResult canonicalize()
Attempts to canonicalize the map and operands.
ArrayRef< Value > getOperands() const
AffineExpr getResult(unsigned i)
void reset(AffineMap map, ValueRange operands, ValueRange results={})
static void difference(const AffineValueMap &a, const AffineValueMap &b, AffineValueMap *res)
Return the value map that is the difference of value maps 'a' and 'b', represented as an affine map a...
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
void buildAffineLoopNest(OpBuilder &builder, Location loc, ArrayRef< int64_t > lbs, ArrayRef< int64_t > ubs, ArrayRef< int64_t > steps, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuilderFn=nullptr)
Builds a perfect nest of affine.for loops, i.e., each loop except the innermost one contains only ano...
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
void extractForInductionVars(ArrayRef< AffineForOp > forInsts, SmallVectorImpl< Value > *ivs)
Extracts the induction variables from a list of AffineForOps and places them in the output argument i...
bool isValidDim(Value value)
Returns true if the given Value can be used as a dimension id in the region of the closest surroundin...
bool isAffineInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp or AffineParallelOp.
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
OpFoldResult computeProduct(Location loc, OpBuilder &builder, ArrayRef< OpFoldResult > terms)
Return the product of terms, creating an affine.apply if any of them are non-constant values.
AffineForOp getForInductionVarOwner(Value val)
Returns the loop parent of an induction variable.
void canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands)
Modifies both map and operands in-place so as to:
OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a maximum across the results of applying map to operands,...
bool isAffineForInductionVar(Value val)
Returns true if the provided value is the induction variable of an AffineForOp.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineMinOp that computes a minimum across the results of applying map to operands,...
bool isTopLevelValue(Value value)
A utility function to check if a value is defined at the top level of an op with trait AffineScope or...
Region * getAffineAnalysisScope(Operation *op)
Returns the closest region enclosing op that is held by a non-affine operation; nullptr if there is n...
void fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl< Value > *operands, bool composeAffineMin=false)
Given an affine map map and its input operands, this method composes into map, maps of AffineApplyOps...
void canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl< Value > *operands)
Canonicalizes an integer set the same way canonicalizeMapAndOperands does for affine maps.
void extractInductionVars(ArrayRef< Operation * > affineOps, SmallVectorImpl< Value > &ivs)
Extracts the induction variables from a list of either AffineForOp or AffineParallelOp and places the...
bool isValidSymbol(Value value)
Returns true if the given value can be used as a symbol in the region of the closest surrounding op t...
AffineParallelOp getAffineParallelInductionVarOwner(Value val)
Returns true if the provided value is among the induction variables of an AffineParallelOp.
Region * getAffineScope(Operation *op)
Returns the closest region enclosing op that is held by an operation with trait AffineScope; nullptr ...
ParseResult parseDimAndSymbolList(OpAsmParser &parser, SmallVectorImpl< Value > &operands, unsigned &numDims)
Parses dimension and symbol list.
bool isAffineParallelInductionVar(Value val)
Returns true if val is the induction variable of an AffineParallelOp.
AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns an AffineMinOp obtained by composing map and operands with AffineApplyOps supplying those ope...
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Include the generated interface declarations.
AffineMap simplifyAffineMap(AffineMap map)
Simplifies an affine map by simplifying its underlying AffineExpr results.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
AffineMap removeDuplicateExprs(AffineMap map)
Returns a map with the same dimension and symbol count as map, but whose results are the unique affin...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
std::function< SmallVector< Value >( OpBuilder &b, Location loc, ArrayRef< BlockArgument > newBbArgs)> NewYieldValuesFn
A function that returns the additional yielded values during replaceWithAdditionalYields.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
std::optional< int64_t > getBoundForAffineExpr(AffineExpr expr, unsigned numDims, unsigned numSymbols, ArrayRef< std::optional< int64_t > > constLowerBounds, ArrayRef< std::optional< int64_t > > constUpperBounds, bool isUpper)
Get a lower or upper (depending on isUpper) bound for expr while using the constant lower and upper b...
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
AffineExprKind
Definition AffineExpr.h:40
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
Definition AffineExpr.h:50
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
Definition AffineExpr.h:46
@ DimId
Dimensional identifier.
Definition AffineExpr.h:59
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
Definition AffineExpr.h:48
@ SymbolId
Symbolic identifier.
Definition AffineExpr.h:61
AffineExpr getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs)
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
AffineMap foldAttributesIntoMap(Builder &b, AffineMap map, ArrayRef< OpFoldResult > operands, SmallVector< Value > &remainingValues)
Fold all attributes among the given operands into the affine map.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
AffineExpr getAffineSymbolExpr(unsigned position, MLIRContext *context)
Canonicalize the affine map result expression order of an affine min/max operation.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Remove duplicated expressions in affine min/max ops.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
Merge an affine min/max op to its consumers if its consumer is also an affine min/max op.
LogicalResult matchAndRewrite(T affineOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
This class represents a listener that may be used to hook into various actions within an OpBuilder.
Definition Builders.h:287
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.