MLIR 22.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1//===- ExpandDivs.cpp - Expansion patterns for MemRef operations ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15
16namespace mlir {
17namespace memref {
18#define GEN_PASS_DEF_EXPANDOPSPASS
19#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
20} // namespace memref
21} // namespace mlir
22
23using namespace mlir;
24
25namespace {
26
27/// Converts `memref.reshape` that has a target shape of a statically-known
28/// size to `memref.reinterpret_cast`.
29struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
30public:
32
33 LogicalResult matchAndRewrite(memref::ReshapeOp op,
34 PatternRewriter &rewriter) const final {
35 auto shapeType = cast<MemRefType>(op.getShape().getType());
36 if (!shapeType.hasStaticShape())
37 return failure();
39 int64_t rank = cast<MemRefType>(shapeType).getDimSize(0);
41 sizes.resize(rank);
42 strides.resize(rank);
44 Location loc = op.getLoc();
45 Value stride = nullptr;
46 int64_t staticStride = 1;
47 for (int i = rank - 1; i >= 0; --i) {
48 Value size;
49 // Load dynamic sizes from the shape input, use constants for static dims.
50 if (op.getType().isDynamicDim(i)) {
52 size = memref::LoadOp::create(rewriter, loc, op.getShape(), index);
53 if (!isa<IndexType>(size.getType()))
54 size = arith::IndexCastOp::create(rewriter, loc,
55 rewriter.getIndexType(), size);
56 sizes[i] = size;
57 } else {
58 auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i));
59 size = arith::ConstantOp::create(rewriter, loc, sizeAttr);
60 sizes[i] = sizeAttr;
61 }
62 if (stride)
63 strides[i] = stride;
64 else
65 strides[i] = rewriter.getIndexAttr(staticStride);
66
67 if (i > 0) {
68 if (stride) {
69 stride = arith::MulIOp::create(rewriter, loc, stride, size);
70 } else if (op.getType().isDynamicDim(i)) {
71 stride = arith::MulIOp::create(
72 rewriter, loc,
73 arith::ConstantIndexOp::create(rewriter, loc, staticStride),
74 size);
75 } else {
76 staticStride *= op.getType().getDimSize(i);
77 }
78 }
79 }
80 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
81 op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0),
82 sizes, strides);
83 return success();
84 }
85};
86
87struct ExpandOpsPass : public memref::impl::ExpandOpsPassBase<ExpandOpsPass> {
88 void runOnOperation() override {
89 MLIRContext &ctx = getContext();
90
94
95 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
96 target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
97 return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
98 });
99 if (failed(applyPartialConversion(getOperation(), target,
100 std::move(patterns))))
101 signalPassFailure();
102 }
103};
104
105} // namespace
106
108 patterns.add<MemRefReshapeOpConverter>(patterns.getContext());
109}
return success()
b getContext())
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...