MLIR 23.0.0git
AffineExpandIndexOps.cpp
Go to the documentation of this file.
1//===- AffineExpandIndexOps.cpp - Affine expand index ops pass ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to expand affine index ops into one or more more
10// fundamental operations.
11//===----------------------------------------------------------------------===//
12
15
20
21namespace mlir {
22namespace affine {
23#define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS
24#include "mlir/Dialect/Affine/Transforms/Passes.h.inc"
25} // namespace affine
26} // namespace mlir
27
28using namespace mlir;
29using namespace mlir::affine;
30
31/// Given a basis (in static and dynamic components), return the sequence of
32/// suffix products of the basis, including the product of the entire basis,
33/// which must **not** contain an outer bound.
34///
35/// If excess dynamic values are provided, the values at the beginning
36/// will be ignored. This allows for dropping the outer bound without
37/// needing to manipulate the dynamic value array. `knownPositive`
38/// indicases that the values being used to compute the strides are known
39/// to be non-negative.
41 ValueRange dynamicBasis,
42 ArrayRef<int64_t> staticBasis,
43 bool knownNonNegative) {
44 if (staticBasis.empty())
45 return {};
46
48 result.reserve(staticBasis.size());
49 size_t dynamicIndex = dynamicBasis.size();
50 Value dynamicPart = nullptr;
51 int64_t staticPart = 1;
52 // The products of the strides can't have overflow by definition of
53 // affine.*_index.
54 arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw;
55 if (knownNonNegative)
56 ovflags = ovflags | arith::IntegerOverflowFlags::nuw;
57 for (int64_t elem : llvm::reverse(staticBasis)) {
58 if (ShapedType::isDynamic(elem)) {
59 // Note: basis elements and their products are, definitionally,
60 // non-negative, so `nuw` is justified.
61 if (dynamicPart)
62 dynamicPart =
63 arith::MulIOp::create(rewriter, loc, dynamicPart,
64 dynamicBasis[dynamicIndex - 1], ovflags);
65 else
66 dynamicPart = dynamicBasis[dynamicIndex - 1];
67 --dynamicIndex;
68 } else {
69 staticPart *= elem;
70 }
71
72 if (dynamicPart && staticPart == 1) {
73 result.push_back(dynamicPart);
74 } else {
75 Value stride =
76 rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
77 if (dynamicPart)
78 stride =
79 arith::MulIOp::create(rewriter, loc, dynamicPart, stride, ovflags);
80 result.push_back(stride);
81 }
82 }
83 std::reverse(result.begin(), result.end());
84 return result;
85}
86
87/// Broadcast a scalar value to match the given type. If the type is already
88/// scalar, returns the value as-is. For vector types, uses vector.broadcast.
90 Value value, Type targetType) {
91 if (value.getType() == targetType)
92 return value;
93 return vector::BroadcastOp::create(rewriter, loc, targetType, value);
94}
95
96LogicalResult
98 AffineDelinearizeIndexOp op) {
99 Location loc = op.getLoc();
100 Value linearIdx = op.getLinearIndex();
101 unsigned numResults = op.getNumResults();
102 ArrayRef<int64_t> staticBasis = op.getStaticBasis();
103 if (numResults == staticBasis.size())
104 staticBasis = staticBasis.drop_front();
105
106 if (numResults == 1) {
107 rewriter.replaceOp(op, linearIdx);
108 return success();
109 }
110
111 SmallVector<Value> results;
112 results.reserve(numResults);
113 SmallVector<Value> strides =
114 computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
115 /*knownNonNegative=*/true);
116
117 // Broadcast strides and zero to match the linear index type (needed for
118 // vector types where the strides are scalar but the index is a vector).
119 Type indexType = linearIdx.getType();
120 for (Value &stride : strides)
121 stride = broadcastToMatchType(rewriter, loc, stride, indexType);
122
123 Value zero =
124 arith::ConstantOp::create(rewriter, loc, rewriter.getZeroAttr(indexType));
125
126 Value initialPart =
127 arith::FloorDivSIOp::create(rewriter, loc, linearIdx, strides.front());
128 results.push_back(initialPart);
129
130 auto emitModTerm = [&](Value stride) -> Value {
131 Value remainder = arith::RemSIOp::create(rewriter, loc, linearIdx, stride);
132 Value remainderNegative = arith::CmpIOp::create(
133 rewriter, loc, arith::CmpIPredicate::slt, remainder, zero);
134 // If the correction is relevant, this term is <= stride, which is known
135 // to be positive in `index`. Otherwise, while 2 * stride might overflow,
136 // this branch won't be taken, so the risk of `poison` is fine.
137 Value corrected = arith::AddIOp::create(rewriter, loc, remainder, stride,
138 arith::IntegerOverflowFlags::nsw);
139 Value mod = arith::SelectOp::create(rewriter, loc, remainderNegative,
140 corrected, remainder);
141 return mod;
142 };
143
144 // Generate all the intermediate parts
145 for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
146 Value thisStride = strides[i];
147 Value nextStride = strides[i + 1];
148 Value modulus = emitModTerm(thisStride);
149 // We know both inputs are positive, so floorDiv == div.
150 // This could potentially be a divui, but it's not clear if that would
151 // cause issues.
152 Value divided = arith::DivSIOp::create(rewriter, loc, modulus, nextStride);
153 results.push_back(divided);
154 }
155
156 results.push_back(emitModTerm(strides.back()));
157
158 rewriter.replaceOp(op, results);
159 return success();
160}
161
163 AffineLinearizeIndexOp op) {
164 // Should be folded away, included here for safety.
165 if (op.getMultiIndex().empty()) {
166 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
167 op, rewriter.getZeroAttr(op.getLinearIndex().getType()));
168 return success();
169 }
170
171 Location loc = op.getLoc();
172 ValueRange multiIndex = op.getMultiIndex();
173 Type indexType = op.getLinearIndex().getType();
174 size_t numIndexes = multiIndex.size();
175 ArrayRef<int64_t> staticBasis = op.getStaticBasis();
176 if (numIndexes == staticBasis.size())
177 staticBasis = staticBasis.drop_front();
178
179 SmallVector<Value> strides =
180 computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
181 /*knownNonNegative=*/op.getDisjoint());
182
183 // Broadcast strides to match the index type (needed for vector types).
184 for (Value &stride : strides)
185 stride = broadcastToMatchType(rewriter, loc, stride, indexType);
186
188 scaledValues.reserve(numIndexes);
189
190 // Note: strides doesn't contain a value for the final element (stride 1)
191 // and everything else lines up. We use the "mutable" accessor so we can get
192 // our hands on an `OpOperand&` for the loop invariant counting function.
193 for (auto [stride, idxOp] :
194 llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
195 Value scaledIdx = arith::MulIOp::create(rewriter, loc, idxOp.get(), stride,
196 arith::IntegerOverflowFlags::nsw);
197 int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
198 scaledValues.emplace_back(scaledIdx, numHoistableLoops);
199 }
200 scaledValues.emplace_back(
201 multiIndex.back(),
202 numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
203
204 // Sort by how many enclosing loops there are, ties implicitly broken by
205 // size of the stride.
206 llvm::stable_sort(scaledValues,
207 [&](auto l, auto r) { return l.second > r.second; });
208
209 Value result = scaledValues.front().first;
210 for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) {
211 std::ignore = numHoistableLoops;
212 result = arith::AddIOp::create(rewriter, loc, result, scaledValue,
213 arith::IntegerOverflowFlags::nsw);
214 }
215 rewriter.replaceOp(op, result);
216 return success();
217}
218
219namespace {
220struct LowerDelinearizeIndexOps
221 : public OpRewritePattern<AffineDelinearizeIndexOp> {
222 using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
223 LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
224 PatternRewriter &rewriter) const override {
225 return affine::lowerAffineDelinearizeIndexOp(rewriter, op);
226 }
227};
228
229struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
231 LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
232 PatternRewriter &rewriter) const override {
233 return affine::lowerAffineLinearizeIndexOp(rewriter, op);
234 }
235};
236
237class ExpandAffineIndexOpsPass
238 : public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
239public:
240 ExpandAffineIndexOpsPass() = default;
241
242 void runOnOperation() override {
243 MLIRContext *context = &getContext();
244 RewritePatternSet patterns(context);
246 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
247 return signalPassFailure();
248 }
249};
250
251} // namespace
252
254 RewritePatternSet &patterns) {
255 patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
256 patterns.getContext());
257}
258
260 return std::make_unique<ExpandAffineIndexOpsPass>();
261}
return success()
static Value broadcastToMatchType(RewriterBase &rewriter, Location loc, Value value, Type targetType)
Broadcast a scalar value to match the given type.
b getContext())
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
MLIRContext * getContext() const
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
LogicalResult lowerAffineDelinearizeIndexOp(RewriterBase &rewriter, AffineDelinearizeIndexOp op)
Lowers affine.delinearize_index into a sequence of division and remainder operations.
LogicalResult lowerAffineLinearizeIndexOp(RewriterBase &rewriter, AffineLinearizeIndexOp op)
Lowers affine.linearize_index into a sequence of multiplications and additions.
std::unique_ptr< Pass > createAffineExpandIndexOpsPass()
Creates a pass to expand affine index operations into more fundamental operations (not necessarily re...
int64_t numEnclosingInvariantLoops(OpOperand &operand)
Performs explicit copying for the contiguous sequence of operations in the block iterator range [‘beg...
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns)
Populate patterns that expand affine index operations into more fundamental operations (not necessari...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...