MLIR  20.0.0git
AffineExpandIndexOps.cpp
Go to the documentation of this file.
1 //===- AffineExpandIndexOps.cpp - Affine expand index ops pass ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to expand affine index ops into one or more more
10 // fundamental operations.
11 //===----------------------------------------------------------------------===//
12 
14 
20 
21 namespace mlir {
22 namespace affine {
23 #define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS
24 #include "mlir/Dialect/Affine/Passes.h.inc"
25 } // namespace affine
26 } // namespace mlir
27 
28 using namespace mlir;
29 using namespace mlir::affine;
30 
31 namespace {
32 /// Lowers `affine.delinearize_index` into a sequence of division and remainder
33 /// operations.
34 struct LowerDelinearizeIndexOps
35  : public OpRewritePattern<AffineDelinearizeIndexOp> {
37  LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
38  PatternRewriter &rewriter) const override {
39  FailureOr<SmallVector<Value>> multiIndex =
40  delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(),
41  op.getEffectiveBasis(), /*hasOuterBound=*/false);
42  if (failed(multiIndex))
43  return failure();
44  rewriter.replaceOp(op, *multiIndex);
45  return success();
46  }
47 };
48 
49 /// Lowers `affine.linearize_index` into a sequence of multiplications and
50 /// additions.
51 struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
53  LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
54  PatternRewriter &rewriter) const override {
55  // Should be folded away, included here for safety.
56  if (op.getMultiIndex().empty()) {
57  rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
58  return success();
59  }
60 
61  SmallVector<OpFoldResult> multiIndex =
62  getAsOpFoldResult(op.getMultiIndex());
63  OpFoldResult linearIndex =
64  linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis());
65  Value linearIndexValue =
66  getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex);
67  rewriter.replaceOp(op, linearIndexValue);
68  return success();
69  }
70 };
71 
72 class ExpandAffineIndexOpsPass
73  : public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
74 public:
75  ExpandAffineIndexOpsPass() = default;
76 
77  void runOnOperation() override {
78  MLIRContext *context = &getContext();
79  RewritePatternSet patterns(context);
81  if (failed(
82  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
83  return signalPassFailure();
84  }
85 };
86 
87 } // namespace
88 
90  RewritePatternSet &patterns) {
91  patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
92  patterns.getContext());
93 }
94 
96  return std::make_unique<ExpandAffineIndexOpsPass>();
97 }
static MLIRContext * getContext(OpFoldResult val)
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
RewritePatternSet & insert(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:937
MLIRContext * getContext() const
Definition: PatternMatch.h:829
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
std::unique_ptr< Pass > createAffineExpandIndexOpsPass()
Creates a pass to expand affine index operations into more fundamental operations (not necessarily re...
FailureOr< SmallVector< Value > > delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef< Value > basis, bool hasOuterBound=true)
Generate the IR to delinearize linearIndex given the basis and return the multi-index.
Definition: Utils.cpp:1946
void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns)
Populate patterns that expand affine index operations into more fundamental operations (not necessari...
OpFoldResult linearizeIndex(ArrayRef< OpFoldResult > multiIndex, ArrayRef< OpFoldResult > basis, ImplicitLocOpBuilder &builder)
Definition: Utils.cpp:2006
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:103
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362