MLIR  22.0.0git
AffineExpandIndexOps.cpp
Go to the documentation of this file.
1 //===- AffineExpandIndexOps.cpp - Affine expand index ops pass ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to expand affine index ops into one or more more
10 // fundamental operations.
11 //===----------------------------------------------------------------------===//
12 
15 
19 
20 namespace mlir {
21 namespace affine {
22 #define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS
23 #include "mlir/Dialect/Affine/Passes.h.inc"
24 } // namespace affine
25 } // namespace mlir
26 
27 using namespace mlir;
28 using namespace mlir::affine;
29 
30 /// Given a basis (in static and dynamic components), return the sequence of
31 /// suffix products of the basis, including the product of the entire basis,
32 /// which must **not** contain an outer bound.
33 ///
34 /// If excess dynamic values are provided, the values at the beginning
35 /// will be ignored. This allows for dropping the outer bound without
36 /// needing to manipulate the dynamic value array. `knownPositive`
37 /// indicases that the values being used to compute the strides are known
38 /// to be non-negative.
40  ValueRange dynamicBasis,
41  ArrayRef<int64_t> staticBasis,
42  bool knownNonNegative) {
43  if (staticBasis.empty())
44  return {};
45 
46  SmallVector<Value> result;
47  result.reserve(staticBasis.size());
48  size_t dynamicIndex = dynamicBasis.size();
49  Value dynamicPart = nullptr;
50  int64_t staticPart = 1;
51  // The products of the strides can't have overflow by definition of
52  // affine.*_index.
53  arith::IntegerOverflowFlags ovflags = arith::IntegerOverflowFlags::nsw;
54  if (knownNonNegative)
55  ovflags = ovflags | arith::IntegerOverflowFlags::nuw;
56  for (int64_t elem : llvm::reverse(staticBasis)) {
57  if (ShapedType::isDynamic(elem)) {
58  // Note: basis elements and their products are, definitionally,
59  // non-negative, so `nuw` is justified.
60  if (dynamicPart)
61  dynamicPart =
62  arith::MulIOp::create(rewriter, loc, dynamicPart,
63  dynamicBasis[dynamicIndex - 1], ovflags);
64  else
65  dynamicPart = dynamicBasis[dynamicIndex - 1];
66  --dynamicIndex;
67  } else {
68  staticPart *= elem;
69  }
70 
71  if (dynamicPart && staticPart == 1) {
72  result.push_back(dynamicPart);
73  } else {
74  Value stride =
75  rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
76  if (dynamicPart)
77  stride =
78  arith::MulIOp::create(rewriter, loc, dynamicPart, stride, ovflags);
79  result.push_back(stride);
80  }
81  }
82  std::reverse(result.begin(), result.end());
83  return result;
84 }
85 
86 LogicalResult
88  AffineDelinearizeIndexOp op) {
89  Location loc = op.getLoc();
90  Value linearIdx = op.getLinearIndex();
91  unsigned numResults = op.getNumResults();
92  ArrayRef<int64_t> staticBasis = op.getStaticBasis();
93  if (numResults == staticBasis.size())
94  staticBasis = staticBasis.drop_front();
95 
96  if (numResults == 1) {
97  rewriter.replaceOp(op, linearIdx);
98  return success();
99  }
100 
101  SmallVector<Value> results;
102  results.reserve(numResults);
103  SmallVector<Value> strides =
104  computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
105  /*knownNonNegative=*/true);
106 
107  Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
108 
109  Value initialPart =
110  arith::FloorDivSIOp::create(rewriter, loc, linearIdx, strides.front());
111  results.push_back(initialPart);
112 
113  auto emitModTerm = [&](Value stride) -> Value {
114  Value remainder = arith::RemSIOp::create(rewriter, loc, linearIdx, stride);
115  Value remainderNegative = arith::CmpIOp::create(
116  rewriter, loc, arith::CmpIPredicate::slt, remainder, zero);
117  // If the correction is relevant, this term is <= stride, which is known
118  // to be positive in `index`. Otherwise, while 2 * stride might overflow,
119  // this branch won't be taken, so the risk of `poison` is fine.
120  Value corrected = arith::AddIOp::create(rewriter, loc, remainder, stride,
121  arith::IntegerOverflowFlags::nsw);
122  Value mod = arith::SelectOp::create(rewriter, loc, remainderNegative,
123  corrected, remainder);
124  return mod;
125  };
126 
127  // Generate all the intermediate parts
128  for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
129  Value thisStride = strides[i];
130  Value nextStride = strides[i + 1];
131  Value modulus = emitModTerm(thisStride);
132  // We know both inputs are positive, so floorDiv == div.
133  // This could potentially be a divui, but it's not clear if that would
134  // cause issues.
135  Value divided = arith::DivSIOp::create(rewriter, loc, modulus, nextStride);
136  results.push_back(divided);
137  }
138 
139  results.push_back(emitModTerm(strides.back()));
140 
141  rewriter.replaceOp(op, results);
142  return success();
143 }
144 
146  AffineLinearizeIndexOp op) {
147  // Should be folded away, included here for safety.
148  if (op.getMultiIndex().empty()) {
150  return success();
151  }
152 
153  Location loc = op.getLoc();
154  ValueRange multiIndex = op.getMultiIndex();
155  size_t numIndexes = multiIndex.size();
156  ArrayRef<int64_t> staticBasis = op.getStaticBasis();
157  if (numIndexes == staticBasis.size())
158  staticBasis = staticBasis.drop_front();
159 
160  SmallVector<Value> strides =
161  computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis,
162  /*knownNonNegative=*/op.getDisjoint());
164  scaledValues.reserve(numIndexes);
165 
166  // Note: strides doesn't contain a value for the final element (stride 1)
167  // and everything else lines up. We use the "mutable" accessor so we can get
168  // our hands on an `OpOperand&` for the loop invariant counting function.
169  for (auto [stride, idxOp] :
170  llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
171  Value scaledIdx = arith::MulIOp::create(rewriter, loc, idxOp.get(), stride,
172  arith::IntegerOverflowFlags::nsw);
173  int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
174  scaledValues.emplace_back(scaledIdx, numHoistableLoops);
175  }
176  scaledValues.emplace_back(
177  multiIndex.back(),
178  numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
179 
180  // Sort by how many enclosing loops there are, ties implicitly broken by
181  // size of the stride.
182  llvm::stable_sort(scaledValues,
183  [&](auto l, auto r) { return l.second > r.second; });
184 
185  Value result = scaledValues.front().first;
186  for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) {
187  std::ignore = numHoistableLoops;
188  result = arith::AddIOp::create(rewriter, loc, result, scaledValue,
189  arith::IntegerOverflowFlags::nsw);
190  }
191  rewriter.replaceOp(op, result);
192  return success();
193 }
194 
195 namespace {
196 struct LowerDelinearizeIndexOps
197  : public OpRewritePattern<AffineDelinearizeIndexOp> {
199  LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
200  PatternRewriter &rewriter) const override {
201  return affine::lowerAffineDelinearizeIndexOp(rewriter, op);
202  }
203 };
204 
205 struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
207  LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
208  PatternRewriter &rewriter) const override {
209  return affine::lowerAffineLinearizeIndexOp(rewriter, op);
210  }
211 };
212 
213 class ExpandAffineIndexOpsPass
214  : public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
215 public:
216  ExpandAffineIndexOpsPass() = default;
217 
218  void runOnOperation() override {
219  MLIRContext *context = &getContext();
220  RewritePatternSet patterns(context);
222  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
223  return signalPassFailure();
224  }
225 };
226 
227 } // namespace
228 
231  patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
232  patterns.getContext());
233 }
234 
236  return std::make_unique<ExpandAffineIndexOpsPass>();
237 }
static MLIRContext * getContext(OpFoldResult val)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:113
LogicalResult lowerAffineDelinearizeIndexOp(RewriterBase &rewriter, AffineDelinearizeIndexOp op)
Lowers affine.delinearize_index into a sequence of division and remainder operations.
LogicalResult lowerAffineLinearizeIndexOp(RewriterBase &rewriter, AffineLinearizeIndexOp op)
Lowers affine.linearize_index into a sequence of multiplications and additions.
std::unique_ptr< Pass > createAffineExpandIndexOpsPass()
Creates a pass to expand affine index operations into more fundamental operations (not necessarily re...
int64_t numEnclosingInvariantLoops(OpOperand &operand)
Count the number of loops surrounding operand such that operand could be hoisted above.
Definition: LoopUtils.cpp:2817
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns)
Populate patterns that expand affine index operations into more fundamental operations (not necessari...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319