28 #define GEN_PASS_DEF_LOWERAFFINEPASS
29 #include "mlir/Conversion/Passes.h.inc"
43 arith::CmpIPredicate predicate,
45 assert(!values.empty() &&
"empty min/max chain");
46 assert(predicate == arith::CmpIPredicate::sgt ||
47 predicate == arith::CmpIPredicate::slt);
49 auto valueIt = values.begin();
50 Value value = *valueIt++;
51 for (; valueIt != values.end(); ++valueIt) {
52 if (predicate == arith::CmpIPredicate::sgt)
53 value = builder.
create<arith::MaxSIOp>(loc, value, *valueIt);
55 value = builder.
create<arith::MinSIOp>(loc, value, *valueIt);
86 op.getUpperBoundOperands());
94 op.getLowerBoundOperands());
102 LogicalResult matchAndRewrite(AffineMinOp op,
118 LogicalResult matchAndRewrite(AffineMaxOp op,
135 LogicalResult matchAndRewrite(AffineYieldOp op,
137 if (isa<scf::ParallelOp>(op->getParentOp())) {
151 LogicalResult matchAndRewrite(AffineForOp op,
157 rewriter.
create<arith::ConstantIndexOp>(loc, op.getStepAsInt());
158 auto scfForOp = rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound,
159 step, op.getInits());
162 scfForOp.getRegion().end());
163 rewriter.
replaceOp(op, scfForOp.getResults());
174 LogicalResult matchAndRewrite(AffineParallelOp op,
183 lowerBoundTuple.reserve(op.getNumDims());
184 upperBoundTuple.reserve(op.getNumDims());
185 for (
unsigned i = 0, e = op.getNumDims(); i < e; ++i) {
187 op.getLowerBoundsOperands());
190 lowerBoundTuple.push_back(lower);
193 op.getUpperBoundsOperands());
196 upperBoundTuple.push_back(upper);
198 steps.reserve(op.getSteps().size());
199 for (int64_t step : op.getSteps())
200 steps.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, step));
203 auto affineParOpTerminator =
204 cast<AffineYieldOp>(op.getBody()->getTerminator());
205 scf::ParallelOp parOp;
206 if (op.getResults().empty()) {
208 parOp = rewriter.
create<scf::ParallelOp>(loc, lowerBoundTuple,
209 upperBoundTuple, steps,
213 parOp.getRegion().end());
214 rewriter.
replaceOp(op, parOp.getResults());
223 for (
auto pair : llvm::zip(reductions, op.getResultTypes())) {
227 Type resultType = std::get<1>(pair);
228 std::optional<arith::AtomicRMWKind> reductionOp =
229 arith::symbolizeAtomicRMWKind(
230 static_cast<uint64_t
>(cast<IntegerAttr>(reduction).getInt()));
231 assert(reductionOp &&
"Reduction operation cannot be of None Type");
232 arith::AtomicRMWKind reductionOpValue = *reductionOp;
233 identityVals.push_back(
236 parOp = rewriter.
create<scf::ParallelOp>(
237 loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
243 parOp.getRegion().end());
244 assert(reductions.size() == affineParOpTerminator->getNumOperands() &&
245 "Unequal number of reductions and operands.");
250 affineParOpTerminator, affineParOpTerminator->getOperands());
251 for (
unsigned i = 0, end = reductions.size(); i < end; i++) {
253 std::optional<arith::AtomicRMWKind> reductionOp =
254 arith::symbolizeAtomicRMWKind(
255 cast<IntegerAttr>(reductions[i]).getInt());
256 assert(reductionOp &&
"Reduction Operation cannot be of None Type");
257 arith::AtomicRMWKind reductionOpValue = *reductionOp;
259 Block &reductionBody = reduceOp.getReductions()[i].
front();
262 reductionOpValue, rewriter, loc, reductionBody.
getArgument(0),
264 rewriter.
create<scf::ReduceReturnOp>(loc, reductionResult);
266 rewriter.
replaceOp(op, parOp.getResults());
275 LogicalResult matchAndRewrite(AffineIfOp op,
277 auto loc = op.getLoc();
280 auto integerSet = op.getIntegerSet();
281 Value zeroConstant = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
286 Value cond =
nullptr;
287 for (
unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) {
288 AffineExpr constraintExpr = integerSet.getConstraint(i);
289 bool isEquality = integerSet.isEq(i);
292 auto numDims = integerSet.getNumDims();
294 operandsRef.take_front(numDims),
295 operandsRef.drop_front(numDims));
299 isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge;
301 rewriter.
create<arith::CmpIOp>(loc, pred, affResult, zeroConstant);
303 ? rewriter.
create<arith::AndIOp>(loc, cond, cmpVal).getResult()
307 : rewriter.
create<arith::ConstantIntOp>(loc, 1,
310 bool hasElseRegion = !op.getElseRegion().empty();
311 auto ifOp = rewriter.
create<scf::IfOp>(loc, op.getResultTypes(), cond,
314 &ifOp.getThenRegion().back());
315 rewriter.
eraseBlock(&ifOp.getThenRegion().back());
318 &ifOp.getElseRegion().back());
319 rewriter.
eraseBlock(&ifOp.getElseRegion().back());
323 rewriter.
replaceOp(op, ifOp.getResults());
334 LogicalResult matchAndRewrite(AffineApplyOp op,
336 auto maybeExpandedMap =
338 llvm::to_vector<8>(op.getOperands()));
339 if (!maybeExpandedMap)
341 rewriter.
replaceOp(op, *maybeExpandedMap);
353 LogicalResult matchAndRewrite(AffineLoadOp op,
357 auto resultOperands =
376 LogicalResult matchAndRewrite(AffinePrefetchOp op,
380 auto resultOperands =
387 op, op.getMemref(), *resultOperands, op.getIsWrite(),
388 op.getLocalityHint(), op.getIsDataCache());
400 LogicalResult matchAndRewrite(AffineStoreOp op,
404 auto maybeExpandedMap =
406 if (!maybeExpandedMap)
411 op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
433 if (!maybeExpandedSrcMap)
439 if (!maybeExpandedDstMap)
445 if (!maybeExpandedTagMap)
450 op, op.
getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
451 *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
452 *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
468 auto maybeExpandedTagMap =
470 if (!maybeExpandedTagMap)
475 op, op.
getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
483 class AffineVectorLoadLowering :
public OpRewritePattern<AffineVectorLoadOp> {
487 LogicalResult matchAndRewrite(AffineVectorLoadOp op,
491 auto resultOperands =
498 op, op.getVectorType(), op.getMemRef(), *resultOperands);
506 class AffineVectorStoreLowering :
public OpRewritePattern<AffineVectorStoreOp> {
510 LogicalResult matchAndRewrite(AffineVectorStoreOp op,
514 auto maybeExpandedMap =
516 if (!maybeExpandedMap)
520 op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
531 AffineDmaStartLowering,
532 AffineDmaWaitLowering,
536 AffineParallelLowering,
537 AffinePrefetchLowering,
541 AffineYieldOpLowering>(
patterns.getContext());
549 AffineVectorLoadLowering,
550 AffineVectorStoreLowering>(
patterns.getContext());
555 class LowerAffine :
public impl::LowerAffinePassBase<LowerAffine> {
556 void runOnOperation()
override {
562 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
563 scf::SCFDialect, VectorDialect>();
static Value lowerAffineMapMax(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands)
Emit instructions that correspond to computing the maximum value among the values of a (potentially) ...
static Value buildMinMaxReductionSeq(Location loc, arith::CmpIPredicate predicate, ValueRange values, OpBuilder &builder)
Given a range of values, emit the code that reduces them with "min" or "max" depending on the provide...
static Value lowerAffineMapMin(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands)
Emit instructions that correspond to computing the minimum value among the values of a (potentially) ...
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
AffineMap getDstMap()
Returns the affine map used to access the destination memref.
unsigned getSrcMemRefOperandIndex()
Returns the operand index of the source memref.
unsigned getTagMemRefOperandIndex()
Returns the operand index of the tag memref.
AffineMap getSrcMap()
Returns the affine map used to access the source memref.
AffineMap getTagMap()
Returns the affine map used to access the tag memref.
unsigned getDstMemRefOperandIndex()
Returns the operand index of the destination memref.
Value getSrcMemRef()
Returns the source MemRefType for this DMA operation.
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
Value getTagMemRef()
Returns the Tag MemRef associated with the DMA operation being waited on.
AffineMap getTagMap()
Returns the affine map used to access the tag memref.
operand_range getTagIndices()
Returns the tag memref index for this DMA operation.
std::optional< SmallVector< Value, 8 > > expandAffineMap(OpBuilder &builder, Location loc, AffineMap affineMap, ValueRange operands)
Create a sequence of operations that implement the affineMap applied to the given operands (as it it ...
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns)
Populate patterns that expand affine index operations into more fundamental operations (not necessari...
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Include the generated interface declarations.
Value lowerAffineUpperBound(affine::AffineForOp op, OpBuilder &builder)
Emit code that computes the upper bound of the given affine loop using standard arithmetic operations...
const FrozenRewritePatternSet & patterns
void populateAffineToVectorConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert vector-related Affine ops to the Vector dialect.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateAffineToStdConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the Affine dialect to the Standard dialect,...
Value lowerAffineLowerBound(affine::AffineForOp op, OpBuilder &builder)
Emit code that computes the lower bound of the given affine loop using standard arithmetic operations...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...