MLIR  20.0.0git
TosaToMLProgram.cpp
Go to the documentation of this file.
1 //===- TosaToMLProgram.cpp - Lowering Tosa to MLProgram Dialect------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // These rewriters lower from the TOSA dialect to the MLProgram dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/IRMapping.h"
17 #include "mlir/IR/PatternMatch.h"
18 
19 using namespace mlir;
20 using namespace tosa;
21 namespace {
22 
23 class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {
24 public:
26 
27  LogicalResult matchAndRewrite(tosa::VariableOp op,
28  PatternRewriter &rewriter) const final {
29  auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
30  op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true,
31  op.getInitialValueAttr(), /*sym_visibility=*/nullptr);
32  newVariable.setPrivate();
33  rewriter.replaceOp(op, newVariable);
34  return success();
35  }
36 };
37 
38 class VariableWriteOpConverter
39  : public OpRewritePattern<tosa::VariableWriteOp> {
40 public:
42 
43  LogicalResult matchAndRewrite(tosa::VariableWriteOp op,
44  PatternRewriter &rewriter) const final {
45  auto globalSymbolRef =
46  SymbolRefAttr::get(rewriter.getContext(), op.getName());
47  auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>(
48  op.getLoc(), globalSymbolRef, op.getValue());
49  rewriter.replaceOp(op, newVariableWrite);
50  return success();
51  }
52 };
53 
54 class VariableReadOpConverter : public OpRewritePattern<tosa::VariableReadOp> {
55 public:
57 
58  LogicalResult matchAndRewrite(tosa::VariableReadOp op,
59  PatternRewriter &rewriter) const final {
60  auto globalSymbolRef =
61  SymbolRefAttr::get(rewriter.getContext(), op.getName());
62  auto newVariableRead = rewriter.create<ml_program::GlobalLoadOp>(
63  op.getLoc(), op.getType(), globalSymbolRef);
64  rewriter.replaceOp(op, newVariableRead);
65 
66  return success();
67  }
68 };
69 
70 } // namespace
71 
73  RewritePatternSet *patterns) {
74  patterns->add<VariableOpConverter, VariableWriteOpConverter,
75  VariableReadOpConverter>(patterns->getContext());
76 }
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
static std::unique_ptr< T > create(Args &&...args)
This method provides a convenient interface for creating and initializing derived rewrite patterns of...
Definition: PatternMatch.h:276
void populateTosaToMLProgramConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358