MLIR  22.0.0git
ShapeToShapeLowering.cpp
Go to the documentation of this file.
1 //===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/PatternMatch.h"
17 
18 namespace mlir {
19 #define GEN_PASS_DEF_SHAPETOSHAPELOWERINGPASS
20 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
21 } // namespace mlir
22 
23 using namespace mlir;
24 using namespace mlir::shape;
25 
26 namespace {
27 /// Converts `shape.num_elements` to `shape.reduce`.
28 struct NumElementsOpConverter : public OpRewritePattern<NumElementsOp> {
29 public:
31 
32  LogicalResult matchAndRewrite(NumElementsOp op,
33  PatternRewriter &rewriter) const final;
34 };
35 } // namespace
36 
37 LogicalResult
38 NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
39  PatternRewriter &rewriter) const {
40  auto loc = op.getLoc();
41  Type valueType = op.getResult().getType();
42  Value init = op->getDialect()
43  ->materializeConstant(rewriter, rewriter.getIndexAttr(1),
44  valueType, loc)
45  ->getResult(0);
46  ReduceOp reduce = ReduceOp::create(rewriter, loc, op.getShape(), init);
47 
48  // Generate reduce operator.
49  Block *body = reduce.getBody();
51  Value product = MulOp::create(b, loc, valueType, body->getArgument(1),
52  body->getArgument(2));
53  shape::YieldOp::create(b, loc, product);
54 
55  rewriter.replaceOp(op, reduce.getResult());
56  return success();
57 }
58 
59 namespace {
60 struct ShapeToShapeLowering
61  : public impl::ShapeToShapeLoweringPassBase<ShapeToShapeLowering> {
62  void runOnOperation() override;
63 };
64 } // namespace
65 
66 void ShapeToShapeLowering::runOnOperation() {
67  MLIRContext &ctx = getContext();
68 
71 
72  ConversionTarget target(getContext());
73  target.addLegalDialect<arith::ArithDialect, ShapeDialect>();
74  target.addIllegalOp<NumElementsOp>();
75  if (failed(mlir::applyPartialConversion(getOperation(), target,
76  std::move(patterns))))
77  signalPassFailure();
78 }
79 
81  patterns.add<NumElementsOpConverter>(patterns.getContext());
82 }
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2915
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
This class describes a specific conversion target.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
static OpBuilder atBlockEnd(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to after the last operation in the block but still insid...
Definition: Builders.h:244
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
void populateShapeRewritePatterns(RewritePatternSet &patterns)
Collects a set of patterns to rewrite ops within the Shape dialect.
const FrozenRewritePatternSet & patterns
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319