MLIR 22.0.0git
DialectConversion.h
Go to the documentation of this file.
1//===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file declares a generic pass for converting between MLIR dialects.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14#define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
15
16#include "mlir/Config/mlir-config.h"
18#include "llvm/ADT/MapVector.h"
19#include "llvm/ADT/StringMap.h"
20#include <type_traits>
21
22namespace mlir {
23
24// Forward declarations.
25class Attribute;
26class Block;
27struct ConversionConfig;
28class ConversionPatternRewriter;
29class MLIRContext;
30class Operation;
32class Type;
33class Value;
34
35//===----------------------------------------------------------------------===//
36// Type Conversion
37//===----------------------------------------------------------------------===//
38
39/// Type conversion class. Specific conversions and materializations can be
40/// registered using addConversion and addMaterialization, respectively.
41class TypeConverter {
42public:
43 /// Type alias to allow derived classes to inherit constructors with
44 /// `using Base::Base;`.
45 using Base = TypeConverter;
46
47 virtual ~TypeConverter() = default;
48 TypeConverter() = default;
49 // Copy the registered conversions, but not the caches
50 TypeConverter(const TypeConverter &other)
51 : conversions(other.conversions),
52 sourceMaterializations(other.sourceMaterializations),
53 targetMaterializations(other.targetMaterializations),
54 typeAttributeConversions(other.typeAttributeConversions) {}
55 TypeConverter &operator=(const TypeConverter &other) {
56 conversions = other.conversions;
57 sourceMaterializations = other.sourceMaterializations;
58 targetMaterializations = other.targetMaterializations;
59 typeAttributeConversions = other.typeAttributeConversions;
60 return *this;
61 }
62
63 /// This class provides all of the information necessary to convert a type
64 /// signature.
65 class SignatureConversion {
66 public:
67 SignatureConversion(unsigned numOrigInputs)
68 : remappedInputs(numOrigInputs) {}
69
70 /// This struct represents a range of new types or a range of values that
71 /// remaps an existing signature input.
72 struct InputMapping {
73 size_t inputNo, size;
74 SmallVector<Value, 1> replacementValues;
75
76 /// Return "true" if this input was replaces with one or multiple values.
77 bool replacedWithValues() const { return !replacementValues.empty(); }
78 };
79
80 /// Return the argument types for the new signature.
81 ArrayRef<Type> getConvertedTypes() const { return argTypes; }
82
83 /// Get the input mapping for the given argument.
84 std::optional<InputMapping> getInputMapping(unsigned input) const {
85 return remappedInputs[input];
86 }
87
88 //===------------------------------------------------------------------===//
89 // Conversion Hooks
90 //===------------------------------------------------------------------===//
91
92 /// Remap an input of the original signature with a new set of types. The
93 /// new types are appended to the new signature conversion.
94 void addInputs(unsigned origInputNo, ArrayRef<Type> types);
95
96 /// Append new input types to the signature conversion, this should only be
97 /// used if the new types are not intended to remap an existing input.
98 void addInputs(ArrayRef<Type> types);
99
100 /// Remap an input of the original signature to `replacements`
101 /// values. This drops the original argument.
102 void remapInput(unsigned origInputNo, ArrayRef<Value> replacements);
103
104 private:
105 /// Remap an input of the original signature with a range of types in the
106 /// new signature.
107 void remapInput(unsigned origInputNo, unsigned newInputNo,
108 unsigned newInputCount = 1);
109
110 /// The remapping information for each of the original arguments.
111 SmallVector<std::optional<InputMapping>, 4> remappedInputs;
112
113 /// The set of new argument types.
114 SmallVector<Type, 4> argTypes;
115 };
116
117 /// The general result of a type attribute conversion callback, allowing
118 /// for early termination. The default constructor creates the na case.
119 class AttributeConversionResult {
120 public:
121 constexpr AttributeConversionResult() : impl() {}
122 AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
123
124 static AttributeConversionResult result(Attribute attr);
125 static AttributeConversionResult na();
126 static AttributeConversionResult abort();
127
128 bool hasResult() const;
129 bool isNa() const;
130 bool isAbort() const;
131
132 Attribute getResult() const;
133
134 private:
135 AttributeConversionResult(Attribute attr, unsigned tag) : impl(attr, tag) {}
136
137 llvm::PointerIntPair<Attribute, 2> impl;
138 // Note that na is 0 so that we can use PointerIntPair's default
139 // constructor.
140 static constexpr unsigned naTag = 0;
141 static constexpr unsigned resultTag = 1;
142 static constexpr unsigned abortTag = 2;
143 };
144
145 /// Register a conversion function. A conversion function must be convertible
146 /// to any of the following forms (where `T` is `Value` or a class derived
147 /// from `Type`, including `Type` itself):
148 ///
149 /// * std::optional<Type>(T)
150 /// - This form represents a 1-1 type conversion. It should return nullptr
151 /// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
152 /// the converter is allowed to try another conversion function to
153 /// perform the conversion.
154 /// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
155 /// - This form represents a 1-N type conversion. It should return
156 /// `failure` or `std::nullopt` to signify a failed conversion. If the
157 /// new set of types is empty, the type is removed and any usages of the
158 /// existing value are expected to be removed during conversion. If
159 /// `std::nullopt` is returned, the converter is allowed to try another
160 /// conversion function to perform the conversion.
161 ///
162 /// Conversion functions that accept `Value` as the first argument are
163 /// context-aware. I.e., they can take into account IR when converting the
164 /// type of the given value. Context-unaware conversion functions accept
165 /// `Type` or a derived class as the first argument.
166 ///
167 /// Note: Context-unaware conversions are cached, but context-aware
168 /// conversions are not.
169 ///
170 /// Note: When attempting to convert a type, e.g. via 'convertType', the
171 /// mostly recently added conversions will be invoked first.
172 template <typename FnT, typename T = typename llvm::function_traits<
173 std::decay_t<FnT>>::template arg_t<0>>
174 void addConversion(FnT &&callback) {
175 registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
176 }
177
178 /// All of the following materializations require function objects that are
179 /// convertible to the following form:
180 /// `Value(OpBuilder &, T, ValueRange, Location)`,
181 /// where `T` is any subclass of `Type`. This function is responsible for
182 /// creating an operation, using the OpBuilder and Location provided, that
183 /// "casts" a range of values into a single value of the given type `T`. It
184 /// must return a Value of the type `T` on success and `nullptr` if
185 /// it failed but other materialization should be attempted. Materialization
186 /// functions must be provided when a type conversion may persist after the
187 /// conversion has finished.
188 ///
189 /// Note: Target materializations may optionally accept an additional Type
190 /// parameter, which is the original type of the SSA value. Furthermore, `T`
191 /// can be a TypeRange; in that case, the function must return a
192 /// SmallVector<Value>.
193
194 /// This method registers a materialization that will be called when
195 /// converting a replacement value back to its original source type.
196 /// This is used when some uses of the original value persist beyond the main
197 /// conversion.
198 template <typename FnT, typename T = typename llvm::function_traits<
199 std::decay_t<FnT>>::template arg_t<1>>
200 void addSourceMaterialization(FnT &&callback) {
201 sourceMaterializations.emplace_back(
202 wrapSourceMaterialization<T>(std::forward<FnT>(callback)));
203 }
204
205 /// This method registers a materialization that will be called when
206 /// converting a value to a target type according to a pattern's type
207 /// converter.
208 ///
209 /// Note: Target materializations can optionally inspect the "original"
210 /// type. This type may be different from the type of the input value.
211 /// For example, let's assume that a conversion pattern "P1" replaced an SSA
212 /// value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion
213 /// pattern "P2" matches an op that has "v1" as an operand. Let's furthermore
214 /// assume that "P2" determines that the converted target type of "t1" is
215 /// "t3", which may be different from "t2". In this example, the target
216 /// materialization will be invoked with: outputType = "t3", inputs = "v2",
217 /// originalType = "t1". Note that the original type "t1" cannot be recovered
218 /// from just "t3" and "v2"; that's why the originalType parameter exists.
219 ///
220 /// Note: During a 1:N conversion, the result types can be a TypeRange. In
221 /// that case the materialization produces a SmallVector<Value>.
222 template <typename FnT, typename T = typename llvm::function_traits<
223 std::decay_t<FnT>>::template arg_t<1>>
224 void addTargetMaterialization(FnT &&callback) {
225 targetMaterializations.emplace_back(
226 wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
227 }
228
229 /// Register a conversion function for attributes within types. Type
230 /// converters may call this function in order to allow hooking into the
231 /// translation of attributes that exist within types. For example, a type
232 /// converter for the `memref` type could use these conversions to convert
233 /// memory spaces or layouts in an extensible way.
234 ///
235 /// The conversion functions take a non-null Type or subclass of Type and a
236 /// non-null Attribute (or subclass of Attribute), and returns a
237 /// `AttributeConversionResult`. This result can either contain an
238 /// `Attribute`, which may be `nullptr`, representing the conversion's
239 /// success, `AttributeConversionResult::na()` (the default empty value),
240 /// indicating that the conversion function did not apply and that further
241 /// conversion functions should be checked, or
242 /// `AttributeConversionResult::abort()` indicating that the conversion
243 /// process should be aborted.
244 ///
245 /// Registered conversion functions are callled in the reverse of the order in
246 /// which they were registered.
247 template <
248 typename FnT,
249 typename T =
250 typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
251 typename A =
252 typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
253 void addTypeAttributeConversion(FnT &&callback) {
254 registerTypeAttributeConversion(
255 wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
256 }
257
258 /// Convert the given type. This function returns failure if no valid
259 /// conversion exists, success otherwise. If the new set of types is empty,
260 /// the type is removed and any usages of the existing value are expected to
261 /// be removed during conversion.
262 ///
263 /// Note: This overload invokes only context-unaware type conversion
264 /// functions. Users should call the other overload if possible.
265 LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
266
267 /// Convert the type of the given value. This function returns failure if no
268 /// valid conversion exists, success otherwise. If the new set of types is
269 /// empty, the type is removed and any usages of the existing value are
270 /// expected to be removed during conversion.
271 ///
272 /// Note: This overload invokes both context-aware and context-unaware type
273 /// conversion functions.
274 LogicalResult convertType(Value v, SmallVectorImpl<Type> &results) const;
275
276 /// This hook simplifies defining 1-1 type conversions. This function returns
277 /// the type to convert to on success, and a null type on failure.
278 Type convertType(Type t) const;
279 Type convertType(Value v) const;
280
281 /// Attempts a 1-1 type conversion, expecting the result type to be
282 /// `TargetType`. Returns the converted type cast to `TargetType` on success,
283 /// and a null type on conversion or cast failure.
284 template <typename TargetType>
285 TargetType convertType(Type t) const {
286 return dyn_cast_or_null<TargetType>(convertType(t));
287 }
288 template <typename TargetType>
289 TargetType convertType(Value v) const {
290 return dyn_cast_or_null<TargetType>(convertType(v));
291 }
292
293 /// Convert the given types, filling 'results' as necessary. This returns
294 /// "failure" if the conversion of any of the types fails, "success"
295 /// otherwise.
296 LogicalResult convertTypes(TypeRange types,
297 SmallVectorImpl<Type> &results) const;
298
299 /// Convert the types of the given values, filling 'results' as necessary.
300 /// This returns "failure" if the conversion of any of the types fails,
301 /// "success" otherwise.
302 LogicalResult convertTypes(ValueRange values,
303 SmallVectorImpl<Type> &results) const;
304
305 /// Return true if the given type is legal for this type converter, i.e. the
306 /// type converts to itself.
307 bool isLegal(Type type) const;
308 bool isLegal(Value value) const;
309
310 /// Return true if all of the given types are legal for this type converter.
311 bool isLegal(TypeRange range) const {
312 return llvm::all_of(range, [this](Type type) { return isLegal(type); });
313 }
314 bool isLegal(ValueRange range) const {
315 return llvm::all_of(range, [this](Value value) { return isLegal(value); });
316 }
317
318 /// Return true if the given operation has legal operand and result types.
319 bool isLegal(Operation *op) const;
320
321 /// Return true if the types of block arguments within the region are legal.
322 bool isLegal(Region *region) const;
323
324 /// Return true if the inputs and outputs of the given function type are
325 /// legal.
326 bool isSignatureLegal(FunctionType ty) const;
327
328 /// This method allows for converting a specific argument of a signature. It
329 /// takes as inputs the original argument input number, type.
330 /// On success, it populates 'result' with any new mappings.
331 LogicalResult convertSignatureArg(unsigned inputNo, Type type,
332 SignatureConversion &result) const;
333 LogicalResult convertSignatureArgs(TypeRange types,
334 SignatureConversion &result,
335 unsigned origInputOffset = 0) const;
336 LogicalResult convertSignatureArg(unsigned inputNo, Value value,
337 SignatureConversion &result) const;
338 LogicalResult convertSignatureArgs(ValueRange values,
339 SignatureConversion &result,
340 unsigned origInputOffset = 0) const;
341
342 /// This function converts the type signature of the given block, by invoking
343 /// 'convertSignatureArg' for each argument. This function should return a
344 /// valid conversion for the signature on success, std::nullopt otherwise.
345 std::optional<SignatureConversion> convertBlockSignature(Block *block) const;
346
347 /// Materialize a conversion from a set of types into one result type by
348 /// generating a cast sequence of some kind. See the respective
349 /// `add*Materialization` for more information on the context for these
350 /// methods.
351 Value materializeSourceConversion(OpBuilder &builder, Location loc,
352 Type resultType, ValueRange inputs) const;
353 Value materializeTargetConversion(OpBuilder &builder, Location loc,
354 Type resultType, ValueRange inputs,
355 Type originalType = {}) const;
356 SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
357 Location loc,
358 TypeRange resultType,
359 ValueRange inputs,
360 Type originalType = {}) const;
361
362 /// Convert an attribute present `attr` from within the type `type` using
363 /// the registered conversion functions. If no applicable conversion has been
364 /// registered, return std::nullopt. Note that the empty attribute/`nullptr`
365 /// is a valid return value for this function.
366 std::optional<Attribute> convertTypeAttribute(Type type,
367 Attribute attr) const;
368
369private:
370 /// The signature of the callback used to convert a type. If the new set of
371 /// types is empty, the type is removed and any usages of the existing value
372 /// are expected to be removed during conversion.
373 using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
374 PointerUnion<Type, Value>, SmallVectorImpl<Type> &)>;
375
376 /// The signature of the callback used to materialize a source conversion.
377 ///
378 /// Arguments: builder, result type, inputs, location
379 using SourceMaterializationCallbackFn =
380 std::function<Value(OpBuilder &, Type, ValueRange, Location)>;
381
382 /// The signature of the callback used to materialize a target conversion.
383 ///
384 /// Arguments: builder, result types, inputs, location, original type
385 using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
386 OpBuilder &, TypeRange, ValueRange, Location, Type)>;
387
388 /// The signature of the callback used to convert a type attribute.
389 using TypeAttributeConversionCallbackFn =
390 std::function<AttributeConversionResult(Type, Attribute)>;
391
392 /// Generate a wrapper for the given callback. This allows for accepting
393 /// different callback forms, that all compose into a single version.
394 /// With callback of form: `std::optional<Type>(T)`, where `T` can be a
395 /// `Value` or a `Type` (or a class derived from `Type`).
396 template <typename T, typename FnT>
397 std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
398 wrapCallback(FnT &&callback) {
399 return wrapCallback<T>([callback = std::forward<FnT>(callback)](
400 T typeOrValue, SmallVectorImpl<Type> &results) {
401 if (std::optional<Type> resultOpt = callback(typeOrValue)) {
402 bool wasSuccess = static_cast<bool>(*resultOpt);
403 if (wasSuccess)
404 results.push_back(*resultOpt);
405 return std::optional<LogicalResult>(success(wasSuccess));
406 }
407 return std::optional<LogicalResult>();
408 });
409 }
410 /// With callback of form: `std::optional<LogicalResult>(
411 /// T, SmallVectorImpl<Type> &)`, where `T` is a type.
412 template <typename T, typename FnT>
413 std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
414 std::is_base_of_v<Type, T>,
415 ConversionCallbackFn>
416 wrapCallback(FnT &&callback) const {
417 return [callback = std::forward<FnT>(callback)](
418 PointerUnion<Type, Value> typeOrValue,
419 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
420 T derivedType;
421 if (Type t = dyn_cast<Type>(typeOrValue)) {
422 derivedType = dyn_cast<T>(t);
423 } else if (Value v = dyn_cast<Value>(typeOrValue)) {
424 derivedType = dyn_cast<T>(v.getType());
425 } else {
426 llvm_unreachable("unexpected variant");
427 }
428 if (!derivedType)
429 return std::nullopt;
430 return callback(derivedType, results);
431 };
432 }
433 /// With callback of form: `std::optional<LogicalResult>(
434 /// T, SmallVectorImpl<Type>)`, where `T` is a `Value`.
435 template <typename T, typename FnT>
436 std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
437 std::is_same_v<T, Value>,
438 ConversionCallbackFn>
439 wrapCallback(FnT &&callback) {
440 contextAwareTypeConversionsIndex = conversions.size();
441 return [callback = std::forward<FnT>(callback)](
442 PointerUnion<Type, Value> typeOrValue,
443 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
444 if (Type t = dyn_cast<Type>(typeOrValue)) {
445 // Context-aware type conversion was called with a type.
446 return std::nullopt;
447 } else if (Value v = dyn_cast<Value>(typeOrValue)) {
448 return callback(v, results);
449 }
450 llvm_unreachable("unexpected variant");
451 return std::nullopt;
452 };
453 }
454
455 /// Register a type conversion.
456 void registerConversion(ConversionCallbackFn callback) {
457 conversions.emplace_back(std::move(callback));
458 cachedDirectConversions.clear();
459 cachedMultiConversions.clear();
460 }
461
462 /// Generate a wrapper for the given source materialization callback. The
463 /// callback may take any subclass of `Type` and the wrapper will check for
464 /// the target type to be of the expected class before calling the callback.
465 template <typename T, typename FnT>
466 SourceMaterializationCallbackFn
467 wrapSourceMaterialization(FnT &&callback) const {
468 return [callback = std::forward<FnT>(callback)](
469 OpBuilder &builder, Type resultType, ValueRange inputs,
470 Location loc) -> Value {
471 if (T derivedType = dyn_cast<T>(resultType))
472 return callback(builder, derivedType, inputs, loc);
473 return Value();
474 };
475 }
476
477 /// Generate a wrapper for the given target materialization callback.
478 /// The callback may take any subclass of `Type` and the wrapper will check
479 /// for the target type to be of the expected class before calling the
480 /// callback.
481 ///
482 /// With callback of form:
483 /// - Value(OpBuilder &, T, ValueRange, Location, Type)
484 /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location, Type)
485 template <typename T, typename FnT>
486 std::enable_if_t<
487 std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
488 TargetMaterializationCallbackFn>
489 wrapTargetMaterialization(FnT &&callback) const {
490 return [callback = std::forward<FnT>(callback)](
491 OpBuilder &builder, TypeRange resultTypes, ValueRange inputs,
492 Location loc, Type originalType) -> SmallVector<Value> {
493 SmallVector<Value> result;
494 if constexpr (std::is_same<T, TypeRange>::value) {
495 // This is a 1:N target materialization. Return the produces values
496 // directly.
497 result = callback(builder, resultTypes, inputs, loc, originalType);
498 } else if constexpr (std::is_assignable<Type, T>::value) {
499 // This is a 1:1 target materialization. Invoke the callback only if a
500 // single SSA value is requested.
501 if (resultTypes.size() == 1) {
502 // Invoke the callback only if the type class of the callback matches
503 // the requested result type.
504 if (T derivedType = dyn_cast<T>(resultTypes.front())) {
505 // 1:1 materializations produce single values, but we store 1:N
506 // target materialization functions in the type converter. Wrap the
507 // result value in a SmallVector<Value>.
508 Value val =
509 callback(builder, derivedType, inputs, loc, originalType);
510 if (val)
511 result.push_back(val);
512 }
513 }
514 } else {
515 static_assert(sizeof(T) == 0, "T must be a Type or a TypeRange");
516 }
517 return result;
518 };
519 }
520 /// With callback of form:
521 /// - Value(OpBuilder &, T, ValueRange, Location)
522 /// - SmallVector<Value>(OpBuilder &, TypeRange, ValueRange, Location)
523 template <typename T, typename FnT>
524 std::enable_if_t<
525 std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
526 TargetMaterializationCallbackFn>
527 wrapTargetMaterialization(FnT &&callback) const {
528 return wrapTargetMaterialization<T>(
529 [callback = std::forward<FnT>(callback)](
530 OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc,
531 Type originalType) {
532 return callback(builder, resultTypes, inputs, loc);
533 });
534 }
535
536 /// Generate a wrapper for the given memory space conversion callback. The
537 /// callback may take any subclass of `Attribute` and the wrapper will check
538 /// for the target attribute to be of the expected class before calling the
539 /// callback.
540 template <typename T, typename A, typename FnT>
541 TypeAttributeConversionCallbackFn
542 wrapTypeAttributeConversion(FnT &&callback) const {
543 return [callback = std::forward<FnT>(callback)](
544 Type type, Attribute attr) -> AttributeConversionResult {
545 if (T derivedType = dyn_cast<T>(type)) {
546 if (A derivedAttr = dyn_cast_or_null<A>(attr))
547 return callback(derivedType, derivedAttr);
548 }
549 return AttributeConversionResult::na();
550 };
551 }
552
553 /// Register a memory space conversion, clearing caches.
554 void
555 registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
556 typeAttributeConversions.emplace_back(std::move(callback));
557 // Clear type conversions in case a memory space is lingering inside.
558 cachedDirectConversions.clear();
559 cachedMultiConversions.clear();
560 }
561
562 /// Internal implementation of the type conversion.
563 LogicalResult convertTypeImpl(PointerUnion<Type, Value> t,
564 SmallVectorImpl<Type> &results) const;
565
566 /// The set of registered conversion functions.
567 SmallVector<ConversionCallbackFn, 4> conversions;
568
569 /// The list of registered materialization functions.
570 SmallVector<SourceMaterializationCallbackFn, 2> sourceMaterializations;
571 SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
572
573 /// The list of registered type attribute conversion functions.
574 SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
575
576 /// A set of cached conversions to avoid recomputing in the common case.
577 /// Direct 1-1 conversions are the most common, so this cache stores the
578 /// successful 1-1 conversions as well as all failed conversions.
579 mutable DenseMap<Type, Type> cachedDirectConversions;
580 /// This cache stores the successful 1->N conversions, where N != 1.
581 mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
582 /// A mutex used for cache access
583 mutable llvm::sys::SmartRWMutex<true> cacheMutex;
584 /// Whether the type converter has context-aware type conversions. I.e.,
585 /// conversion rules that depend on the SSA value instead of just the type.
586 /// We store here the index in the `conversions` vector of the last added
587 /// context-aware conversion, if any. This is useful because we can't cache
588 /// the result of type conversion happening after context-aware conversions,
589 /// because the type converter may return different results for the same input
590 /// type. This is why it is recommened to add context-aware conversions first,
591 /// any context-free conversions after will benefit from caching.
592 int contextAwareTypeConversionsIndex = -1;
593};
594
595//===----------------------------------------------------------------------===//
596// Conversion Patterns
597//===----------------------------------------------------------------------===//
598
599/// Base class for the conversion patterns. This pattern class enables type
600/// conversions, and other uses specific to the conversion framework. As such,
601/// patterns of this type can only be used with the 'apply*' methods below.
602class ConversionPattern : public RewritePattern {
603public:
604 using OpAdaptor = ArrayRef<Value>;
605 using OneToNOpAdaptor = ArrayRef<ValueRange>;
606
607 /// Hook for derived classes to implement combined matching and rewriting.
608 /// This overload supports only 1:1 replacements. The 1:N overload is called
609 /// by the driver. By default, it calls this 1:1 overload or fails to match
610 /// if 1:N replacements were found.
611 virtual LogicalResult
612 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
613 ConversionPatternRewriter &rewriter) const {
614 llvm_unreachable("matchAndRewrite is not implemented");
615 }
616
617 /// Hook for derived classes to implement combined matching and rewriting.
618 /// This overload supports 1:N replacements.
619 virtual LogicalResult
620 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
621 ConversionPatternRewriter &rewriter) const {
622 return dispatchTo1To1(*this, op, operands, rewriter);
623 }
624
625 /// Attempt to match and rewrite the IR root at the specified operation.
626 LogicalResult matchAndRewrite(Operation *op,
627 PatternRewriter &rewriter) const final;
628
629 /// Return the type converter held by this pattern, or nullptr if the pattern
630 /// does not require type conversion.
631 const TypeConverter *getTypeConverter() const { return typeConverter; }
632
633 template <typename ConverterTy>
634 std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
635 const ConverterTy *>
636 getTypeConverter() const {
637 return static_cast<const ConverterTy *>(typeConverter);
638 }
639
640protected:
641 /// See `RewritePattern::RewritePattern` for information on the other
642 /// available constructors.
643 using RewritePattern::RewritePattern;
644 /// Construct a conversion pattern with the given converter, and forward the
645 /// remaining arguments to RewritePattern.
646 template <typename... Args>
647 ConversionPattern(const TypeConverter &typeConverter, Args &&...args)
648 : RewritePattern(std::forward<Args>(args)...),
649 typeConverter(&typeConverter) {}
650
651 /// Given an array of value ranges, which are the inputs to a 1:N adaptor,
652 /// try to extract the single value of each range to construct a the inputs
653 /// for a 1:1 adaptor.
654 ///
655 /// Returns failure if at least one range has 0 or more than 1 value.
656 FailureOr<SmallVector<Value>>
657 getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands) const;
658
659 /// Overloaded method used to dispatch to the 1:1 'matchAndRewrite' method
660 /// if possible and emit diagnostic with a failure return value otherwise.
661 /// 'self' should be '*this' of the derived-pattern and is used to dispatch
662 /// to the correct 'matchAndRewrite' method in the derived pattern.
663 template <typename SelfPattern, typename SourceOp>
664 static LogicalResult dispatchTo1To1(const SelfPattern &self, SourceOp op,
665 ArrayRef<ValueRange> operands,
666 ConversionPatternRewriter &rewriter);
667
668 /// Same as above, but accepts an adaptor as operand.
669 template <typename SelfPattern, typename SourceOp>
670 static LogicalResult dispatchTo1To1(
671 const SelfPattern &self, SourceOp op,
672 typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
673 ConversionPatternRewriter &rewriter);
674
675protected:
676 /// An optional type converter for use by this pattern.
677 const TypeConverter *typeConverter = nullptr;
678};
679
680/// OpConversionPattern is a wrapper around ConversionPattern that allows for
681/// matching and rewriting against an instance of a derived operation class as
682/// opposed to a raw Operation.
683template <typename SourceOp>
684class OpConversionPattern : public ConversionPattern {
685public:
686 /// Type alias to allow derived classes to inherit constructors with
687 /// `using Base::Base;`.
688 using Base = OpConversionPattern;
689
690 using OpAdaptor = typename SourceOp::Adaptor;
691 using OneToNOpAdaptor =
692 typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
693
694 OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
695 : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
696 OpConversionPattern(const TypeConverter &typeConverter, MLIRContext *context,
697 PatternBenefit benefit = 1)
698 : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
699 context) {}
700
701 /// Wrappers around the ConversionPattern methods that pass the derived op
702 /// type.
703 LogicalResult
704 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
705 ConversionPatternRewriter &rewriter) const final {
706 auto sourceOp = cast<SourceOp>(op);
707 return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
708 }
709 LogicalResult
710 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
711 ConversionPatternRewriter &rewriter) const final {
712 auto sourceOp = cast<SourceOp>(op);
713 return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
714 rewriter);
715 }
716
717 /// Methods that operate on the SourceOp type. One of these must be
718 /// overridden by the derived pattern class.
719 virtual LogicalResult
720 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
721 ConversionPatternRewriter &rewriter) const {
722 llvm_unreachable("matchAndRewrite is not implemented");
723 }
724 virtual LogicalResult
725 matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
726 ConversionPatternRewriter &rewriter) const {
727 return dispatchTo1To1(*this, op, adaptor, rewriter);
728 }
729
730private:
731 using ConversionPattern::matchAndRewrite;
732};
733
734/// OpInterfaceConversionPattern is a wrapper around ConversionPattern that
735/// allows for matching and rewriting against an instance of an OpInterface
736/// class as opposed to a raw Operation.
737template <typename SourceOp>
738class OpInterfaceConversionPattern : public ConversionPattern {
739public:
740 /// Type alias to allow derived classes to inherit constructors with
741 /// `using Base::Base;`.
742 using Base = OpInterfaceConversionPattern;
743
744 OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
745 : ConversionPattern(Pattern::MatchInterfaceOpTypeTag(),
746 SourceOp::getInterfaceID(), benefit, context) {}
747 OpInterfaceConversionPattern(const TypeConverter &typeConverter,
748 MLIRContext *context, PatternBenefit benefit = 1)
749 : ConversionPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(),
750 SourceOp::getInterfaceID(), benefit, context) {}
751
752 /// Wrappers around the ConversionPattern methods that pass the derived op
753 /// type.
754 LogicalResult
755 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
756 ConversionPatternRewriter &rewriter) const final {
757 return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
758 }
759 LogicalResult
760 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
761 ConversionPatternRewriter &rewriter) const final {
762 return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
763 }
764
765 /// Methods that operate on the SourceOp type. One of these must be
766 /// overridden by the derived pattern class.
767 virtual LogicalResult
768 matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
769 ConversionPatternRewriter &rewriter) const {
770 llvm_unreachable("matchAndRewrite is not implemented");
771 }
772 virtual LogicalResult
773 matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
774 ConversionPatternRewriter &rewriter) const {
775 return dispatchTo1To1(*this, op, operands, rewriter);
776 }
777
778private:
779 using ConversionPattern::matchAndRewrite;
780};
781
782/// OpTraitConversionPattern is a wrapper around ConversionPattern that allows
783/// for matching and rewriting against instances of an operation that possess a
784/// given trait.
785template <template <typename> class TraitType>
786class OpTraitConversionPattern : public ConversionPattern {
787public:
788 /// Type alias to allow derived classes to inherit constructors with
789 /// `using Base::Base;`.
790 using Base = OpTraitConversionPattern;
791
792 OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
793 : ConversionPattern(Pattern::MatchTraitOpTypeTag(),
794 TypeID::get<TraitType>(), benefit, context) {}
795 OpTraitConversionPattern(const TypeConverter &typeConverter,
796 MLIRContext *context, PatternBenefit benefit = 1)
797 : ConversionPattern(typeConverter, Pattern::MatchTraitOpTypeTag(),
798 TypeID::get<TraitType>(), benefit, context) {}
799};
800
801/// Generic utility to convert op result types according to type converter
802/// without knowing exact op type.
803/// Clones existing op with new result types and returns it.
804FailureOr<Operation *>
805convertOpResultTypes(Operation *op, ValueRange operands,
806 const TypeConverter &converter,
807 ConversionPatternRewriter &rewriter);
808
809/// Add a pattern to the given pattern list to convert the signature of a
810/// FunctionOpInterface op with the given type converter. This only supports
811/// ops which use FunctionType to represent their type.
812void populateFunctionOpInterfaceTypeConversionPattern(
813 StringRef functionLikeOpName, RewritePatternSet &patterns,
814 const TypeConverter &converter, PatternBenefit benefit = 1);
815
816template <typename FuncOpT>
817void populateFunctionOpInterfaceTypeConversionPattern(
818 RewritePatternSet &patterns, const TypeConverter &converter,
819 PatternBenefit benefit = 1) {
820 populateFunctionOpInterfaceTypeConversionPattern(
821 FuncOpT::getOperationName(), patterns, converter, benefit);
822}
823
824void populateAnyFunctionOpInterfaceTypeConversionPattern(
825 RewritePatternSet &patterns, const TypeConverter &converter,
826 PatternBenefit benefit = 1);
827
828//===----------------------------------------------------------------------===//
829// Conversion PatternRewriter
830//===----------------------------------------------------------------------===//
831
832namespace detail {
834} // namespace detail
835
836/// This class implements a pattern rewriter for use with ConversionPatterns. It
837/// extends the base PatternRewriter and provides special conversion specific
838/// hooks.
839class ConversionPatternRewriter final : public PatternRewriter {
840public:
841 ~ConversionPatternRewriter() override;
842
843 /// Return the configuration of the current dialect conversion.
844 const ConversionConfig &getConfig() const;
845
846 /// Apply a signature conversion to given block. This replaces the block with
847 /// a new block containing the updated signature. The operations of the given
848 /// block are inlined into the newly-created block, which is returned.
849 ///
850 /// If no block argument types are changing, the original block will be
851 /// left in place and returned.
852 ///
853 /// A signature converison must be provided. (Type converters can construct
854 /// a signature conversion with `convertBlockSignature`.)
855 ///
856 /// Optionally, a type converter can be provided to build materializations.
857 /// Note: If no type converter was provided or the type converter does not
858 /// specify any suitable source/target materialization rules, the dialect
859 /// conversion may fail to legalize unresolved materializations.
860 Block *
861 applySignatureConversion(Block *block,
862 TypeConverter::SignatureConversion &conversion,
863 const TypeConverter *converter = nullptr);
864
865 /// Apply a signature conversion to each block in the given region. This
866 /// replaces each block with a new block containing the updated signature. If
867 /// an updated signature would match the current signature, the respective
868 /// block is left in place as is. (See `applySignatureConversion` for
869 /// details.) The new entry block of the region is returned.
870 ///
871 /// SignatureConversions are computed with the specified type converter.
872 /// This function returns "failure" if the type converter failed to compute
873 /// a SignatureConversion for at least one block.
874 ///
875 /// Optionally, a special SignatureConversion can be specified for the entry
876 /// block. This is because the types of the entry block arguments are often
877 /// tied semantically to the operation.
878 FailureOr<Block *> convertRegionTypes(
879 Region *region, const TypeConverter &converter,
880 TypeConverter::SignatureConversion *entryConversion = nullptr);
881
882 /// Replace all the uses of `from` with `to`. The type of `from` and `to` is
883 /// allowed to differ. The conversion driver will try to reconcile all type
884 /// mismatches that still exist at the end of the conversion with
885 /// materializations. This function supports both 1:1 and 1:N replacements.
886 ///
887 /// Note: If `allowPatternRollback` is set to "true", this function behaves
888 /// slightly different:
889 ///
890 /// 1. All current and future uses of `from` are replaced. The same value must
891 /// not be replaced multiple times. That's an API violation.
892 /// 2. Uses are not replaced immediately but in a delayed fashion. Patterns
893 /// may still see the original uses when inspecting IR.
894 /// 3. Uses within the same block that appear before the defining operation
895 /// of the replacement value are not replaced. This allows users to
896 /// perform certain replaceAllUsesExcept-style replacements, even though
897 /// such API is not directly supported.
898 ///
899 /// Note: In an attempt to align the ConversionPatternRewriter and
900 /// RewriterBase APIs, (3) may be removed in the future.
901 void replaceAllUsesWith(Value from, ValueRange to);
902 void replaceAllUsesWith(Value from, Value to) override {
903 replaceAllUsesWith(from, ValueRange{to});
904 }
905
906 /// Replace the uses of `from` with `to` for which the `functor` returns
907 /// "true". The conversion driver will try to reconcile all type mismatches
908 /// that still exist at the end of the conversion with materializations.
909 /// This function supports both 1:1 and 1:N replacements.
910 ///
911 /// Note: The functor is also applied to builtin.unrealized_conversion_cast
912 /// ops that may have been inserted by the conversion driver. Some uses may
913 /// have been wrapped in unrealized_conversion_cast ops due to type changes.
914 ///
915 /// Note: This function is not supported in rollback mode. Calling it in
916 /// rollback mode will trigger an assertion. Furthermore, the
917 /// `allUsesReplaced` flag is not supported yet.
918 void replaceUsesWithIf(Value from, Value to,
919 function_ref<bool(OpOperand &)> functor,
920 bool *allUsesReplaced = nullptr) override {
921 replaceUsesWithIf(from, ValueRange{to}, functor, allUsesReplaced);
922 }
923 void replaceUsesWithIf(Value from, ValueRange to,
924 function_ref<bool(OpOperand &)> functor,
925 bool *allUsesReplaced = nullptr);
926
927 /// Return the converted value of 'key' with a type defined by the type
928 /// converter of the currently executing pattern. Return nullptr in the case
929 /// of failure, the remapped value otherwise.
930 Value getRemappedValue(Value key);
931
932 /// Return the converted values that replace 'keys' with types defined by the
933 /// type converter of the currently executing pattern. Returns failure if the
934 /// remap failed, success otherwise.
935 LogicalResult getRemappedValues(ValueRange keys,
936 SmallVectorImpl<Value> &results);
937
938 //===--------------------------------------------------------------------===//
939 // PatternRewriter Hooks
940 //===--------------------------------------------------------------------===//
941
942 /// Indicate that the conversion rewriter can recover from rewrite failure.
943 /// Recovery is supported via rollback, allowing for continued processing of
944 /// patterns even if a failure is encountered during the rewrite step.
945 bool canRecoverFromRewriteFailure() const override { return true; }
946
947 /// Replace the given operation with the new values. The number of op results
948 /// and replacement values must match. The types may differ: the dialect
949 /// conversion driver will reconcile any surviving type mismatches at the end
950 /// of the conversion process with source materializations. The given
951 /// operation is erased.
952 void replaceOp(Operation *op, ValueRange newValues) override;
953
954 /// Replace the given operation with the results of the new op. The number of
955 /// op results must match. The types may differ: the dialect conversion
956 /// driver will reconcile any surviving type mismatches at the end of the
957 /// conversion process with source materializations. The original operation
958 /// is erased.
959 void replaceOp(Operation *op, Operation *newOp) override;
960
961 /// Replace the given operation with the new value ranges. The number of op
962 /// results and value ranges must match. The given operation is erased.
963 void replaceOpWithMultiple(Operation *op,
964 SmallVector<SmallVector<Value>> &&newValues);
965 template <typename RangeT = ValueRange>
966 void replaceOpWithMultiple(Operation *op, ArrayRef<RangeT> newValues) {
967 replaceOpWithMultiple(op,
968 llvm::to_vector_of<SmallVector<Value>>(newValues));
969 }
970 template <typename RangeT>
971 void replaceOpWithMultiple(Operation *op, RangeT &&newValues) {
972 replaceOpWithMultiple(op,
973 ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
974 }
975
976 /// PatternRewriter hook for erasing a dead operation. The uses of this
977 /// operation *must* be made dead by the end of the conversion process,
978 /// otherwise an assert will be issued.
979 void eraseOp(Operation *op) override;
980
981 /// PatternRewriter hook for erase all operations in a block. This is not yet
982 /// implemented for dialect conversion.
983 void eraseBlock(Block *block) override;
984
985 /// PatternRewriter hook for inlining the ops of a block into another block.
986 void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
987 ValueRange argValues = {}) override;
988 using PatternRewriter::inlineBlockBefore;
989
990 /// PatternRewriter hook for updating the given operation in-place.
991 /// Note: These methods only track updates to the given operation itself,
992 /// and not nested regions. Updates to regions will still require notification
993 /// through other more specific hooks above.
994 void startOpModification(Operation *op) override;
995
996 /// PatternRewriter hook for updating the given operation in-place.
997 void finalizeOpModification(Operation *op) override;
998
999 /// PatternRewriter hook for updating the given operation in-place.
1000 void cancelOpModification(Operation *op) override;
1001
1002 /// Return a reference to the internal implementation.
1003 detail::ConversionPatternRewriterImpl &getImpl();
1004
1005 /// Attempt to legalize the given operation. This can be used within
1006 /// conversion patterns to change the default pre-order legalization order.
1007 /// Returns "success" if the operation was legalized, "failure" otherwise.
1008 ///
1009 /// Note: In a partial conversion, this function returns "success" even if
1010 /// the operation could not be legalized, as long as it was not explicitly
1011 /// marked as illegal in the conversion target.
1012 LogicalResult legalize(Operation *op);
1013
1014 /// Attempt to legalize the given region. This can be used within
1015 /// conversion patterns to change the default pre-order legalization order.
1016 /// Returns "success" if the region was legalized, "failure" otherwise.
1017 ///
1018 /// If the current pattern runs with a type converter, the entry block
1019 /// signature will be converted before legalizing the operations in the
1020 /// region.
1021 ///
1022 /// Note: In a partial conversion, this function returns "success" even if
1023 /// an operation could not be legalized, as long as it was not explicitly
1024 /// marked as illegal in the conversion target.
1025 LogicalResult legalize(Region *r);
1026
1027private:
1028 // Allow OperationConverter to construct new rewriters.
1029 friend struct OperationConverter;
1030
1031 /// Conversion pattern rewriters must not be used outside of dialect
1032 /// conversions. They apply some IR rewrites in a delayed fashion and could
1033 /// bring the IR into an inconsistent state when used standalone.
1034 explicit ConversionPatternRewriter(MLIRContext *ctx,
1035 const ConversionConfig &config,
1036 OperationConverter &converter);
1037
1038 // Hide unsupported pattern rewriter API.
1039 using OpBuilder::setListener;
1040
1041 std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
1042};
1043
1044template <typename SelfPattern, typename SourceOp>
1045LogicalResult
1046ConversionPattern::dispatchTo1To1(const SelfPattern &self, SourceOp op,
1047 ArrayRef<ValueRange> operands,
1048 ConversionPatternRewriter &rewriter) {
1049 FailureOr<SmallVector<Value>> oneToOneOperands =
1050 self.getOneToOneAdaptorOperands(operands);
1051 if (failed(oneToOneOperands))
1052 return rewriter.notifyMatchFailure(op,
1053 "pattern '" + self.getDebugName() +
1054 "' does not support 1:N conversion");
1055 return self.matchAndRewrite(op, *oneToOneOperands, rewriter);
1056}
1057
1058template <typename SelfPattern, typename SourceOp>
1059LogicalResult ConversionPattern::dispatchTo1To1(
1060 const SelfPattern &self, SourceOp op,
1061 typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
1062 ConversionPatternRewriter &rewriter) {
1063 FailureOr<SmallVector<Value>> oneToOneOperands =
1064 self.getOneToOneAdaptorOperands(adaptor.getOperands());
1065 if (failed(oneToOneOperands))
1066 return rewriter.notifyMatchFailure(op,
1067 "pattern '" + self.getDebugName() +
1068 "' does not support 1:N conversion");
1069 return self.matchAndRewrite(
1070 op, typename SourceOp::Adaptor(*oneToOneOperands, adaptor), rewriter);
1071}
1072
1073//===----------------------------------------------------------------------===//
1074// ConversionTarget
1075//===----------------------------------------------------------------------===//
1076
1077/// This class describes a specific conversion target.
1078class ConversionTarget {
1079public:
1080 /// This enumeration corresponds to the specific action to take when
1081 /// considering an operation legal for this conversion target.
1082 enum class LegalizationAction {
1083 /// The target supports this operation.
1084 Legal,
1085
1086 /// This operation has dynamic legalization constraints that must be checked
1087 /// by the target.
1088 Dynamic,
1089
1090 /// The target explicitly does not support this operation.
1091 Illegal,
1092 };
1093
1094 /// A structure containing additional information describing a specific legal
1095 /// operation instance.
1096 struct LegalOpDetails {
1097 /// A flag that indicates if this operation is 'recursively' legal. This
1098 /// means that if an operation is legal, either statically or dynamically,
1099 /// all of the operations nested within are also considered legal.
1100 bool isRecursivelyLegal = false;
1101 };
1102
1103 /// The signature of the callback used to determine if an operation is
1104 /// dynamically legal on the target.
1105 using DynamicLegalityCallbackFn =
1106 std::function<std::optional<bool>(Operation *)>;
1107
1108 ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
1109 virtual ~ConversionTarget() = default;
1110
1111 //===--------------------------------------------------------------------===//
1112 // Legality Registration
1113 //===--------------------------------------------------------------------===//
1114
1115 /// Register a legality action for the given operation.
1116 void setOpAction(OperationName op, LegalizationAction action);
1117 template <typename OpT>
1118 void setOpAction(LegalizationAction action) {
1119 setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
1120 }
1121
1122 /// Register the given operations as legal.
1123 void addLegalOp(OperationName op) {
1124 setOpAction(op, LegalizationAction::Legal);
1125 }
1126 template <typename OpT>
1127 void addLegalOp() {
1128 addLegalOp(OperationName(OpT::getOperationName(), &ctx));
1129 }
1130 template <typename OpT, typename OpT2, typename... OpTs>
1131 void addLegalOp() {
1132 addLegalOp<OpT>();
1133 addLegalOp<OpT2, OpTs...>();
1134 }
1135
1136 /// Register the given operation as dynamically legal and set the dynamic
1137 /// legalization callback to the one provided.
1138 void addDynamicallyLegalOp(OperationName op,
1139 const DynamicLegalityCallbackFn &callback) {
1140 setOpAction(op, LegalizationAction::Dynamic);
1141 setLegalityCallback(op, callback);
1142 }
1143 template <typename OpT>
1144 void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
1145 addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
1146 callback);
1147 }
1148 template <typename OpT, typename OpT2, typename... OpTs>
1149 void addDynamicallyLegalOp(const DynamicLegalityCallbackFn &callback) {
1150 addDynamicallyLegalOp<OpT>(callback);
1151 addDynamicallyLegalOp<OpT2, OpTs...>(callback);
1152 }
1153 template <typename OpT, class Callable>
1154 std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
1155 addDynamicallyLegalOp(Callable &&callback) {
1156 addDynamicallyLegalOp<OpT>(
1157 [=](Operation *op) { return callback(cast<OpT>(op)); });
1158 }
1159
1160 /// Register the given operation as illegal, i.e. this operation is known to
1161 /// not be supported by this target.
1162 void addIllegalOp(OperationName op) {
1163 setOpAction(op, LegalizationAction::Illegal);
1164 }
1165 template <typename OpT>
1166 void addIllegalOp() {
1167 addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
1168 }
1169 template <typename OpT, typename OpT2, typename... OpTs>
1170 void addIllegalOp() {
1171 addIllegalOp<OpT>();
1172 addIllegalOp<OpT2, OpTs...>();
1173 }
1174
1175 /// Mark an operation, that *must* have either been set as `Legal` or
1176 /// `DynamicallyLegal`, as being recursively legal. This means that in
1177 /// addition to the operation itself, all of the operations nested within are
1178 /// also considered legal. An optional dynamic legality callback may be
1179 /// provided to mark subsets of legal instances as recursively legal.
1180 void markOpRecursivelyLegal(OperationName name,
1181 const DynamicLegalityCallbackFn &callback);
1182 template <typename OpT>
1183 void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) {
1184 OperationName opName(OpT::getOperationName(), &ctx);
1185 markOpRecursivelyLegal(opName, callback);
1186 }
1187 template <typename OpT, typename OpT2, typename... OpTs>
1188 void markOpRecursivelyLegal(const DynamicLegalityCallbackFn &callback = {}) {
1189 markOpRecursivelyLegal<OpT>(callback);
1190 markOpRecursivelyLegal<OpT2, OpTs...>(callback);
1191 }
1192 template <typename OpT, class Callable>
1193 std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
1194 markOpRecursivelyLegal(Callable &&callback) {
1195 markOpRecursivelyLegal<OpT>(
1196 [=](Operation *op) { return callback(cast<OpT>(op)); });
1197 }
1198
1199 /// Register a legality action for the given dialects.
1200 void setDialectAction(ArrayRef<StringRef> dialectNames,
1201 LegalizationAction action);
1202
1203 /// Register the operations of the given dialects as legal.
1204 template <typename... Names>
1205 void addLegalDialect(StringRef name, Names... names) {
1206 SmallVector<StringRef, 2> dialectNames({name, names...});
1207 setDialectAction(dialectNames, LegalizationAction::Legal);
1208 }
1209 template <typename... Args>
1210 void addLegalDialect() {
1211 SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1212 setDialectAction(dialectNames, LegalizationAction::Legal);
1213 }
1214
1215 /// Register the operations of the given dialects as dynamically legal, i.e.
1216 /// requiring custom handling by the callback.
1217 template <typename... Names>
1218 void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback,
1219 StringRef name, Names... names) {
1220 SmallVector<StringRef, 2> dialectNames({name, names...});
1221 setDialectAction(dialectNames, LegalizationAction::Dynamic);
1222 setLegalityCallback(dialectNames, callback);
1223 }
1224 template <typename... Args>
1225 void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback) {
1226 addDynamicallyLegalDialect(std::move(callback),
1227 Args::getDialectNamespace()...);
1228 }
1229
1230 /// Register unknown operations as dynamically legal. For operations(and
1231 /// dialects) that do not have a set legalization action, treat them as
1232 /// dynamically legal and invoke the given callback.
1233 void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn) {
1234 setLegalityCallback(fn);
1235 }
1236
1237 /// Register the operations of the given dialects as illegal, i.e.
1238 /// operations of this dialect are not supported by the target.
1239 template <typename... Names>
1240 void addIllegalDialect(StringRef name, Names... names) {
1241 SmallVector<StringRef, 2> dialectNames({name, names...});
1242 setDialectAction(dialectNames, LegalizationAction::Illegal);
1243 }
1244 template <typename... Args>
1245 void addIllegalDialect() {
1246 SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1247 setDialectAction(dialectNames, LegalizationAction::Illegal);
1248 }
1249
1250 //===--------------------------------------------------------------------===//
1251 // Legality Querying
1252 //===--------------------------------------------------------------------===//
1253
1254 /// Get the legality action for the given operation.
1255 std::optional<LegalizationAction> getOpAction(OperationName op) const;
1256
1257 /// If the given operation instance is legal on this target, a structure
1258 /// containing legality information is returned. If the operation is not
1259 /// legal, std::nullopt is returned. Also returns std::nullopt if operation
1260 /// legality wasn't registered by user or dynamic legality callbacks returned
1261 /// None.
1262 ///
1263 /// Note: Legality is actually a 4-state: Legal(recursive=true),
1264 /// Legal(recursive=false), Illegal or Unknown, where Unknown is treated
1265 /// either as Legal or Illegal depending on context.
1266 std::optional<LegalOpDetails> isLegal(Operation *op) const;
1267
1268 /// Returns true is operation instance is illegal on this target. Returns
1269 /// false if operation is legal, operation legality wasn't registered by user
1270 /// or dynamic legality callbacks returned None.
1271 bool isIllegal(Operation *op) const;
1272
1273private:
1274 /// Set the dynamic legality callback for the given operation.
1275 void setLegalityCallback(OperationName name,
1276 const DynamicLegalityCallbackFn &callback);
1277
1278 /// Set the dynamic legality callback for the given dialects.
1279 void setLegalityCallback(ArrayRef<StringRef> dialects,
1280 const DynamicLegalityCallbackFn &callback);
1281
1282 /// Set the dynamic legality callback for the unknown ops.
1283 void setLegalityCallback(const DynamicLegalityCallbackFn &callback);
1284
1285 /// The set of information that configures the legalization of an operation.
1286 struct LegalizationInfo {
1287 /// The legality action this operation was given.
1288 LegalizationAction action = LegalizationAction::Illegal;
1289
1290 /// If some legal instances of this operation may also be recursively legal.
1291 bool isRecursivelyLegal = false;
1292
1293 /// The legality callback if this operation is dynamically legal.
1294 DynamicLegalityCallbackFn legalityFn;
1295 };
1296
1297 /// Get the legalization information for the given operation.
1298 std::optional<LegalizationInfo> getOpInfo(OperationName op) const;
1299
1300 /// A deterministic mapping of operation name and its respective legality
1301 /// information.
1302 llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
1303
1304 /// A set of legality callbacks for given operation names that are used to
1305 /// check if an operation instance is recursively legal.
1306 DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
1307
1308 /// A deterministic mapping of dialect name to the specific legality action to
1309 /// take.
1310 llvm::StringMap<LegalizationAction> legalDialects;
1311
1312 /// A set of dynamic legality callbacks for given dialect names.
1313 llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
1314
1315 /// An optional legality callback for unknown operations.
1316 DynamicLegalityCallbackFn unknownLegalityFn;
1317
1318 /// The current context this target applies to.
1319 MLIRContext &ctx;
1320};
1321
1322#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
1323//===----------------------------------------------------------------------===//
1324// PDL Configuration
1325//===----------------------------------------------------------------------===//
1326
1327/// A PDL configuration that is used to supported dialect conversion
1328/// functionality.
1329class PDLConversionConfig final
1330 : public PDLPatternConfigBase<PDLConversionConfig> {
1331public:
1332 PDLConversionConfig(const TypeConverter *converter) : converter(converter) {}
1333 ~PDLConversionConfig() final = default;
1334
1335 /// Return the type converter used by this configuration, which may be nullptr
1336 /// if no type conversions are expected.
1337 const TypeConverter *getTypeConverter() const { return converter; }
1338
1339 /// Hooks that are invoked at the beginning and end of a rewrite of a matched
1340 /// pattern.
1341 void notifyRewriteBegin(PatternRewriter &rewriter) final;
1342 void notifyRewriteEnd(PatternRewriter &rewriter) final;
1343
1344private:
1345 /// An optional type converter to use for the pattern.
1346 const TypeConverter *converter;
1347};
1348
1349/// Register the dialect conversion PDL functions with the given pattern set.
1350void registerConversionPDLFunctions(RewritePatternSet &patterns);
1351
1352#else
1353
1354// Stubs for when PDL in rewriting is not enabled.
1355
1356inline void registerConversionPDLFunctions(RewritePatternSet &patterns) {}
1357
1358class PDLConversionConfig final {
1359public:
1360 PDLConversionConfig(const TypeConverter * /*converter*/) {}
1361};
1362
1363#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
1364
1365//===----------------------------------------------------------------------===//
1366// ConversionConfig
1367//===----------------------------------------------------------------------===//
1368
1369/// An enum to control folding behavior during dialect conversion.
1370enum class DialectConversionFoldingMode {
1371 /// Never attempt to fold.
1372 Never,
1373 /// Only attempt to fold not legal operations before applying patterns.
1374 BeforePatterns,
1375 /// Only attempt to fold not legal operations after applying patterns.
1376 AfterPatterns,
1377};
1378
1379/// Dialect conversion configuration.
1380struct ConversionConfig {
1381 /// An optional callback used to notify about match failure diagnostics during
1382 /// the conversion. Diagnostics reported to this callback may only be
1383 /// available in debug mode.
1384 function_ref<void(Diagnostic &)> notifyCallback = nullptr;
1385
1386 /// Partial conversion only. All operations that are found not to be
1387 /// legalizable are placed in this set. (Note that if there is an op
1388 /// explicitly marked as illegal, the conversion terminates and the set will
1389 /// not necessarily be complete.)
1390 DenseSet<Operation *> *unlegalizedOps = nullptr;
1391
1392 /// Analysis conversion only. All operations that are found to be legalizable
1393 /// are placed in this set. Note that no actual rewrites are applied to the
1394 /// IR during an analysis conversion and only pre-existing operations are
1395 /// added to the set.
1396 DenseSet<Operation *> *legalizableOps = nullptr;
1397
1398 /// An optional listener that is notified about all IR modifications in case
1399 /// dialect conversion succeeds. If the dialect conversion fails and no IR
1400 /// modifications are visible (i.e., they were all rolled back), or if the
1401 /// dialect conversion is an "analysis conversion", no notifications are
1402 /// sent (apart from `notifyPatternBegin`/notifyPatternEnd`).
1403 ///
1404 /// Note: Notifications are sent in a delayed fashion, when the dialect
1405 /// conversion is guaranteed to succeed. At that point, some IR modifications
1406 /// may already have been materialized. Consequently, operations/blocks that
1407 /// are passed to listener callbacks should not be accessed. (Ops/blocks are
1408 /// guaranteed to be valid pointers and accessing op names is allowed. But
1409 /// there are no guarantees about the state of ops/blocks at the time that a
1410 /// callback is triggered.)
1411 ///
1412 /// Example: Consider a dialect conversion a new op ("test.foo") is created
1413 /// and inserted, and later moved to another block. (Moving ops also triggers
1414 /// "notifyOperationInserted".)
1415 ///
1416 /// (1) notifyOperationInserted: "test.foo" (into block "b1")
1417 /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2")
1418 ///
1419 /// When querying "op->getBlock()" during the first "notifyOperationInserted",
1420 /// "b2" would be returned because "moving an op" is a kind of rewrite that is
1421 /// immediately performed by the dialect conversion (and rolled back upon
1422 /// failure).
1423 //
1424 // Note: When receiving a "notifyBlockInserted"/"notifyOperationInserted"
1425 // callback, the previous region/block is provided to the callback, but not
1426 // the iterator pointing to the exact location within the region/block. That
1427 // is because these notifications are sent with a delay (after the IR has
1428 // already been modified) and iterators into past IR state cannot be
1429 // represented at the moment.
1430 RewriterBase::Listener *listener = nullptr;
1431
1432 /// If set to "true", the dialect conversion attempts to build source/target
1433 /// materializations through the type converter API in lieu of
1434 /// "builtin.unrealized_conversion_cast ops". The conversion process fails if
1435 /// at least one materialization could not be built.
1436 ///
1437 /// If set to "false", the dialect conversion does not build any custom
1438 /// materializations and instead inserts "builtin.unrealized_conversion_cast"
1439 /// ops to ensure that the resulting IR is valid.
1440 bool buildMaterializations = true;
1441
1442 /// If set to "true", pattern rollback is allowed. The conversion driver
1443 /// rolls back IR modifications in the following situations.
1444 ///
1445 /// 1. Pattern implementation returns "failure" after modifying IR.
1446 /// 2. Pattern produces IR (in-place modification or new IR) that is illegal
1447 /// and cannot be legalized by subsequent foldings / pattern applications.
1448 ///
1449 /// Experimental: If set to "false", the conversion driver will produce an
1450 /// LLVM fatal error instead of rolling back IR modifications. Moreover, in
1451 /// case of a failed conversion, the original IR is not restored. The
1452 /// resulting IR may be a mix of original and rewritten IR. (Same as a failed
1453 /// greedy pattern rewrite.) Use the cmake build option
1454 /// `-DMLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS=ON` (ideally together with
1455 /// ASAN) to detect invalid pattern API usage.
1456 ///
1457 /// When pattern rollback is disabled, the conversion driver has to maintain
1458 /// less internal state. This is more efficient, but not supported by all
1459 /// lowering patterns. For details, see
1460 /// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
1461 bool allowPatternRollback = true;
1462
1463 /// The folding mode to use during conversion.
1464 DialectConversionFoldingMode foldingMode =
1465 DialectConversionFoldingMode::BeforePatterns;
1466
1467 /// If set to "true", the materialization kind ("source" or "target") will be
1468 /// attached to "builtin.unrealized_conversion_cast" ops. This flag is useful
1469 /// for debugging, to find out what kind of materialization rule may be
1470 /// missing.
1471 bool attachDebugMaterializationKind = false;
1472};
1473
1474//===----------------------------------------------------------------------===//
1475// Reconcile Unrealized Casts
1476//===----------------------------------------------------------------------===//
1477
1478/// Try to reconcile all given UnrealizedConversionCastOps and store the
1479/// left-over ops in `remainingCastOps` (if provided).
1480///
1481/// This function processes cast ops in a worklist-driven fashion. For each
1482/// cast op, if the chain of input casts eventually reaches a cast op where the
1483/// input types match the output types of the matched op, replace the matched
1484/// op with the inputs.
1485///
1486/// Example:
1487/// %1 = unrealized_conversion_cast %0 : !A to !B
1488/// %2 = unrealized_conversion_cast %1 : !B to !C
1489/// %3 = unrealized_conversion_cast %2 : !C to !A
1490///
1491/// In the above example, %0 can be used instead of %3 and all cast ops are
1492/// folded away.
1495 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
1498 SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
1499
1500//===----------------------------------------------------------------------===//
1501// Op Conversion Entry Points
1502//===----------------------------------------------------------------------===//
1503
1504/// Below we define several entry points for operation conversion. It is
1505/// important to note that the patterns provided to the conversion framework may
1506/// have additional constraints. See the `PatternRewriter Hooks` section of the
1507/// ConversionPatternRewriter, to see what additional constraints are imposed on
1508/// the use of the PatternRewriter.
1509
1510/// Apply a partial conversion on the given operations and all nested
1511/// operations. This method converts as many operations to the target as
1512/// possible, ignoring operations that failed to legalize. This method only
1513/// returns failure if there ops explicitly marked as illegal.
1514LogicalResult
1515applyPartialConversion(ArrayRef<Operation *> ops,
1516 const ConversionTarget &target,
1518 ConversionConfig config = ConversionConfig());
1519LogicalResult
1520applyPartialConversion(Operation *op, const ConversionTarget &target,
1522 ConversionConfig config = ConversionConfig());
1523
1524/// Apply a complete conversion on the given operations, and all nested
1525/// operations. This method returns failure if the conversion of any operation
1526/// fails, or if there are unreachable blocks in any of the regions nested
1527/// within 'ops'.
1528LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
1529 const ConversionTarget &target,
1531 ConversionConfig config = ConversionConfig());
1532LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
1534 ConversionConfig config = ConversionConfig());
1535
1536/// Apply an analysis conversion on the given operations, and all nested
1537/// operations. This method analyzes which operations would be successfully
1538/// converted to the target if a conversion was applied. All operations that
1539/// were found to be legalizable to the given 'target' are placed within the
1540/// provided 'config.legalizableOps' set; note that no actual rewrites are
1541/// applied to the operations on success. This method only returns failure if
1542/// there are unreachable blocks in any of the regions nested within 'ops'.
1543LogicalResult
1544applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
1546 ConversionConfig config = ConversionConfig());
1547LogicalResult
1548applyAnalysisConversion(Operation *op, ConversionTarget &target,
1550 ConversionConfig config = ConversionConfig());
1551} // namespace mlir
1552
1553#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
return success()
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
This class represents a frozen set of patterns that can be processed by a pattern applicator.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
AttrTypeReplacer.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
static void reconcileUnrealizedCasts(const DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152