28#define GEN_PASS_DEF_LOWERAFFINEPASS
29#include "mlir/Conversion/Passes.h.inc"
43 arith::CmpIPredicate predicate,
45 assert(!values.empty() &&
"empty min/max chain");
46 assert(predicate == arith::CmpIPredicate::sgt ||
47 predicate == arith::CmpIPredicate::slt);
49 auto valueIt = values.begin();
50 Value value = *valueIt++;
51 for (; valueIt != values.end(); ++valueIt) {
52 if (predicate == arith::CmpIPredicate::sgt)
53 value = arith::MaxSIOp::create(builder, loc, value, *valueIt);
55 value = arith::MinSIOp::create(builder, loc, value, *valueIt);
65 if (
auto values = expandAffineMap(builder, loc, map, operands))
75 if (
auto values = expandAffineMap(builder, loc, map, operands))
86 op.getUpperBoundOperands());
94 op.getLowerBoundOperands());
100 using OpRewritePattern<AffineMinOp>::OpRewritePattern;
102 LogicalResult matchAndRewrite(AffineMinOp op,
103 PatternRewriter &rewriter)
const override {
116 using OpRewritePattern<AffineMaxOp>::OpRewritePattern;
118 LogicalResult matchAndRewrite(AffineMaxOp op,
119 PatternRewriter &rewriter)
const override {
133 using OpRewritePattern<AffineYieldOp>::OpRewritePattern;
135 LogicalResult matchAndRewrite(AffineYieldOp op,
136 PatternRewriter &rewriter)
const override {
137 if (isa<scf::ParallelOp>(op->getParentOp())) {
149 using OpRewritePattern<AffineForOp>::OpRewritePattern;
151 LogicalResult matchAndRewrite(AffineForOp op,
152 PatternRewriter &rewriter)
const override {
153 Location loc = op.getLoc();
158 auto scfForOp = scf::ForOp::create(rewriter, loc, lowerBound, upperBound,
159 step, op.getInits());
162 scfForOp.getRegion().end());
163 rewriter.
replaceOp(op, scfForOp.getResults());
172 using OpRewritePattern<AffineParallelOp>::OpRewritePattern;
174 LogicalResult matchAndRewrite(AffineParallelOp op,
175 PatternRewriter &rewriter)
const override {
176 Location loc = op.getLoc();
177 SmallVector<Value, 8> steps;
178 SmallVector<Value, 8> upperBoundTuple;
179 SmallVector<Value, 8> lowerBoundTuple;
180 SmallVector<Value, 8> identityVals;
183 lowerBoundTuple.reserve(op.getNumDims());
184 upperBoundTuple.reserve(op.getNumDims());
185 for (
unsigned i = 0, e = op.getNumDims(); i < e; ++i) {
187 op.getLowerBoundsOperands());
190 lowerBoundTuple.push_back(lower);
193 op.getUpperBoundsOperands());
196 upperBoundTuple.push_back(upper);
198 steps.reserve(op.getSteps().size());
199 for (int64_t step : op.getSteps())
203 auto affineParOpTerminator =
204 cast<AffineYieldOp>(op.getBody()->getTerminator());
205 scf::ParallelOp parOp;
206 if (op.getResults().empty()) {
208 parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple,
209 upperBoundTuple, steps,
213 parOp.getRegion().end());
214 rewriter.
replaceOp(op, parOp.getResults());
222 ArrayRef<Attribute> reductions = op.getReductions().getValue();
223 for (
auto pair : llvm::zip(reductions, op.getResultTypes())) {
226 Attribute reduction = std::get<0>(pair);
227 Type resultType = std::get<1>(pair);
228 std::optional<arith::AtomicRMWKind> reductionOp =
229 arith::symbolizeAtomicRMWKind(
230 static_cast<uint64_t
>(cast<IntegerAttr>(reduction).getInt()));
231 assert(reductionOp &&
"Reduction operation cannot be of None Type");
232 arith::AtomicRMWKind reductionOpValue = *reductionOp;
233 identityVals.push_back(
234 arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
236 parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple,
237 upperBoundTuple, steps, identityVals,
243 parOp.getRegion().end());
244 assert(reductions.size() == affineParOpTerminator->getNumOperands() &&
245 "Unequal number of reductions and operands.");
250 affineParOpTerminator, affineParOpTerminator->getOperands());
251 for (
unsigned i = 0, end = reductions.size(); i < end; i++) {
253 std::optional<arith::AtomicRMWKind> reductionOp =
254 arith::symbolizeAtomicRMWKind(
255 cast<IntegerAttr>(reductions[i]).getInt());
256 assert(reductionOp &&
"Reduction Operation cannot be of None Type");
257 arith::AtomicRMWKind reductionOpValue = *reductionOp;
259 Block &reductionBody = reduceOp.getReductions()[i].front();
261 Value reductionResult = arith::getReductionOp(
262 reductionOpValue, rewriter, loc, reductionBody.
getArgument(0),
264 scf::ReduceReturnOp::create(rewriter, loc, reductionResult);
266 rewriter.
replaceOp(op, parOp.getResults());
273 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
275 LogicalResult matchAndRewrite(AffineIfOp op,
276 PatternRewriter &rewriter)
const override {
277 auto loc = op.getLoc();
280 auto integerSet = op.getIntegerSet();
282 SmallVector<Value, 8> operands(op.getOperands());
283 auto operandsRef = llvm::ArrayRef(operands);
286 Value cond =
nullptr;
287 for (
unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) {
288 AffineExpr constraintExpr = integerSet.getConstraint(i);
289 bool isEquality = integerSet.isEq(i);
292 auto numDims = integerSet.getNumDims();
293 Value affResult = expandAffineExpr(rewriter, loc, constraintExpr,
294 operandsRef.take_front(numDims),
295 operandsRef.drop_front(numDims));
299 isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge;
301 arith::CmpIOp::create(rewriter, loc, pred, affResult, zeroConstant);
303 cond ? arith::AndIOp::create(rewriter, loc, cond, cmpVal).getResult()
310 bool hasElseRegion = !op.getElseRegion().empty();
311 auto ifOp = scf::IfOp::create(rewriter, loc, op.getResultTypes(), cond,
314 &ifOp.getThenRegion().back());
315 rewriter.
eraseBlock(&ifOp.getThenRegion().back());
318 &ifOp.getElseRegion().back());
319 rewriter.
eraseBlock(&ifOp.getElseRegion().back());
323 rewriter.
replaceOp(op, ifOp.getResults());
332 using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
334 LogicalResult matchAndRewrite(AffineApplyOp op,
335 PatternRewriter &rewriter)
const override {
336 auto maybeExpandedMap =
337 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
338 llvm::to_vector<8>(op.getOperands()));
339 if (!maybeExpandedMap)
341 rewriter.
replaceOp(op, *maybeExpandedMap);
351 using OpRewritePattern<AffineLoadOp>::OpRewritePattern;
353 LogicalResult matchAndRewrite(AffineLoadOp op,
354 PatternRewriter &rewriter)
const override {
356 SmallVector<Value, 8>
indices(op.getMapOperands());
357 auto resultOperands =
358 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
indices);
374 using OpRewritePattern<AffinePrefetchOp>::OpRewritePattern;
376 LogicalResult matchAndRewrite(AffinePrefetchOp op,
377 PatternRewriter &rewriter)
const override {
379 SmallVector<Value, 8>
indices(op.getMapOperands());
380 auto resultOperands =
381 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
indices);
387 op, op.getMemref(), *resultOperands, op.getIsWrite(),
388 op.getLocalityHint(), op.getIsDataCache());
398 using OpRewritePattern<AffineStoreOp>::OpRewritePattern;
400 LogicalResult matchAndRewrite(AffineStoreOp op,
401 PatternRewriter &rewriter)
const override {
403 SmallVector<Value, 8>
indices(op.getMapOperands());
404 auto maybeExpandedMap =
405 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
indices);
406 if (!maybeExpandedMap)
411 op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
422 using OpRewritePattern<AffineDmaStartOp>::OpRewritePattern;
424 LogicalResult matchAndRewrite(AffineDmaStartOp op,
425 PatternRewriter &rewriter)
const override {
426 SmallVector<Value, 8> operands(op.getOperands());
427 auto operandsRef = llvm::ArrayRef(operands);
430 auto maybeExpandedSrcMap = expandAffineMap(
433 if (!maybeExpandedSrcMap)
436 auto maybeExpandedDstMap = expandAffineMap(
439 if (!maybeExpandedDstMap)
442 auto maybeExpandedTagMap = expandAffineMap(
445 if (!maybeExpandedTagMap)
450 op, op.
getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
451 *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
452 *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
462 using OpRewritePattern<AffineDmaWaitOp>::OpRewritePattern;
464 LogicalResult matchAndRewrite(AffineDmaWaitOp op,
465 PatternRewriter &rewriter)
const override {
468 auto maybeExpandedTagMap =
470 if (!maybeExpandedTagMap)
475 op, op.
getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
483class AffineVectorLoadLowering :
public OpRewritePattern<AffineVectorLoadOp> {
485 using OpRewritePattern<AffineVectorLoadOp>::OpRewritePattern;
487 LogicalResult matchAndRewrite(AffineVectorLoadOp op,
488 PatternRewriter &rewriter)
const override {
490 SmallVector<Value, 8>
indices(op.getMapOperands());
491 auto resultOperands =
492 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
indices);
498 op, op.getVectorType(), op.getMemRef(), *resultOperands);
506class AffineVectorStoreLowering :
public OpRewritePattern<AffineVectorStoreOp> {
508 using OpRewritePattern<AffineVectorStoreOp>::OpRewritePattern;
510 LogicalResult matchAndRewrite(AffineVectorStoreOp op,
511 PatternRewriter &rewriter)
const override {
513 SmallVector<Value, 8>
indices(op.getMapOperands());
514 auto maybeExpandedMap =
515 expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
indices);
516 if (!maybeExpandedMap)
520 op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
531 AffineDmaStartLowering,
532 AffineDmaWaitLowering,
536 AffineParallelLowering,
537 AffinePrefetchLowering,
541 AffineYieldOpLowering>(
patterns.getContext());
549 AffineVectorLoadLowering,
550 AffineVectorStoreLowering>(
patterns.getContext());
556 void runOnOperation()
override {
562 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
563 scf::SCFDialect, VectorDialect>();
564 if (failed(applyPartialConversion(getOperation(),
target,
static Value lowerAffineMapMax(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands)
Emit instructions that correspond to computing the maximum value among the values of a (potentially) ...
static Value buildMinMaxReductionSeq(Location loc, arith::CmpIPredicate predicate, ValueRange values, OpBuilder &builder)
Given a range of values, emit the code that reduces them with "min" or "max" depending on the provide...
static Value lowerAffineMapMin(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands)
Emit instructions that correspond to computing the minimum value among the values of a (potentially) ...
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
BlockArgument getArgument(unsigned i)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Location getLoc()
The source location the operation was defined or derived from.
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
AffineMap getDstMap()
Returns the affine map used to access the destination memref.
unsigned getSrcMemRefOperandIndex()
Returns the operand index of the source memref.
unsigned getTagMemRefOperandIndex()
Returns the operand index of the tag memref.
AffineMap getSrcMap()
Returns the affine map used to access the source memref.
AffineMap getTagMap()
Returns the affine map used to access the tag memref.
unsigned getDstMemRefOperandIndex()
Returns the operand index of the destination memref.
Value getSrcMemRef()
Returns the source MemRefType for this DMA operation.
Value getTagMemRef()
Returns the Tag MemRef associated with the DMA operation being waited on.
AffineMap getTagMap()
Returns the affine map used to access the tag memref.
operand_range getTagIndices()
Returns the tag memref index for this DMA operation.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns)
Populate patterns that expand affine index operations into more fundamental operations (not necessari...
Include the generated interface declarations.
Value lowerAffineUpperBound(affine::AffineForOp op, OpBuilder &builder)
Emit code that computes the upper bound of the given affine loop using standard arithmetic operations...
const FrozenRewritePatternSet & patterns
void populateAffineToVectorConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert vector-related Affine ops to the Vector dialect.
void populateAffineToStdConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the Affine dialect to the Standard dialect,...
Value lowerAffineLowerBound(affine::AffineForOp op, OpBuilder &builder)
Emit code that computes the lower bound of the given affine loop using standard arithmetic operations...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...