23 #include "llvm/ADT/STLExtras.h"
27 #define GEN_PASS_DEF_EXPANDOPS
28 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
57 auto genericOp = rewriter.
create<memref::GenericAtomicRMWOp>(
58 loc, op.getMemref(), op.getIndices());
62 Value lhs = genericOp.getCurrentValue();
63 Value rhs = op.getValue();
67 bodyBuilder.create<memref::AtomicYieldOp>(loc, arithOp);
76 struct MemRefReshapeOpConverter :
public OpRewritePattern<memref::ReshapeOp> {
82 auto shapeType = cast<MemRefType>(op.getShape().getType());
83 if (!shapeType.hasStaticShape())
86 int64_t rank = cast<MemRefType>(shapeType).getDimSize(0);
92 Value stride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
93 for (
int i = rank - 1; i >= 0; --i) {
96 if (op.getType().isDynamicDim(i)) {
97 Value index = rewriter.create<arith::ConstantIndexOp>(loc, i);
98 size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index);
99 if (!isa<IndexType>(size.getType()))
100 size = rewriter.create<arith::IndexCastOp>(
101 loc, rewriter.getIndexType(), size);
104 auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i));
105 size = rewriter.create<arith::ConstantOp>(loc, sizeAttr);
110 stride = rewriter.create<arith::MulIOp>(loc, stride, size);
112 rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
113 op, op.getType(), op.getSource(), rewriter.getIndexAttr(0),
119 struct ExpandOpsPass :
public memref::impl::ExpandOpsBase<ExpandOpsPass> {
120 void runOnOperation()
override {
127 target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
128 target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
129 [](memref::AtomicRMWOp op) {
130 constexpr std::array shouldBeExpandedKinds = {
131 arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf,
132 arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf};
133 return !llvm::is_contained(shouldBeExpandedKinds, op.getKind());
135 target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
136 return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();
139 std::move(patterns))))
147 patterns.
add<AtomicRMWOpConverter, MemRefReshapeOpConverter>(
152 return std::make_unique<ExpandOpsPass>();
static MLIRContext * getContext(OpFoldResult val)
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Value getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs)
Returns the value obtained by applying the reduction operation kind associated with a binary AtomicRM...
std::unique_ptr< Pass > createExpandOpsPass()
Creates an instance of the ExpandOps pass that legalizes memref dialect ops to be convertible to LLVM...
void populateExpandOpsPatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the memref dialect.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...