MLIR  22.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1 //===- ExpandDivs.cpp - Expansion patterns for MemRef operations ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/IR/TypeUtilities.h"
15 
16 namespace mlir {
17 namespace memref {
18 #define GEN_PASS_DEF_EXPANDOPSPASS
19 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
20 } // namespace memref
21 } // namespace mlir
22 
23 using namespace mlir;
24 
25 namespace {
26 
27 /// Converts `memref.reshape` that has a target shape of a statically-known
28 /// size to `memref.reinterpret_cast`.
29 struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
30 public:
32 
33  LogicalResult matchAndRewrite(memref::ReshapeOp op,
34  PatternRewriter &rewriter) const final {
35  auto shapeType = cast<MemRefType>(op.getShape().getType());
36  if (!shapeType.hasStaticShape())
37  return failure();
38 
39  int64_t rank = cast<MemRefType>(shapeType).getDimSize(0);
40  SmallVector<OpFoldResult, 4> sizes, strides;
41  sizes.resize(rank);
42  strides.resize(rank);
43 
44  Location loc = op.getLoc();
45  Value stride = nullptr;
46  int64_t staticStride = 1;
47  for (int i = rank - 1; i >= 0; --i) {
48  Value size;
49  // Load dynamic sizes from the shape input, use constants for static dims.
50  if (op.getType().isDynamicDim(i)) {
51  Value index = arith::ConstantIndexOp::create(rewriter, loc, i);
52  size = memref::LoadOp::create(rewriter, loc, op.getShape(), index);
53  if (!isa<IndexType>(size.getType()))
54  size = arith::IndexCastOp::create(rewriter, loc,
55  rewriter.getIndexType(), size);
56  sizes[i] = size;
57  } else {
58  auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i));
59  size = arith::ConstantOp::create(rewriter, loc, sizeAttr);
60  sizes[i] = sizeAttr;
61  }
62  if (stride)
63  strides[i] = stride;
64  else
65  strides[i] = rewriter.getIndexAttr(staticStride);
66 
67  if (i > 0) {
68  if (stride) {
69  stride = arith::MulIOp::create(rewriter, loc, stride, size);
70  } else if (op.getType().isDynamicDim(i)) {
71  stride = arith::MulIOp::create(
72  rewriter, loc,
73  arith::ConstantIndexOp::create(rewriter, loc, staticStride),
74  size);
75  } else {
76  staticStride *= op.getType().getDimSize(i);
77  }
78  }
79  }
80  rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
81  op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0),
82  sizes, strides);
83  return success();
84  }
85 };
86 
87 struct ExpandOpsPass : public memref::impl::ExpandOpsPassBase<ExpandOpsPass> {
88  void runOnOperation() override {
89  MLIRContext &ctx = getContext();
90 
93  ConversionTarget target(ctx);
94 
95  target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
96  target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
97  return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
98  });
99  if (failed(applyPartialConversion(getOperation(), target,
100  std::move(patterns))))
101  signalPassFailure();
102  }
103 };
104 
105 } // namespace
106 
108  patterns.add<MemRefReshapeOpConverter>(patterns.getContext());
109 }
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
IndexType getIndexType()
Definition: Builders.cpp:50
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
Definition: ExpandOps.cpp:107
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319