18 #define GEN_PASS_DEF_EXPANDOPSPASS
19 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
29 struct MemRefReshapeOpConverter :
public OpRewritePattern<memref::ReshapeOp> {
33 LogicalResult matchAndRewrite(memref::ReshapeOp op,
35 auto shapeType = cast<MemRefType>(op.getShape().getType());
36 if (!shapeType.hasStaticShape())
39 int64_t rank = cast<MemRefType>(shapeType).getDimSize(0);
45 Value stride =
nullptr;
46 int64_t staticStride = 1;
47 for (
int i = rank - 1; i >= 0; --i) {
50 if (op.getType().isDynamicDim(i)) {
52 size = memref::LoadOp::create(rewriter, loc, op.getShape(), index);
53 if (!isa<IndexType>(size.
getType()))
54 size = arith::IndexCastOp::create(rewriter, loc,
58 auto sizeAttr = rewriter.
getIndexAttr(op.getType().getDimSize(i));
59 size = arith::ConstantOp::create(rewriter, loc, sizeAttr);
69 stride = arith::MulIOp::create(rewriter, loc, stride, size);
70 }
else if (op.getType().isDynamicDim(i)) {
71 stride = arith::MulIOp::create(
76 staticStride *= op.getType().getDimSize(i);
81 op, op.getType(), op.getSource(), rewriter.
getIndexAttr(0),
87 struct ExpandOpsPass :
public memref::impl::ExpandOpsPassBase<ExpandOpsPass> {
88 void runOnOperation()
override {
95 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
96 target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
97 return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getIndexAttr(int64_t value)
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...