MLIR  19.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1 //===- StdExpandDivs.cpp - Code to prepare Std for lowering Divs to LLVM -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file Std transformations to expand Divs operation to help for the
10 // lowering to LLVM. Currently implemented transformations are Ceil and Floor
11 // for Signed Integers.
12 //
13 //===----------------------------------------------------------------------===//
14 
16 
21 #include "mlir/IR/TypeUtilities.h"
23 #include "llvm/ADT/STLExtras.h"
24 
25 namespace mlir {
26 namespace memref {
27 #define GEN_PASS_DEF_EXPANDOPS
28 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
29 } // namespace memref
30 } // namespace mlir
31 
32 using namespace mlir;
33 
34 namespace {
35 
36 /// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
37 /// AtomicRMWOpLowering pattern, such as minimum and maximum operations for
38 /// floating-point numbers, to `memref.generic_atomic_rmw` with the expanded
39 /// code.
40 ///
41 /// %x = atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
42 ///
43 /// will be lowered to
44 ///
45 /// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
46 /// ^bb0(%current: f32):
47 /// %1 = arith.maximumf %current, %fval : f32
48 /// memref.atomic_yield %1 : f32
49 /// }
50 struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
51 public:
53 
54  LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
55  PatternRewriter &rewriter) const final {
56  auto loc = op.getLoc();
57  auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
58  loc, op.getMemref(), op.getIndices());
59  OpBuilder bodyBuilder =
60  OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());
61 
62  Value lhs = genericOp.getCurrentValue();
63  Value rhs = op.getValue();
64 
65  Value arithOp =
66  mlir::arith::getReductionOp(op.getKind(), bodyBuilder, loc, lhs, rhs);
67  bodyBuilder.create<memref::AtomicYieldOp>(loc, arithOp);
68 
69  rewriter.replaceOp(op, genericOp.getResult());
70  return success();
71  }
72 };
73 
74 /// Converts `memref.reshape` that has a target shape of a statically-known
75 /// size to `memref.reinterpret_cast`.
76 struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
77 public:
79 
80  LogicalResult matchAndRewrite(memref::ReshapeOp op,
81  PatternRewriter &rewriter) const final {
82  auto shapeType = cast<MemRefType>(op.getShape().getType());
83  if (!shapeType.hasStaticShape())
84  return failure();
85 
86  int64_t rank = cast<MemRefType>(shapeType).getDimSize(0);
87  SmallVector<OpFoldResult, 4> sizes, strides;
88  sizes.resize(rank);
89  strides.resize(rank);
90 
91  Location loc = op.getLoc();
92  Value stride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
93  for (int i = rank - 1; i >= 0; --i) {
94  Value size;
95  // Load dynamic sizes from the shape input, use constants for static dims.
96  if (op.getType().isDynamicDim(i)) {
97  Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
98  size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index);
99  if (!isa<IndexType>(size.getType()))
100  size = rewriter.create<arith::IndexCastOp>(
101  loc, rewriter.getIndexType(), size);
102  sizes[i] = size;
103  } else {
104  auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i));
105  size = rewriter.create<arith::ConstantOp>(loc, sizeAttr);
106  sizes[i] = sizeAttr;
107  }
108  strides[i] = stride;
109  if (i > 0)
110  stride = rewriter.create<arith::MulIOp>(loc, stride, size);
111  }
112  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
113  op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0),
114  sizes, strides);
115  return success();
116  }
117 };
118 
119 struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
120  void runOnOperation() override {
121  MLIRContext &ctx = getContext();
122 
123  RewritePatternSet patterns(&ctx);
125  ConversionTarget target(ctx);
126 
127  target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
128  target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
129  [](memref::AtomicRMWOp op) {
130  constexpr std::array shouldBeExpandedKinds = {
131  arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf,
132  arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf};
133  return !llvm::is_contained(shouldBeExpandedKinds, op.getKind());
134  });
135  target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
136  return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
137  });
138  if (failed(applyPartialConversion(getOperation(), target,
139  std::move(patterns))))
140  signalPassFailure();
141  }
142 };
143 
144 } // namespace
145 
147  patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter>(
148  patterns.getContext());
149 }
150 
151 std::unique_ptr<Pass> mlir::memref::createExpandOpsPass() {
152  return std::make_unique<ExpandOpsPass>();
153 }
static MLIRContext * getContext(OpFoldResult val)
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:248
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:322
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Definition: ArithOps.cpp:2540
std::unique_ptr< Pass > createExpandOpsPass()
Creates an instance of the ExpandOps pass that legalizes memref dialect ops to be convertible to LLVM...
Definition: ExpandOps.cpp:151
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
Definition: ExpandOps.cpp:146
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362