MLIR  21.0.0git
ConvertToEmitCPass.cpp
Go to the documentation of this file.
1 //===- ConvertToEmitCPass.cpp - Conversion to EmitC pass --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/Pass/Pass.h"
17 #include "llvm/Support/Debug.h"
18 
19 #include <memory>
20 
21 #define DEBUG_TYPE "convert-to-emitc"
22 
23 namespace mlir {
24 #define GEN_PASS_DEF_CONVERTTOEMITC
25 #include "mlir/Conversion/Passes.h.inc"
26 } // namespace mlir
27 
28 using namespace mlir;
29 
30 namespace {
31 /// Base class for creating the internal implementation of `convert-to-emitc`
32 /// passes.
33 class ConvertToEmitCPassInterface {
34 public:
35  ConvertToEmitCPassInterface(MLIRContext *context,
36  ArrayRef<std::string> filterDialects);
37  virtual ~ConvertToEmitCPassInterface() = default;
38 
39  /// Get the dependent dialects used by `convert-to-emitc`.
40  static void getDependentDialects(DialectRegistry &registry);
41 
42  /// Initialize the internal state of the `convert-to-emitc` pass
43  /// implementation. This method is invoked by `ConvertToEmitC::initialize`.
44  /// This method returns whether the initialization process failed.
45  virtual LogicalResult initialize() = 0;
46 
47  /// Transform `op` to the EmitC dialect with the conversions available in the
48  /// pass. The analysis manager can be used to query analyzes like
49  /// `DataLayoutAnalysis` to further configure the conversion process. This
50  /// method is invoked by `ConvertToEmitC::runOnOperation`. This method returns
51  /// whether the transformation process failed.
52  virtual LogicalResult transform(Operation *op,
53  AnalysisManager manager) const = 0;
54 
55 protected:
56  /// Visit the `ConvertToEmitCPatternInterface` dialect interfaces and call
57  /// `visitor` with each of the interfaces. If `filterDialects` is non-empty,
58  /// then `visitor` is invoked only with the dialects in the `filterDialects`
59  /// list.
60  LogicalResult visitInterfaces(
62  MLIRContext *context;
63  /// List of dialects names to use as filters.
64  ArrayRef<std::string> filterDialects;
65 };
66 
67 /// This DialectExtension can be attached to the context, which will invoke the
68 /// `apply()` method for every loaded dialect. If a dialect implements the
69 /// `ConvertToEmitCPatternInterface` interface, we load dependent dialects
70 /// through the interface. This extension is loaded in the context before
71 /// starting a pass pipeline that involves dialect conversion to the EmitC
72 /// dialect.
73 class LoadDependentDialectExtension : public DialectExtensionBase {
74 public:
75  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LoadDependentDialectExtension)
76 
77  LoadDependentDialectExtension() : DialectExtensionBase(/*dialectNames=*/{}) {}
78 
79  void apply(MLIRContext *context,
80  MutableArrayRef<Dialect *> dialects) const final {
81  LLVM_DEBUG(llvm::dbgs() << "Convert to EmitC extension load\n");
82  for (Dialect *dialect : dialects) {
83  auto *iface = dyn_cast<ConvertToEmitCPatternInterface>(dialect);
84  if (!iface)
85  continue;
86  LLVM_DEBUG(llvm::dbgs() << "Convert to EmitC found dialect interface for "
87  << dialect->getNamespace() << "\n");
88  }
89  }
90 
91  /// Return a copy of this extension.
92  std::unique_ptr<DialectExtensionBase> clone() const final {
93  return std::make_unique<LoadDependentDialectExtension>(*this);
94  }
95 };
96 
97 //===----------------------------------------------------------------------===//
98 // StaticConvertToEmitC
99 //===----------------------------------------------------------------------===//
100 
101 /// Static implementation of the `convert-to-emitc` pass. This version only
102 /// looks at dialect interfaces to configure the conversion process.
103 struct StaticConvertToEmitC : public ConvertToEmitCPassInterface {
104  /// Pattern set with conversions to the EmitC dialect.
105  std::shared_ptr<const FrozenRewritePatternSet> patterns;
106  /// The conversion target.
107  std::shared_ptr<const ConversionTarget> target;
108  /// The type converter.
109  std::shared_ptr<const TypeConverter> typeConverter;
110  using ConvertToEmitCPassInterface::ConvertToEmitCPassInterface;
111 
112  /// Configure the conversion to EmitC at pass initialization.
113  LogicalResult initialize() final {
114  auto target = std::make_shared<ConversionTarget>(*context);
115  auto typeConverter = std::make_shared<TypeConverter>();
116 
117  // Add fallback identity converison.
118  typeConverter->addConversion([](Type type) -> std::optional<Type> {
119  if (emitc::isSupportedEmitCType(type))
120  return type;
121  return std::nullopt;
122  });
123 
124  RewritePatternSet tempPatterns(context);
125  target->addLegalDialect<emitc::EmitCDialect>();
126  // Populate the patterns with the dialect interface.
127  if (failed(visitInterfaces([&](ConvertToEmitCPatternInterface *iface) {
129  *target, *typeConverter, tempPatterns);
130  })))
131  return failure();
132  this->patterns =
133  std::make_unique<FrozenRewritePatternSet>(std::move(tempPatterns));
134  this->target = target;
135  this->typeConverter = typeConverter;
136  return success();
137  }
138 
139  /// Apply the conversion driver.
140  LogicalResult transform(Operation *op, AnalysisManager manager) const final {
141  if (failed(applyPartialConversion(op, *target, *patterns)))
142  return failure();
143  return success();
144  }
145 };
146 
147 //===----------------------------------------------------------------------===//
148 // ConvertToEmitC
149 //===----------------------------------------------------------------------===//
150 
151 /// This is a generic pass to convert to the EmitC dialect. It uses the
152 /// `ConvertToEmitCPatternInterface` dialect interface to delegate the injection
153 /// of conversion patterns to dialects.
154 class ConvertToEmitC : public impl::ConvertToEmitCBase<ConvertToEmitC> {
155  std::shared_ptr<const ConvertToEmitCPassInterface> impl;
156 
157 public:
158  using impl::ConvertToEmitCBase<ConvertToEmitC>::ConvertToEmitCBase;
159  void getDependentDialects(DialectRegistry &registry) const final {
160  ConvertToEmitCPassInterface::getDependentDialects(registry);
161  }
162 
163  LogicalResult initialize(MLIRContext *context) final {
164  std::shared_ptr<ConvertToEmitCPassInterface> impl;
165  impl = std::make_shared<StaticConvertToEmitC>(context, filterDialects);
166  if (failed(impl->initialize()))
167  return failure();
168  this->impl = impl;
169  return success();
170  }
171 
172  void runOnOperation() final {
173  if (failed(impl->transform(getOperation(), getAnalysisManager())))
174  return signalPassFailure();
175  }
176 };
177 
178 } // namespace
179 
180 //===----------------------------------------------------------------------===//
181 // ConvertToEmitCPassInterface
182 //===----------------------------------------------------------------------===//
183 
184 ConvertToEmitCPassInterface::ConvertToEmitCPassInterface(
185  MLIRContext *context, ArrayRef<std::string> filterDialects)
186  : context(context), filterDialects(filterDialects) {}
187 
188 void ConvertToEmitCPassInterface::getDependentDialects(
189  DialectRegistry &registry) {
190  registry.insert<emitc::EmitCDialect>();
191  registry.addExtensions<LoadDependentDialectExtension>();
192 }
193 
194 LogicalResult ConvertToEmitCPassInterface::visitInterfaces(
196  if (!filterDialects.empty()) {
197  // Test mode: Populate only patterns from the specified dialects. Produce
198  // an error if the dialect is not loaded or does not implement the
199  // interface.
200  for (StringRef dialectName : filterDialects) {
201  Dialect *dialect = context->getLoadedDialect(dialectName);
202  if (!dialect)
203  return emitError(UnknownLoc::get(context))
204  << "dialect not loaded: " << dialectName << "\n";
205  auto *iface = dyn_cast<ConvertToEmitCPatternInterface>(dialect);
206  if (!iface)
207  return emitError(UnknownLoc::get(context))
208  << "dialect does not implement ConvertToEmitCPatternInterface: "
209  << dialectName << "\n";
210  visitor(iface);
211  }
212  } else {
213  // Normal mode: Populate all patterns from all dialects that implement the
214  // interface.
215  for (Dialect *dialect : context->getLoadedDialects()) {
216  auto *iface = dyn_cast<ConvertToEmitCPatternInterface>(dialect);
217  if (!iface)
218  continue;
219  visitor(iface);
220  }
221  }
222  return success();
223 }
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
This class represents an analysis manager for a particular operation instance.
virtual void populateConvertToEmitCConversionPatterns(ConversionTarget &target, TypeConverter &typeConverter, RewritePatternSet &patterns) const =0
Hook for derived dialect interface to provide conversion patterns and mark dialect legal for the conv...
This class represents an opaque dialect extension.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isSupportedEmitCType(mlir::Type type)
Determines whether type is valid in EmitC.
Definition: EmitC.cpp:62
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.