MLIR  22.0.0git
ConvertToLLVMPass.cpp
Go to the documentation of this file.
1 //===- ConvertToLLVMPass.cpp - MLIR LLVM Conversion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/PatternMatch.h"
17 #include "llvm/Support/DebugLog.h"
18 #include <memory>
19 
20 #define DEBUG_TYPE "convert-to-llvm"
21 
22 namespace mlir {
23 #define GEN_PASS_DEF_CONVERTTOLLVMPASS
24 #include "mlir/Conversion/Passes.h.inc"
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 namespace {
30 /// Base class for creating the internal implementation of `convert-to-llvm`
31 /// passes.
32 class ConvertToLLVMPassInterface {
33 public:
34  ConvertToLLVMPassInterface(MLIRContext *context,
35  ArrayRef<std::string> filterDialects,
36  bool allowPatternRollback = true);
37  virtual ~ConvertToLLVMPassInterface() = default;
38 
39  /// Get the dependent dialects used by `convert-to-llvm`.
40  static void getDependentDialects(DialectRegistry &registry);
41 
42  /// Initialize the internal state of the `convert-to-llvm` pass
43  /// implementation. This method is invoked by `ConvertToLLVMPass::initialize`.
44  /// This method returns whether the initialization process failed.
45  virtual LogicalResult initialize() = 0;
46 
47  /// Transform `op` to LLVM with the conversions available in the pass. The
48  /// analysis manager can be used to query analyzes like `DataLayoutAnalysis`
49  /// to further configure the conversion process. This method is invoked by
50  /// `ConvertToLLVMPass::runOnOperation`. This method returns whether the
51  /// transformation process failed.
52  virtual LogicalResult transform(Operation *op,
53  AnalysisManager manager) const = 0;
54 
55 protected:
56  /// Visit the `ConvertToLLVMPatternInterface` dialect interfaces and call
57  /// `visitor` with each of the interfaces. If `filterDialects` is non-empty,
58  /// then `visitor` is invoked only with the dialects in the `filterDialects`
59  /// list.
60  LogicalResult visitInterfaces(
62  MLIRContext *context;
63  /// List of dialects names to use as filters.
64  ArrayRef<std::string> filterDialects;
65  /// An experimental flag to disallow pattern rollback. This is more efficient
66  /// but not supported by all lowering patterns.
67  bool allowPatternRollback;
68 };
69 
70 /// This DialectExtension can be attached to the context, which will invoke the
71 /// `apply()` method for every loaded dialect. If a dialect implements the
72 /// `ConvertToLLVMPatternInterface` interface, we load dependent dialects
73 /// through the interface. This extension is loaded in the context before
74 /// starting a pass pipeline that involves dialect conversion to LLVM.
75 class LoadDependentDialectExtension : public DialectExtensionBase {
76 public:
77  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension)
78 
79  LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
80 
81  void apply(MLIRContext *context,
82  MutableArrayRef<Dialect *> dialects) const final {
83  LDBG() << "Convert to LLVM extension load";
84  for (Dialect *dialect : dialects) {
85  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
86  if (!iface)
87  continue;
88  LDBG() << "Convert to LLVM found dialect interface for "
89  << dialect->getNamespace();
90  iface->loadDependentDialects(context);
91  }
92  }
93 
94  /// Return a copy of this extension.
95  std::unique_ptr<DialectExtensionBase> clone() const final {
96  return std::make_unique<LoadDependentDialectExtension>(*this);
97  }
98 };
99 
100 //===----------------------------------------------------------------------===//
101 // StaticConvertToLLVM
102 //===----------------------------------------------------------------------===//
103 
104 /// Static implementation of the `convert-to-llvm` pass. This version only looks
105 /// at dialect interfaces to configure the conversion process.
106 struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
107  /// Pattern set with conversions to LLVM.
108  std::shared_ptr<const FrozenRewritePatternSet> patterns;
109  /// The conversion target.
110  std::shared_ptr<const ConversionTarget> target;
111  /// The LLVM type converter.
112  std::shared_ptr<const LLVMTypeConverter> typeConverter;
113  using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
114 
115  /// Configure the conversion to LLVM at pass initialization.
116  LogicalResult initialize() final {
117  auto target = std::make_shared<ConversionTarget>(*context);
118  auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
119  RewritePatternSet tempPatterns(context);
120  target->addLegalDialect<LLVM::LLVMDialect>();
121  // Populate the patterns with the dialect interface.
122  if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
124  *target, *typeConverter, tempPatterns);
125  })))
126  return failure();
127  this->patterns =
128  std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
129  this->target = target;
130  this->typeConverter = typeConverter;
131  return success();
132  }
133 
134  /// Apply the conversion driver.
135  LogicalResult transform(Operation *op, AnalysisManager manager) const final {
137  config.allowPatternRollback = allowPatternRollback;
138  if (failed(applyPartialConversion(op, *target, *patterns, config)))
139  return failure();
140  return success();
141  }
142 };
143 
144 //===----------------------------------------------------------------------===//
145 // DynamicConvertToLLVM
146 //===----------------------------------------------------------------------===//
147 
148 /// Dynamic implementation of the `convert-to-llvm` pass. This version inspects
149 /// the IR to configure the conversion to LLVM.
150 struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
151  /// A list of all the `ConvertToLLVMPatternInterface` dialect interfaces used
152  /// to partially configure the conversion process.
153  std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
154  interfaces;
155  using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
156 
157  /// Collect the dialect interfaces used to configure the conversion process.
158  LogicalResult initialize() final {
159  auto interfaces =
160  std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
161  // Collect the interfaces.
162  if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
163  interfaces->push_back(iface);
164  })))
165  return failure();
166  this->interfaces = interfaces;
167  return success();
168  }
169 
170  /// Configure the conversion process and apply the conversion driver.
171  LogicalResult transform(Operation *op, AnalysisManager manager) const final {
172  RewritePatternSet patterns(context);
173  ConversionTarget target(*context);
174  target.addLegalDialect<LLVM::LLVMDialect>();
175  // Get the data layout analysis.
176  const auto &dlAnalysis = manager.getAnalysis<DataLayoutAnalysis>();
177  LLVMTypeConverter typeConverter(context, &dlAnalysis);
178 
179  // Configure the conversion with dialect level interfaces.
180  for (ConvertToLLVMPatternInterface *iface : *interfaces)
181  iface->populateConvertToLLVMConversionPatterns(target, typeConverter,
182  patterns);
183 
184  // Configure the conversion attribute interfaces.
185  populateOpConvertToLLVMConversionPatterns(op, target, typeConverter,
186  patterns);
187 
188  // Apply the conversion.
190  config.allowPatternRollback = allowPatternRollback;
191  if (failed(applyPartialConversion(op, target, std::move(patterns), config)))
192  return failure();
193  return success();
194  }
195 };
196 
197 //===----------------------------------------------------------------------===//
198 // ConvertToLLVMPass
199 //===----------------------------------------------------------------------===//
200 
201 /// This is a generic pass to convert to LLVM, it uses the
202 /// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
203 /// the injection of conversion patterns.
204 class ConvertToLLVMPass
205  : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
206  std::shared_ptr<const ConvertToLLVMPassInterface> impl;
207 
208 public:
209  using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
210  void getDependentDialects(DialectRegistry &registry) const final {
211  ConvertToLLVMPassInterface::getDependentDialects(registry);
212  }
213 
214  LogicalResult initialize(MLIRContext *context) final {
215  std::shared_ptr<ConvertToLLVMPassInterface> impl;
216  // Choose the pass implementation.
217  if (useDynamic)
218  impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects,
219  allowPatternRollback);
220  else
221  impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects,
222  allowPatternRollback);
223  if (failed(impl->initialize()))
224  return failure();
225  this->impl = impl;
226  return success();
227  }
228 
229  void runOnOperation() final {
230  if (failed(impl->transform(getOperation(), getAnalysisManager())))
231  return signalPassFailure();
232  }
233 };
234 
235 } // namespace
236 
237 //===----------------------------------------------------------------------===//
238 // ConvertToLLVMPassInterface
239 //===----------------------------------------------------------------------===//
240 
241 ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
242  MLIRContext *context, ArrayRef<std::string> filterDialects,
243  bool allowPatternRollback)
244  : context(context), filterDialects(filterDialects),
245  allowPatternRollback(allowPatternRollback) {}
246 
247 void ConvertToLLVMPassInterface::getDependentDialects(
248  DialectRegistry &registry) {
249  registry.insert<LLVM::LLVMDialect>();
250  registry.addExtensions<LoadDependentDialectExtension>();
251 }
252 
253 LogicalResult ConvertToLLVMPassInterface::visitInterfaces(
255  if (!filterDialects.empty()) {
256  // Test mode: Populate only patterns from the specified dialects. Produce
257  // an error if the dialect is not loaded or does not implement the
258  // interface.
259  for (StringRef dialectName : filterDialects) {
260  Dialect *dialect = context->getLoadedDialect(dialectName);
261  if (!dialect)
262  return emitError(UnknownLoc::get(context))
263  << "dialect not loaded: " << dialectName << "\n";
264  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
265  if (!iface)
266  return emitError(UnknownLoc::get(context))
267  << "dialect does not implement ConvertToLLVMPatternInterface: "
268  << dialectName << "\n";
269  visitor(iface);
270  }
271  } else {
272  // Normal mode: Populate all patterns from all dialects that implement the
273  // interface.
274  for (Dialect *dialect : context->getLoadedDialects()) {
275  // First time we encounter this dialect: if it implements the interface,
276  // let's populate patterns !
277  auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
278  if (!iface)
279  continue;
280  visitor(iface);
281  }
282  }
283  return success();
284 }
285 
286 //===----------------------------------------------------------------------===//
287 // API
288 //===----------------------------------------------------------------------===//
289 
291  DialectRegistry &registry) {
292  registry.addExtensions<LoadDependentDialectExtension>();
293 }
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
This class represents an analysis manager for a particular operation instance.
This class describes a specific conversion target.
Base class for dialect interfaces providing translation to LLVM IR.
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
Stores data layout objects for each operation that specifies the data layout above and below the give...
This class represents an opaque dialect extension.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateOpConvertToLLVMConversionPatterns(Operation *op, ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Helper function for populating LLVM conversion patterns.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
Dialect conversion configuration.