MLIR  20.0.0git
ConvertToLLVMPass.cpp
Go to the documentation of this file.
1 //===- ConvertToLLVMPass.cpp - MLIR LLVM Conversion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
19 #include <memory>
20 
21 #define DEBUG_TYPE "convert-to-llvm"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTTOLLVMPASS
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 namespace {
31 /// Base class for creating the internal implementation of `convert-to-llvm`
32 /// passes.
33 class ConvertToLLVMPassInterface {
34 public:
35  ConvertToLLVMPassInterface(MLIRContext *context,
36  ArrayRef<std::string> filterDialects);
37  virtual ~ConvertToLLVMPassInterface() = default;
38 
39  /// Get the dependent dialects used by `convert-to-llvm`.
40  static void getDependentDialects(DialectRegistry &registry);
41 
42  /// Initialize the internal state of the `convert-to-llvm` pass
43  /// implementation. This method is invoked by `ConvertToLLVMPass::initialize`.
44  /// This method returns whether the initialization process failed.
45  virtual LogicalResult initialize() = 0;
46 
47  /// Transform `op` to LLVM with the conversions available in the pass. The
48  /// analysis manager can be used to query analyzes like `DataLayoutAnalysis`
49  /// to further configure the conversion process. This method is invoked by
50  /// `ConvertToLLVMPass::runOnOperation`. This method returns whether the
51  /// transformation process failed.
52  virtual LogicalResult transform(Operation *op,
53  AnalysisManager manager) const = 0;
54 
55 protected:
56  /// Visit the `ConvertToLLVMPatternInterface` dialect interfaces and call
57  /// `visitor` with each of the interfaces. If `filterDialects` is non-empty,
58  /// then `visitor` is invoked only with the dialects in the `filterDialects`
59  /// list.
60  LogicalResult visitInterfaces(
62  MLIRContext *context;
63  /// List of dialects names to use as filters.
64  ArrayRef<std::string> filterDialects;
65 };
66 
67 /// This DialectExtension can be attached to the context, which will invoke the
68 /// `apply()` method for every loaded dialect. If a dialect implements the
69 /// `ConvertToLLVMPatternInterface` interface, we load dependent dialects
70 /// through the interface. This extension is loaded in the context before
71 /// starting a pass pipeline that involves dialect conversion to LLVM.
72 class LoadDependentDialectExtension : public DialectExtensionBase {
73 public:
74  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension)
75 
76  LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
77 
78  void apply(MLIRContext *context,
79  MutableArrayRef<Dialect *> dialects) const final {
80  LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM extension load\n");
81  for (Dialect *dialect : dialects) {
82  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
83  if (!iface)
84  continue;
85  LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM found dialect interface for "
86  << dialect->getNamespace() << "\n");
87  iface->loadDependentDialects(context);
88  }
89  }
90 
91  /// Return a copy of this extension.
92  std::unique_ptr<DialectExtensionBase> clone() const final {
93  return std::make_unique<LoadDependentDialectExtension>(*this);
94  }
95 };
96 
97 //===----------------------------------------------------------------------===//
98 // StaticConvertToLLVM
99 //===----------------------------------------------------------------------===//
100 
101 /// Static implementation of the `convert-to-llvm` pass. This version only looks
102 /// at dialect interfaces to configure the conversion process.
103 struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
104  /// Pattern set with conversions to LLVM.
105  std::shared_ptr<const FrozenRewritePatternSet> patterns;
106  /// The conversion target.
107  std::shared_ptr<const ConversionTarget> target;
108  /// The LLVM type converter.
109  std::shared_ptr<const LLVMTypeConverter> typeConverter;
110  using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
111 
112  /// Configure the conversion to LLVM at pass initialization.
113  LogicalResult initialize() final {
114  auto target = std::make_shared<ConversionTarget>(*context);
115  auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
116  RewritePatternSet tempPatterns(context);
117  target->addLegalDialect<LLVM::LLVMDialect>();
118  // Populate the patterns with the dialect interface.
119  if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
121  *target, *typeConverter, tempPatterns);
122  })))
123  return failure();
124  this->patterns =
125  std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
126  this->target = target;
127  this->typeConverter = typeConverter;
128  return success();
129  }
130 
131  /// Apply the conversion driver.
132  LogicalResult transform(Operation *op, AnalysisManager manager) const final {
133  if (failed(applyPartialConversion(op, *target, *patterns)))
134  return failure();
135  return success();
136  }
137 };
138 
139 //===----------------------------------------------------------------------===//
140 // DynamicConvertToLLVM
141 //===----------------------------------------------------------------------===//
142 
143 /// Dynamic implementation of the `convert-to-llvm` pass. This version inspects
144 /// the IR to configure the conversion to LLVM.
145 struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
146  /// A list of all the `ConvertToLLVMPatternInterface` dialect interfaces used
147  /// to partially configure the conversion process.
148  std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
149  interfaces;
150  using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
151 
152  /// Collect the dialect interfaces used to configure the conversion process.
153  LogicalResult initialize() final {
154  auto interfaces =
155  std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
156  // Collect the interfaces.
157  if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
158  interfaces->push_back(iface);
159  })))
160  return failure();
161  this->interfaces = interfaces;
162  return success();
163  }
164 
165  /// Configure the conversion process and apply the conversion driver.
166  LogicalResult transform(Operation *op, AnalysisManager manager) const final {
167  RewritePatternSet patterns(context);
168  ConversionTarget target(*context);
169  target.addLegalDialect<LLVM::LLVMDialect>();
170  // Get the data layout analysis.
171  const auto &dlAnalysis = manager.getAnalysis<DataLayoutAnalysis>();
172  LLVMTypeConverter typeConverter(context, &dlAnalysis);
173 
174  // Configure the conversion with dialect level interfaces.
175  for (ConvertToLLVMPatternInterface *iface : *interfaces)
176  iface->populateConvertToLLVMConversionPatterns(target, typeConverter,
177  patterns);
178 
179  // Configure the conversion attribute interfaces.
180  populateOpConvertToLLVMConversionPatterns(op, target, typeConverter,
181  patterns);
182 
183  // Apply the conversion.
184  if (failed(applyPartialConversion(op, target, std::move(patterns))))
185  return failure();
186  return success();
187  }
188 };
189 
190 //===----------------------------------------------------------------------===//
191 // ConvertToLLVMPass
192 //===----------------------------------------------------------------------===//
193 
194 /// This is a generic pass to convert to LLVM, it uses the
195 /// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
196 /// the injection of conversion patterns.
197 class ConvertToLLVMPass
198  : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
199  std::shared_ptr<const ConvertToLLVMPassInterface> impl;
200 
201 public:
202  using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
203  void getDependentDialects(DialectRegistry &registry) const final {
204  ConvertToLLVMPassInterface::getDependentDialects(registry);
205  }
206 
207  LogicalResult initialize(MLIRContext *context) final {
208  std::shared_ptr<ConvertToLLVMPassInterface> impl;
209  // Choose the pass implementation.
210  if (useDynamic)
211  impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects);
212  else
213  impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects);
214  if (failed(impl->initialize()))
215  return failure();
216  this->impl = impl;
217  return success();
218  }
219 
220  void runOnOperation() final {
221  if (failed(impl->transform(getOperation(), getAnalysisManager())))
222  return signalPassFailure();
223  }
224 };
225 
226 } // namespace
227 
228 //===----------------------------------------------------------------------===//
229 // ConvertToLLVMPassInterface
230 //===----------------------------------------------------------------------===//
231 
232 ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
233  MLIRContext *context, ArrayRef<std::string> filterDialects)
234  : context(context), filterDialects(filterDialects) {}
235 
236 void ConvertToLLVMPassInterface::getDependentDialects(
237  DialectRegistry &registry) {
238  registry.insert<LLVM::LLVMDialect>();
239  registry.addExtensions<LoadDependentDialectExtension>();
240 }
241 
242 LogicalResult ConvertToLLVMPassInterface::visitInterfaces(
244  if (!filterDialects.empty()) {
245  // Test mode: Populate only patterns from the specified dialects. Produce
246  // an error if the dialect is not loaded or does not implement the
247  // interface.
248  for (StringRef dialectName : filterDialects) {
249  Dialect *dialect = context->getLoadedDialect(dialectName);
250  if (!dialect)
251  return emitError(UnknownLoc::get(context))
252  << "dialect not loaded: " << dialectName << "\n";
253  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
254  if (!iface)
255  return emitError(UnknownLoc::get(context))
256  << "dialect does not implement ConvertToLLVMPatternInterface: "
257  << dialectName << "\n";
258  visitor(iface);
259  }
260  } else {
261  // Normal mode: Populate all patterns from all dialects that implement the
262  // interface.
263  for (Dialect *dialect : context->getLoadedDialects()) {
264  // First time we encounter this dialect: if it implements the interface,
265  // let's populate patterns !
266  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
267  if (!iface)
268  continue;
269  visitor(iface);
270  }
271  }
272  return success();
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // API
277 //===----------------------------------------------------------------------===//
278 
280  DialectRegistry &registry) {
281  registry.addExtensions<LoadDependentDialectExtension>();
282 }
283 
284 std::unique_ptr<Pass> mlir::createConvertToLLVMPass() {
285  return std::make_unique<ConvertToLLVMPass>();
286 }
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:274
This class represents an analysis manager for a particular operation instance.
This class describes a specific conversion target.
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
Stores data layout objects for each operation that specifies the data layout above and below the give...
This class represents an opaque dialect extension.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::unique_ptr< Pass > createConvertToLLVMPass()
Create a pass that performs dialect conversion to LLVM for all dialects implementing ConvertToLLVMPat...
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateOpConvertToLLVMConversionPatterns(Operation *op, ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Helper function for populating LLVM conversion patterns.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.