MLIR  20.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1 //===- StdExpandDivs.cpp - Code to prepare Std for lowering Divs to LLVM -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file Std transformations to expand Divs operation to help for the
10 // lowering to LLVM. Currently implemented transformations are Ceil and Floor
11 // for Signed Integers.
12 //
13 //===----------------------------------------------------------------------===//
14 
16 
21 #include "mlir/IR/TypeUtilities.h"
23 #include "llvm/ADT/STLExtras.h"
24 
25 namespace mlir {
26 namespace memref {
27 #define GEN_PASS_DEF_EXPANDOPS
28 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
29 } // namespace memref
30 } // namespace mlir
31 
32 using namespace mlir;
33 
34 namespace {
35 
36 /// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
37 /// AtomicRMWOpLowering pattern, such as minimum and maximum operations for
38 /// floating-point numbers, to `memref.generic_atomic_rmw` with the expanded
39 /// code.
40 ///
41 /// %x = atomic_rmw maximumf %fval, %F[%i] : (f32, memref<10xf32>) -> f32
42 ///
43 /// will be lowered to
44 ///
45 /// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
46 /// ^bb0(%current: f32):
47 /// %1 = arith.maximumf %current, %fval : f32
48 /// memref.atomic_yield %1 : f32
49 /// }
50 struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
51 public:
53 
54  LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
55  PatternRewriter &rewriter) const final {
56  auto loc = op.getLoc();
57  auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
58  loc, op.getMemref(), op.getIndices());
59  OpBuilder bodyBuilder =
60  OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());
61 
62  Value lhs = genericOp.getCurrentValue();
63  Value rhs = op.getValue();
64 
65  Value arithOp =
66  mlir::arith::getReductionOp(op.getKind(), bodyBuilder, loc, lhs, rhs);
67  bodyBuilder.create<memref::AtomicYieldOp>(loc, arithOp);
68 
69  rewriter.replaceOp(op, genericOp.getResult());
70  return success();
71  }
72 };
73 
74 /// Converts `memref.reshape` that has a target shape of a statically-known
75 /// size to `memref.reinterpret_cast`.
76 struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
77 public:
79 
80  LogicalResult matchAndRewrite(memref::ReshapeOp op,
81  PatternRewriter &rewriter) const final {
82  auto shapeType = cast<MemRefType>(op.getShape().getType());
83  if (!shapeType.hasStaticShape())
84  return failure();
85 
86  int64_t rank = cast<MemRefType>(shapeType).getDimSize(0);
87  SmallVector<OpFoldResult, 4> sizes, strides;
88  sizes.resize(rank);
89  strides.resize(rank);
90 
91  Location loc = op.getLoc();
92  Value stride = nullptr;
93  int64_t staticStride = 1;
94  for (int i = rank - 1; i >= 0; --i) {
95  Value size;
96  // Load dynamic sizes from the shape input, use constants for static dims.
97  if (op.getType().isDynamicDim(i)) {
98  Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
99  size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index);
100  if (!isa<IndexType>(size.getType()))
101  size = rewriter.create<arith::IndexCastOp>(
102  loc, rewriter.getIndexType(), size);
103  sizes[i] = size;
104  } else {
105  auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i));
106  size = rewriter.create<arith::ConstantOp>(loc, sizeAttr);
107  sizes[i] = sizeAttr;
108  }
109  if (stride)
110  strides[i] = stride;
111  else
112  strides[i] = rewriter.getIndexAttr(staticStride);
113 
114  if (i > 0) {
115  if (stride) {
116  stride = rewriter.create<arith::MulIOp>(loc, stride, size);
117  } else if (op.getType().isDynamicDim(i)) {
118  stride = rewriter.create<arith::MulIOp>(
119  loc, rewriter.create<arith::ConstantIndexOp>(loc, staticStride),
120  size);
121  } else {
122  staticStride *= op.getType().getDimSize(i);
123  }
124  }
125  }
126  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
127  op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0),
128  sizes, strides);
129  return success();
130  }
131 };
132 
133 struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
134  void runOnOperation() override {
135  MLIRContext &ctx = getContext();
136 
137  RewritePatternSet patterns(&ctx);
139  ConversionTarget target(ctx);
140 
141  target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
142  target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
143  [](memref::AtomicRMWOp op) {
144  constexpr std::array shouldBeExpandedKinds = {
145  arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf,
146  arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf};
147  return !llvm::is_contained(shouldBeExpandedKinds, op.getKind());
148  });
149  target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
150  return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
151  });
152  if (failed(applyPartialConversion(getOperation(), target,
153  std::move(patterns))))
154  signalPassFailure();
155  }
156 };
157 
158 } // namespace
159 
161  patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter>(
162  patterns.getContext());
163 }
164 
165 std::unique_ptr<Pass> mlir::memref::createExpandOpsPass() {
166  return std::make_unique<ExpandOpsPass>();
167 }
static MLIRContext * getContext(OpFoldResult val)
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:216
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:255
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:329
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
Definition: ArithOps.cpp:2608
std::unique_ptr< Pass > createExpandOpsPass()
Creates an instance of the ExpandOps pass that legalizes memref dialect ops to be convertible to LLVM...
Definition: ExpandOps.cpp:165
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
Definition: ExpandOps.cpp:160
Include the generated interface declarations.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362