MLIR 23.0.0git
AffineExpandIndexOpsAsAffine.cpp
Go to the documentation of this file.
1//===- AffineExpandIndexOpsAsAffine.cpp - Expand index ops to apply pass --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to expand affine index ops into one or more more
10// fundamental operations.
11//===----------------------------------------------------------------------===//
12
14
23
24namespace mlir {
25namespace affine {
26#define GEN_PASS_DEF_AFFINEEXPANDINDEXOPSASAFFINE
27#include "mlir/Dialect/Affine/Transforms/Passes.h.inc"
28} // namespace affine
29} // namespace mlir
30
31using namespace mlir;
32using namespace mlir::affine;
33
34namespace {
35/// Lowers `affine.delinearize_index` into a sequence of division and remainder
36/// operations via affine.apply. For vector types, unrolls to per-element
37/// scalar affine.apply operations.
38struct LowerDelinearizeIndexOps
39 : public OpRewritePattern<AffineDelinearizeIndexOp> {
40 using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
41 LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
42 PatternRewriter &rewriter) const override {
43 Location loc = op.getLoc();
44 Value linearIndex = op.getLinearIndex();
45 auto vecTy = dyn_cast<VectorType>(linearIndex.getType());
46
47 // Scalar case: use the existing affine lowering path.
48 if (!vecTy) {
49 FailureOr<SmallVector<Value>> multiIndex =
50 delinearizeIndex(rewriter, loc, linearIndex, op.getEffectiveBasis(),
51 /*hasOuterBound=*/false);
52 if (failed(multiIndex))
53 return failure();
54 rewriter.replaceOp(op, *multiIndex);
55 return success();
56 }
57
58 // Vector case: unroll to per-element scalar affine.apply operations
59 // using StaticTileOffsetRange for multi-dimensional vector support.
60 if (vecTy.isScalable())
61 return rewriter.notifyMatchFailure(op, "scalable vectors not supported");
62
63 unsigned numResults = op.getNumResults();
64 ArrayRef<int64_t> shape = vecTy.getShape();
65 SmallVector<int64_t> tileShape(shape.size(), 1);
66
67 SmallVector<Value> resultVecs(numResults);
68 Value poison = ub::PoisonOp::create(rewriter, loc, vecTy);
69 for (unsigned r = 0; r < numResults; ++r)
70 resultVecs[r] = poison;
71
72 for (SmallVector<int64_t> pos : StaticTileOffsetRange(shape, tileShape)) {
73 Value scalar = vector::ExtractOp::create(rewriter, loc, linearIndex, pos);
74
75 FailureOr<SmallVector<Value>> scalarResults =
76 delinearizeIndex(rewriter, loc, scalar, op.getEffectiveBasis(),
77 /*hasOuterBound=*/false);
78 if (failed(scalarResults))
79 return failure();
80
81 for (unsigned r = 0; r < numResults; ++r)
82 resultVecs[r] = vector::InsertOp::create(
83 rewriter, loc, (*scalarResults)[r], resultVecs[r], pos);
84 }
85
86 rewriter.replaceOp(op, resultVecs);
87 return success();
88 }
89};
90
91/// Lowers `affine.linearize_index` into a sequence of multiplications and
92/// additions via affine.apply. For vector types, unrolls to per-element
93/// scalar affine.apply operations.
94struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
96 LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
97 PatternRewriter &rewriter) const override {
98 Location loc = op.getLoc();
99 auto vecTy = dyn_cast<VectorType>(op.getLinearIndex().getType());
100
101 // Scalar case: use the existing affine lowering path.
102 if (!vecTy) {
103 // Should be folded away, included here for safety.
104 if (op.getMultiIndex().empty()) {
105 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
106 return success();
107 }
108
109 SmallVector<OpFoldResult> multiIndex =
110 getAsOpFoldResult(op.getMultiIndex());
111 OpFoldResult linearIndex =
112 linearizeIndex(rewriter, loc, multiIndex, op.getMixedBasis());
113 Value linearIndexValue =
114 getValueOrCreateConstantIntOp(rewriter, loc, linearIndex);
115 rewriter.replaceOp(op, linearIndexValue);
116 return success();
117 }
118
119 // Vector case: unroll to per-element scalar affine.apply operations
120 // using StaticTileOffsetRange for multi-dimensional vector support.
121 if (vecTy.isScalable())
122 return rewriter.notifyMatchFailure(op, "scalable vectors not supported");
123
124 ArrayRef<int64_t> shape = vecTy.getShape();
125 SmallVector<int64_t> tileShape(shape.size(), 1);
126 ValueRange multiIndex = op.getMultiIndex();
127
128 Value result = ub::PoisonOp::create(rewriter, loc, vecTy);
129
130 for (SmallVector<int64_t> pos : StaticTileOffsetRange(shape, tileShape)) {
131 SmallVector<OpFoldResult> scalarIndices;
132 for (Value vec : multiIndex)
133 scalarIndices.push_back(
134 vector::ExtractOp::create(rewriter, loc, vec, pos).getResult());
135
136 OpFoldResult linearIndex =
137 linearizeIndex(rewriter, loc, scalarIndices, op.getMixedBasis());
138 Value scalarResult =
139 getValueOrCreateConstantIntOp(rewriter, loc, linearIndex);
140
141 result =
142 vector::InsertOp::create(rewriter, loc, scalarResult, result, pos);
143 }
144
145 rewriter.replaceOp(op, result);
146 return success();
147 }
148};
149
150class ExpandAffineIndexOpsAsAffinePass
152 ExpandAffineIndexOpsAsAffinePass> {
153public:
154 ExpandAffineIndexOpsAsAffinePass() = default;
155
156 void runOnOperation() override {
157 MLIRContext *context = &getContext();
158 RewritePatternSet patterns(context);
160 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
161 return signalPassFailure();
162 }
163};
164
165} // namespace
166
168 RewritePatternSet &patterns) {
169 patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
170 patterns.getContext());
171}
172
174 return std::make_unique<ExpandAffineIndexOpsAsAffinePass>();
175}
return success()
b getContext())
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Type getType() const
Return the type of this value.
Definition Value.h:105
std::unique_ptr< Pass > createAffineExpandIndexOpsAsAffinePass()
Creates a pass to expand affine index operations into affine.apply operations.
void populateAffineExpandIndexOpsAsAffinePatterns(RewritePatternSet &patterns)
Populate patterns that expand affine index operations into their equivalent affine....
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:105
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...