18#define GEN_PASS_DEF_EXPANDOPSPASS
19#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
29struct MemRefReshapeOpConverter :
public OpRewritePattern<memref::ReshapeOp> {
33 LogicalResult matchAndRewrite(memref::ReshapeOp op,
35 auto shapeType = cast<MemRefType>(op.getShape().getType());
36 if (!shapeType.hasStaticShape())
39 int64_t rank = cast<MemRefType>(shapeType).getDimSize(0);
45 Value stride =
nullptr;
47 for (
int i = rank - 1; i >= 0; --i) {
50 if (op.getType().isDynamicDim(i)) {
52 size = memref::LoadOp::create(rewriter, loc, op.getShape(),
index);
53 if (!isa<IndexType>(size.
getType()))
54 size = arith::IndexCastOp::create(rewriter, loc,
55 rewriter.getIndexType(), size);
58 auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i));
59 size = arith::ConstantOp::create(rewriter, loc, sizeAttr);
65 strides[i] = rewriter.getIndexAttr(staticStride);
69 stride = arith::MulIOp::create(rewriter, loc, stride, size);
70 }
else if (op.getType().isDynamicDim(i)) {
71 stride = arith::MulIOp::create(
76 staticStride *= op.getType().getDimSize(i);
80 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
81 op, op.getType(), op.getSource(), rewriter.getIndexAttr(0),
88 void runOnOperation()
override {
95 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
96 target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
97 return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
99 if (failed(applyPartialConversion(getOperation(),
target,
100 std::move(patterns))))
108 patterns.
add<MemRefReshapeOpConverter>(patterns.
getContext());
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...