MLIR  21.0.0git
AffineExpandIndexOps.cpp
Go to the documentation of this file.
1 //===- AffineExpandIndexOps.cpp - Affine expand index ops pass ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to expand affine index ops into one or more more
10 // fundamental operations.
11 //===----------------------------------------------------------------------===//
12 
15 
21 
22 namespace mlir {
23 namespace affine {
24 #define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS
25 #include "mlir/Dialect/Affine/Passes.h.inc"
26 } // namespace affine
27 } // namespace mlir
28 
29 using namespace mlir;
30 using namespace mlir::affine;
31 
32 /// Given a basis (in static and dynamic components), return the sequence of
33 /// suffix products of the basis, including the product of the entire basis,
34 /// which must **not** contain an outer bound.
35 ///
36 /// If excess dynamic values are provided, the values at the beginning
37 /// will be ignored. This allows for dropping the outer bound without
38 /// needing to manipulate the dynamic value array. `knownPositive`
39 /// indicases that the values being used to compute the strides are known
40 /// to be non-negative.
42  ValueRange dynamicBasis,
43  ArrayRef<int64_t> staticBasis,
44  bool knownNonNegative) {
45  if (staticBasis.empty())
46  return {};
47 
48  SmallVector<Value> result;
49  result.reserve(staticBasis.size());
50  size_t dynamicIndex = dynamicBasis.size();
51  Value dynamicPart = nullptr;
52  int64_t staticPart = 1;
53  // The products of the strides can't have overflow by definition of
54  // affine.*_index.
55  arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw;
56  if (knownNonNegative)
57  ovflags = ovflags | arith::IntegerOverflowFlags::nuw;
58  for (int64_t elem : llvm::reverse(staticBasis)) {
59  if (ShapedType::isDynamic(elem)) {
60  // Note: basis elements and their products are, definitionally,
61  // non-negative, so `nuw` is justified.
62  if (dynamicPart)
63  dynamicPart = rewriter.create<arith::MulIOp>(
64  loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags);
65  else
66  dynamicPart = dynamicBasis[dynamicIndex - 1];
67  --dynamicIndex;
68  } else {
69  staticPart *= elem;
70  }
71 
72  if (dynamicPart && staticPart == 1) {
73  result.push_back(dynamicPart);
74  } else {
75  Value stride =
76  rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
77  if (dynamicPart)
78  stride =
79  rewriter.create<arith::MulIOp>(loc, dynamicPart, stride, ovflags);
80  result.push_back(stride);
81  }
82  }
83  std::reverse(result.begin(), result.end());
84  return result;
85 }
86 
87 namespace {
88 /// Lowers `affine.delinearize_index` into a sequence of division and remainder
89 /// operations.
90 struct LowerDelinearizeIndexOps
91  : public OpRewritePattern<AffineDelinearizeIndexOp> {
93  LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
94  PatternRewriter &rewriter) const override {
95  Location loc = op.getLoc();
96  Value linearIdx = op.getLinearIndex();
97  unsigned numResults = op.getNumResults();
98  ArrayRef<int64_t> staticBasis = op.getStaticBasis();
99  if (numResults == staticBasis.size())
100  staticBasis = staticBasis.drop_front();
101 
102  if (numResults == 1) {
103  rewriter.replaceOp(op, linearIdx);
104  return success();
105  }
106 
107  SmallVector<Value> results;
108  results.reserve(numResults);
109  SmallVector<Value> strides =
110  computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
111  /*knownNonNegative=*/true);
112 
113  Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
114 
115  Value initialPart =
116  rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
117  results.push_back(initialPart);
118 
119  auto emitModTerm = [&](Value stride) -> Value {
120  Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
121  Value remainderNegative = rewriter.create<arith::CmpIOp>(
122  loc, arith::CmpIPredicate::slt, remainder, zero);
123  // If the correction is relevant, this term is <= stride, which is known
124  // to be positive in `index`. Otherwise, while 2 * stride might overflow,
125  // this branch won't be taken, so the risk of `poison` is fine.
126  Value corrected = rewriter.create<arith::AddIOp>(
127  loc, remainder, stride, arith::IntegerOverflowFlags::nsw);
128  Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
129  corrected, remainder);
130  return mod;
131  };
132 
133  // Generate all the intermediate parts
134  for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
135  Value thisStride = strides[i];
136  Value nextStride = strides[i + 1];
137  Value modulus = emitModTerm(thisStride);
138  // We know both inputs are positive, so floorDiv == div.
139  // This could potentially be a divui, but it's not clear if that would
140  // cause issues.
141  Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
142  results.push_back(divided);
143  }
144 
145  results.push_back(emitModTerm(strides.back()));
146 
147  rewriter.replaceOp(op, results);
148  return success();
149  }
150 };
151 
152 /// Lowers `affine.linearize_index` into a sequence of multiplications and
153 /// additions. Make a best effort to sort the input indices so that
154 /// the most loop-invariant terms are at the left of the additions
155 /// to enable loop-invariant code motion.
156 struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
158  LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
159  PatternRewriter &rewriter) const override {
160  // Should be folded away, included here for safety.
161  if (op.getMultiIndex().empty()) {
162  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
163  return success();
164  }
165 
166  Location loc = op.getLoc();
167  ValueRange multiIndex = op.getMultiIndex();
168  size_t numIndexes = multiIndex.size();
169  ArrayRef<int64_t> staticBasis = op.getStaticBasis();
170  if (numIndexes == staticBasis.size())
171  staticBasis = staticBasis.drop_front();
172 
173  SmallVector<Value> strides =
174  computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
175  /*knownNonNegative=*/op.getDisjoint());
177  scaledValues.reserve(numIndexes);
178 
179  // Note: strides doesn't contain a value for the final element (stride 1)
180  // and everything else lines up. We use the "mutable" accessor so we can get
181  // our hands on an `OpOperand&` for the loop invariant counting function.
182  for (auto [stride, idxOp] :
183  llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
184  Value scaledIdx = rewriter.create<arith::MulIOp>(
185  loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw);
186  int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
187  scaledValues.emplace_back(scaledIdx, numHoistableLoops);
188  }
189  scaledValues.emplace_back(
190  multiIndex.back(),
191  numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
192 
193  // Sort by how many enclosing loops there are, ties implicitly broken by
194  // size of the stride.
195  llvm::stable_sort(scaledValues,
196  [&](auto l, auto r) { return l.second > r.second; });
197 
198  Value result = scaledValues.front().first;
199  for (auto [scaledValue, numHoistableLoops] :
200  llvm::drop_begin(scaledValues)) {
201  std::ignore = numHoistableLoops;
202  result = rewriter.create<arith::AddIOp>(loc, result, scaledValue,
203  arith::IntegerOverflowFlags::nsw);
204  }
205  rewriter.replaceOp(op, result);
206  return success();
207  }
208 };
209 
210 class ExpandAffineIndexOpsPass
211  : public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
212 public:
213  ExpandAffineIndexOpsPass() = default;
214 
215  void runOnOperation() override {
216  MLIRContext *context = &getContext();
217  RewritePatternSet patterns(context);
219  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
220  return signalPassFailure();
221  }
222 };
223 
224 } // namespace
225 
228  patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
229  patterns.getContext());
230 }
231 
233  return std::make_unique<ExpandAffineIndexOpsPass>();
234 }
static MLIRContext * getContext(OpFoldResult val)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
std::unique_ptr< Pass > createAffineExpandIndexOpsPass()
Creates a pass to expand affine index operations into more fundamental operations (not necessarily re...
int64_t numEnclosingInvariantLoops(OpOperand &operand)
Count the number of loops surrounding operand such that operand could be hoisted above.
Definition: LoopUtils.cpp:2831
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns)
Populate patterns that expand affine index operations into more fundamental operations (not necessari...
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319