MLIR 22.0.0git
ConvertToLLVMPass.cpp
Go to the documentation of this file.
1//===- ConvertToLLVMPass.cpp - MLIR LLVM Conversion -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17#include "llvm/Support/DebugLog.h"
18#include <memory>
19
20#define DEBUG_TYPE "convert-to-llvm"
21
22namespace mlir {
23#define GEN_PASS_DEF_CONVERTTOLLVMPASS
24#include "mlir/Conversion/Passes.h.inc"
25} // namespace mlir
26
27using namespace mlir;
28
29namespace {
30/// Base class for creating the internal implementation of `convert-to-llvm`
31/// passes.
32class ConvertToLLVMPassInterface {
33public:
34 ConvertToLLVMPassInterface(MLIRContext *context,
35 ArrayRef<std::string> filterDialects,
36 bool allowPatternRollback = true);
37 virtual ~ConvertToLLVMPassInterface() = default;
38
39 /// Get the dependent dialects used by `convert-to-llvm`.
40 static void getDependentDialects(DialectRegistry &registry);
41
42 /// Initialize the internal state of the `convert-to-llvm` pass
43 /// implementation. This method is invoked by `ConvertToLLVMPass::initialize`.
44 /// This method returns whether the initialization process failed.
45 virtual LogicalResult initialize() = 0;
46
47 /// Transform `op` to LLVM with the conversions available in the pass. The
48 /// analysis manager can be used to query analyzes like `DataLayoutAnalysis`
49 /// to further configure the conversion process. This method is invoked by
50 /// `ConvertToLLVMPass::runOnOperation`. This method returns whether the
51 /// transformation process failed.
52 virtual LogicalResult transform(Operation *op,
53 AnalysisManager manager) const = 0;
54
55protected:
56 /// Visit the `ConvertToLLVMPatternInterface` dialect interfaces and call
57 /// `visitor` with each of the interfaces. If `filterDialects` is non-empty,
58 /// then `visitor` is invoked only with the dialects in the `filterDialects`
59 /// list.
60 LogicalResult visitInterfaces(
61 llvm::function_ref<void(ConvertToLLVMPatternInterface *)> visitor);
62 MLIRContext *context;
63 /// List of dialects names to use as filters.
64 ArrayRef<std::string> filterDialects;
65 /// An experimental flag to disallow pattern rollback. This is more efficient
66 /// but not supported by all lowering patterns.
67 bool allowPatternRollback;
68};
69
70/// This DialectExtension can be attached to the context, which will invoke the
71/// `apply()` method for every loaded dialect. If a dialect implements the
72/// `ConvertToLLVMPatternInterface` interface, we load dependent dialects
73/// through the interface. This extension is loaded in the context before
74/// starting a pass pipeline that involves dialect conversion to LLVM.
75class LoadDependentDialectExtension : public DialectExtensionBase {
76public:
77 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension)
78
79 LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
80
81 void apply(MLIRContext *context,
82 MutableArrayRef<Dialect *> dialects) const final {
83 LDBG() << "Convert to LLVM extension load";
84 for (Dialect *dialect : dialects) {
85 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
86 if (!iface)
87 continue;
88 LDBG() << "Convert to LLVM found dialect interface for "
89 << dialect->getNamespace();
90 iface->loadDependentDialects(context);
91 }
92 }
93
94 /// Return a copy of this extension.
95 std::unique_ptr<DialectExtensionBase> clone() const final {
96 return std::make_unique<LoadDependentDialectExtension>(*this);
97 }
98};
99
100//===----------------------------------------------------------------------===//
101// StaticConvertToLLVM
102//===----------------------------------------------------------------------===//
103
104/// Static implementation of the `convert-to-llvm` pass. This version only looks
105/// at dialect interfaces to configure the conversion process.
106struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
107 /// Pattern set with conversions to LLVM.
108 std::shared_ptr<const FrozenRewritePatternSet> patterns;
109 /// The conversion target.
110 std::shared_ptr<const ConversionTarget> target;
111 /// The LLVM type converter.
112 std::shared_ptr<const LLVMTypeConverter> typeConverter;
113 using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
114
115 /// Configure the conversion to LLVM at pass initialization.
116 LogicalResult initialize() final {
117 auto target = std::make_shared<ConversionTarget>(*context);
118 auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
119 RewritePatternSet tempPatterns(context);
120 target->addLegalDialect<LLVM::LLVMDialect>();
121 // Populate the patterns with the dialect interface.
122 if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
124 *target, *typeConverter, tempPatterns);
125 })))
126 return failure();
127 this->patterns =
128 std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
129 this->target = target;
130 this->typeConverter = typeConverter;
131 return success();
132 }
133
134 /// Apply the conversion driver.
135 LogicalResult transform(Operation *op, AnalysisManager manager) const final {
136 ConversionConfig config;
137 config.allowPatternRollback = allowPatternRollback;
138 if (failed(applyPartialConversion(op, *target, *patterns, config)))
139 return failure();
140 return success();
141 }
142};
143
144//===----------------------------------------------------------------------===//
145// DynamicConvertToLLVM
146//===----------------------------------------------------------------------===//
147
148/// Dynamic implementation of the `convert-to-llvm` pass. This version inspects
149/// the IR to configure the conversion to LLVM.
150struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
151 /// A list of all the `ConvertToLLVMPatternInterface` dialect interfaces used
152 /// to partially configure the conversion process.
153 std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
154 interfaces;
155 using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
156
157 /// Collect the dialect interfaces used to configure the conversion process.
158 LogicalResult initialize() final {
159 auto interfaces =
160 std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
161 // Collect the interfaces.
162 if (failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
163 interfaces->push_back(iface);
164 })))
165 return failure();
166 this->interfaces = interfaces;
167 return success();
168 }
169
170 /// Configure the conversion process and apply the conversion driver.
171 LogicalResult transform(Operation *op, AnalysisManager manager) const final {
172 RewritePatternSet patterns(context);
173 ConversionTarget target(*context);
174 target.addLegalDialect<LLVM::LLVMDialect>();
175 // Get the data layout analysis.
176 const auto &dlAnalysis = manager.getAnalysis<DataLayoutAnalysis>();
177 const DataLayout &dl = dlAnalysis.getAtOrAbove(op);
178 LowerToLLVMOptions options(context, dl);
179 LLVMTypeConverter typeConverter(context, options, &dlAnalysis);
180
181 // Configure the conversion with dialect level interfaces.
182 for (ConvertToLLVMPatternInterface *iface : *interfaces)
184 patterns);
185
186 // Configure the conversion attribute interfaces.
188 patterns);
189
190 // Apply the conversion.
191 ConversionConfig config;
192 config.allowPatternRollback = allowPatternRollback;
193 if (failed(applyPartialConversion(op, target, std::move(patterns), config)))
194 return failure();
195 return success();
196 }
197};
198
199//===----------------------------------------------------------------------===//
200// ConvertToLLVMPass
201//===----------------------------------------------------------------------===//
202
203/// This is a generic pass to convert to LLVM, it uses the
204/// `ConvertToLLVMPatternInterface` dialect interface to delegate to dialects
205/// the injection of conversion patterns.
206class ConvertToLLVMPass
207 : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
208 std::shared_ptr<const ConvertToLLVMPassInterface> impl;
209
210public:
211 using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
212 void getDependentDialects(DialectRegistry &registry) const final {
213 ConvertToLLVMPassInterface::getDependentDialects(registry);
214 }
215
216 LogicalResult initialize(MLIRContext *context) final {
217 std::shared_ptr<ConvertToLLVMPassInterface> impl;
218 // Choose the pass implementation.
219 if (useDynamic)
220 impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects,
221 allowPatternRollback);
222 else
223 impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects,
224 allowPatternRollback);
225 if (failed(impl->initialize()))
226 return failure();
227 this->impl = impl;
228 return success();
229 }
230
231 void runOnOperation() final {
232 if (failed(impl->transform(getOperation(), getAnalysisManager())))
233 return signalPassFailure();
234 }
235};
236
237} // namespace
238
239//===----------------------------------------------------------------------===//
240// ConvertToLLVMPassInterface
241//===----------------------------------------------------------------------===//
242
243ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
244 MLIRContext *context, ArrayRef<std::string> filterDialects,
245 bool allowPatternRollback)
246 : context(context), filterDialects(filterDialects),
247 allowPatternRollback(allowPatternRollback) {}
248
249void ConvertToLLVMPassInterface::getDependentDialects(
250 DialectRegistry &registry) {
251 registry.insert<LLVM::LLVMDialect>();
252 registry.addExtensions<LoadDependentDialectExtension>();
253}
254
255LogicalResult ConvertToLLVMPassInterface::visitInterfaces(
256 llvm::function_ref<void(ConvertToLLVMPatternInterface *)> visitor) {
257 if (!filterDialects.empty()) {
258 // Test mode: Populate only patterns from the specified dialects. Produce
259 // an error if the dialect is not loaded or does not implement the
260 // interface.
261 for (StringRef dialectName : filterDialects) {
262 Dialect *dialect = context->getLoadedDialect(dialectName);
263 if (!dialect)
264 return emitError(UnknownLoc::get(context))
265 << "dialect not loaded: " << dialectName << "\n";
266 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
267 if (!iface)
268 return emitError(UnknownLoc::get(context))
269 << "dialect does not implement ConvertToLLVMPatternInterface: "
270 << dialectName << "\n";
271 visitor(iface);
272 }
273 } else {
274 // Normal mode: Populate all patterns from all dialects that implement the
275 // interface.
276 for (Dialect *dialect : context->getLoadedDialects()) {
277 // First time we encounter this dialect: if it implements the interface,
278 // let's populate patterns !
279 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
280 if (!iface)
281 continue;
282 visitor(iface);
283 }
284 }
285 return success();
286}
287
288//===----------------------------------------------------------------------===//
289// API
290//===----------------------------------------------------------------------===//
291
293 DialectRegistry &registry) {
294 registry.addExtensions<LoadDependentDialectExtension>();
295}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static llvm::ManagedStatic< PassManagerOptions > options
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
This class represents an opaque dialect extension.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry)
Register the extension that will load dependent dialects for LLVM conversion.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
void populateOpConvertToLLVMConversionPatterns(Operation *op, ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Helper function for populating LLVM conversion patterns.