MLIR 23.0.0git
TosaToLinalgPass.cpp
Go to the documentation of this file.
1//===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass legalizes Tosa operations to the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
29
30namespace mlir {
31#define GEN_PASS_DEF_TOSATOLINALG
32#include "mlir/Conversion/Passes.h.inc"
33} // namespace mlir
34
35using namespace mlir;
36
37namespace {
38struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
39public:
40 void getDependentDialects(DialectRegistry &registry) const override {
41 registry
42 .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
43 index::IndexDialect, tensor::TensorDialect, scf::SCFDialect>();
44 }
45
46 void runOnOperation() override {
47 RewritePatternSet patterns(&getContext());
48 ConversionTarget target(getContext());
49 target.addLegalDialect<linalg::LinalgDialect, tensor::TensorDialect,
50 scf::SCFDialect>();
51 target.addIllegalDialect<tosa::TosaDialect>();
52
53 // Not every TOSA op can be legalized to linalg.
54 target.addLegalOp<tosa::ApplyScaleOp>();
55 target.addLegalOp<tosa::IfOp>();
56 target.addLegalOp<tosa::ConstOp>();
57 target.addLegalOp<tosa::ConstShapeOp>();
58 target.addLegalOp<tosa::WhileOp>();
59 target.addLegalOp<tosa::ConcatOp>();
60 target.addLegalOp<tosa::SliceOp>();
61 target.addLegalOp<tosa::ReshapeOp>();
62 target.addLegalOp<tosa::PadOp>();
63
64 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
65
66 TypeConverter converter;
68
69 FunctionOpInterface func = getOperation();
71 if (failed(applyFullConversion(func, target, std::move(patterns))))
72 signalPassFailure();
73 }
74};
75} // namespace
76
77std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
78 return std::make_unique<TosaToLinalg>();
79}
80
82 OpPassManager &pm, const TosaToLinalgOptions &options,
83 const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions,
84 std::optional<tosa::TosaValidationOptions> validationOptions,
85 std::optional<TosaAttachTargetOptions> attachTargetOptions) {
86 // Optional decompositions are designed to benefit linalg.
87 if (!options.disableTosaDecompositions)
88 pm.addNestedPass<func::FuncOp>(
89 tosa::createTosaOptionalDecompositionsPass());
90 pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
91
92 pm.addNestedPass<func::FuncOp>(tosa::createTosaInferShapesPass());
93 pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
94 pm.addNestedPass<func::FuncOp>(
95 tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions));
96 pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
97 // TODO: Remove pass that operates on const tensor and enable optionality
98 pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
99 {options.aggressiveReduceConstant}));
100 pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
101 if (!attachTargetOptions) {
102 attachTargetOptions = TosaAttachTargetOptions();
103 attachTargetOptions->profiles = {"pro_int", "pro_fp"};
104 // TODO: populate with all the extensions that the tosa->linalg conversion
105 // supports
106 attachTargetOptions->extensions = {"doubleround"};
107 }
108 pm.addPass(tosa::createTosaAttachTarget(*attachTargetOptions));
109 if (validationOptions)
110 pm.addPass(tosa::createTosaValidation(*validationOptions));
111 pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
112}
113
114//===----------------------------------------------------------------------===//
115// Pipeline registration.
116//===----------------------------------------------------------------------===//
117
120 "tosa-to-linalg-pipeline",
121 "The default pipeline for converting TOSA operators to the equivalent "
122 "operations using the tensor operations in LinAlg as well as LinAlg "
123 "named operations.",
124 [](OpPassManager &pm) {
125 TosaToLinalgOptions tosaToLinalgOptions;
126 TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
127 TosaValidationOptions validationOptions;
128 validationOptions.strictOpSpecAlignment = false;
129 validationOptions.allowInvalidOpDatatypeCombinations = false;
130 tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
131 tosaToLinalgNamedOptions,
132 validationOptions);
133 });
134}
b getContext())
static llvm::ManagedStatic< PassManagerOptions > options
This class represents a pass manager that runs passes on either a specific operation type,...
Definition PassManager.h:46
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Definition Pass.cpp:392
void addNestedPass(std::unique_ptr< Pass > pass)
Add the given pass to a nested pass manager for the given operation kind OpT.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
std::unique_ptr< Pass > createTosaToLinalgNamed(const TosaToLinalgNamedOptions &options=TosaToLinalgNamedOptions())
void addTosaToLinalgPasses(OpPassManager &pm, const TosaToLinalgOptions &options, const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions=TosaToLinalgNamedOptions(), std::optional< tosa::TosaValidationOptions > validationOptions=tosa::TosaValidationOptions{false, false}, std::optional< TosaAttachTargetOptions > attachTargetOptions=std::nullopt)
Populates passes to convert from TOSA to Linalg.
std::unique_ptr< Pass > createTosaToLinalg()
void populateTosaTypeConversion(TypeConverter &converter)
void registerTosaToLinalgPipelines()
Populates TOSA to linalg pipelines Currently, this includes only the "tosa-to-linalg-pipeline".
Include the generated interface declarations.
std::unique_ptr< Pass > createCanonicalizerPass(const GreedyRewriteConfig &config, ArrayRef< std::string > disabledPatterns={}, ArrayRef< std::string > enabledPatterns={})
Creates an instance of the Canonicalizer pass with the specified config.
PassPipelineRegistration provides a global initializer that registers a Pass pipeline builder routine...