13#ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14#define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
16#include "mlir/Config/mlir-config.h"
18#include "llvm/ADT/MapVector.h"
19#include "llvm/ADT/StringMap.h"
27struct ConversionConfig;
28class ConversionPatternRewriter;
45 using Base = TypeConverter;
47 virtual ~TypeConverter() =
default;
48 TypeConverter() =
default;
50 TypeConverter(
const TypeConverter &other)
51 : conversions(other.conversions),
52 sourceMaterializations(other.sourceMaterializations),
53 targetMaterializations(other.targetMaterializations),
54 typeAttributeConversions(other.typeAttributeConversions) {}
55 TypeConverter &operator=(
const TypeConverter &other) {
56 conversions = other.conversions;
57 sourceMaterializations = other.sourceMaterializations;
58 targetMaterializations = other.targetMaterializations;
59 typeAttributeConversions = other.typeAttributeConversions;
65 class SignatureConversion {
67 SignatureConversion(
unsigned numOrigInputs)
68 : remappedInputs(numOrigInputs) {}
74 SmallVector<Value, 1> replacementValues;
77 bool replacedWithValues()
const {
return !replacementValues.empty(); }
81 ArrayRef<Type> getConvertedTypes()
const {
return argTypes; }
84 std::optional<InputMapping> getInputMapping(
unsigned input)
const {
85 return remappedInputs[input];
94 void addInputs(
unsigned origInputNo, ArrayRef<Type> types);
98 void addInputs(ArrayRef<Type> types);
102 void remapInput(
unsigned origInputNo, ArrayRef<Value> replacements);
107 void remapInput(
unsigned origInputNo,
unsigned newInputNo,
108 unsigned newInputCount = 1);
111 SmallVector<std::optional<InputMapping>, 4> remappedInputs;
114 SmallVector<Type, 4> argTypes;
119 class AttributeConversionResult {
121 constexpr AttributeConversionResult() : impl() {}
122 AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
124 static AttributeConversionResult
result(Attribute attr);
125 static AttributeConversionResult na();
126 static AttributeConversionResult abort();
128 bool hasResult()
const;
130 bool isAbort()
const;
132 Attribute getResult()
const;
135 AttributeConversionResult(Attribute attr,
unsigned tag) : impl(attr, tag) {}
137 llvm::PointerIntPair<Attribute, 2> impl;
140 static constexpr unsigned naTag = 0;
141 static constexpr unsigned resultTag = 1;
142 static constexpr unsigned abortTag = 2;
172 template <
typename FnT,
typename T =
typename llvm::function_traits<
173 std::decay_t<FnT>>::template arg_t<0>>
174 void addConversion(FnT &&callback) {
175 registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
198 template <
typename FnT,
typename T =
typename llvm::function_traits<
199 std::decay_t<FnT>>::template arg_t<1>>
200 void addSourceMaterialization(FnT &&callback) {
201 sourceMaterializations.emplace_back(
202 wrapSourceMaterialization<T>(std::forward<FnT>(callback)));
222 template <
typename FnT,
typename T =
typename llvm::function_traits<
223 std::decay_t<FnT>>::template arg_t<1>>
224 void addTargetMaterialization(FnT &&callback) {
225 targetMaterializations.emplace_back(
226 wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
250 typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
252 typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
253 void addTypeAttributeConversion(FnT &&callback) {
254 registerTypeAttributeConversion(
255 wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
265 LogicalResult convertType(Type t, SmallVectorImpl<Type> &results)
const;
274 LogicalResult convertType(Value v, SmallVectorImpl<Type> &results)
const;
278 Type convertType(Type t)
const;
279 Type convertType(Value v)
const;
284 template <
typename TargetType>
285 TargetType convertType(Type t)
const {
286 return dyn_cast_or_null<TargetType>(convertType(t));
288 template <
typename TargetType>
289 TargetType convertType(Value v)
const {
290 return dyn_cast_or_null<TargetType>(convertType(v));
296 LogicalResult convertTypes(
TypeRange types,
297 SmallVectorImpl<Type> &results)
const;
303 SmallVectorImpl<Type> &results)
const;
307 bool isLegal(Type type)
const;
308 bool isLegal(Value value)
const;
312 return llvm::all_of(range, [
this](Type type) {
return isLegal(type); });
315 return llvm::all_of(range, [
this](Value value) {
return isLegal(value); });
319 bool isLegal(Operation *op)
const;
322 bool isLegal(Region *region)
const;
326 bool isSignatureLegal(FunctionType ty)
const;
331 LogicalResult convertSignatureArg(
unsigned inputNo, Type type,
332 SignatureConversion &
result)
const;
333 LogicalResult convertSignatureArgs(
TypeRange types,
334 SignatureConversion &
result,
335 unsigned origInputOffset = 0)
const;
336 LogicalResult convertSignatureArg(
unsigned inputNo, Value value,
337 SignatureConversion &
result)
const;
338 LogicalResult convertSignatureArgs(
ValueRange values,
339 SignatureConversion &
result,
340 unsigned origInputOffset = 0)
const;
345 std::optional<SignatureConversion> convertBlockSignature(Block *block)
const;
351 Value materializeSourceConversion(OpBuilder &builder, Location loc,
353 Value materializeTargetConversion(OpBuilder &builder, Location loc,
355 Type originalType = {})
const;
356 SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
360 Type originalType = {})
const;
366 std::optional<Attribute> convertTypeAttribute(Type type,
367 Attribute attr)
const;
373 using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
374 PointerUnion<Type, Value>, SmallVectorImpl<Type> &)>;
379 using SourceMaterializationCallbackFn =
380 std::function<Value(OpBuilder &, Type,
ValueRange, Location)>;
385 using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
389 using TypeAttributeConversionCallbackFn =
390 std::function<AttributeConversionResult(Type, Attribute)>;
396 template <
typename T,
typename FnT>
397 std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
398 wrapCallback(FnT &&callback) {
399 return wrapCallback<T>([callback = std::forward<FnT>(callback)](
400 T typeOrValue, SmallVectorImpl<Type> &results) {
401 if (std::optional<Type> resultOpt = callback(typeOrValue)) {
402 bool wasSuccess =
static_cast<bool>(*resultOpt);
404 results.push_back(*resultOpt);
405 return std::optional<LogicalResult>(
success(wasSuccess));
407 return std::optional<LogicalResult>();
412 template <
typename T,
typename FnT>
413 std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
414 std::is_base_of_v<Type, T>,
415 ConversionCallbackFn>
416 wrapCallback(FnT &&callback)
const {
417 return [callback = std::forward<FnT>(callback)](
418 PointerUnion<Type, Value> typeOrValue,
419 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
421 if (Type t = dyn_cast<Type>(typeOrValue)) {
422 derivedType = dyn_cast<T>(t);
423 }
else if (Value v = dyn_cast<Value>(typeOrValue)) {
424 derivedType = dyn_cast<T>(v.getType());
426 llvm_unreachable(
"unexpected variant");
430 return callback(derivedType, results);
435 template <
typename T,
typename FnT>
436 std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
437 std::is_same_v<T, Value>,
438 ConversionCallbackFn>
439 wrapCallback(FnT &&callback) {
440 contextAwareTypeConversionsIndex = conversions.size();
441 return [callback = std::forward<FnT>(callback)](
442 PointerUnion<Type, Value> typeOrValue,
443 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
444 if (Type t = dyn_cast<Type>(typeOrValue)) {
447 }
else if (Value v = dyn_cast<Value>(typeOrValue)) {
448 return callback(v, results);
450 llvm_unreachable(
"unexpected variant");
456 void registerConversion(ConversionCallbackFn callback) {
457 conversions.emplace_back(std::move(callback));
458 cachedDirectConversions.clear();
459 cachedMultiConversions.clear();
465 template <
typename T,
typename FnT>
466 SourceMaterializationCallbackFn
467 wrapSourceMaterialization(FnT &&callback)
const {
468 return [callback = std::forward<FnT>(callback)](
469 OpBuilder &builder, Type resultType,
ValueRange inputs,
470 Location loc) -> Value {
471 if (T derivedType = dyn_cast<T>(resultType))
472 return callback(builder, derivedType, inputs, loc);
485 template <
typename T,
typename FnT>
487 std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
488 TargetMaterializationCallbackFn>
489 wrapTargetMaterialization(FnT &&callback)
const {
490 return [callback = std::forward<FnT>(callback)](
492 Location loc, Type originalType) -> SmallVector<Value> {
493 SmallVector<Value>
result;
494 if constexpr (std::is_same<T, TypeRange>::value) {
497 result = callback(builder, resultTypes, inputs, loc, originalType);
498 }
else if constexpr (std::is_assignable<Type, T>::value) {
501 if (resultTypes.size() == 1) {
504 if (T derivedType = dyn_cast<T>(resultTypes.front())) {
509 callback(builder, derivedType, inputs, loc, originalType);
515 static_assert(
sizeof(T) == 0,
"T must be a Type or a TypeRange");
523 template <
typename T,
typename FnT>
525 std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
526 TargetMaterializationCallbackFn>
527 wrapTargetMaterialization(FnT &&callback)
const {
528 return wrapTargetMaterialization<T>(
529 [callback = std::forward<FnT>(callback)](
530 OpBuilder &builder, T resultTypes,
ValueRange inputs, Location loc,
532 return callback(builder, resultTypes, inputs, loc);
540 template <
typename T,
typename A,
typename FnT>
541 TypeAttributeConversionCallbackFn
542 wrapTypeAttributeConversion(FnT &&callback)
const {
543 return [callback = std::forward<FnT>(callback)](
544 Type type, Attribute attr) -> AttributeConversionResult {
545 if (T derivedType = dyn_cast<T>(type)) {
546 if (A derivedAttr = dyn_cast_or_null<A>(attr))
547 return callback(derivedType, derivedAttr);
549 return AttributeConversionResult::na();
555 registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
556 typeAttributeConversions.emplace_back(std::move(callback));
558 cachedDirectConversions.clear();
559 cachedMultiConversions.clear();
563 LogicalResult convertTypeImpl(PointerUnion<Type, Value> t,
564 SmallVectorImpl<Type> &results)
const;
567 SmallVector<ConversionCallbackFn, 4> conversions;
570 SmallVector<SourceMaterializationCallbackFn, 2> sourceMaterializations;
571 SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
574 SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
579 mutable DenseMap<Type, Type> cachedDirectConversions;
581 mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
583 mutable llvm::sys::SmartRWMutex<true> cacheMutex;
592 int contextAwareTypeConversionsIndex = -1;
604 using OpAdaptor = ArrayRef<Value>;
605 using OneToNOpAdaptor = ArrayRef<ValueRange>;
611 virtual LogicalResult
612 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
613 ConversionPatternRewriter &rewriter)
const {
614 llvm_unreachable(
"matchAndRewrite is not implemented");
619 virtual LogicalResult
620 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
621 ConversionPatternRewriter &rewriter)
const {
622 return dispatchTo1To1(*
this, op, operands, rewriter);
626 LogicalResult matchAndRewrite(Operation *op,
627 PatternRewriter &rewriter)
const final;
631 const TypeConverter *getTypeConverter()
const {
return typeConverter; }
633 template <
typename ConverterTy>
634 std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
636 getTypeConverter()
const {
637 return static_cast<const ConverterTy *
>(typeConverter);
643 using RewritePattern::RewritePattern;
646 template <
typename... Args>
647 ConversionPattern(
const TypeConverter &typeConverter, Args &&...args)
648 : RewritePattern(std::forward<Args>(args)...),
649 typeConverter(&typeConverter) {}
656 FailureOr<SmallVector<Value>>
657 getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands)
const;
663 template <
typename SelfPattern,
typename SourceOp>
664 static LogicalResult dispatchTo1To1(
const SelfPattern &self, SourceOp op,
665 ArrayRef<ValueRange> operands,
666 ConversionPatternRewriter &rewriter);
669 template <
typename SelfPattern,
typename SourceOp>
670 static LogicalResult dispatchTo1To1(
671 const SelfPattern &self, SourceOp op,
672 typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
673 ConversionPatternRewriter &rewriter);
677 const TypeConverter *typeConverter =
nullptr;
683template <
typename SourceOp>
684class OpConversionPattern :
public ConversionPattern {
688 using Base = OpConversionPattern;
690 using OpAdaptor =
typename SourceOp::Adaptor;
691 using OneToNOpAdaptor =
692 typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
694 OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
695 : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
696 OpConversionPattern(
const TypeConverter &typeConverter, MLIRContext *context,
697 PatternBenefit benefit = 1)
698 : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
704 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
705 ConversionPatternRewriter &rewriter)
const final {
706 auto sourceOp = cast<SourceOp>(op);
707 return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
710 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
711 ConversionPatternRewriter &rewriter)
const final {
712 auto sourceOp = cast<SourceOp>(op);
713 return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
719 virtual LogicalResult
720 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
721 ConversionPatternRewriter &rewriter)
const {
722 llvm_unreachable(
"matchAndRewrite is not implemented");
724 virtual LogicalResult
725 matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
726 ConversionPatternRewriter &rewriter)
const {
727 return dispatchTo1To1(*
this, op, adaptor, rewriter);
731 using ConversionPattern::matchAndRewrite;
737template <
typename SourceOp>
738class OpInterfaceConversionPattern :
public ConversionPattern {
742 using Base = OpInterfaceConversionPattern;
744 OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
745 : ConversionPattern(Pattern::MatchInterfaceOpTypeTag(),
746 SourceOp::getInterfaceID(), benefit, context) {}
747 OpInterfaceConversionPattern(
const TypeConverter &typeConverter,
748 MLIRContext *context, PatternBenefit benefit = 1)
749 : ConversionPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(),
750 SourceOp::getInterfaceID(), benefit, context) {}
755 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
756 ConversionPatternRewriter &rewriter)
const final {
757 return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
760 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
761 ConversionPatternRewriter &rewriter)
const final {
762 return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
767 virtual LogicalResult
768 matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
769 ConversionPatternRewriter &rewriter)
const {
770 llvm_unreachable(
"matchAndRewrite is not implemented");
772 virtual LogicalResult
773 matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
774 ConversionPatternRewriter &rewriter)
const {
775 return dispatchTo1To1(*
this, op, operands, rewriter);
779 using ConversionPattern::matchAndRewrite;
785template <
template <
typename>
class TraitType>
786class OpTraitConversionPattern :
public ConversionPattern {
790 using Base = OpTraitConversionPattern;
792 OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
793 : ConversionPattern(Pattern::MatchTraitOpTypeTag(),
794 TypeID::
get<TraitType>(), benefit, context) {}
795 OpTraitConversionPattern(
const TypeConverter &typeConverter,
796 MLIRContext *context, PatternBenefit benefit = 1)
797 : ConversionPattern(typeConverter, Pattern::MatchTraitOpTypeTag(),
798 TypeID::
get<TraitType>(), benefit, context) {}
804FailureOr<Operation *>
806 const TypeConverter &converter,
807 ConversionPatternRewriter &rewriter);
812void populateFunctionOpInterfaceTypeConversionPattern(
816template <
typename FuncOpT>
817void populateFunctionOpInterfaceTypeConversionPattern(
820 populateFunctionOpInterfaceTypeConversionPattern(
821 FuncOpT::getOperationName(),
patterns, converter, benefit);
824void populateAnyFunctionOpInterfaceTypeConversionPattern(
841 ~ConversionPatternRewriter()
override;
844 const ConversionConfig &getConfig()
const;
861 applySignatureConversion(Block *block,
862 TypeConverter::SignatureConversion &conversion,
863 const TypeConverter *converter =
nullptr);
878 FailureOr<Block *> convertRegionTypes(
879 Region *region,
const TypeConverter &converter,
880 TypeConverter::SignatureConversion *entryConversion =
nullptr);
901 void replaceAllUsesWith(Value from,
ValueRange to);
902 void replaceAllUsesWith(Value from, Value to)
override {
918 void replaceUsesWithIf(Value from, Value to,
919 function_ref<
bool(OpOperand &)> functor,
920 bool *allUsesReplaced =
nullptr)
override {
921 replaceUsesWithIf(from,
ValueRange{to}, functor, allUsesReplaced);
923 void replaceUsesWithIf(Value from,
ValueRange to,
924 function_ref<
bool(OpOperand &)> functor,
925 bool *allUsesReplaced =
nullptr);
930 Value getRemappedValue(Value key);
935 LogicalResult getRemappedValues(
ValueRange keys,
936 SmallVectorImpl<Value> &results);
945 bool canRecoverFromRewriteFailure()
const override {
return true; }
952 void replaceOp(Operation *op,
ValueRange newValues)
override;
959 void replaceOp(Operation *op, Operation *newOp)
override;
963 void replaceOpWithMultiple(Operation *op,
964 SmallVector<SmallVector<Value>> &&newValues);
965 template <
typename RangeT = ValueRange>
966 void replaceOpWithMultiple(Operation *op, ArrayRef<RangeT> newValues) {
967 replaceOpWithMultiple(op,
968 llvm::to_vector_of<SmallVector<Value>>(newValues));
970 template <
typename RangeT>
971 void replaceOpWithMultiple(Operation *op, RangeT &&newValues) {
972 replaceOpWithMultiple(op,
973 ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
979 void eraseOp(Operation *op)
override;
983 void eraseBlock(Block *block)
override;
986 void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
988 using PatternRewriter::inlineBlockBefore;
994 void startOpModification(Operation *op)
override;
997 void finalizeOpModification(Operation *op)
override;
1000 void cancelOpModification(Operation *op)
override;
1003 detail::ConversionPatternRewriterImpl &getImpl();
1012 LogicalResult legalize(Operation *op);
1025 LogicalResult legalize(Region *r);
1029 friend struct OperationConverter;
1034 explicit ConversionPatternRewriter(MLIRContext *ctx,
1035 const ConversionConfig &config,
1036 OperationConverter &converter);
1039 using OpBuilder::setListener;
1041 std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
1044template <
typename SelfPattern,
typename SourceOp>
1046ConversionPattern::dispatchTo1To1(
const SelfPattern &self, SourceOp op,
1048 ConversionPatternRewriter &rewriter) {
1049 FailureOr<SmallVector<Value>> oneToOneOperands =
1050 self.getOneToOneAdaptorOperands(operands);
1051 if (
failed(oneToOneOperands))
1052 return rewriter.notifyMatchFailure(op,
1053 "pattern '" + self.getDebugName() +
1054 "' does not support 1:N conversion");
1055 return self.matchAndRewrite(op, *oneToOneOperands, rewriter);
1058template <
typename SelfPattern,
typename SourceOp>
1059LogicalResult ConversionPattern::dispatchTo1To1(
1060 const SelfPattern &self, SourceOp op,
1062 ConversionPatternRewriter &rewriter) {
1063 FailureOr<SmallVector<Value>> oneToOneOperands =
1064 self.getOneToOneAdaptorOperands(adaptor.getOperands());
1065 if (
failed(oneToOneOperands))
1066 return rewriter.notifyMatchFailure(op,
1067 "pattern '" + self.getDebugName() +
1068 "' does not support 1:N conversion");
1069 return self.matchAndRewrite(
1070 op,
typename SourceOp::Adaptor(*oneToOneOperands, adaptor), rewriter);
1078class ConversionTarget {
1082 enum class LegalizationAction {
1096 struct LegalOpDetails {
1100 bool isRecursivelyLegal =
false;
1105 using DynamicLegalityCallbackFn =
1106 std::function<std::optional<bool>(Operation *)>;
1108 ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
1109 virtual ~ConversionTarget() =
default;
1111 MLIRContext &
getContext()
const {
return ctx; }
1118 void setOpAction(OperationName op, LegalizationAction action);
1119 template <
typename OpT>
1120 void setOpAction(LegalizationAction action) {
1121 setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
1125 void addLegalOp(OperationName op) {
1126 setOpAction(op, LegalizationAction::Legal);
1128 template <
typename OpT>
1130 addLegalOp(OperationName(OpT::getOperationName(), &ctx));
1132 template <
typename OpT,
typename OpT2,
typename... OpTs>
1135 addLegalOp<OpT2, OpTs...>();
1140 void addDynamicallyLegalOp(OperationName op,
1141 const DynamicLegalityCallbackFn &callback) {
1142 setOpAction(op, LegalizationAction::Dynamic);
1143 setLegalityCallback(op, callback);
1145 template <
typename OpT>
1146 void addDynamicallyLegalOp(
const DynamicLegalityCallbackFn &callback) {
1147 addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
1150 template <
typename OpT,
typename OpT2,
typename... OpTs>
1151 void addDynamicallyLegalOp(
const DynamicLegalityCallbackFn &callback) {
1152 addDynamicallyLegalOp<OpT>(callback);
1153 addDynamicallyLegalOp<OpT2, OpTs...>(callback);
1155 template <
typename OpT,
class Callable>
1156 std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
1157 addDynamicallyLegalOp(Callable &&callback) {
1158 addDynamicallyLegalOp<OpT>(
1159 [=](Operation *op) {
return callback(cast<OpT>(op)); });
1164 void addIllegalOp(OperationName op) {
1165 setOpAction(op, LegalizationAction::Illegal);
1167 template <
typename OpT>
1168 void addIllegalOp() {
1169 addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
1171 template <
typename OpT,
typename OpT2,
typename... OpTs>
1172 void addIllegalOp() {
1173 addIllegalOp<OpT>();
1174 addIllegalOp<OpT2, OpTs...>();
1182 void markOpRecursivelyLegal(OperationName name,
1183 const DynamicLegalityCallbackFn &callback);
1184 template <
typename OpT>
1185 void markOpRecursivelyLegal(
const DynamicLegalityCallbackFn &callback = {}) {
1186 OperationName opName(OpT::getOperationName(), &ctx);
1187 markOpRecursivelyLegal(opName, callback);
1189 template <
typename OpT,
typename OpT2,
typename... OpTs>
1190 void markOpRecursivelyLegal(
const DynamicLegalityCallbackFn &callback = {}) {
1191 markOpRecursivelyLegal<OpT>(callback);
1192 markOpRecursivelyLegal<OpT2, OpTs...>(callback);
1194 template <
typename OpT,
class Callable>
1195 std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
1196 markOpRecursivelyLegal(Callable &&callback) {
1197 markOpRecursivelyLegal<OpT>(
1198 [=](Operation *op) {
return callback(cast<OpT>(op)); });
1202 void setDialectAction(ArrayRef<StringRef> dialectNames,
1203 LegalizationAction action);
1206 template <
typename... Names>
1207 void addLegalDialect(StringRef name, Names... names) {
1208 SmallVector<StringRef, 2> dialectNames({name, names...});
1209 setDialectAction(dialectNames, LegalizationAction::Legal);
1211 template <
typename... Args>
1212 void addLegalDialect() {
1213 SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1214 setDialectAction(dialectNames, LegalizationAction::Legal);
1219 template <
typename... Names>
1220 void addDynamicallyLegalDialect(
const DynamicLegalityCallbackFn &callback,
1221 StringRef name, Names... names) {
1222 SmallVector<StringRef, 2> dialectNames({name, names...});
1223 setDialectAction(dialectNames, LegalizationAction::Dynamic);
1224 setLegalityCallback(dialectNames, callback);
1226 template <
typename... Args>
1227 void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback) {
1228 addDynamicallyLegalDialect(std::move(callback),
1229 Args::getDialectNamespace()...);
1235 void markUnknownOpDynamicallyLegal(
const DynamicLegalityCallbackFn &fn) {
1236 setLegalityCallback(fn);
1241 template <
typename... Names>
1242 void addIllegalDialect(StringRef name, Names... names) {
1243 SmallVector<StringRef, 2> dialectNames({name, names...});
1244 setDialectAction(dialectNames, LegalizationAction::Illegal);
1246 template <
typename... Args>
1247 void addIllegalDialect() {
1248 SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1249 setDialectAction(dialectNames, LegalizationAction::Illegal);
1257 std::optional<LegalizationAction> getOpAction(OperationName op)
const;
1268 std::optional<LegalOpDetails> isLegal(Operation *op)
const;
1273 bool isIllegal(Operation *op)
const;
1277 void setLegalityCallback(OperationName name,
1278 const DynamicLegalityCallbackFn &callback);
1281 void setLegalityCallback(ArrayRef<StringRef> dialects,
1282 const DynamicLegalityCallbackFn &callback);
1285 void setLegalityCallback(
const DynamicLegalityCallbackFn &callback);
1288 struct LegalizationInfo {
1290 LegalizationAction action = LegalizationAction::Illegal;
1293 bool isRecursivelyLegal =
false;
1296 DynamicLegalityCallbackFn legalityFn;
1300 std::optional<LegalizationInfo> getOpInfo(OperationName op)
const;
1304 llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
1308 DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
1312 llvm::StringMap<LegalizationAction> legalDialects;
1315 llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
1318 DynamicLegalityCallbackFn unknownLegalityFn;
1324#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
1331class PDLConversionConfig final
1332 :
public PDLPatternConfigBase<PDLConversionConfig> {
1334 PDLConversionConfig(
const TypeConverter *converter) : converter(converter) {}
1335 ~PDLConversionConfig() final = default;
1339 const TypeConverter *getTypeConverter()
const {
return converter; }
1343 void notifyRewriteBegin(PatternRewriter &rewriter)
final;
1344 void notifyRewriteEnd(PatternRewriter &rewriter)
final;
1348 const TypeConverter *converter;
1360class PDLConversionConfig final {
1362 PDLConversionConfig(
const TypeConverter * ) {}
1372enum class DialectConversionFoldingMode {
1382struct ConversionConfig {
1386 function_ref<void(Diagnostic &)> notifyCallback =
nullptr;
1392 DenseSet<Operation *> *unlegalizedOps =
nullptr;
1398 DenseSet<Operation *> *legalizableOps =
nullptr;
1432 RewriterBase::Listener *listener =
nullptr;
1442 bool buildMaterializations =
true;
1463 bool allowPatternRollback =
true;
1466 DialectConversionFoldingMode foldingMode =
1467 DialectConversionFoldingMode::BeforePatterns;
1473 bool attachDebugMaterializationKind =
false;
1518 const ConversionTarget &
target,
1520 ConversionConfig
config = ConversionConfig());
1524 ConversionConfig
config = ConversionConfig());
1531 const ConversionTarget &
target,
1533 ConversionConfig
config = ConversionConfig());
1534LogicalResult applyFullConversion(
Operation *op,
const ConversionTarget &
target,
1536 ConversionConfig
config = ConversionConfig());
1548 ConversionConfig
config = ConversionConfig());
1552 ConversionConfig
config = ConversionConfig());
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
static void reconcileUnrealizedCasts(const DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref