17#include "llvm/Support/DebugLog.h"
20#define DEBUG_TYPE "convert-to-llvm"
23#define GEN_PASS_DEF_CONVERTTOLLVMPASS
24#include "mlir/Conversion/Passes.h.inc"
32class ConvertToLLVMPassInterface {
34 ConvertToLLVMPassInterface(MLIRContext *context,
35 ArrayRef<std::string> filterDialects,
36 bool allowPatternRollback =
true);
37 virtual ~ConvertToLLVMPassInterface() =
default;
40 static void getDependentDialects(DialectRegistry ®istry);
52 virtual LogicalResult transform(Operation *op,
53 AnalysisManager manager)
const = 0;
60 LogicalResult visitInterfaces(
61 llvm::function_ref<
void(ConvertToLLVMPatternInterface *)> visitor);
64 ArrayRef<std::string> filterDialects;
67 bool allowPatternRollback;
79 LoadDependentDialectExtension() : DialectExtensionBase({}) {}
81 void apply(MLIRContext *context,
82 MutableArrayRef<Dialect *> dialects)
const final {
83 LDBG() <<
"Convert to LLVM extension load";
84 for (Dialect *dialect : dialects) {
85 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
88 LDBG() <<
"Convert to LLVM found dialect interface for "
89 << dialect->getNamespace();
90 iface->loadDependentDialects(context);
95 std::unique_ptr<DialectExtensionBase>
clone() const final {
96 return std::make_unique<LoadDependentDialectExtension>(*
this);
106struct StaticConvertToLLVM :
public ConvertToLLVMPassInterface {
108 std::shared_ptr<const FrozenRewritePatternSet> patterns;
110 std::shared_ptr<const ConversionTarget> target;
112 std::shared_ptr<const LLVMTypeConverter> typeConverter;
113 using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
117 auto target = std::make_shared<ConversionTarget>(*context);
118 auto typeConverter = std::make_shared<LLVMTypeConverter>(context);
119 RewritePatternSet tempPatterns(context);
120 target->addLegalDialect<LLVM::LLVMDialect>();
122 if (
failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
124 *target, *typeConverter, tempPatterns);
128 std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
129 this->target = target;
130 this->typeConverter = typeConverter;
135 LogicalResult transform(Operation *op, AnalysisManager manager)
const final {
137 config.allowPatternRollback = allowPatternRollback;
138 if (
failed(applyPartialConversion(op, *target, *patterns,
config)))
150struct DynamicConvertToLLVM :
public ConvertToLLVMPassInterface {
153 std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
155 using ConvertToLLVMPassInterface::ConvertToLLVMPassInterface;
160 std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
162 if (
failed(visitInterfaces([&](ConvertToLLVMPatternInterface *iface) {
163 interfaces->push_back(iface);
166 this->interfaces = interfaces;
171 LogicalResult transform(Operation *op, AnalysisManager manager)
const final {
172 RewritePatternSet
patterns(context);
173 ConversionTarget
target(*context);
174 target.addLegalDialect<LLVM::LLVMDialect>();
176 const auto &dlAnalysis = manager.getAnalysis<DataLayoutAnalysis>();
177 const DataLayout &dl = dlAnalysis.getAtOrAbove(op);
178 LowerToLLVMOptions
options(context, dl);
179 LLVMTypeConverter typeConverter(context,
options, &dlAnalysis);
182 for (ConvertToLLVMPatternInterface *iface : *interfaces)
192 config.allowPatternRollback = allowPatternRollback;
206class ConvertToLLVMPass
208 std::shared_ptr<const ConvertToLLVMPassInterface> impl;
211 using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
212 void getDependentDialects(DialectRegistry ®istry)
const final {
213 ConvertToLLVMPassInterface::getDependentDialects(registry);
216 LogicalResult
initialize(MLIRContext *context)
final {
217 std::shared_ptr<ConvertToLLVMPassInterface> impl;
220 impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects,
221 allowPatternRollback);
223 impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects,
224 allowPatternRollback);
225 if (
failed(impl->initialize()))
231 void runOnOperation() final {
232 if (
failed(impl->transform(getOperation(), getAnalysisManager())))
233 return signalPassFailure();
243ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
245 bool allowPatternRollback)
246 : context(context), filterDialects(filterDialects),
247 allowPatternRollback(allowPatternRollback) {}
249void ConvertToLLVMPassInterface::getDependentDialects(
251 registry.
insert<LLVM::LLVMDialect>();
255LogicalResult ConvertToLLVMPassInterface::visitInterfaces(
256 llvm::function_ref<
void(ConvertToLLVMPatternInterface *)> visitor) {
257 if (!filterDialects.empty()) {
261 for (StringRef dialectName : filterDialects) {
264 return emitError(UnknownLoc::get(context))
265 <<
"dialect not loaded: " << dialectName <<
"\n";
266 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
268 return emitError(UnknownLoc::get(context))
269 <<
"dialect does not implement ConvertToLLVMPatternInterface: "
270 << dialectName <<
"\n";
279 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
static llvm::ManagedStatic< PassManagerOptions > options
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
virtual void populateConvertToLLVMConversionPatterns(ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
This class represents an opaque dialect extension.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry)
Register the extension that will load dependent dialects for LLVM conversion.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
void populateOpConvertToLLVMConversionPatterns(Operation *op, ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
Helper function for populating LLVM conversion patterns.