17#include "llvm/Support/DebugLog.h"
20#define DEBUG_TYPE "convert-to-llvm"
23#define GEN_PASS_DEF_CONVERTTOLLVMPASS
24#include "mlir/Conversion/Passes.h.inc"
32class ConvertToLLVMPassInterface {
34 ConvertToLLVMPassInterface(MLIRContext *context,
35 ArrayRef<std::string> filterDialects,
36 bool allowPatternRollback =
true);
37 virtual ~ConvertToLLVMPassInterface() =
default;
40 static void getDependentDialects(DialectRegistry ®istry);
52 virtual LogicalResult transform(Operation *op,
53 AnalysisManager manager)
const = 0;
60 LogicalResult visitInterfaces(
61 llvm::function_ref<
void(ConvertToLLVMPatternInterface *)> visitor);
64 ArrayRef<std::string> filterDialects;
67 bool allowPatternRollback;
79 LoadDependentDialectExtension() : DialectExtensionBase({}) {}
81 void apply(MLIRContext *context,
82 MutableArrayRef<Dialect *> dialects)
const final {
83 LDBG() <<
"Convert to LLVM extension load";
84 for (Dialect *dialect : dialects) {
85 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
88 LDBG() <<
"Convert to LLVM found dialect interface for "
89 << dialect->getNamespace();
90 iface->loadDependentDialects(context);
95 std::unique_ptr<DialectExtensionBase>
clone() const final {
96 return std::make_unique<LoadDependentDialectExtension>(*
this);
106struct StaticConvertToLLVM :
public ConvertToLLVMPassInterface {
108 std::shared_ptr<const FrozenRewritePatternSet> patterns;
110 std::shared_ptr<const ConversionTarget> target;
112 std::shared_ptr<const LLVMTypeConverter> typeConverter;
113 using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
117 auto target = std::make_shared<ConversionTarget>(*context);
118 auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
119 RewritePatternSet tempPatterns(context);
120 target->addLegalDialect<LLVM::LLVMDialect>();
122 if (
failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
124 *target, *typeConverter, tempPatterns);
128 std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
129 this->target = target;
130 this->typeConverter = typeConverter;
135 LogicalResult transform(Operation *op, AnalysisManager manager)
const final {
136 ConversionConfig config;
137 config.allowPatternRollback = allowPatternRollback;
138 if (
failed(applyPartialConversion(op, *target, *patterns, config)))
150struct DynamicConvertToLLVM :
public ConvertToLLVMPassInterface {
153 std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
155 using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
160 std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
162 if (
failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
163 interfaces->push_back(iface);
166 this->interfaces = interfaces;
171 LogicalResult transform(Operation *op, AnalysisManager manager)
const final {
172 RewritePatternSet patterns(context);
173 ConversionTarget
target(*context);
174 target.addLegalDialect<LLVM::LLVMDialect>();
176 const auto &dlAnalysis = manager.getAnalysis<DataLayoutAnalysis>();
177 const DataLayout &dl = dlAnalysis.getAtOrAbove(op);
178 LowerToLLVMOptions
options(context, dl);
179 LLVMTypeConverter typeConverter(context,
options, &dlAnalysis);
182 for (ConvertToLLVMPatternInterface *iface : *interfaces)
191 ConversionConfig config;
192 config.allowPatternRollback = allowPatternRollback;
193 if (
failed(applyPartialConversion(op,
target, std::move(patterns), config)))
206class ConvertToLLVMPass
208 std::shared_ptr<const ConvertToLLVMPassInterface> impl;
211 using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
212 void getDependentDialects(DialectRegistry ®istry)
const final {
213 ConvertToLLVMPassInterface::getDependentDialects(registry);
216 LogicalResult
initialize(MLIRContext *context)
final {
217 std::shared_ptr<ConvertToLLVMPassInterface> impl;
220 impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects,
221 allowPatternRollback);
223 impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects,
224 allowPatternRollback);
225 if (
failed(impl->initialize()))
231 void runOnOperation() final {
232 if (
failed(impl->transform(getOperation(), getAnalysisManager())))
233 return signalPassFailure();
243ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
245 bool allowPatternRollback)
246 : context(context), filterDialects(filterDialects),
247 allowPatternRollback(allowPatternRollback) {}
249void ConvertToLLVMPassInterface::getDependentDialects(
251 registry.
insert<LLVM::LLVMDialect>();
255LogicalResult ConvertToLLVMPassInterface::visitInterfaces(
256 llvm::function_ref<
void(ConvertToLLVMPatternInterface *)> visitor) {
257 if (!filterDialects.empty()) {
261 for (StringRef dialectName : filterDialects) {
264 return emitError(UnknownLoc::get(context))
265 <<
"dialect not loaded: " << dialectName <<
"\n";
266 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
268 return emitError(UnknownLoc::get(context))
269 <<
"dialect does not implement ConvertToLLVMPatternInterface: "
270 << dialectName <<
"\n";
279 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static llvm::ManagedStatic< PassManagerOptions > options
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
This class represents an opaque dialect extension.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
void populateOpConvertToLLVMConversionPatterns(Operation *op, ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Helper function for populating LLVM conversion patterns.