MLIR 23.0.0git
ConvertToEmitCPass.cpp
Go to the documentation of this file.
1//===- ConvertToEmitCPass.cpp - Conversion to EmitC pass --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
14#include "mlir/Pass/Pass.h"
16#include "llvm/Support/Debug.h"
17
18#include <memory>
19
20#define DEBUG_TYPE "convert-to-emitc"
21
22namespace mlir {
23#define GEN_PASS_DEF_CONVERTTOEMITC
24#include "mlir/Conversion/Passes.h.inc"
25} // namespace mlir
26
27using namespace mlir;
28
29namespace {
30/// Base class for creating the internal implementation of `convert-to-emitc`
31/// passes.
32class ConvertToEmitCPassInterface {
33public:
34 ConvertToEmitCPassInterface(MLIRContext *context,
35 ArrayRef<std::string> filterDialects,
36 std::optional<bool> lowerToCpp);
37 virtual ~ConvertToEmitCPassInterface() = default;
38
39 /// Get the dependent dialects used by `convert-to-emitc`.
40 static void getDependentDialects(DialectRegistry &registry);
41
42 /// Initialize the internal state of the `convert-to-emitc` pass
43 /// implementation. This method is invoked by `ConvertToEmitC::initialize`.
44 /// This method returns whether the initialization process failed.
45 virtual LogicalResult initialize() = 0;
46
47 /// Transform `op` to the EmitC dialect with the conversions available in the
48 /// pass. The analysis manager can be used to query analyzes like
49 /// `DataLayoutAnalysis` to further configure the conversion process. This
50 /// method is invoked by `ConvertToEmitC::runOnOperation`. This method returns
51 /// whether the transformation process failed.
52 virtual LogicalResult transform(Operation *op,
53 AnalysisManager manager) const = 0;
54
55protected:
56 /// Visit the `ConvertToEmitCPatternInterface` dialect interfaces and call
57 /// `visitor` with each of the interfaces. If `filterDialects` is non-empty,
58 /// then `visitor` is invoked only with the dialects in the `filterDialects`
59 /// list.
60 LogicalResult visitInterfaces(
61 llvm::function_ref<void(ConvertToEmitCPatternInterface *)> visitor);
62 MLIRContext *context;
63 /// List of dialects names to use as filters.
64 ArrayRef<std::string> filterDialects;
65 std::optional<bool> lowerToCpp;
66};
67
68/// This DialectExtension can be attached to the context, which will invoke the
69/// `apply()` method for every loaded dialect. If a dialect implements the
70/// `ConvertToEmitCPatternInterface` interface, we load dependent dialects
71/// through the interface. This extension is loaded in the context before
72/// starting a pass pipeline that involves dialect conversion to the EmitC
73/// dialect.
74class LoadDependentDialectExtension : public DialectExtensionBase {
75public:
76 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension)
77
78 LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
79
80 void apply(MLIRContext *context,
81 MutableArrayRef<Dialect *> dialects) const final {
82 LLVM_DEBUG(llvm::dbgs() << "Convert to EmitC extension load\n");
83 for (Dialect *dialect : dialects) {
84 auto *iface = dyn_cast<ConvertToEmitCPatternInterface>(dialect);
85 if (!iface)
86 continue;
87 LLVM_DEBUG(llvm::dbgs() << "Convert to EmitC found dialect interface for "
88 << dialect->getNamespace() << "\n");
89 }
90 }
91
92 /// Return a copy of this extension.
93 std::unique_ptr<DialectExtensionBase> clone() const final {
94 return std::make_unique<LoadDependentDialectExtension>(*this);
95 }
96};
97
98//===----------------------------------------------------------------------===//
99// StaticConvertToEmitC
100//===----------------------------------------------------------------------===//
101
102/// Static implementation of the `convert-to-emitc` pass. This version only
103/// looks at dialect interfaces to configure the conversion process.
104struct StaticConvertToEmitC : public ConvertToEmitCPassInterface {
105 /// Pattern set with conversions to the EmitC dialect.
106 std::shared_ptr<const FrozenRewritePatternSet> patterns;
107 /// The conversion target.
108 std::shared_ptr<const ConversionTarget> target;
109 /// The type converter.
110 std::shared_ptr<const TypeConverter> typeConverter;
111 using ConvertToEmitCPassInterface::ConvertToEmitCPassInterface;
112
113 /// Configure the conversion to EmitC at pass initialization.
114 LogicalResult initialize() final {
115 auto target = std::make_shared<ConversionTarget>(*context);
116 auto typeConverter = std::make_shared<EmitCTypeConverter>(context);
117
118 RewritePatternSet tempPatterns(context);
119 target->addLegalDialect<emitc::EmitCDialect>();
120 // Populate the patterns with the dialect interface.
121 if (failed(visitInterfaces([&](ConvertToEmitCPatternInterface *iface) {
122 iface->populateConvertToEmitCConversionPatterns(
123 *target, *typeConverter, tempPatterns, lowerToCpp);
124 })))
125 return failure();
126 this->patterns =
127 std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
128 this->target = target;
129 this->typeConverter = typeConverter;
130 return success();
131 }
132
133 /// Apply the conversion driver.
134 LogicalResult transform(Operation *op, AnalysisManager manager) const final {
135 if (failed(applyPartialConversion(op, *target, *patterns)))
136 return failure();
137 return success();
138 }
139};
140
141//===----------------------------------------------------------------------===//
142// ConvertToEmitC
143//===----------------------------------------------------------------------===//
144
145/// This is a generic pass to convert to the EmitC dialect. It uses the
146/// `ConvertToEmitCPatternInterface` dialect interface to delegate the injection
147/// of conversion patterns to dialects.
148class ConvertToEmitC : public impl::ConvertToEmitCBase<ConvertToEmitC> {
149 std::shared_ptr<const ConvertToEmitCPassInterface> impl;
150
151public:
152 using impl::ConvertToEmitCBase<ConvertToEmitC>::ConvertToEmitCBase;
153 void getDependentDialects(DialectRegistry &registry) const final {
154 ConvertToEmitCPassInterface::getDependentDialects(registry);
155 }
156
157 LogicalResult initialize(MLIRContext *context) final {
158 std::shared_ptr<ConvertToEmitCPassInterface> impl;
159 std::optional<bool> lowerToCppOverride;
160 if (this->lowerToCpp.hasValue())
161 lowerToCppOverride = this->lowerToCpp;
162 impl = std::make_shared<StaticConvertToEmitC>(context, filterDialects,
163 lowerToCppOverride);
164 if (failed(impl->initialize()))
165 return failure();
166 this->impl = impl;
167 return success();
168 }
169
170 void runOnOperation() final {
171 if (failed(impl->transform(getOperation(), getAnalysisManager())))
172 return signalPassFailure();
173 }
174};
175
176} // namespace
177
178//===----------------------------------------------------------------------===//
179// ConvertToEmitCPassInterface
180//===----------------------------------------------------------------------===//
181
182ConvertToEmitCPassInterface::ConvertToEmitCPassInterface(
183 MLIRContext *context, ArrayRef<std::string> filterDialects,
184 std::optional<bool> lowerToCpp)
185 : context(context), filterDialects(filterDialects), lowerToCpp(lowerToCpp) {
186}
187
188void ConvertToEmitCPassInterface::getDependentDialects(
189 DialectRegistry &registry) {
190 registry.insert<emitc::EmitCDialect>();
191 registry.addExtensions<LoadDependentDialectExtension>();
192}
193
194LogicalResult ConvertToEmitCPassInterface::visitInterfaces(
195 llvm::function_ref<void(ConvertToEmitCPatternInterface *)> visitor) {
196 if (!filterDialects.empty()) {
197 // Test mode: Populate only patterns from the specified dialects. Produce
198 // an error if the dialect is not loaded or does not implement the
199 // interface.
200 for (StringRef dialectName : filterDialects) {
201 Dialect *dialect = context->getLoadedDialect(dialectName);
202 if (!dialect)
203 return emitError(UnknownLoc::get(context))
204 << "dialect not loaded: " << dialectName << "\n";
205 auto *iface = dyn_cast<ConvertToEmitCPatternInterface>(dialect);
206 if (!iface)
207 return emitError(UnknownLoc::get(context))
208 << "dialect does not implement ConvertToEmitCPatternInterface: "
209 << dialectName << "\n";
210 visitor(iface);
211 }
212 } else {
213 // Normal mode: Populate all patterns from all dialects that implement the
214 // interface.
215 for (Dialect *dialect : context->getLoadedDialects()) {
216 auto *iface = dyn_cast<ConvertToEmitCPatternInterface>(dialect);
217 if (!iface)
218 continue;
219 visitor(iface);
220 }
221 }
222 return success();
223}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition TypeID.h:331
This class represents an opaque dialect extension.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
std::vector< Dialect * > getLoadedDialects()
Return information about all IR dialects loaded in the context.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)