MLIR 22.0.0git
AffineExpandIndexOps.cpp
Go to the documentation of this file.
1//===- AffineExpandIndexOps.cpp - Affine expand index ops pass ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to expand affine index ops into one or more more
10// fundamental operations.
11//===----------------------------------------------------------------------===//
12
15
19
20namespace mlir {
21namespace affine {
22#define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS
23#include "mlir/Dialect/Affine/Passes.h.inc"
24} // namespace affine
25} // namespace mlir
26
27using namespace mlir;
28using namespace mlir::affine;
29
30/// Given a basis (in static and dynamic components), return the sequence of
31/// suffix products of the basis, including the product of the entire basis,
32/// which must **not** contain an outer bound.
33///
34/// If excess dynamic values are provided, the values at the beginning
35/// will be ignored. This allows for dropping the outer bound without
36/// needing to manipulate the dynamic value array. `knownPositive`
37/// indicases that the values being used to compute the strides are known
38/// to be non-negative.
40 ValueRange dynamicBasis,
41 ArrayRef<int64_t> staticBasis,
42 bool knownNonNegative) {
43 if (staticBasis.empty())
44 return {};
45
47 result.reserve(staticBasis.size());
48 size_t dynamicIndex = dynamicBasis.size();
49 Value dynamicPart = nullptr;
50 int64_t staticPart = 1;
51 // The products of the strides can't have overflow by definition of
52 // affine.*_index.
53 arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw;
54 if (knownNonNegative)
55 ovflags = ovflags | arith::IntegerOverflowFlags::nuw;
56 for (int64_t elem : llvm::reverse(staticBasis)) {
57 if (ShapedType::isDynamic(elem)) {
58 // Note: basis elements and their products are, definitionally,
59 // non-negative, so `nuw` is justified.
60 if (dynamicPart)
61 dynamicPart =
62 arith::MulIOp::create(rewriter, loc, dynamicPart,
63 dynamicBasis[dynamicIndex - 1], ovflags);
64 else
65 dynamicPart = dynamicBasis[dynamicIndex - 1];
66 --dynamicIndex;
67 } else {
68 staticPart *= elem;
69 }
70
71 if (dynamicPart && staticPart == 1) {
72 result.push_back(dynamicPart);
73 } else {
74 Value stride =
75 rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
76 if (dynamicPart)
77 stride =
78 arith::MulIOp::create(rewriter, loc, dynamicPart, stride, ovflags);
79 result.push_back(stride);
80 }
81 }
82 std::reverse(result.begin(), result.end());
83 return result;
84}
85
86LogicalResult
88 AffineDelinearizeIndexOp op) {
89 Location loc = op.getLoc();
90 Value linearIdx = op.getLinearIndex();
91 unsigned numResults = op.getNumResults();
92 ArrayRef<int64_t> staticBasis = op.getStaticBasis();
93 if (numResults == staticBasis.size())
94 staticBasis = staticBasis.drop_front();
95
96 if (numResults == 1) {
97 rewriter.replaceOp(op, linearIdx);
98 return success();
99 }
100
101 SmallVector<Value> results;
102 results.reserve(numResults);
103 SmallVector<Value> strides =
104 computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
105 /*knownNonNegative=*/true);
106
107 Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
108
109 Value initialPart =
110 arith::FloorDivSIOp::create(rewriter, loc, linearIdx, strides.front());
111 results.push_back(initialPart);
112
113 auto emitModTerm = [&](Value stride) -> Value {
114 Value remainder = arith::RemSIOp::create(rewriter, loc, linearIdx, stride);
115 Value remainderNegative = arith::CmpIOp::create(
116 rewriter, loc, arith::CmpIPredicate::slt, remainder, zero);
117 // If the correction is relevant, this term is <= stride, which is known
118 // to be positive in `index`. Otherwise, while 2 * stride might overflow,
119 // this branch won't be taken, so the risk of `poison` is fine.
120 Value corrected = arith::AddIOp::create(rewriter, loc, remainder, stride,
121 arith::IntegerOverflowFlags::nsw);
122 Value mod = arith::SelectOp::create(rewriter, loc, remainderNegative,
123 corrected, remainder);
124 return mod;
125 };
126
127 // Generate all the intermediate parts
128 for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
129 Value thisStride = strides[i];
130 Value nextStride = strides[i + 1];
131 Value modulus = emitModTerm(thisStride);
132 // We know both inputs are positive, so floorDiv == div.
133 // This could potentially be a divui, but it's not clear if that would
134 // cause issues.
135 Value divided = arith::DivSIOp::create(rewriter, loc, modulus, nextStride);
136 results.push_back(divided);
137 }
139 results.push_back(emitModTerm(strides.back()));
141 rewriter.replaceOp(op, results);
142 return success();
146 AffineLinearizeIndexOp op) {
147 // Should be folded away, included here for safety.
148 if (op.getMultiIndex().empty()) {
150 return success();
151 }
152
153 Location loc = op.getLoc();
154 ValueRange multiIndex = op.getMultiIndex();
155 size_t numIndexes = multiIndex.size();
156 ArrayRef<int64_t> staticBasis = op.getStaticBasis();
157 if (numIndexes == staticBasis.size())
158 staticBasis = staticBasis.drop_front();
160 SmallVector<Value> strides =
161 computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
162 /*knownNonNegative=*/op.getDisjoint());
164 scaledValues.reserve(numIndexes);
165
166 // Note: strides doesn't contain a value for the final element (stride 1)
167 // and everything else lines up. We use the "mutable" accessor so we can get
168 // our hands on an `OpOperand&` for the loop invariant counting function.
169 for (auto [stride, idxOp] :
170 llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
171 Value scaledIdx = arith::MulIOp::create(rewriter, loc, idxOp.get(), stride,
172 arith::IntegerOverflowFlags::nsw);
173 int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
174 scaledValues.emplace_back(scaledIdx, numHoistableLoops);
175 }
176 scaledValues.emplace_back(
177 multiIndex.back(),
178 numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
179
180 // Sort by how many enclosing loops there are, ties implicitly broken by
181 // size of the stride.
182 llvm::stable_sort(scaledValues,
183 [&](auto l, auto r) { return l.second > r.second; });
184
185 Value result = scaledValues.front().first;
186 for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) {
187 std::ignore = numHoistableLoops;
188 result = arith::AddIOp::create(rewriter, loc, result, scaledValue,
189 arith::IntegerOverflowFlags::nsw);
190 }
191 rewriter.replaceOp(op, result);
192 return success();
193}
194
195namespace {
196struct LowerDelinearizeIndexOps
197 : public OpRewritePattern<AffineDelinearizeIndexOp> {
198 using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
199 LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
200 PatternRewriter &rewriter) const override {
201 return affine::lowerAffineDelinearizeIndexOp(rewriter, op);
202 }
203};
204
205struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
207 LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
208 PatternRewriter &rewriter) const override {
209 return affine::lowerAffineLinearizeIndexOp(rewriter, op);
210 }
211};
212
213class ExpandAffineIndexOpsPass
214 : public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
215public:
216 ExpandAffineIndexOpsPass() = default;
217
218 void runOnOperation() override {
219 MLIRContext *context = &getContext();
220 RewritePatternSet patterns(context);
222 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
223 return signalPassFailure();
224 }
225};
226
227} // namespace
228
231 patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
232 patterns.getContext());
233}
234
236 return std::make_unique<ExpandAffineIndexOpsPass>();
237}
return success()
b getContext())
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
LogicalResult lowerAffineDelinearizeIndexOp(RewriterBase &rewriter, AffineDelinearizeIndexOp op)
Lowers affine.delinearize_index into a sequence of division and remainder operations.
LogicalResult lowerAffineLinearizeIndexOp(RewriterBase &rewriter, AffineLinearizeIndexOp op)
Lowers affine.linearize_index into a sequence of multiplications and additions.
std::unique_ptr< Pass > createAffineExpandIndexOpsPass()
Creates a pass to expand affine index operations into more fundamental operations (not necessarily re...
int64_t numEnclosingInvariantLoops(OpOperand &operand)
Performs explicit copying for the contiguous sequence of operations in the block iterator range [‘beg...
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns)
Populate patterns that expand affine index operations into more fundamental operations (not necessari...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...