MLIR 22.0.0git
ShapeToShapeLowering.cpp
Go to the documentation of this file.
1//===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
14#include "mlir/IR/Builders.h"
17
18namespace mlir {
19#define GEN_PASS_DEF_SHAPETOSHAPELOWERINGPASS
20#include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
21} // namespace mlir
22
23using namespace mlir;
24using namespace mlir::shape;
25
26namespace {
27/// Converts `shape.num_elements` to `shape.reduce`.
28struct NumElementsOpConverter : public OpRewritePattern<NumElementsOp> {
29public:
31
32 LogicalResult matchAndRewrite(NumElementsOp op,
33 PatternRewriter &rewriter) const final;
34};
35} // namespace
36
37LogicalResult
38NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
39 PatternRewriter &rewriter) const {
40 auto loc = op.getLoc();
41 Type valueType = op.getResult().getType();
42 Value init = op->getDialect()
43 ->materializeConstant(rewriter, rewriter.getIndexAttr(1),
44 valueType, loc)
45 ->getResult(0);
46 ReduceOp reduce = ReduceOp::create(rewriter, loc, op.getShape(), init);
47
48 // Generate reduce operator.
49 Block *body = reduce.getBody();
50 OpBuilder b = OpBuilder::atBlockEnd(body);
51 Value product = MulOp::create(b, loc, valueType, body->getArgument(1),
52 body->getArgument(2));
53 shape::YieldOp::create(b, loc, product);
54
55 rewriter.replaceOp(op, reduce.getResult());
56 return success();
57}
58
59namespace {
60struct ShapeToShapeLowering
61 : public impl::ShapeToShapeLoweringPassBase<ShapeToShapeLowering> {
62 void runOnOperation() override;
63};
64} // namespace
65
66void ShapeToShapeLowering::runOnOperation() {
67 MLIRContext &ctx = getContext();
68
69 RewritePatternSet patterns(&ctx);
71
72 ConversionTarget target(getContext());
73 target.addLegalDialect<arith::ArithDialect, ShapeDialect>();
74 target.addIllegalOp<NumElementsOp>();
75 if (failed(mlir::applyPartialConversion(getOperation(), target,
76 std::move(patterns))))
77 signalPassFailure();
78}
79
81 patterns.add<NumElementsOpConverter>(patterns.getContext());
82}
return success()
static int64_t product(ArrayRef< int64_t > vals)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
BlockArgument getArgument(unsigned i)
Definition Block.h:129
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition Builders.h:246
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
void populateShapeRewritePatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the Shape dialect.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...