26#define GEN_PASS_DEF_AFFINEEXPANDINDEXOPSASAFFINE
27#include "mlir/Dialect/Affine/Transforms/Passes.h.inc"
38struct LowerDelinearizeIndexOps
40 using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
41 LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
42 PatternRewriter &rewriter)
const override {
43 Location loc = op.getLoc();
44 Value linearIndex = op.getLinearIndex();
45 auto vecTy = dyn_cast<VectorType>(linearIndex.
getType());
49 FailureOr<SmallVector<Value>> multiIndex =
50 delinearizeIndex(rewriter, loc, linearIndex, op.getEffectiveBasis(),
60 if (vecTy.isScalable())
63 unsigned numResults = op.getNumResults();
64 ArrayRef<int64_t> shape = vecTy.getShape();
65 SmallVector<int64_t> tileShape(shape.size(), 1);
67 SmallVector<Value> resultVecs(numResults);
68 Value poison = ub::PoisonOp::create(rewriter, loc, vecTy);
69 for (
unsigned r = 0; r < numResults; ++r)
70 resultVecs[r] = poison;
72 for (SmallVector<int64_t> pos : StaticTileOffsetRange(shape, tileShape)) {
73 Value scalar = vector::ExtractOp::create(rewriter, loc, linearIndex, pos);
75 FailureOr<SmallVector<Value>> scalarResults =
76 delinearizeIndex(rewriter, loc, scalar, op.getEffectiveBasis(),
81 for (
unsigned r = 0; r < numResults; ++r)
82 resultVecs[r] = vector::InsertOp::create(
83 rewriter, loc, (*scalarResults)[r], resultVecs[r], pos);
94struct LowerLinearizeIndexOps final :
OpRewritePattern<AffineLinearizeIndexOp> {
96 LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
97 PatternRewriter &rewriter)
const override {
98 Location loc = op.getLoc();
99 auto vecTy = dyn_cast<VectorType>(op.getLinearIndex().getType());
104 if (op.getMultiIndex().empty()) {
109 SmallVector<OpFoldResult> multiIndex =
111 OpFoldResult linearIndex =
113 Value linearIndexValue =
115 rewriter.
replaceOp(op, linearIndexValue);
121 if (vecTy.isScalable())
124 ArrayRef<int64_t> shape = vecTy.getShape();
125 SmallVector<int64_t> tileShape(shape.size(), 1);
128 Value
result = ub::PoisonOp::create(rewriter, loc, vecTy);
130 for (SmallVector<int64_t> pos : StaticTileOffsetRange(shape, tileShape)) {
131 SmallVector<OpFoldResult> scalarIndices;
132 for (Value vec : multiIndex)
133 scalarIndices.push_back(
134 vector::ExtractOp::create(rewriter, loc, vec, pos).getResult());
136 OpFoldResult linearIndex =
142 vector::InsertOp::create(rewriter, loc, scalarResult,
result, pos);
150class ExpandAffineIndexOpsAsAffinePass
152 ExpandAffineIndexOpsAsAffinePass> {
154 ExpandAffineIndexOpsAsAffinePass() =
default;
156 void runOnOperation()
override {
158 RewritePatternSet patterns(context);
161 return signalPassFailure();
169 patterns.
insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
174 return std::make_unique<ExpandAffineIndexOpsAsAffinePass>();
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type getType() const
Return the type of this value.
std::unique_ptr< Pass > createAffineExpandIndexOpsAsAffinePass()
Creates a pass to expand affine index operations into affine.apply operations.
void populateAffineExpandIndexOpsAsAffinePatterns(RewritePatternSet &patterns)
Populate patterns that expand affine index operations into their equivalent affine....
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...