MLIR  20.0.0git
ConvertToLLVMPass.cpp
Go to the documentation of this file.
1 //===- ConvertToLLVMPass.cpp - MLIR LLVM Conversion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Pass/Pass.h"
18 #include <memory>
19 
20 #define DEBUG_TYPE "convert-to-llvm"
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_CONVERTTOLLVMPASS
24 #include "mlir/Conversion/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 namespace {
30 
31 /// This DialectExtension can be attached to the context, which will invoke the
32 /// `apply()` method for every loaded dialect. If a dialect implements the
33 /// `ConvertToLLVMPatternInterface` interface, we load dependent dialects
34 /// through the interface. This extension is loaded in the context before
35 /// starting a pass pipeline that involves dialect conversion to LLVM.
36 class LoadDependentDialectExtension : public DialectExtensionBase {
37 public:
38  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension)
39 
40  LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
41 
42  void apply(MLIRContext *context,
43  MutableArrayRef<Dialect *> dialects) const final {
44  LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM extension load\n");
45  for (Dialect *dialect : dialects) {
46  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
47  if (!iface)
48  continue;
49  LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM found dialect interface for "
50  << dialect->getNamespace() << "\n");
51  iface->loadDependentDialects(context);
52  }
53  }
54 
55  /// Return a copy of this extension.
56  std::unique_ptr<DialectExtensionBase> clone() const final {
57  return std::make_unique<LoadDependentDialectExtension>(*this);
58  }
59 };
60 
61 /// This is a generic pass to convert to LLVM, it uses the
62 /// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
63 /// the injection of conversion patterns.
64 class ConvertToLLVMPass
65  : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
66  std::shared_ptr<const FrozenRewritePatternSet> patterns;
67  std::shared_ptr<const ConversionTarget> target;
68  std::shared_ptr<const LLVMTypeConverter> typeConverter;
69 
70 public:
71  using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
72  void getDependentDialects(DialectRegistry &registry) const final {
73  registry.insert<LLVM::LLVMDialect>();
74  registry.addExtensions<LoadDependentDialectExtension>();
75  }
76 
77  LogicalResult initialize(MLIRContext *context) final {
78  RewritePatternSet tempPatterns(context);
79  auto target = std::make_shared<ConversionTarget>(*context);
80  target->addLegalDialect<LLVM::LLVMDialect>();
81  auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
82 
83  if (!filterDialects.empty()) {
84  // Test mode: Populate only patterns from the specified dialects. Produce
85  // an error if the dialect is not loaded or does not implement the
86  // interface.
87  for (std::string &dialectName : filterDialects) {
88  Dialect *dialect = context->getLoadedDialect(dialectName);
89  if (!dialect)
90  return emitError(UnknownLoc::get(context))
91  << "dialect not loaded: " << dialectName << "\n";
92  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
93  if (!iface)
94  return emitError(UnknownLoc::get(context))
95  << "dialect does not implement ConvertToLLVMPatternInterface: "
96  << dialectName << "\n";
97  iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
98  tempPatterns);
99  }
100  } else {
101  // Normal mode: Populate all patterns from all dialects that implement the
102  // interface.
103  for (Dialect *dialect : context->getLoadedDialects()) {
104  // First time we encounter this dialect: if it implements the interface,
105  // let's populate patterns !
106  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
107  if (!iface)
108  continue;
109  iface->populateConvertToLLVMConversionPatterns(*target, *typeConverter,
110  tempPatterns);
111  }
112  }
113 
114  this->patterns =
115  std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
116  this->target = target;
117  this->typeConverter = typeConverter;
118  return success();
119  }
120 
121  void runOnOperation() final {
122  if (failed(applyPartialConversion(getOperation(), *target, *patterns)))
123  signalPassFailure();
124  }
125 };
126 
127 } // namespace
128 
130  DialectRegistry &registry) {
131  registry.addExtensions<LoadDependentDialectExtension>();
132 }
133 
134 std::unique_ptr<Pass> mlir::createConvertToLLVMPass() {
135  return std::make_unique<ConvertToLLVMPass>();
136 }
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:274
This class represents an opaque dialect extension.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::unique_ptr< Pass > createConvertToLLVMPass()
Create a pass that performs dialect conversion to LLVM for all dialects implementing ConvertToLLVMPat...
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.