29 #define GEN_PASS_DEF_CONVERTAFFINETOSTANDARD
30 #include "mlir/Conversion/Passes.h.inc"
44 arith::CmpIPredicate predicate,
46 assert(!values.empty() &&
"empty min/max chain");
47 assert(predicate == arith::CmpIPredicate::sgt ||
48 predicate == arith::CmpIPredicate::slt);
50 auto valueIt = values.begin();
51 Value value = *valueIt++;
52 for (; valueIt != values.end(); ++valueIt) {
53 if (predicate == arith::CmpIPredicate::sgt)
54 value = builder.
create<arith::MaxSIOp>(loc, value, *valueIt);
56 value = builder.
create<arith::MinSIOp>(loc, value, *valueIt);
87 op.getUpperBoundOperands());
95 op.getLowerBoundOperands());
103 LogicalResult matchAndRewrite(AffineMinOp op,
119 LogicalResult matchAndRewrite(AffineMaxOp op,
136 LogicalResult matchAndRewrite(AffineYieldOp op,
152 LogicalResult matchAndRewrite(AffineForOp op,
158 rewriter.
create<arith::ConstantIndexOp>(loc, op.getStepAsInt());
159 auto scfForOp = rewriter.
create<scf::ForOp>(loc, lowerBound, upperBound,
160 step, op.getInits());
163 scfForOp.getRegion().
end());
164 rewriter.
replaceOp(op, scfForOp.getResults());
175 LogicalResult matchAndRewrite(AffineParallelOp op,
184 lowerBoundTuple.reserve(op.getNumDims());
185 upperBoundTuple.reserve(op.getNumDims());
186 for (
unsigned i = 0, e = op.getNumDims(); i < e; ++i) {
188 op.getLowerBoundsOperands());
191 lowerBoundTuple.push_back(lower);
194 op.getUpperBoundsOperands());
197 upperBoundTuple.push_back(upper);
199 steps.reserve(op.getSteps().size());
200 for (int64_t step : op.getSteps())
201 steps.push_back(rewriter.
create<arith::ConstantIndexOp>(loc, step));
204 auto affineParOpTerminator =
205 cast<AffineYieldOp>(op.getBody()->getTerminator());
206 scf::ParallelOp parOp;
209 parOp = rewriter.
create<scf::ParallelOp>(loc, lowerBoundTuple,
210 upperBoundTuple, steps,
214 parOp.getRegion().
end());
215 rewriter.
replaceOp(op, parOp.getResults());
228 Type resultType = std::get<1>(pair);
229 std::optional<arith::AtomicRMWKind> reductionOp =
230 arith::symbolizeAtomicRMWKind(
231 static_cast<uint64_t
>(cast<IntegerAttr>(reduction).getInt()));
232 assert(reductionOp &&
"Reduction operation cannot be of None Type");
233 arith::AtomicRMWKind reductionOpValue = *reductionOp;
234 identityVals.push_back(
237 parOp = rewriter.
create<scf::ParallelOp>(
238 loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
244 parOp.getRegion().
end());
245 assert(reductions.size() == affineParOpTerminator->getNumOperands() &&
246 "Unequal number of reductions and operands.");
251 affineParOpTerminator, affineParOpTerminator->getOperands());
252 for (
unsigned i = 0, end = reductions.size(); i < end; i++) {
254 std::optional<arith::AtomicRMWKind> reductionOp =
255 arith::symbolizeAtomicRMWKind(
256 cast<IntegerAttr>(reductions[i]).getInt());
257 assert(reductionOp &&
"Reduction Operation cannot be of None Type");
258 arith::AtomicRMWKind reductionOpValue = *reductionOp;
260 Block &reductionBody = reduceOp.getReductions()[i].
front();
263 reductionOpValue, rewriter, loc, reductionBody.
getArgument(0),
265 rewriter.
create<scf::ReduceReturnOp>(loc, reductionResult);
267 rewriter.
replaceOp(op, parOp.getResults());
276 LogicalResult matchAndRewrite(AffineIfOp op,
281 auto integerSet = op.getIntegerSet();
282 Value zeroConstant = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
287 Value cond =
nullptr;
288 for (
unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) {
289 AffineExpr constraintExpr = integerSet.getConstraint(i);
290 bool isEquality = integerSet.isEq(i);
293 auto numDims = integerSet.getNumDims();
295 operandsRef.take_front(numDims),
296 operandsRef.drop_front(numDims));
300 isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge;
302 rewriter.
create<arith::CmpIOp>(loc, pred, affResult, zeroConstant);
304 ? rewriter.
create<arith::AndIOp>(loc, cond, cmpVal).getResult()
308 : rewriter.
create<arith::ConstantIntOp>(loc, 1,
311 bool hasElseRegion = !op.getElseRegion().empty();
315 &ifOp.getThenRegion().back());
316 rewriter.
eraseBlock(&ifOp.getThenRegion().back());
319 &ifOp.getElseRegion().back());
320 rewriter.
eraseBlock(&ifOp.getElseRegion().back());
324 rewriter.
replaceOp(op, ifOp.getResults());
335 LogicalResult matchAndRewrite(AffineApplyOp op,
337 auto maybeExpandedMap =
340 if (!maybeExpandedMap)
342 rewriter.
replaceOp(op, *maybeExpandedMap);
354 LogicalResult matchAndRewrite(AffineLoadOp op,
358 auto resultOperands =
377 LogicalResult matchAndRewrite(AffinePrefetchOp op,
381 auto resultOperands =
388 op, op.getMemref(), *resultOperands, op.getIsWrite(),
389 op.getLocalityHint(), op.getIsDataCache());
401 LogicalResult matchAndRewrite(AffineStoreOp op,
405 auto maybeExpandedMap =
407 if (!maybeExpandedMap)
412 op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
432 rewriter, op.
getLoc(), op.getSrcMap(),
433 operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1));
434 if (!maybeExpandedSrcMap)
438 rewriter, op.
getLoc(), op.getDstMap(),
439 operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1));
440 if (!maybeExpandedDstMap)
444 rewriter, op.
getLoc(), op.getTagMap(),
445 operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1));
446 if (!maybeExpandedTagMap)
451 op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(),
452 *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(),
453 *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride());
469 auto maybeExpandedTagMap =
471 if (!maybeExpandedTagMap)
476 op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements());
484 class AffineVectorLoadLowering :
public OpRewritePattern<AffineVectorLoadOp> {
488 LogicalResult matchAndRewrite(AffineVectorLoadOp op,
492 auto resultOperands =
499 op, op.getVectorType(), op.getMemRef(), *resultOperands);
507 class AffineVectorStoreLowering :
public OpRewritePattern<AffineVectorStoreOp> {
511 LogicalResult matchAndRewrite(AffineVectorStoreOp op,
515 auto maybeExpandedMap =
517 if (!maybeExpandedMap)
521 op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap);
532 AffineDmaStartLowering,
533 AffineDmaWaitLowering,
537 AffineParallelLowering,
538 AffinePrefetchLowering,
542 AffineYieldOpLowering>(patterns.
getContext());
550 AffineVectorLoadLowering,
551 AffineVectorStoreLowering>(patterns.
getContext());
556 class LowerAffinePass
557 :
public impl::ConvertAffineToStandardBase<LowerAffinePass> {
558 void runOnOperation()
override {
564 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
565 scf::SCFDialect, VectorDialect>();
567 std::move(patterns))))
576 return std::make_unique<LowerAffinePass>();
static Value lowerAffineMapMax(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands)
Emit instructions that correspond to computing the maximum value among the values of a (potentially) ...
static Value buildMinMaxReductionSeq(Location loc, arith::CmpIPredicate predicate, ValueRange values, OpBuilder &builder)
Given a range of values, emit the code that reduces them with "min" or "max" depending on the provide...
static Value lowerAffineMapMin(OpBuilder &builder, Location loc, AffineMap map, ValueRange operands)
Emit instructions that correspond to computing the minimum value among the values of a (potentially) ...
static MLIRContext * getContext(OpFoldResult val)
Base type for affine expression.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
AffineDmaStartOp starts a non-blocking DMA operation that transfers data from a source memref to a de...
AffineDmaWaitOp blocks until the completion of a DMA operation associated with the tag element 'tag[i...
std::optional< SmallVector< Value, 8 > > expandAffineMap(OpBuilder &builder, Location loc, AffineMap affineMap, ValueRange operands)
Create a sequence of operations that implement the affineMap applied to the given operands (as it it ...
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, ValueRange symbolValues)
Emit code that computes the given affine expression using standard arithmetic operations applied to t...
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns)
Populate patterns that expand affine index operations into more fundamental operations (not necessari...
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Include the generated interface declarations.
Value lowerAffineUpperBound(affine::AffineForOp op, OpBuilder &builder)
Emit code that computes the upper bound of the given affine loop using standard arithmetic operations...
void populateAffineToVectorConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert vector-related Affine ops to the Vector dialect.
std::unique_ptr< Pass > createLowerAffinePass()
Lowers affine control flow operations (ForStmt, IfStmt and AffineApplyOp) to equivalent lower-level c...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void populateAffineToStdConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the Affine dialect to the Standard dialect,...
Value lowerAffineLowerBound(affine::AffineForOp op, OpBuilder &builder)
Emit code that computes the lower bound of the given affine loop using standard arithmetic operations...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...