20 #define GEN_PASS_DEF_SHAPETOSHAPELOWERING
21 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
33 LogicalResult matchAndRewrite(NumElementsOp op,
39 NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
47 ReduceOp
reduce = rewriter.
create<ReduceOp>(loc, op.getShape(), init);
53 body->getArgument(2));
61 struct ShapeToShapeLowering
62 :
public impl::ShapeToShapeLoweringBase<ShapeToShapeLowering> {
63 void runOnOperation()
override;
67 void ShapeToShapeLowering::runOnOperation() {
74 target.addLegalDialect<arith::ArithDialect, ShapeDialect>();
75 target.addIllegalOp<NumElementsOp>();
77 std::move(patterns))))
86 return std::make_unique<ShapeToShapeLowering>();
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Block represents an ordered list of Operations.
IntegerAttr getIndexAttr(int64_t value)
This class describes a specific conversion target.
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Include the generated interface declarations.
std::unique_ptr< Pass > createShapeToShapeLowering()
Creates an instance of the ShapeToShapeLowering pass that legalizes Shape dialect to be convertible t...
void populateShapeRewritePatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the Shape dialect.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...