MLIR 23.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1//===- ExpandDivs.cpp - Expansion patterns for MemRef operations ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15
16namespace mlir {
17namespace memref {
18#define GEN_PASS_DEF_EXPANDOPSPASS
19#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
20} // namespace memref
21} // namespace mlir
22
23using namespace mlir;
24
25namespace {
26
27/// Converts `memref.reshape` that has a target shape of a statically-known
28/// size to `memref.reinterpret_cast`.
29struct MemRefReshapeOpConverter : public OpRewritePattern<memref::ReshapeOp> {
30public:
32
33 LogicalResult matchAndRewrite(memref::ReshapeOp op,
34 PatternRewriter &rewriter) const final {
35 auto shapeType = cast<MemRefType>(op.getShape().getType());
36 if (!shapeType.hasStaticShape())
37 return failure();
39 int64_t rank = cast<MemRefType>(shapeType).getDimSize(0);
41 sizes.resize(rank);
42 strides.resize(rank);
44 Location loc = op.getLoc();
45 Value stride = nullptr;
46 int64_t staticStride = 1;
47 for (int i = rank - 1; i >= 0; --i) {
48 Value size;
49 // Load dynamic sizes from the shape input, use constants for static dims.
50 if (op.getType().isDynamicDim(i)) {
52 size = memref::LoadOp::create(rewriter, loc, op.getShape(), index);
53 if (!isa<IndexType>(size.getType()))
54 size = arith::IndexCastOp::create(rewriter, loc,
55 rewriter.getIndexType(), size);
56 sizes[i] = size;
57 } else {
58 auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i));
59 size = arith::ConstantOp::create(rewriter, loc, sizeAttr);
60 sizes[i] = sizeAttr;
61 }
62 if (stride)
63 strides[i] = stride;
64 else
65 strides[i] = rewriter.getIndexAttr(staticStride);
66
67 if (i > 0) {
68 if (stride) {
69 stride = arith::MulIOp::create(rewriter, loc, stride, size);
70 } else if (op.getType().isDynamicDim(i)) {
71 stride = arith::MulIOp::create(
72 rewriter, loc,
73 arith::ConstantIndexOp::create(rewriter, loc, staticStride),
74 size);
75 } else {
76 staticStride *= op.getType().getDimSize(i);
77 }
78 }
79 }
80 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
81 op, op.getType(), op.getSource(), /*offset=*/rewriter.getIndexAttr(0),
82 sizes, strides);
83 return success();
84 }
85};
86
87struct ExpandOpsPass : public memref::impl::ExpandOpsPassBase<ExpandOpsPass> {
88 void runOnOperation() override {
89 MLIRContext &ctx = getContext();
90
91 RewritePatternSet patterns(&ctx);
94
95 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
96 target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
97 return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
98 });
99 if (failed(applyPartialConversion(getOperation(), target,
100 std::move(patterns))))
101 signalPassFailure();
102 }
103};
104
105} // namespace
106
108 patterns.add<MemRefReshapeOpConverter>(patterns.getContext());
109}
return success()
b getContext())
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...