MLIR  16.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1 //===- StdExpandDivs.cpp - Code to prepare Std for lowering Divs to LLVM -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file Std transformations to expand Divs operation to help for the
10 // lowering to LLVM. Currently implemented transformations are Ceil and Floor
11 // for Signed Integers.
12 //
13 //===----------------------------------------------------------------------===//
14 
16 
20 #include "mlir/IR/TypeUtilities.h"
22 
23 namespace mlir {
24 namespace memref {
25 #define GEN_PASS_DEF_EXPANDOPS
26 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
27 } // namespace memref
28 } // namespace mlir
29 
30 using namespace mlir;
31 
32 namespace {
33 
34 /// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
35 /// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
36 /// `memref.generic_atomic_rmw` with the expanded code.
37 ///
38 /// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
39 ///
40 /// will be lowered to
41 ///
42 /// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
43 /// ^bb0(%current: f32):
44 /// %cmp = arith.cmpf "ogt", %current, %fval : f32
45 /// %new_value = select %cmp, %current, %fval : f32
46 /// memref.atomic_yield %new_value : f32
47 /// }
48 struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
49 public:
51 
52  LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
53  PatternRewriter &rewriter) const final {
54  arith::CmpFPredicate predicate;
55  switch (op.getKind()) {
56  case arith::AtomicRMWKind::maxf:
57  predicate = arith::CmpFPredicate::OGT;
58  break;
59  case arith::AtomicRMWKind::minf:
60  predicate = arith::CmpFPredicate::OLT;
61  break;
62  default:
63  return failure();
64  }
65 
66  auto loc = op.getLoc();
67  auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
68  loc, op.getMemref(), op.getIndices());
69  OpBuilder bodyBuilder =
70  OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());
71 
72  Value lhs = genericOp.getCurrentValue();
73  Value rhs = op.getValue();
74  Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
75  Value select = bodyBuilder.create<arith::SelectOp>(loc, cmp, lhs, rhs);
76  bodyBuilder.create<memref::AtomicYieldOp>(loc, select);
77 
78  rewriter.replaceOp(op, genericOp.getResult());
79  return success();
80  }
81 };
82 
83 /// Converts `memref.reshape` that has a target shape of a statically-known
84 /// size to `memref.reinterpret_cast`.
85 struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
86 public:
88 
89  LogicalResult matchAndRewrite(memref::ReshapeOp op,
90  PatternRewriter &rewriter) const final {
91  auto shapeType = op.getShape().getType().cast<MemRefType>();
92  if (!shapeType.hasStaticShape())
93  return failure();
94 
95  int64_t rank = shapeType.cast<MemRefType>().getDimSize(0);
96  SmallVector<OpFoldResult, 4> sizes, strides;
97  sizes.resize(rank);
98  strides.resize(rank);
99 
100  Location loc = op.getLoc();
101  Value stride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
102  for (int i = rank - 1; i >= 0; --i) {
103  Value size;
104  // Load dynamic sizes from the shape input, use constants for static dims.
105  if (op.getType().isDynamicDim(i)) {
106  Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
107  size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index);
108  if (!size.getType().isa<IndexType>())
109  size = rewriter.create<arith::IndexCastOp>(
110  loc, rewriter.getIndexType(), size);
111  sizes[i] = size;
112  } else {
113  sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i));
114  size =
115  rewriter.create<arith::ConstantOp>(loc, sizes[i].get<Attribute>());
116  }
117  strides[i] = stride;
118  if (i > 0)
119  stride = rewriter.create<arith::MulIOp>(loc, stride, size);
120  }
121  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
122  op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0),
123  sizes, strides);
124  return success();
125  }
126 };
127 
128 struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
129  void runOnOperation() override {
130  MLIRContext &ctx = getContext();
131 
132  RewritePatternSet patterns(&ctx);
134  ConversionTarget target(ctx);
135 
136  target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
137  target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
138  [](memref::AtomicRMWOp op) {
139  return op.getKind() != arith::AtomicRMWKind::maxf &&
140  op.getKind() != arith::AtomicRMWKind::minf;
141  });
142  target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
143  return !op.getShape().getType().cast<MemRefType>().hasStaticShape();
144  });
145  if (failed(applyPartialConversion(getOperation(), target,
146  std::move(patterns))))
147  signalPassFailure();
148  }
149 };
150 
151 } // namespace
152 
154  patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter>(
155  patterns.getContext());
156 }
157 
158 std::unique_ptr<Pass> mlir::memref::createExpandOpsPass() {
159  return std::make_unique<ExpandOpsPass>();
160 }
Include the generated interface declarations.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
LogicalResult applyPartialConversion(ArrayRef< Operation *> ops, ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation *> *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
Definition: ExpandOps.cpp:153
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:234
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:414
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:360
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn&#39;t have a listener...
Definition: Builders.h:270
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
IndexType getIndexType()
Definition: Builders.cpp:48
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:80
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
This class describes a specific conversion target.
std::unique_ptr< Pass > createExpandOpsPass()
Creates an instance of the ExpandOps pass that legalizes memref dialect ops to be convertible to LLVM...
Definition: ExpandOps.cpp:158
This class helps build Operations.
Definition: Builders.h:196
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:101
MLIRContext * getContext() const