MLIR  20.0.0git
AffineExpandIndexOps.cpp
Go to the documentation of this file.
1 //===- AffineExpandIndexOps.cpp - Affine expand index ops pass ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to expand affine index ops into one or more more
10 // fundamental operations.
11 //===----------------------------------------------------------------------===//
12 
15 
21 
22 namespace mlir {
23 namespace affine {
24 #define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS
25 #include "mlir/Dialect/Affine/Passes.h.inc"
26 } // namespace affine
27 } // namespace mlir
28 
29 using namespace mlir;
30 using namespace mlir::affine;
31 
32 /// Given a basis (in static and dynamic components), return the sequence of
33 /// suffix products of the basis, including the product of the entire basis,
34 /// which must **not** contain an outer bound.
35 ///
36 /// If excess dynamic values are provided, the values at the beginning
37 /// will be ignored. This allows for dropping the outer bound without
38 /// needing to manipulate the dynamic value array.
40  ValueRange dynamicBasis,
41  ArrayRef<int64_t> staticBasis) {
42  if (staticBasis.empty())
43  return {};
44 
45  SmallVector<Value> result;
46  result.reserve(staticBasis.size());
47  size_t dynamicIndex = dynamicBasis.size();
48  Value dynamicPart = nullptr;
49  int64_t staticPart = 1;
50  for (int64_t elem : llvm::reverse(staticBasis)) {
51  if (ShapedType::isDynamic(elem)) {
52  if (dynamicPart)
53  dynamicPart = rewriter.create<arith::MulIOp>(
54  loc, dynamicPart, dynamicBasis[dynamicIndex - 1]);
55  else
56  dynamicPart = dynamicBasis[dynamicIndex - 1];
57  --dynamicIndex;
58  } else {
59  staticPart *= elem;
60  }
61 
62  if (dynamicPart && staticPart == 1) {
63  result.push_back(dynamicPart);
64  } else {
65  Value stride =
66  rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
67  if (dynamicPart)
68  stride = rewriter.create<arith::MulIOp>(loc, dynamicPart, stride);
69  result.push_back(stride);
70  }
71  }
72  std::reverse(result.begin(), result.end());
73  return result;
74 }
75 
76 namespace {
77 /// Lowers `affine.delinearize_index` into a sequence of division and remainder
78 /// operations.
79 struct LowerDelinearizeIndexOps
80  : public OpRewritePattern<AffineDelinearizeIndexOp> {
82  LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
83  PatternRewriter &rewriter) const override {
84  Location loc = op.getLoc();
85  Value linearIdx = op.getLinearIndex();
86  unsigned numResults = op.getNumResults();
87  ArrayRef<int64_t> staticBasis = op.getStaticBasis();
88  if (numResults == staticBasis.size())
89  staticBasis = staticBasis.drop_front();
90 
91  if (numResults == 1) {
92  rewriter.replaceOp(op, linearIdx);
93  return success();
94  }
95 
96  SmallVector<Value> results;
97  results.reserve(numResults);
98  SmallVector<Value> strides =
99  computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
100 
101  Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
102 
103  Value initialPart =
104  rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
105  results.push_back(initialPart);
106 
107  auto emitModTerm = [&](Value stride) -> Value {
108  Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
109  Value remainderNegative = rewriter.create<arith::CmpIOp>(
110  loc, arith::CmpIPredicate::slt, remainder, zero);
111  Value corrected = rewriter.create<arith::AddIOp>(loc, remainder, stride);
112  Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
113  corrected, remainder);
114  return mod;
115  };
116 
117  // Generate all the intermediate parts
118  for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
119  Value thisStride = strides[i];
120  Value nextStride = strides[i + 1];
121  Value modulus = emitModTerm(thisStride);
122  // We know both inputs are positive, so floorDiv == div.
123  // This could potentially be a divui, but it's not clear if that would
124  // cause issues.
125  Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
126  results.push_back(divided);
127  }
128 
129  results.push_back(emitModTerm(strides.back()));
130 
131  rewriter.replaceOp(op, results);
132  return success();
133  }
134 };
135 
136 /// Lowers `affine.linearize_index` into a sequence of multiplications and
137 /// additions. Make a best effort to sort the input indices so that
138 /// the most loop-invariant terms are at the left of the additions
139 /// to enable loop-invariant code motion.
140 struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
142  LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
143  PatternRewriter &rewriter) const override {
144  // Should be folded away, included here for safety.
145  if (op.getMultiIndex().empty()) {
146  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
147  return success();
148  }
149 
150  Location loc = op.getLoc();
151  ValueRange multiIndex = op.getMultiIndex();
152  size_t numIndexes = multiIndex.size();
153  ArrayRef<int64_t> staticBasis = op.getStaticBasis();
154  if (numIndexes == staticBasis.size())
155  staticBasis = staticBasis.drop_front();
156 
157  SmallVector<Value> strides =
158  computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
160  scaledValues.reserve(numIndexes);
161 
162  // Note: strides doesn't contain a value for the final element (stride 1)
163  // and everything else lines up. We use the "mutable" accessor so we can get
164  // our hands on an `OpOperand&` for the loop invariant counting function.
165  for (auto [stride, idxOp] :
166  llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
167  Value scaledIdx =
168  rewriter.create<arith::MulIOp>(loc, idxOp.get(), stride);
169  int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
170  scaledValues.emplace_back(scaledIdx, numHoistableLoops);
171  }
172  scaledValues.emplace_back(
173  multiIndex.back(),
174  numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
175 
176  // Sort by how many enclosing loops there are, ties implicitly broken by
177  // size of the stride.
178  llvm::stable_sort(scaledValues,
179  [&](auto l, auto r) { return l.second > r.second; });
180 
181  Value result = scaledValues.front().first;
182  for (auto [scaledValue, numHoistableLoops] :
183  llvm::drop_begin(scaledValues)) {
184  std::ignore = numHoistableLoops;
185  result = rewriter.create<arith::AddIOp>(loc, result, scaledValue);
186  }
187  rewriter.replaceOp(op, result);
188  return success();
189  }
190 };
191 
192 class ExpandAffineIndexOpsPass
193  : public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
194 public:
195  ExpandAffineIndexOpsPass() = default;
196 
197  void runOnOperation() override {
198  MLIRContext *context = &getContext();
199  RewritePatternSet patterns(context);
201  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
202  return signalPassFailure();
203  }
204 };
205 
206 } // namespace
207 
210  patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
211  patterns.getContext());
212 }
213 
215  return std::make_unique<ExpandAffineIndexOpsPass>();
216 }
static MLIRContext * getContext(OpFoldResult val)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:529
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
std::unique_ptr< Pass > createAffineExpandIndexOpsPass()
Creates a pass to expand affine index operations into more fundamental operations (not necessarily re...
int64_t numEnclosingInvariantLoops(OpOperand &operand)
Count the number of loops surrounding operand such that operand could be hoisted above.
Definition: LoopUtils.cpp:2787
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns)
Populate patterns that expand affine index operations into more fundamental operations (not necessari...
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362