28#define GEN_PASS_DEF_LOWERAFFINEPASS
29#include "mlir/Conversion/Passes.h.inc"
43 arith::CmpIPredicate predicate,
45 assert(!values.empty() &&
"empty min/max chain");
46 assert(predicate == arith::CmpIPredicate::sgt ||
47 predicate == arith::CmpIPredicate::slt);
49 auto valueIt = values.begin();
50 Value value = *valueIt++;
51 for (; valueIt != values.end(); ++valueIt) {
52 if (predicate == arith::CmpIPredicate::sgt)
53 value = arith::MaxSIOp::create(builder, loc, value, *valueIt);
55 value = arith::MinSIOp::create(builder, loc, value, *valueIt);
65 if (
auto values = expandAffineMap(builder, loc, map, operands))
75 if (
auto values = expandAffineMap(builder, loc, map, operands))
86 op.getUpperBoundOperands());
94 op.getLowerBoundOperands());
100 using OpRewritePattern<AffineMinOp>::OpRewritePattern;
102 LogicalResult matchAndRewrite(AffineMinOp op,
103 PatternRewriter &rewriter)
const override {
116 using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
118 LogicalResult matchAndRewrite(AffineMaxOp op,
119 PatternRewriter &rewriter)
const override {
133 using OpRewritePattern<AffineYieldOp>::OpRewritePattern;
135 LogicalResult matchAndRewrite(AffineYieldOp op,
136 PatternRewriter &rewriter)
const override {
137 if (isa<scf::ParallelOp, AffineParallelOp>(op->getParentOp())) {
149 using OpRewritePattern<AffineForOp>::OpRewritePattern;
151 LogicalResult matchAndRewrite(AffineForOp op,
152 PatternRewriter &rewriter)
const override {
153 Location loc = op.getLoc();
158 auto scfForOp = scf::ForOp::create(rewriter, loc, lowerBound, upperBound,
159 step, op.getInits());
162 scfForOp.getRegion().end());
163 rewriter.
replaceOp(op, scfForOp.getResults());
172 using OpRewritePattern<AffineParallelOp>::OpRewritePattern;
174 LogicalResult matchAndRewrite(AffineParallelOp op,
175 PatternRewriter &rewriter)
const override {
176 Location loc = op.getLoc();
177 SmallVector<Value, 8> steps;
178 SmallVector<Value, 8> upperBoundTuple;
179 SmallVector<Value, 8> lowerBoundTuple;
180 SmallVector<Value, 8> identityVals;
183 lowerBoundTuple.reserve(op.getNumDims());
184 upperBoundTuple.reserve(op.getNumDims());
185 for (
unsigned i = 0, e = op.getNumDims(); i < e; ++i) {
187 op.getLowerBoundsOperands());
190 lowerBoundTuple.push_back(lower);
193 op.getUpperBoundsOperands());
196 upperBoundTuple.push_back(upper);
198 steps.reserve(op.getSteps().size());
199 for (int64_t step : op.getSteps())
203 auto affineParOpTerminator =
204 cast<AffineYieldOp>(op.getBody()->getTerminator());
205 scf::ParallelOp parOp;
206 if (op.getResults().empty()) {
208 parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple,
209 upperBoundTuple, steps,
213 parOp.getRegion().end());
214 rewriter.
replaceOp(op, parOp.getResults());
222 ArrayRef<Attribute> reductions = op.getReductions().getValue();
223 for (
auto pair : llvm::zip(reductions, op.getResultTypes())) {
226 Attribute reduction = std::get<0>(pair);
227 Type resultType = std::get<1>(pair);
228 std::optional<arith::AtomicRMWKind> reductionOp =
229 arith::symbolizeAtomicRMWKind(
230 static_cast<uint64_t
>(cast<IntegerAttr>(reduction).getInt()));
231 assert(reductionOp &&
"Reduction operation cannot be of None Type");
232 arith::AtomicRMWKind reductionOpValue = *reductionOp;
234 arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc);
237 op,
"unsupported reduction kind for identity value");
238 identityVals.push_back(identityVal);
240 parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple,
241 upperBoundTuple, steps, identityVals,
247 parOp.getRegion().end());
248 assert(reductions.size() == affineParOpTerminator->getNumOperands() &&
249 "Unequal number of reductions and operands.");
254 affineParOpTerminator, affineParOpTerminator->getOperands());
255 for (
unsigned i = 0, end = reductions.size(); i < end; i++) {
257 std::optional<arith::AtomicRMWKind> reductionOp =
258 arith::symbolizeAtomicRMWKind(
259 cast<IntegerAttr>(reductions[i]).getInt());
260 assert(reductionOp &&
"Reduction Operation cannot be of None Type");
261 arith::AtomicRMWKind reductionOpValue = *reductionOp;
263 Block &reductionBody = reduceOp.getReductions()[i].front();
265 Value reductionResult = arith::getReductionOp(
266 reductionOpValue, rewriter, loc, reductionBody.
getArgument(0),
268 scf::ReduceReturnOp::create(rewriter, loc, reductionResult);
270 rewriter.
replaceOp(op, parOp.getResults());
277 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
279 LogicalResult matchAndRewrite(AffineIfOp op,
280 PatternRewriter &rewriter)
const override {
281 auto loc = op.getLoc();
284 auto integerSet = op.getIntegerSet();
286 SmallVector<Value, 8> operands(op.getOperands());
287 auto operandsRef = llvm::ArrayRef(operands);
290 Value cond =
nullptr;
291 for (
unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) {
292 AffineExpr constraintExpr = integerSet.getConstraint(i);
293 bool isEquality = integerSet.isEq(i);
296 auto numDims = integerSet.getNumDims();
297 Value affResult = expandAffineExpr(rewriter, loc, constraintExpr,
298 operandsRef.take_front(numDims),
299 operandsRef.drop_front(numDims));
303 isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge;
305 arith::CmpIOp::create(rewriter, loc, pred, affResult, zeroConstant);
307 cond ? arith::AndIOp::create(rewriter, loc, cond, cmpVal).getResult()
314 bool hasElseRegion = !op.getElseRegion().empty();
315 auto ifOp = scf::IfOp::create(rewriter, loc, op.getResultTypes(), cond,
318 &ifOp.getThenRegion().back());
319 rewriter.
eraseBlock(&ifOp.getThenRegion().back());
322 &ifOp.getElseRegion().back());
323 rewriter.
eraseBlock(&ifOp.getElseRegion().back());
327 rewriter.
replaceOp(op, ifOp.getResults());
336 using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
338 LogicalResult matchAndRewrite(AffineApplyOp op,
339 PatternRewriter &rewriter)
const override {
340 auto maybeExpandedMap =
341 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
342 llvm::to_vector<8>(op.getOperands()));
343 if (!maybeExpandedMap)
345 rewriter.
replaceOp(op, *maybeExpandedMap);
355 using OpRewritePattern<AffineLoadOp>::OpRewritePattern;
357 LogicalResult matchAndRewrite(AffineLoadOp op,
358 PatternRewriter &rewriter)
const override {
360 SmallVector<Value, 8>
indices(op.getMapOperands());
361 auto resultOperands =
362 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
indices);
378 using OpRewritePattern<AffinePrefetchOp>::OpRewritePattern;
380 LogicalResult matchAndRewrite(AffinePrefetchOp op,
381 PatternRewriter &rewriter)
const override {
383 SmallVector<Value, 8>
indices(op.getMapOperands());
384 auto resultOperands =
385 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
indices);
391 op, op.getMemref(), *resultOperands, op.getIsWrite(),
392 op.getLocalityHint(), op.getIsDataCache());
402 using OpRewritePattern<AffineStoreOp>::OpRewritePattern;
404 LogicalResult matchAndRewrite(AffineStoreOp op,
405 PatternRewriter &rewriter)
const override {
407 SmallVector<Value, 8>
indices(op.getMapOperands());
408 auto maybeExpandedMap =
409 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
indices);
410 if (!maybeExpandedMap)
415 op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
426 using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern;
428 LogicalResult matchAndRewrite(AffineDmaStartOp op,
429 PatternRewriter &rewriter)
const override {
430 SmallVector<Value, 8> operands(op.getOperands());
431 auto operandsRef = llvm::ArrayRef(operands);
434 auto maybeExpandedSrcMap = expandAffineMap(
435 rewriter, op.getLoc(), op.getSrcMap(),
436 operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1));
437 if (!maybeExpandedSrcMap)
440 auto maybeExpandedDstMap = expandAffineMap(
441 rewriter, op.getLoc(), op.getDstMap(),
442 operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1));
443 if (!maybeExpandedDstMap)
446 auto maybeExpandedTagMap = expandAffineMap(
447 rewriter, op.getLoc(), op.getTagMap(),
448 operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1));
449 if (!maybeExpandedTagMap)
454 op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
455 *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
456 *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
466 using OpRewritePattern<AffineDmaWaitOp>::OpRewritePattern;
468 LogicalResult matchAndRewrite(AffineDmaWaitOp op,
469 PatternRewriter &rewriter)
const override {
471 SmallVector<Value, 8>
indices(op.getTagIndices());
472 auto maybeExpandedTagMap =
473 expandAffineMap(rewriter, op.getLoc(), op.getTagMap(),
indices);
474 if (!maybeExpandedTagMap)
479 op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
487class AffineVectorLoadLowering :
public OpRewritePattern<AffineVectorLoadOp> {
489 using OpRewritePattern<AffineVectorLoadOp>::OpRewritePattern;
491 LogicalResult matchAndRewrite(AffineVectorLoadOp op,
492 PatternRewriter &rewriter)
const override {
494 SmallVector<Value, 8>
indices(op.getMapOperands());
495 auto resultOperands =
496 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
indices);
502 op, op.getVectorType(), op.getMemRef(), *resultOperands);
510class AffineVectorStoreLowering :
public OpRewritePattern<AffineVectorStoreOp> {
512 using OpRewritePattern<AffineVectorStoreOp>::OpRewritePattern;
514 LogicalResult matchAndRewrite(AffineVectorStoreOp op,
515 PatternRewriter &rewriter)
const override {
517 SmallVector<Value, 8>
indices(op.getMapOperands());
518 auto maybeExpandedMap =
519 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
indices);
520 if (!maybeExpandedMap)
524 op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
535 AffineDmaStartLowering,
536 AffineDmaWaitLowering,
540 AffineParallelLowering,
541 AffinePrefetchLowering,
545 AffineYieldOpLowering>(patterns.
getContext());
553 AffineVectorLoadLowering,
554 AffineVectorStoreLowering>(patterns.
getContext());
559class LowerAffine :
public impl::LowerAffinePassBase<LowerAffine> {
560 void runOnOperation()
override {
566 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
567 scf::SCFDialect, VectorDialect>();
568 if (failed(applyPartialConversion(getOperation(),
target,
569 std::move(patterns))))
static Value lowerAffineMapMax(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands)
Emit instructions that correspond to computing the maximum value among the values of a (potentially) ...
static Value buildMinMaxReductionSeq(Location loc, arith::CmpIPredicate predicate, ValueRange values, OpBuilder &builder)
Given a range of values, emit the code that reduces them with "min" or "max" depending on the provide...
static Value lowerAffineMapMin(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands)
Emit instructions that correspond to computing the minimum value among the values of a (potentially) ...
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
BlockArgument getArgument(unsigned i)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns)
Populate patterns that expand affine index operations into more fundamental operations (not necessari...
Include the generated interface declarations.
Value lowerAffineUpperBound(affine::AffineForOp op, OpBuilder &builder)
Emit code that computes the upper bound of the given affine loop using standard arithmetic operations...
void populateAffineToVectorConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert vector-related Affine ops to the Vector dialect.
void populateAffineToStdConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the Affine dialect to the Standard dialect,...
Value lowerAffineLowerBound(affine::AffineForOp op, OpBuilder &builder)
Emit code that computes the lower bound of the given affine loop using standard arithmetic operations...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...