16#include "llvm/Support/Debug.h"
20#define DEBUG_TYPE "convert-to-emitc"
23#define GEN_PASS_DEF_CONVERTTOEMITC
24#include "mlir/Conversion/Passes.h.inc"
32class ConvertToEmitCPassInterface {
34 ConvertToEmitCPassInterface(MLIRContext *context,
35 ArrayRef<std::string> filterDialects,
36 std::optional<bool> lowerToCpp);
37 virtual ~ConvertToEmitCPassInterface() =
default;
40 static void getDependentDialects(DialectRegistry ®istry);
52 virtual LogicalResult transform(Operation *op,
53 AnalysisManager manager)
const = 0;
60 LogicalResult visitInterfaces(
61 llvm::function_ref<
void(ConvertToEmitCPatternInterface *)> visitor);
64 ArrayRef<std::string> filterDialects;
65 std::optional<bool> lowerToCpp;
78 LoadDependentDialectExtension() : DialectExtensionBase({}) {}
80 void apply(MLIRContext *context,
81 MutableArrayRef<Dialect *> dialects)
const final {
82 LLVM_DEBUG(llvm::dbgs() <<
"Convert to EmitC extension load\n");
83 for (Dialect *dialect : dialects) {
84 auto *iface = dyn_cast<ConvertToEmitCPatternInterface>(dialect);
87 LLVM_DEBUG(llvm::dbgs() <<
"Convert to EmitC found dialect interface for "
88 << dialect->getNamespace() <<
"\n");
93 std::unique_ptr<DialectExtensionBase>
clone() const final {
94 return std::make_unique<LoadDependentDialectExtension>(*
this);
104struct StaticConvertToEmitC :
public ConvertToEmitCPassInterface {
106 std::shared_ptr<const FrozenRewritePatternSet> patterns;
108 std::shared_ptr<const ConversionTarget> target;
110 std::shared_ptr<const TypeConverter> typeConverter;
111 using ConvertToEmitCPassInterface::ConvertToEmitCPassInterface;
115 auto target = std::make_shared<ConversionTarget>(*context);
116 auto typeConverter = std::make_shared<EmitCTypeConverter>(context);
118 RewritePatternSet tempPatterns(context);
119 target->addLegalDialect<emitc::EmitCDialect>();
121 if (
failed(visitInterfaces([&](ConvertToEmitCPatternInterface *iface) {
122 iface->populateConvertToEmitCConversionPatterns(
123 *target, *typeConverter, tempPatterns, lowerToCpp);
127 std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
128 this->target = target;
129 this->typeConverter = typeConverter;
134 LogicalResult transform(Operation *op, AnalysisManager manager)
const final {
135 if (
failed(applyPartialConversion(op, *target, *patterns)))
149 std::shared_ptr<const ConvertToEmitCPassInterface> impl;
152 using impl::ConvertToEmitCBase<ConvertToEmitC>::ConvertToEmitCBase;
153 void getDependentDialects(DialectRegistry ®istry)
const final {
154 ConvertToEmitCPassInterface::getDependentDialects(registry);
157 LogicalResult
initialize(MLIRContext *context)
final {
158 std::shared_ptr<ConvertToEmitCPassInterface> impl;
159 std::optional<bool> lowerToCppOverride;
160 if (this->lowerToCpp.hasValue())
161 lowerToCppOverride = this->lowerToCpp;
162 impl = std::make_shared<StaticConvertToEmitC>(context, filterDialects,
164 if (
failed(impl->initialize()))
170 void runOnOperation() final {
171 if (
failed(impl->transform(getOperation(), getAnalysisManager())))
172 return signalPassFailure();
182ConvertToEmitCPassInterface::ConvertToEmitCPassInterface(
184 std::optional<bool> lowerToCpp)
185 : context(context), filterDialects(filterDialects), lowerToCpp(lowerToCpp) {
188void ConvertToEmitCPassInterface::getDependentDialects(
190 registry.
insert<emitc::EmitCDialect>();
194LogicalResult ConvertToEmitCPassInterface::visitInterfaces(
195 llvm::function_ref<
void(ConvertToEmitCPatternInterface *)> visitor) {
196 if (!filterDialects.empty()) {
200 for (StringRef dialectName : filterDialects) {
203 return emitError(UnknownLoc::get(context))
204 <<
"dialect not loaded: " << dialectName <<
"\n";
205 auto *iface = dyn_cast<ConvertToEmitCPatternInterface>(dialect);
207 return emitError(UnknownLoc::get(context))
208 <<
"dialect does not implement ConvertToEmitCPatternInterface: "
209 << dialectName <<
"\n";
216 auto *iface = dyn_cast<ConvertToEmitCPatternInterface>(dialect);
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
This class represents an opaque dialect extension.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)