MLIR 22.0.0git
ConvertToEmitCPass.cpp
Go to the documentation of this file.
1//===- ConvertToEmitCPass.cpp - Conversion to EmitC pass --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
13#include "mlir/Pass/Pass.h"
15#include "llvm/Support/Debug.h"
16
17#include <memory>
18
19#define DEBUG_TYPE "convert-to-emitc"
20
21namespace mlir {
22#define GEN_PASS_DEF_CONVERTTOEMITC
23#include "mlir/Conversion/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27
28namespace {
29/// Base class for creating the internal implementation of `convert-to-emitc`
30/// passes.
31class ConvertToEmitCPassInterface {
32public:
33 ConvertToEmitCPassInterface(MLIRContext *context,
34 ArrayRef<std::string> filterDialects);
35 virtual ~ConvertToEmitCPassInterface() = default;
36
37 /// Get the dependent dialects used by `convert-to-emitc`.
38 static void getDependentDialects(DialectRegistry &registry);
39
40 /// Initialize the internal state of the `convert-to-emitc` pass
41 /// implementation. This method is invoked by `ConvertToEmitC::initialize`.
42 /// This method returns whether the initialization process failed.
43 virtual LogicalResult initialize() = 0;
44
45 /// Transform `op` to the EmitC dialect with the conversions available in the
46 /// pass. The analysis manager can be used to query analyzes like
47 /// `DataLayoutAnalysis` to further configure the conversion process. This
48 /// method is invoked by `ConvertToEmitC::runOnOperation`. This method returns
49 /// whether the transformation process failed.
50 virtual LogicalResult transform(Operation *op,
51 AnalysisManager manager) const = 0;
52
53protected:
54 /// Visit the `ConvertToEmitCPatternInterface` dialect interfaces and call
55 /// `visitor` with each of the interfaces. If `filterDialects` is non-empty,
56 /// then `visitor` is invoked only with the dialects in the `filterDialects`
57 /// list.
58 LogicalResult visitInterfaces(
59 llvm::function_ref<void(ConvertToEmitCPatternInterface *)> visitor);
60 MLIRContext *context;
61 /// List of dialects names to use as filters.
62 ArrayRef<std::string> filterDialects;
63};
64
65/// This DialectExtension can be attached to the context, which will invoke the
66/// `apply()` method for every loaded dialect. If a dialect implements the
67/// `ConvertToEmitCPatternInterface` interface, we load dependent dialects
68/// through the interface. This extension is loaded in the context before
69/// starting a pass pipeline that involves dialect conversion to the EmitC
70/// dialect.
71class LoadDependentDialectExtension : public DialectExtensionBase {
72public:
73 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension)
74
75 LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
76
77 void apply(MLIRContext *context,
78 MutableArrayRef<Dialect *> dialects) const final {
79 LLVM_DEBUG(llvm::dbgs() << "Convert to EmitC extension load\n");
80 for (Dialect *dialect : dialects) {
81 auto *iface = dyn_cast<ConvertToEmitCPatternInterface>(dialect);
82 if (!iface)
83 continue;
84 LLVM_DEBUG(llvm::dbgs() << "Convert to EmitC found dialect interface for "
85 << dialect->getNamespace() << "\n");
86 }
87 }
88
89 /// Return a copy of this extension.
90 std::unique_ptr<DialectExtensionBase> clone() const final {
91 return std::make_unique<LoadDependentDialectExtension>(*this);
92 }
93};
94
95//===----------------------------------------------------------------------===//
96// StaticConvertToEmitC
97//===----------------------------------------------------------------------===//
98
99/// Static implementation of the `convert-to-emitc` pass. This version only
100/// looks at dialect interfaces to configure the conversion process.
101struct StaticConvertToEmitC : public ConvertToEmitCPassInterface {
102 /// Pattern set with conversions to the EmitC dialect.
103 std::shared_ptr<const FrozenRewritePatternSet> patterns;
104 /// The conversion target.
105 std::shared_ptr<const ConversionTarget> target;
106 /// The type converter.
107 std::shared_ptr<const TypeConverter> typeConverter;
108 using ConvertToEmitCPassInterface::ConvertToEmitCPassInterface;
109
110 /// Configure the conversion to EmitC at pass initialization.
111 LogicalResult initialize() final {
112 auto target = std::make_shared<ConversionTarget>(*context);
113 auto typeConverter = std::make_shared<TypeConverter>();
114
115 // Add fallback identity converison.
116 typeConverter->addConversion([](Type type) -> std::optional<Type> {
118 return type;
119 return std::nullopt;
120 });
121
122 RewritePatternSet tempPatterns(context);
123 target->addLegalDialect<emitc::EmitCDialect>();
124 // Populate the patterns with the dialect interface.
125 if (failed(visitInterfaces([&](ConvertToEmitCPatternInterface *iface) {
127 *target, *typeConverter, tempPatterns);
128 })))
129 return failure();
130 this->patterns =
131 std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
132 this->target = target;
133 this->typeConverter = typeConverter;
134 return success();
135 }
136
137 /// Apply the conversion driver.
138 LogicalResult transform(Operation *op, AnalysisManager manager) const final {
139 if (failed(applyPartialConversion(op, *target, *patterns)))
140 return failure();
141 return success();
142 }
143};
144
145//===----------------------------------------------------------------------===//
146// ConvertToEmitC
147//===----------------------------------------------------------------------===//
148
149/// This is a generic pass to convert to the EmitC dialect. It uses the
150/// `ConvertToEmitCPatternInterface` dialect interface to delegate the injection
151/// of conversion patterns to dialects.
152class ConvertToEmitC : public impl::ConvertToEmitCBase<ConvertToEmitC> {
153 std::shared_ptr<const ConvertToEmitCPassInterface> impl;
154
155public:
156 using impl::ConvertToEmitCBase<ConvertToEmitC>::ConvertToEmitCBase;
157 void getDependentDialects(DialectRegistry &registry) const final {
158 ConvertToEmitCPassInterface::getDependentDialects(registry);
159 }
160
161 LogicalResult initialize(MLIRContext *context) final {
162 std::shared_ptr<ConvertToEmitCPassInterface> impl;
163 impl = std::make_shared<StaticConvertToEmitC>(context, filterDialects);
164 if (failed(impl->initialize()))
165 return failure();
166 this->impl = impl;
167 return success();
168 }
169
170 void runOnOperation() final {
171 if (failed(impl->transform(getOperation(), getAnalysisManager())))
172 return signalPassFailure();
173 }
174};
175
176} // namespace
177
178//===----------------------------------------------------------------------===//
179// ConvertToEmitCPassInterface
180//===----------------------------------------------------------------------===//
181
182ConvertToEmitCPassInterface::ConvertToEmitCPassInterface(
183 MLIRContext *context, ArrayRef<std::string> filterDialects)
184 : context(context), filterDialects(filterDialects) {}
185
186void ConvertToEmitCPassInterface::getDependentDialects(
187 DialectRegistry &registry) {
188 registry.insert<emitc::EmitCDialect>();
189 registry.addExtensions<LoadDependentDialectExtension>();
190}
191
192LogicalResult ConvertToEmitCPassInterface::visitInterfaces(
193 llvm::function_ref<void(ConvertToEmitCPatternInterface *)> visitor) {
194 if (!filterDialects.empty()) {
195 // Test mode: Populate only patterns from the specified dialects. Produce
196 // an error if the dialect is not loaded or does not implement the
197 // interface.
198 for (StringRef dialectName : filterDialects) {
199 Dialect *dialect = context->getLoadedDialect(dialectName);
200 if (!dialect)
201 return emitError(UnknownLoc::get(context))
202 << "dialect not loaded: " << dialectName << "\n";
203 auto *iface = dyn_cast<ConvertToEmitCPatternInterface>(dialect);
204 if (!iface)
205 return emitError(UnknownLoc::get(context))
206 << "dialect does not implement ConvertToEmitCPatternInterface: "
207 << dialectName << "\n";
208 visitor(iface);
209 }
210 } else {
211 // Normal mode: Populate all patterns from all dialects that implement the
212 // interface.
213 for (Dialect *dialect : context->getLoadedDialects()) {
214 auto *iface = dyn_cast<ConvertToEmitCPatternInterface>(dialect);
215 if (!iface)
216 continue;
217 visitor(iface);
218 }
219 }
220 return success();
221}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
virtual void populateConvertToEmitCConversionPatterns(ConversionTarget &target, TypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
This class represents an opaque dialect extension.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
bool isSupportedEmitCType(mlir::Type type)
Determines whether type is valid in EmitC.
Definition EmitC.cpp:61
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)