MLIR  22.0.0git
ConvertToLLVMPass.cpp
Go to the documentation of this file.
1 //===- ConvertToLLVMPass.cpp - MLIR LLVM Conversion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/PatternMatch.h"
17 #include "llvm/Support/DebugLog.h"
18 #include <memory>
19 
20 #define DEBUG_TYPE "convert-to-llvm"
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_CONVERTTOLLVMPASS
24 #include "mlir/Conversion/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 namespace {
30 /// Base class for creating the internal implementation of `convert-to-llvm`
31 /// passes.
32 class ConvertToLLVMPassInterface {
33 public:
34  ConvertToLLVMPassInterface(MLIRContext *context,
35  ArrayRef<std::string> filterDialects,
36  bool allowPatternRollback = true);
37  virtual ~ConvertToLLVMPassInterface() = default;
38 
39  /// Get the dependent dialects used by `convert-to-llvm`.
40  static void getDependentDialects(DialectRegistry &registry);
41 
42  /// Initialize the internal state of the `convert-to-llvm` pass
43  /// implementation. This method is invoked by `ConvertToLLVMPass::initialize`.
44  /// This method returns whether the initialization process failed.
45  virtual LogicalResult initialize() = 0;
46 
47  /// Transform `op` to LLVM with the conversions available in the pass. The
48  /// analysis manager can be used to query analyzes like `DataLayoutAnalysis`
49  /// to further configure the conversion process. This method is invoked by
50  /// `ConvertToLLVMPass::runOnOperation`. This method returns whether the
51  /// transformation process failed.
52  virtual LogicalResult transform(Operation *op,
53  AnalysisManager manager) const = 0;
54 
55 protected:
56  /// Visit the `ConvertToLLVMPatternInterface` dialect interfaces and call
57  /// `visitor` with each of the interfaces. If `filterDialects` is non-empty,
58  /// then `visitor` is invoked only with the dialects in the `filterDialects`
59  /// list.
60  LogicalResult visitInterfaces(
62  MLIRContext *context;
63  /// List of dialects names to use as filters.
64  ArrayRef<std::string> filterDialects;
65  /// An experimental flag to disallow pattern rollback. This is more efficient
66  /// but not supported by all lowering patterns.
67  bool allowPatternRollback;
68 };
69 
70 /// This DialectExtension can be attached to the context, which will invoke the
71 /// `apply()` method for every loaded dialect. If a dialect implements the
72 /// `ConvertToLLVMPatternInterface` interface, we load dependent dialects
73 /// through the interface. This extension is loaded in the context before
74 /// starting a pass pipeline that involves dialect conversion to LLVM.
75 class LoadDependentDialectExtension : public DialectExtensionBase {
76 public:
77  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension)
78 
79  LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
80 
81  void apply(MLIRContext *context,
82  MutableArrayRef<Dialect *> dialects) const final {
83  LDBG() << "Convert to LLVM extension load";
84  for (Dialect *dialect : dialects) {
85  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
86  if (!iface)
87  continue;
88  LDBG() << "Convert to LLVM found dialect interface for "
89  << dialect->getNamespace();
90  iface->loadDependentDialects(context);
91  }
92  }
93 
94  /// Return a copy of this extension.
95  std::unique_ptr<DialectExtensionBase> clone() const final {
96  return std::make_unique<LoadDependentDialectExtension>(*this);
97  }
98 };
99 
100 //===----------------------------------------------------------------------===//
101 // StaticConvertToLLVM
102 //===----------------------------------------------------------------------===//
103 
104 /// Static implementation of the `convert-to-llvm` pass. This version only looks
105 /// at dialect interfaces to configure the conversion process.
106 struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
107  /// Pattern set with conversions to LLVM.
108  std::shared_ptr<const FrozenRewritePatternSet> patterns;
109  /// The conversion target.
110  std::shared_ptr<const ConversionTarget> target;
111  /// The LLVM type converter.
112  std::shared_ptr<const LLVMTypeConverter> typeConverter;
113  using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
114 
115  /// Configure the conversion to LLVM at pass initialization.
116  LogicalResult initialize() final {
117  auto target = std::make_shared<ConversionTarget>(*context);
118  auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
119  RewritePatternSet tempPatterns(context);
120  target->addLegalDialect<LLVM::LLVMDialect>();
121  // Populate the patterns with the dialect interface.
122  if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
124  *target, *typeConverter, tempPatterns);
125  })))
126  return failure();
127  this->patterns =
128  std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
129  this->target = target;
130  this->typeConverter = typeConverter;
131  return success();
132  }
133 
134  /// Apply the conversion driver.
135  LogicalResult transform(Operation *op, AnalysisManager manager) const final {
137  config.allowPatternRollback = allowPatternRollback;
138  if (failed(applyPartialConversion(op, *target, *patterns, config)))
139  return failure();
140  return success();
141  }
142 };
143 
144 //===----------------------------------------------------------------------===//
145 // DynamicConvertToLLVM
146 //===----------------------------------------------------------------------===//
147 
148 /// Dynamic implementation of the `convert-to-llvm` pass. This version inspects
149 /// the IR to configure the conversion to LLVM.
150 struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
151  /// A list of all the `ConvertToLLVMPatternInterface` dialect interfaces used
152  /// to partially configure the conversion process.
153  std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
154  interfaces;
155  using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
156 
157  /// Collect the dialect interfaces used to configure the conversion process.
158  LogicalResult initialize() final {
159  auto interfaces =
160  std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
161  // Collect the interfaces.
162  if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
163  interfaces->push_back(iface);
164  })))
165  return failure();
166  this->interfaces = interfaces;
167  return success();
168  }
169 
170  /// Configure the conversion process and apply the conversion driver.
171  LogicalResult transform(Operation *op, AnalysisManager manager) const final {
172  RewritePatternSet patterns(context);
173  ConversionTarget target(*context);
174  target.addLegalDialect<LLVM::LLVMDialect>();
175  // Get the data layout analysis.
176  const auto &dlAnalysis = manager.getAnalysis<DataLayoutAnalysis>();
177  const DataLayout &dl = dlAnalysis.getAtOrAbove(op);
178  LowerToLLVMOptions options(context, dl);
179  LLVMTypeConverter typeConverter(context, options, &dlAnalysis);
180 
181  // Configure the conversion with dialect level interfaces.
182  for (ConvertToLLVMPatternInterface *iface : *interfaces)
183  iface->populateConvertToLLVMConversionPatterns(target, typeConverter,
184  patterns);
185 
186  // Configure the conversion attribute interfaces.
187  populateOpConvertToLLVMConversionPatterns(op, target, typeConverter,
188  patterns);
189 
190  // Apply the conversion.
192  config.allowPatternRollback = allowPatternRollback;
193  if (failed(applyPartialConversion(op, target, std::move(patterns), config)))
194  return failure();
195  return success();
196  }
197 };
198 
199 //===----------------------------------------------------------------------===//
200 // ConvertToLLVMPass
201 //===----------------------------------------------------------------------===//
202 
203 /// This is a generic pass to convert to LLVM, it uses the
204 /// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
205 /// the injection of conversion patterns.
206 class ConvertToLLVMPass
207  : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
208  std::shared_ptr<const ConvertToLLVMPassInterface> impl;
209 
210 public:
211  using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
212  void getDependentDialects(DialectRegistry &registry) const final {
213  ConvertToLLVMPassInterface::getDependentDialects(registry);
214  }
215 
216  LogicalResult initialize(MLIRContext *context) final {
217  std::shared_ptr<ConvertToLLVMPassInterface> impl;
218  // Choose the pass implementation.
219  if (useDynamic)
220  impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects,
221  allowPatternRollback);
222  else
223  impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects,
224  allowPatternRollback);
225  if (failed(impl->initialize()))
226  return failure();
227  this->impl = impl;
228  return success();
229  }
230 
231  void runOnOperation() final {
232  if (failed(impl->transform(getOperation(), getAnalysisManager())))
233  return signalPassFailure();
234  }
235 };
236 
237 } // namespace
238 
239 //===----------------------------------------------------------------------===//
240 // ConvertToLLVMPassInterface
241 //===----------------------------------------------------------------------===//
242 
243 ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
244  MLIRContext *context, ArrayRef<std::string> filterDialects,
245  bool allowPatternRollback)
246  : context(context), filterDialects(filterDialects),
247  allowPatternRollback(allowPatternRollback) {}
248 
249 void ConvertToLLVMPassInterface::getDependentDialects(
250  DialectRegistry &registry) {
251  registry.insert<LLVM::LLVMDialect>();
252  registry.addExtensions<LoadDependentDialectExtension>();
253 }
254 
255 LogicalResult ConvertToLLVMPassInterface::visitInterfaces(
257  if (!filterDialects.empty()) {
258  // Test mode: Populate only patterns from the specified dialects. Produce
259  // an error if the dialect is not loaded or does not implement the
260  // interface.
261  for (StringRef dialectName : filterDialects) {
262  Dialect *dialect = context->getLoadedDialect(dialectName);
263  if (!dialect)
264  return emitError(UnknownLoc::get(context))
265  << "dialect not loaded: " << dialectName << "\n";
266  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
267  if (!iface)
268  return emitError(UnknownLoc::get(context))
269  << "dialect does not implement ConvertToLLVMPatternInterface: "
270  << dialectName << "\n";
271  visitor(iface);
272  }
273  } else {
274  // Normal mode: Populate all patterns from all dialects that implement the
275  // interface.
276  for (Dialect *dialect : context->getLoadedDialects()) {
277  // First time we encounter this dialect: if it implements the interface,
278  // let's populate patterns !
279  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
280  if (!iface)
281  continue;
282  visitor(iface);
283  }
284  }
285  return success();
286 }
287 
288 //===----------------------------------------------------------------------===//
289 // API
290 //===----------------------------------------------------------------------===//
291 
293  DialectRegistry &registry) {
294  registry.addExtensions<LoadDependentDialectExtension>();
295 }
static llvm::ManagedStatic< PassManagerOptions > options
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
This class represents an analysis manager for a particular operation instance.
This class describes a specific conversion target.
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
Stores data layout objects for each operation that specifies the data layout above and below the give...
The main mechanism for performing data layout queries.
This class represents an opaque dialect extension.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
Options to control the LLVM lowering.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateOpConvertToLLVMConversionPatterns(Operation *op, ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Helper function for populating LLVM conversion patterns.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.