15#include "llvm/Support/Debug.h"
19#define DEBUG_TYPE "convert-to-emitc"
22#define GEN_PASS_DEF_CONVERTTOEMITC
23#include "mlir/Conversion/Passes.h.inc"
31class ConvertToEmitCPassInterface {
33 ConvertToEmitCPassInterface(MLIRContext *context,
34 ArrayRef<std::string> filterDialects);
35 virtual ~ConvertToEmitCPassInterface() =
default;
38 static void getDependentDialects(DialectRegistry ®istry);
50 virtual LogicalResult transform(Operation *op,
51 AnalysisManager manager)
const = 0;
58 LogicalResult visitInterfaces(
59 llvm::function_ref<
void(ConvertToEmitCPatternInterface *)> visitor);
62 ArrayRef<std::string> filterDialects;
75 LoadDependentDialectExtension() : DialectExtensionBase({}) {}
77 void apply(MLIRContext *context,
78 MutableArrayRef<Dialect *> dialects)
const final {
79 LLVM_DEBUG(llvm::dbgs() <<
"Convert to EmitC extension load\n");
80 for (Dialect *dialect : dialects) {
81 auto *iface = dyn_cast<ConvertToEmitCPatternInterface>(dialect);
84 LLVM_DEBUG(llvm::dbgs() <<
"Convert to EmitC found dialect interface for "
85 << dialect->getNamespace() <<
"\n");
90 std::unique_ptr<DialectExtensionBase>
clone() const final {
91 return std::make_unique<LoadDependentDialectExtension>(*
this);
101struct StaticConvertToEmitC :
public ConvertToEmitCPassInterface {
103 std::shared_ptr<const FrozenRewritePatternSet> patterns;
105 std::shared_ptr<const ConversionTarget> target;
107 std::shared_ptr<const TypeConverter> typeConverter;
108 using ConvertToEmitCPassInterface::ConvertToEmitCPassInterface;
112 auto target = std::make_shared<ConversionTarget>(*context);
113 auto typeConverter = std::make_shared<TypeConverter>();
116 typeConverter->addConversion([](Type type) -> std::optional<Type> {
122 RewritePatternSet tempPatterns(context);
123 target->addLegalDialect<emitc::EmitCDialect>();
125 if (
failed(visitInterfaces([&](ConvertToEmitCPatternInterface *iface) {
127 *target, *typeConverter, tempPatterns);
131 std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
132 this->target = target;
133 this->typeConverter = typeConverter;
138 LogicalResult transform(Operation *op, AnalysisManager manager)
const final {
139 if (
failed(applyPartialConversion(op, *target, *patterns)))
153 std::shared_ptr<const ConvertToEmitCPassInterface> impl;
156 using impl::ConvertToEmitCBase<ConvertToEmitC>::ConvertToEmitCBase;
157 void getDependentDialects(DialectRegistry ®istry)
const final {
158 ConvertToEmitCPassInterface::getDependentDialects(registry);
161 LogicalResult
initialize(MLIRContext *context)
final {
162 std::shared_ptr<ConvertToEmitCPassInterface> impl;
163 impl = std::make_shared<StaticConvertToEmitC>(context, filterDialects);
164 if (
failed(impl->initialize()))
170 void runOnOperation() final {
171 if (
failed(impl->transform(getOperation(), getAnalysisManager())))
172 return signalPassFailure();
182ConvertToEmitCPassInterface::ConvertToEmitCPassInterface(
184 : context(context), filterDialects(filterDialects) {}
186void ConvertToEmitCPassInterface::getDependentDialects(
188 registry.
insert<emitc::EmitCDialect>();
192LogicalResult ConvertToEmitCPassInterface::visitInterfaces(
193 llvm::function_ref<
void(ConvertToEmitCPatternInterface *)> visitor) {
194 if (!filterDialects.empty()) {
198 for (StringRef dialectName : filterDialects) {
201 return emitError(UnknownLoc::get(context))
202 <<
"dialect not loaded: " << dialectName <<
"\n";
203 auto *iface = dyn_cast<ConvertToEmitCPatternInterface>(dialect);
205 return emitError(UnknownLoc::get(context))
206 <<
"dialect does not implement ConvertToEmitCPatternInterface: "
207 << dialectName <<
"\n";
214 auto *iface = dyn_cast<ConvertToEmitCPatternInterface>(dialect);
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
virtual void populateConvertToEmitCConversionPatterns(ConversionTarget &target, TypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
This class represents an opaque dialect extension.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
bool isSupportedEmitCType(mlir::Type type)
Determines whether type is valid in EmitC.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)