MLIR  19.0.0git
ConvertToLLVMPass.cpp
Go to the documentation of this file.
1 //===- ConvertToLLVMPass.cpp - MLIR LLVM Conversion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Pass/Pass.h"
18 #include <memory>
19 
20 #define DEBUG_TYPE "convert-to-llvm"
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_CONVERTTOLLVMPASS
24 #include "mlir/Conversion/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 namespace {
30 
31 /// This DialectExtension can be attached to the context, which will invoke the
32 /// `apply()` method for every loaded dialect. If a dialect implements the
33 /// `ConvertToLLVMPatternInterface` interface, we load dependent dialects
34 /// through the interface. This extension is loaded in the context before
35 /// starting a pass pipeline that involves dialect conversion to LLVM.
36 class LoadDependentDialectExtension : public DialectExtensionBase {
37 public:
38  LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
39 
40  void apply(MLIRContext *context,
41  MutableArrayRef<Dialect *> dialects) const final {
42  LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM extension load\n");
43  for (Dialect *dialect : dialects) {
44  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
45  if (!iface)
46  continue;
47  LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM found dialect interface for "
48  << dialect->getNamespace() << "\n");
49  iface->loadDependentDialects(context);
50  }
51  }
52 
53  /// Return a copy of this extension.
54  std::unique_ptr<DialectExtensionBase> clone() const final {
55  return std::make_unique<LoadDependentDialectExtension>(*this);
56  }
57 };
58 
59 /// This is a generic pass to convert to LLVM, it uses the
60 /// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
61 /// the injection of conversion patterns.
62 class ConvertToLLVMPass
63  : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
64  std::shared_ptr<const FrozenRewritePatternSet> patterns;
65  std::shared_ptr<const ConversionTarget> target;
66  std::shared_ptr<const LLVMTypeConverter> typeConverter;
67 
68 public:
69  using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
70  void getDependentDialects(DialectRegistry &registry) const final {
71  registry.insert<LLVM::LLVMDialect>();
72  registry.addExtensions<LoadDependentDialectExtension>();
73  }
74 
75  LogicalResult initialize(MLIRContext *context) final {
76  RewritePatternSet tempPatterns(context);
77  auto target = std::make_shared<ConversionTarget>(*context);
78  target->addLegalDialect<LLVM::LLVMDialect>();
79  auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
80 
81  if (!filterDialects.empty()) {
82  // Test mode: Populate only patterns from the specified dialects. Produce
83  // an error if the dialect is not loaded or does not implement the
84  // interface.
85  for (std::string &dialectName : filterDialects) {
86  Dialect *dialect = context->getLoadedDialect(dialectName);
87  if (!dialect)
88  return emitError(UnknownLoc::get(context))
89  << "dialect not loaded: " << dialectName << "\n";
90  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
91  if (!iface)
92  return emitError(UnknownLoc::get(context))
93  << "dialect does not implement ConvertToLLVMPatternInterface: "
94  << dialectName << "\n";
95  iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
96  tempPatterns);
97  }
98  } else {
99  // Normal mode: Populate all patterns from all dialects that implement the
100  // interface.
101  for (Dialect *dialect : context->getLoadedDialects()) {
102  // First time we encounter this dialect: if it implements the interface,
103  // let's populate patterns !
104  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
105  if (!iface)
106  continue;
107  iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
108  tempPatterns);
109  }
110  }
111 
112  this->patterns =
113  std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
114  this->target = target;
115  this->typeConverter = typeConverter;
116  return success();
117  }
118 
119  void runOnOperation() final {
120  if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
121  signalPassFailure();
122  }
123 };
124 
125 } // namespace
126 
128  DialectRegistry &registry) {
129  registry.addExtensions<LoadDependentDialectExtension>();
130 }
131 
132 std::unique_ptr<Pass> mlir::createConvertToLLVMPass() {
133  return std::make_unique<ConvertToLLVMPass>();
134 }
This class represents an opaque dialect extension.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:41
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::unique_ptr< Pass > createConvertToLLVMPass()
Create a pass that performs dialect conversion to LLVM for all dialects implementing ConvertToLLVMPat...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26