MLIR  20.0.0git
ConvertShapeConstraints.cpp
Go to the documentation of this file.
1 //===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Pass/PassRegistry.h"
19 
20 namespace mlir {
21 #define GEN_PASS_DEF_CONVERTSHAPECONSTRAINTS
22 #include "mlir/Conversion/Passes.h.inc"
23 } // namespace mlir
24 
25 using namespace mlir;
26 
27 namespace {
28 #include "ShapeToStandard.cpp.inc"
29 } // namespace
30 
31 namespace {
32 class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
33 public:
35  LogicalResult matchAndRewrite(shape::CstrRequireOp op,
36  PatternRewriter &rewriter) const override {
37  rewriter.create<cf::AssertOp>(op.getLoc(), op.getPred(), op.getMsgAttr());
38  rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
39  return success();
40  }
41 };
42 } // namespace
43 
45  RewritePatternSet &patterns) {
46  patterns.add<CstrBroadcastableToRequire>(patterns.getContext());
47  patterns.add<CstrEqToRequire>(patterns.getContext());
48  patterns.add<ConvertCstrRequireOp>(patterns.getContext());
49 }
50 
51 namespace {
52 // This pass eliminates shape constraints from the program, converting them to
53 // eager (side-effecting) error handling code. After eager error handling code
54 // is emitted, witnesses are satisfied, so they are replace with
55 // `shape.const_witness true`.
56 class ConvertShapeConstraints
57  : public impl::ConvertShapeConstraintsBase<ConvertShapeConstraints> {
58  void runOnOperation() override {
59  auto *func = getOperation();
60  auto *context = &getContext();
61 
62  RewritePatternSet patterns(context);
64 
65  if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
66  return signalPassFailure();
67  }
68 };
69 } // namespace
70 
71 std::unique_ptr<Pass> mlir::createConvertShapeConstraintsPass() {
72  return std::make_unique<ConvertShapeConstraints>();
73 }
static MLIRContext * getContext(OpFoldResult val)
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Include the generated interface declarations.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
std::unique_ptr< Pass > createConvertShapeConstraintsPass()
void populateConvertShapeConstraintsConversionPatterns(RewritePatternSet &patterns)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362