MLIR 22.0.0git
TosaToMLProgram.cpp
Go to the documentation of this file.
1//===- TosaToMLProgram.cpp - Lowering Tosa to MLProgram Dialect------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the TOSA dialect to the MLProgram dialect.
10//
11//===----------------------------------------------------------------------===//
12
17
18using namespace mlir;
19using namespace tosa;
20namespace {
21
22class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {
23public:
24 using OpRewritePattern<tosa::VariableOp>::OpRewritePattern;
25
26 LogicalResult matchAndRewrite(tosa::VariableOp op,
27 PatternRewriter &rewriter) const final {
28 auto variableType = tosa::getVariableType(op);
29 auto newVariable = mlir::ml_program::GlobalOp::create(
30 rewriter, op.getLoc(), op.getName(), variableType, /*is_mutable=*/true,
31 op.getInitialValueAttr(), /*sym_visibility=*/nullptr);
32 newVariable.setPrivate();
33 rewriter.replaceOp(op, newVariable);
34 return success();
35 }
36};
37
38class VariableWriteOpConverter
39 : public OpRewritePattern<tosa::VariableWriteOp> {
40public:
41 using OpRewritePattern<tosa::VariableWriteOp>::OpRewritePattern;
42
43 LogicalResult matchAndRewrite(tosa::VariableWriteOp op,
44 PatternRewriter &rewriter) const final {
45 auto globalSymbolRef =
46 SymbolRefAttr::get(rewriter.getContext(), op.getName());
47 auto newVariableWrite = ml_program::GlobalStoreOp::create(
48 rewriter, op.getLoc(), globalSymbolRef, op.getInput1());
49 rewriter.replaceOp(op, newVariableWrite);
50 return success();
51 }
52};
53
54class VariableReadOpConverter : public OpRewritePattern<tosa::VariableReadOp> {
55public:
56 using OpRewritePattern<tosa::VariableReadOp>::OpRewritePattern;
57
58 LogicalResult matchAndRewrite(tosa::VariableReadOp op,
59 PatternRewriter &rewriter) const final {
60 auto globalSymbolRef =
61 SymbolRefAttr::get(rewriter.getContext(), op.getName());
62 auto newVariableRead = ml_program::GlobalLoadOp::create(
63 rewriter, op.getLoc(), op.getType(), globalSymbolRef);
64 rewriter.replaceOp(op, newVariableRead);
65
66 return success();
67 }
68};
69
70} // namespace
71
74 patterns->add<VariableOpConverter, VariableWriteOpConverter,
75 VariableReadOpConverter>(patterns->getContext());
76}
return success()
RankedTensorType getVariableType(VariableOp variableOp)
void populateTosaToMLProgramConversionPatterns(RewritePatternSet *patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...