23 #include "llvm/ADT/STLExtras.h"
27 #define GEN_PASS_DEF_EXPANDOPS
28 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
54 LogicalResult matchAndRewrite(memref::AtomicRMWOp op,
56 auto loc = op.getLoc();
57 auto genericOp = rewriter.
create<memref::GenericAtomicRMWOp>(
58 loc, op.getMemref(), op.getIndices());
62 Value lhs = genericOp.getCurrentValue();
63 Value rhs = op.getValue();
67 bodyBuilder.create<memref::AtomicYieldOp>(loc, arithOp);
69 rewriter.
replaceOp(op, genericOp.getResult());
76 struct MemRefReshapeOpConverter :
public OpRewritePattern<memref::ReshapeOp> {
80 LogicalResult matchAndRewrite(memref::ReshapeOp op,
82 auto shapeType = cast<MemRefType>(op.getShape().getType());
83 if (!shapeType.hasStaticShape())
86 int64_t rank = cast<MemRefType>(shapeType).getDimSize(0);
92 Value stride =
nullptr;
93 int64_t staticStride = 1;
94 for (
int i = rank - 1; i >= 0; --i) {
97 if (op.getType().isDynamicDim(i)) {
98 Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
99 size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index);
100 if (!isa<IndexType>(size.getType()))
101 size = rewriter.create<arith::IndexCastOp>(
102 loc, rewriter.getIndexType(), size);
105 auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i));
106 size = rewriter.create<arith::ConstantOp>(loc, sizeAttr);
112 strides[i] = rewriter.getIndexAttr(staticStride);
116 stride = rewriter.create<arith::MulIOp>(loc, stride, size);
117 }
else if (op.getType().isDynamicDim(i)) {
118 stride = rewriter.create<arith::MulIOp>(
119 loc, rewriter.create<arith::ConstantIndexOp>(loc, staticStride),
122 staticStride *= op.getType().getDimSize(i);
126 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
127 op, op.getType(), op.getSource(), rewriter.getIndexAttr(0),
133 struct ExpandOpsPass :
public memref::impl::ExpandOpsBase<ExpandOpsPass> {
134 void runOnOperation()
override {
141 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
142 target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
143 [](memref::AtomicRMWOp op) {
144 constexpr std::array shouldBeExpandedKinds = {
145 arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf,
146 arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf};
147 return !llvm::is_contained(shouldBeExpandedKinds, op.getKind());
149 target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
150 return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
153 std::move(patterns))))
161 patterns.
add<AtomicRMWOpConverter, MemRefReshapeOpConverter>(
166 return std::make_unique<ExpandOpsPass>();
static MLIRContext * getContext(OpFoldResult val)
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
std::unique_ptr< Pass > createExpandOpsPass()
Creates an instance of the ExpandOps pass that legalizes memref dialect ops to be convertible to LLVM...
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
Include the generated interface declarations.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...