MLIR  22.0.0git
TosaToLinalgPass.cpp
Go to the documentation of this file.
1 //===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass legalizes Tosa operations to the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
24 #include "mlir/IR/PatternMatch.h"
25 #include "mlir/Pass/PassManager.h"
27 #include "mlir/Transforms/Passes.h"
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_TOSATOLINALG
31 #include "mlir/Conversion/Passes.h.inc"
32 } // namespace mlir
33 
34 using namespace mlir;
35 
36 namespace {
37 struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
38 public:
39  void getDependentDialects(DialectRegistry &registry) const override {
40  registry
41  .insert<arith::ArithDialect, linalg::LinalgDialect, math::MathDialect,
42  index::IndexDialect, tensor::TensorDialect, scf::SCFDialect>();
43  }
44 
45  void runOnOperation() override {
47  ConversionTarget target(getContext());
48  target.addLegalDialect<linalg::LinalgDialect, tensor::TensorDialect,
49  scf::SCFDialect>();
50  target.addIllegalDialect<tosa::TosaDialect>();
51 
52  // Not every TOSA op can be legalized to linalg.
53  target.addLegalOp<tosa::ApplyScaleOp>();
54  target.addLegalOp<tosa::IfOp>();
55  target.addLegalOp<tosa::ConstOp>();
56  target.addLegalOp<tosa::ConstShapeOp>();
57  target.addLegalOp<tosa::WhileOp>();
58  target.addLegalOp<tosa::ConcatOp>();
59  target.addLegalOp<tosa::SliceOp>();
60  target.addLegalOp<tosa::ReshapeOp>();
61  target.addLegalOp<tosa::PadOp>();
62 
63  target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
64 
65  TypeConverter converter;
67 
68  FunctionOpInterface func = getOperation();
70  if (failed(applyFullConversion(func, target, std::move(patterns))))
71  signalPassFailure();
72  }
73 };
74 } // namespace
75 
76 std::unique_ptr<Pass> mlir::tosa::createTosaToLinalg() {
77  return std::make_unique<TosaToLinalg>();
78 }
79 
81  OpPassManager &pm, const TosaToLinalgOptions &options,
82  const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions,
83  std::optional<tosa::TosaValidationOptions> validationOptions) {
84  // Optional decompositions are designed to benefit linalg.
85  if (!options.disableTosaDecompositions)
86  pm.addNestedPass<func::FuncOp>(
87  tosa::createTosaOptionalDecompositionsPass());
88  pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
89 
90  pm.addNestedPass<func::FuncOp>(tosa::createTosaInferShapesPass());
91  pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
92  pm.addNestedPass<func::FuncOp>(
93  tosa::createTosaToLinalgNamed(tosaToLinalgNamedOptions));
94  pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
95  // TODO: Remove pass that operates on const tensor and enable optionality
96  pm.addNestedPass<func::FuncOp>(tosa::createTosaLayerwiseConstantFoldPass(
97  {options.aggressiveReduceConstant}));
98  pm.addNestedPass<func::FuncOp>(tosa::createTosaMakeBroadcastablePass());
99  if (validationOptions)
100  pm.addPass(tosa::createTosaValidation(*validationOptions));
101  pm.addNestedPass<func::FuncOp>(tosa::createTosaToLinalg());
102 }
103 
104 //===----------------------------------------------------------------------===//
105 // Pipeline registration.
106 //===----------------------------------------------------------------------===//
107 
110  "tosa-to-linalg-pipeline",
111  "The default pipeline for converting TOSA operators to the equivalent "
112  "operations using the tensor operations in LinAlg as well as LinAlg "
113  "named operations.",
114  [](OpPassManager &pm) {
115  TosaToLinalgOptions tosaToLinalgOptions;
116  TosaToLinalgNamedOptions tosaToLinalgNamedOptions;
117  TosaValidationOptions validationOptions;
118  validationOptions.profile = {"none"};
119  validationOptions.extension = {"none"};
120  validationOptions.strictOpSpecAlignment = false;
121  validationOptions.allowInvalidOpDatatypeCombinations = false;
122  validationOptions.level = tosa::TosaLevelEnum::EightK;
123  tosa::addTosaToLinalgPasses(pm, tosaToLinalgOptions,
124  tosaToLinalgNamedOptions,
125  validationOptions);
126  });
127 }
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addLegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as legal.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
This class represents a pass manager that runs passes on either a specific operation type,...
Definition: PassManager.h:46
void addPass(std::unique_ptr< Pass > pass)
Add the given pass to this pass manager.
Definition: Pass.cpp:367
void addNestedPass(std::unique_ptr< Pass > pass)
Add the given pass to a nested pass manager for the given operation kind OpT.
Definition: PassManager.h:115
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Type conversion class.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
std::unique_ptr< Pass > createTosaToLinalgNamed(const TosaToLinalgNamedOptions &options=TosaToLinalgNamedOptions())
std::unique_ptr< Pass > createTosaToLinalg()
void addTosaToLinalgPasses(OpPassManager &pm, const TosaToLinalgOptions &options, const TosaToLinalgNamedOptions &tosaToLinalgNamedOptions=TosaToLinalgNamedOptions(), std::optional< tosa::TosaValidationOptions > validationOptions=tosa::TosaValidationOptions{ {"none"}, {"none"}, false, false, tosa::TosaLevelEnum::None})
Populates passes to convert from TOSA to Linalg on buffers.
void populateTosaTypeConversion(TypeConverter &converter)
void registerTosaToLinalgPipelines()
Populates TOSA to linalg pipelines Currently, this includes only the "tosa-to-linalg-pipeline".
Include the generated interface declarations.
LogicalResult applyFullConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Apply a complete conversion on the given operations, and all nested operations.
const FrozenRewritePatternSet & patterns
std::unique_ptr< Pass > createCanonicalizerPass()
Creates an instance of the Canonicalizer pass, configured with default settings (which can be overrid...
PassPipelineRegistration provides a global initializer that registers a Pass pipeline builder routine...
Definition: PassRegistry.h:177