18#define GEN_PASS_DEF_EXPANDOPSPASS
19#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
29struct MemRefReshapeOpConverter :
public OpRewritePattern<memref::ReshapeOp> {
33 LogicalResult matchAndRewrite(memref::ReshapeOp op,
35 auto shapeType = cast<MemRefType>(op.getShape().getType());
36 if (!shapeType.hasStaticShape())
39 int64_t rank = cast<MemRefType>(shapeType).getDimSize(0);
45 Value stride =
nullptr;
47 for (
int i = rank - 1; i >= 0; --i) {
50 if (op.getType().isDynamicDim(i)) {
52 size = memref::LoadOp::create(rewriter, loc, op.getShape(),
index);
53 if (!isa<IndexType>(size.
getType()))
54 size = arith::IndexCastOp::create(rewriter, loc,
55 rewriter.getIndexType(), size);
58 auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i));
59 size = arith::ConstantOp::create(rewriter, loc, sizeAttr);
65 strides[i] = rewriter.getIndexAttr(staticStride);
69 stride = arith::MulIOp::create(rewriter, loc, stride, size);
70 }
else if (op.getType().isDynamicDim(i)) {
71 stride = arith::MulIOp::create(
76 staticStride *= op.getType().getDimSize(i);
80 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
81 op, op.getType(), op.getSource(), rewriter.getIndexAttr(0),
88 void runOnOperation()
override {
95 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
96 target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
97 return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
99 if (failed(applyPartialConversion(getOperation(),
target,
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...