13#ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14#define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
16#include "mlir/Config/mlir-config.h"
18#include "llvm/ADT/MapVector.h"
19#include "llvm/ADT/StringMap.h"
27struct ConversionConfig;
28class ConversionPatternRewriter;
45 using Base = TypeConverter;
47 virtual ~TypeConverter() =
default;
48 TypeConverter() =
default;
50 TypeConverter(
const TypeConverter &other)
51 : conversions(other.conversions),
52 sourceMaterializations(other.sourceMaterializations),
53 targetMaterializations(other.targetMaterializations),
54 typeAttributeConversions(other.typeAttributeConversions) {}
55 TypeConverter &operator=(
const TypeConverter &other) {
56 conversions = other.conversions;
57 sourceMaterializations = other.sourceMaterializations;
58 targetMaterializations = other.targetMaterializations;
59 typeAttributeConversions = other.typeAttributeConversions;
65 class SignatureConversion {
67 SignatureConversion(
unsigned numOrigInputs)
68 : remappedInputs(numOrigInputs) {}
74 SmallVector<Value, 1> replacementValues;
77 bool replacedWithValues()
const {
return !replacementValues.empty(); }
81 ArrayRef<Type> getConvertedTypes()
const {
return argTypes; }
84 std::optional<InputMapping> getInputMapping(
unsigned input)
const {
85 return remappedInputs[input];
94 void addInputs(
unsigned origInputNo, ArrayRef<Type> types);
98 void addInputs(ArrayRef<Type> types);
102 void remapInput(
unsigned origInputNo, ArrayRef<Value> replacements);
107 void remapInput(
unsigned origInputNo,
unsigned newInputNo,
108 unsigned newInputCount = 1);
111 SmallVector<std::optional<InputMapping>, 4> remappedInputs;
114 SmallVector<Type, 4> argTypes;
119 class AttributeConversionResult {
121 constexpr AttributeConversionResult() : impl() {}
122 AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
124 static AttributeConversionResult
result(Attribute attr);
125 static AttributeConversionResult na();
126 static AttributeConversionResult abort();
128 bool hasResult()
const;
130 bool isAbort()
const;
132 Attribute getResult()
const;
135 AttributeConversionResult(Attribute attr,
unsigned tag) : impl(attr, tag) {}
137 llvm::PointerIntPair<Attribute, 2> impl;
140 static constexpr unsigned naTag = 0;
141 static constexpr unsigned resultTag = 1;
142 static constexpr unsigned abortTag = 2;
172 template <
typename FnT,
typename T =
typename llvm::function_traits<
173 std::decay_t<FnT>>::template arg_t<0>>
174 void addConversion(FnT &&callback) {
175 registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
198 template <
typename FnT,
typename T =
typename llvm::function_traits<
199 std::decay_t<FnT>>::template arg_t<1>>
200 void addSourceMaterialization(FnT &&callback) {
201 sourceMaterializations.emplace_back(
202 wrapSourceMaterialization<T>(std::forward<FnT>(callback)));
222 template <
typename FnT,
typename T =
typename llvm::function_traits<
223 std::decay_t<FnT>>::template arg_t<1>>
224 void addTargetMaterialization(FnT &&callback) {
225 targetMaterializations.emplace_back(
226 wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
250 typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
252 typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
253 void addTypeAttributeConversion(FnT &&callback) {
254 registerTypeAttributeConversion(
255 wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
265 LogicalResult convertType(Type t, SmallVectorImpl<Type> &results)
const;
274 LogicalResult convertType(Value v, SmallVectorImpl<Type> &results)
const;
278 Type convertType(Type t)
const;
279 Type convertType(Value v)
const;
284 template <
typename TargetType>
285 TargetType convertType(Type t)
const {
286 return dyn_cast_or_null<TargetType>(convertType(t));
288 template <
typename TargetType>
289 TargetType convertType(Value v)
const {
290 return dyn_cast_or_null<TargetType>(convertType(v));
296 LogicalResult convertTypes(
TypeRange types,
297 SmallVectorImpl<Type> &results)
const;
303 SmallVectorImpl<Type> &results)
const;
307 bool isLegal(Type type)
const;
308 bool isLegal(Value value)
const;
312 return llvm::all_of(range, [
this](Type type) {
return isLegal(type); });
315 return llvm::all_of(range, [
this](Value value) {
return isLegal(value); });
319 bool isLegal(Operation *op)
const;
322 bool isLegal(Region *region)
const;
326 bool isSignatureLegal(FunctionType ty)
const;
331 LogicalResult convertSignatureArg(
unsigned inputNo, Type type,
332 SignatureConversion &
result)
const;
333 LogicalResult convertSignatureArgs(
TypeRange types,
334 SignatureConversion &
result,
335 unsigned origInputOffset = 0)
const;
336 LogicalResult convertSignatureArg(
unsigned inputNo, Value value,
337 SignatureConversion &
result)
const;
338 LogicalResult convertSignatureArgs(
ValueRange values,
339 SignatureConversion &
result,
340 unsigned origInputOffset = 0)
const;
345 std::optional<SignatureConversion> convertBlockSignature(Block *block)
const;
351 Value materializeSourceConversion(OpBuilder &builder, Location loc,
353 Value materializeTargetConversion(OpBuilder &builder, Location loc,
355 Type originalType = {})
const;
356 SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
360 Type originalType = {})
const;
366 std::optional<Attribute> convertTypeAttribute(Type type,
367 Attribute attr)
const;
373 using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
374 PointerUnion<Type, Value>, SmallVectorImpl<Type> &)>;
379 using SourceMaterializationCallbackFn =
380 std::function<Value(OpBuilder &, Type,
ValueRange, Location)>;
385 using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
389 using TypeAttributeConversionCallbackFn =
390 std::function<AttributeConversionResult(Type, Attribute)>;
396 template <
typename T,
typename FnT>
397 std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
398 wrapCallback(FnT &&callback) {
399 return wrapCallback<T>([callback = std::forward<FnT>(callback)](
400 T typeOrValue, SmallVectorImpl<Type> &results) {
401 if (std::optional<Type> resultOpt = callback(typeOrValue)) {
402 bool wasSuccess =
static_cast<bool>(*resultOpt);
404 results.push_back(*resultOpt);
405 return std::optional<LogicalResult>(
success(wasSuccess));
407 return std::optional<LogicalResult>();
412 template <
typename T,
typename FnT>
413 std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
414 std::is_base_of_v<Type, T>,
415 ConversionCallbackFn>
416 wrapCallback(FnT &&callback)
const {
417 return [callback = std::forward<FnT>(callback)](
418 PointerUnion<Type, Value> typeOrValue,
419 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
421 if (Type t = dyn_cast<Type>(typeOrValue)) {
422 derivedType = dyn_cast<T>(t);
423 }
else if (Value v = dyn_cast<Value>(typeOrValue)) {
424 derivedType = dyn_cast<T>(v.getType());
426 llvm_unreachable(
"unexpected variant");
430 return callback(derivedType, results);
435 template <
typename T,
typename FnT>
436 std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
437 std::is_same_v<T, Value>,
438 ConversionCallbackFn>
439 wrapCallback(FnT &&callback) {
440 contextAwareTypeConversionsIndex = conversions.size();
441 return [callback = std::forward<FnT>(callback)](
442 PointerUnion<Type, Value> typeOrValue,
443 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
444 if (Type t = dyn_cast<Type>(typeOrValue)) {
447 }
else if (Value v = dyn_cast<Value>(typeOrValue)) {
448 return callback(v, results);
450 llvm_unreachable(
"unexpected variant");
456 void registerConversion(ConversionCallbackFn callback) {
457 conversions.emplace_back(std::move(callback));
458 cachedDirectConversions.clear();
459 cachedMultiConversions.clear();
465 template <
typename T,
typename FnT>
466 SourceMaterializationCallbackFn
467 wrapSourceMaterialization(FnT &&callback)
const {
468 return [callback = std::forward<FnT>(callback)](
469 OpBuilder &builder, Type resultType,
ValueRange inputs,
470 Location loc) -> Value {
471 if (T derivedType = dyn_cast<T>(resultType))
472 return callback(builder, derivedType, inputs, loc);
485 template <
typename T,
typename FnT>
487 std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
488 TargetMaterializationCallbackFn>
489 wrapTargetMaterialization(FnT &&callback)
const {
490 return [callback = std::forward<FnT>(callback)](
492 Location loc, Type originalType) -> SmallVector<Value> {
493 SmallVector<Value>
result;
494 if constexpr (std::is_same<T, TypeRange>::value) {
497 result = callback(builder, resultTypes, inputs, loc, originalType);
498 }
else if constexpr (std::is_assignable<Type, T>::value) {
501 if (resultTypes.size() == 1) {
504 if (T derivedType = dyn_cast<T>(resultTypes.front())) {
509 callback(builder, derivedType, inputs, loc, originalType);
515 static_assert(
sizeof(T) == 0,
"T must be a Type or a TypeRange");
523 template <
typename T,
typename FnT>
525 std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
526 TargetMaterializationCallbackFn>
527 wrapTargetMaterialization(FnT &&callback)
const {
528 return wrapTargetMaterialization<T>(
529 [callback = std::forward<FnT>(callback)](
530 OpBuilder &builder, T resultTypes,
ValueRange inputs, Location loc,
532 return callback(builder, resultTypes, inputs, loc);
540 template <
typename T,
typename A,
typename FnT>
541 TypeAttributeConversionCallbackFn
542 wrapTypeAttributeConversion(FnT &&callback)
const {
543 return [callback = std::forward<FnT>(callback)](
544 Type type, Attribute attr) -> AttributeConversionResult {
545 if (T derivedType = dyn_cast<T>(type)) {
546 if (A derivedAttr = dyn_cast_or_null<A>(attr))
547 return callback(derivedType, derivedAttr);
549 return AttributeConversionResult::na();
555 registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
556 typeAttributeConversions.emplace_back(std::move(callback));
558 cachedDirectConversions.clear();
559 cachedMultiConversions.clear();
563 LogicalResult convertTypeImpl(PointerUnion<Type, Value> t,
564 SmallVectorImpl<Type> &results)
const;
567 SmallVector<ConversionCallbackFn, 4> conversions;
570 SmallVector<SourceMaterializationCallbackFn, 2> sourceMaterializations;
571 SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
574 SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
579 mutable DenseMap<Type, Type> cachedDirectConversions;
581 mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
583 mutable llvm::sys::SmartRWMutex<true> cacheMutex;
592 int contextAwareTypeConversionsIndex = -1;
604 using OpAdaptor = ArrayRef<Value>;
605 using OneToNOpAdaptor = ArrayRef<ValueRange>;
611 virtual LogicalResult
612 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
613 ConversionPatternRewriter &rewriter)
const {
614 llvm_unreachable(
"matchAndRewrite is not implemented");
619 virtual LogicalResult
620 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
621 ConversionPatternRewriter &rewriter)
const {
622 return dispatchTo1To1(*
this, op, operands, rewriter);
626 LogicalResult matchAndRewrite(Operation *op,
627 PatternRewriter &rewriter)
const final;
631 const TypeConverter *getTypeConverter()
const {
return typeConverter; }
633 template <
typename ConverterTy>
634 std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
636 getTypeConverter()
const {
637 return static_cast<const ConverterTy *
>(typeConverter);
643 using RewritePattern::RewritePattern;
646 template <
typename... Args>
647 ConversionPattern(
const TypeConverter &typeConverter, Args &&...args)
648 : RewritePattern(std::forward<Args>(args)...),
649 typeConverter(&typeConverter) {}
656 FailureOr<SmallVector<Value>>
657 getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands)
const;
663 template <
typename SelfPattern,
typename SourceOp>
664 static LogicalResult dispatchTo1To1(
const SelfPattern &self, SourceOp op,
665 ArrayRef<ValueRange> operands,
666 ConversionPatternRewriter &rewriter);
669 template <
typename SelfPattern,
typename SourceOp>
670 static LogicalResult dispatchTo1To1(
671 const SelfPattern &self, SourceOp op,
672 typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
673 ConversionPatternRewriter &rewriter);
677 const TypeConverter *typeConverter =
nullptr;
683template <
typename SourceOp>
684class OpConversionPattern :
public ConversionPattern {
688 using Base = OpConversionPattern;
690 using OpAdaptor =
typename SourceOp::Adaptor;
691 using OneToNOpAdaptor =
692 typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
694 OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
695 : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
696 OpConversionPattern(
const TypeConverter &typeConverter, MLIRContext *context,
697 PatternBenefit benefit = 1)
698 : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
704 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
705 ConversionPatternRewriter &rewriter)
const final {
706 auto sourceOp = cast<SourceOp>(op);
707 return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
710 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
711 ConversionPatternRewriter &rewriter)
const final {
712 auto sourceOp = cast<SourceOp>(op);
713 return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
719 virtual LogicalResult
720 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
721 ConversionPatternRewriter &rewriter)
const {
722 llvm_unreachable(
"matchAndRewrite is not implemented");
724 virtual LogicalResult
725 matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
726 ConversionPatternRewriter &rewriter)
const {
727 return dispatchTo1To1(*
this, op, adaptor, rewriter);
731 using ConversionPattern::matchAndRewrite;
737template <
typename SourceOp>
738class OpInterfaceConversionPattern :
public ConversionPattern {
742 using Base = OpInterfaceConversionPattern;
744 OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
745 : ConversionPattern(Pattern::MatchInterfaceOpTypeTag(),
746 SourceOp::getInterfaceID(), benefit, context) {}
747 OpInterfaceConversionPattern(
const TypeConverter &typeConverter,
748 MLIRContext *context, PatternBenefit benefit = 1)
749 : ConversionPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(),
750 SourceOp::getInterfaceID(), benefit, context) {}
755 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
756 ConversionPatternRewriter &rewriter)
const final {
757 return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
760 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
761 ConversionPatternRewriter &rewriter)
const final {
762 return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
767 virtual LogicalResult
768 matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
769 ConversionPatternRewriter &rewriter)
const {
770 llvm_unreachable(
"matchAndRewrite is not implemented");
772 virtual LogicalResult
773 matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
774 ConversionPatternRewriter &rewriter)
const {
775 return dispatchTo1To1(*
this, op, operands, rewriter);
779 using ConversionPattern::matchAndRewrite;
785template <
template <
typename>
class TraitType>
786class OpTraitConversionPattern :
public ConversionPattern {
790 using Base = OpTraitConversionPattern;
792 OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
793 : ConversionPattern(Pattern::MatchTraitOpTypeTag(),
794 TypeID::
get<TraitType>(), benefit, context) {}
795 OpTraitConversionPattern(
const TypeConverter &typeConverter,
796 MLIRContext *context, PatternBenefit benefit = 1)
797 : ConversionPattern(typeConverter, Pattern::MatchTraitOpTypeTag(),
798 TypeID::
get<TraitType>(), benefit, context) {}
804FailureOr<Operation *>
806 const TypeConverter &converter,
807 ConversionPatternRewriter &rewriter);
812void populateFunctionOpInterfaceTypeConversionPattern(
816template <
typename FuncOpT>
817void populateFunctionOpInterfaceTypeConversionPattern(
820 populateFunctionOpInterfaceTypeConversionPattern(
821 FuncOpT::getOperationName(),
patterns, converter, benefit);
824void populateAnyFunctionOpInterfaceTypeConversionPattern(
841 ~ConversionPatternRewriter()
override;
844 const ConversionConfig &getConfig()
const;
861 applySignatureConversion(Block *block,
862 TypeConverter::SignatureConversion &conversion,
863 const TypeConverter *converter =
nullptr);
878 FailureOr<Block *> convertRegionTypes(
879 Region *region,
const TypeConverter &converter,
880 TypeConverter::SignatureConversion *entryConversion =
nullptr);
901 void replaceAllUsesWith(Value from,
ValueRange to);
902 void replaceAllUsesWith(Value from, Value to)
override {
918 void replaceUsesWithIf(Value from, Value to,
919 function_ref<
bool(OpOperand &)> functor,
920 bool *allUsesReplaced =
nullptr)
override {
921 replaceUsesWithIf(from,
ValueRange{to}, functor, allUsesReplaced);
923 void replaceUsesWithIf(Value from,
ValueRange to,
924 function_ref<
bool(OpOperand &)> functor,
925 bool *allUsesReplaced =
nullptr);
930 Value getRemappedValue(Value key);
935 LogicalResult getRemappedValues(
ValueRange keys,
936 SmallVectorImpl<Value> &results);
945 bool canRecoverFromRewriteFailure()
const override {
return true; }
952 void replaceOp(Operation *op,
ValueRange newValues)
override;
959 void replaceOp(Operation *op, Operation *newOp)
override;
963 void replaceOpWithMultiple(Operation *op,
964 SmallVector<SmallVector<Value>> &&newValues);
965 template <
typename RangeT = ValueRange>
966 void replaceOpWithMultiple(Operation *op, ArrayRef<RangeT> newValues) {
967 replaceOpWithMultiple(op,
968 llvm::to_vector_of<SmallVector<Value>>(newValues));
970 template <
typename RangeT>
971 void replaceOpWithMultiple(Operation *op, RangeT &&newValues) {
972 replaceOpWithMultiple(op,
973 ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
979 void eraseOp(Operation *op)
override;
983 void eraseBlock(Block *block)
override;
986 void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
988 using PatternRewriter::inlineBlockBefore;
994 void startOpModification(Operation *op)
override;
997 void finalizeOpModification(Operation *op)
override;
1000 void cancelOpModification(Operation *op)
override;
1003 detail::ConversionPatternRewriterImpl &getImpl();
1012 LogicalResult legalize(Operation *op);
1025 LogicalResult legalize(Region *r);
1029 friend struct OperationConverter;
1034 explicit ConversionPatternRewriter(MLIRContext *ctx,
1035 const ConversionConfig &config,
1036 OperationConverter &converter);
1039 using OpBuilder::setListener;
1041 std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
1044template <
typename SelfPattern,
typename SourceOp>
1046ConversionPattern::dispatchTo1To1(
const SelfPattern &self, SourceOp op,
1048 ConversionPatternRewriter &rewriter) {
1049 FailureOr<SmallVector<Value>> oneToOneOperands =
1050 self.getOneToOneAdaptorOperands(operands);
1051 if (
failed(oneToOneOperands))
1052 return rewriter.notifyMatchFailure(op,
1053 "pattern '" + self.getDebugName() +
1054 "' does not support 1:N conversion");
1055 return self.matchAndRewrite(op, *oneToOneOperands, rewriter);
1058template <
typename SelfPattern,
typename SourceOp>
1059LogicalResult ConversionPattern::dispatchTo1To1(
1060 const SelfPattern &self, SourceOp op,
1062 ConversionPatternRewriter &rewriter) {
1063 FailureOr<SmallVector<Value>> oneToOneOperands =
1064 self.getOneToOneAdaptorOperands(adaptor.getOperands());
1065 if (
failed(oneToOneOperands))
1066 return rewriter.notifyMatchFailure(op,
1067 "pattern '" + self.getDebugName() +
1068 "' does not support 1:N conversion");
1069 return self.matchAndRewrite(
1070 op,
typename SourceOp::Adaptor(*oneToOneOperands, adaptor), rewriter);
1078class ConversionTarget {
1082 enum class LegalizationAction {
1096 struct LegalOpDetails {
1100 bool isRecursivelyLegal =
false;
1105 using DynamicLegalityCallbackFn =
1106 std::function<std::optional<bool>(Operation *)>;
1108 ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
1109 virtual ~ConversionTarget() =
default;
1116 void setOpAction(OperationName op, LegalizationAction action);
1117 template <
typename OpT>
1118 void setOpAction(LegalizationAction action) {
1119 setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
1123 void addLegalOp(OperationName op) {
1124 setOpAction(op, LegalizationAction::Legal);
1126 template <
typename OpT>
1128 addLegalOp(OperationName(OpT::getOperationName(), &ctx));
1130 template <
typename OpT,
typename OpT2,
typename... OpTs>
1133 addLegalOp<OpT2, OpTs...>();
1138 void addDynamicallyLegalOp(OperationName op,
1139 const DynamicLegalityCallbackFn &callback) {
1140 setOpAction(op, LegalizationAction::Dynamic);
1141 setLegalityCallback(op, callback);
1143 template <
typename OpT>
1144 void addDynamicallyLegalOp(
const DynamicLegalityCallbackFn &callback) {
1145 addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
1148 template <
typename OpT,
typename OpT2,
typename... OpTs>
1149 void addDynamicallyLegalOp(
const DynamicLegalityCallbackFn &callback) {
1150 addDynamicallyLegalOp<OpT>(callback);
1151 addDynamicallyLegalOp<OpT2, OpTs...>(callback);
1153 template <
typename OpT,
class Callable>
1154 std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
1155 addDynamicallyLegalOp(Callable &&callback) {
1156 addDynamicallyLegalOp<OpT>(
1157 [=](Operation *op) {
return callback(cast<OpT>(op)); });
1162 void addIllegalOp(OperationName op) {
1163 setOpAction(op, LegalizationAction::Illegal);
1165 template <
typename OpT>
1166 void addIllegalOp() {
1167 addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
1169 template <
typename OpT,
typename OpT2,
typename... OpTs>
1170 void addIllegalOp() {
1171 addIllegalOp<OpT>();
1172 addIllegalOp<OpT2, OpTs...>();
1180 void markOpRecursivelyLegal(OperationName name,
1181 const DynamicLegalityCallbackFn &callback);
1182 template <
typename OpT>
1183 void markOpRecursivelyLegal(
const DynamicLegalityCallbackFn &callback = {}) {
1184 OperationName opName(OpT::getOperationName(), &ctx);
1185 markOpRecursivelyLegal(opName, callback);
1187 template <
typename OpT,
typename OpT2,
typename... OpTs>
1188 void markOpRecursivelyLegal(
const DynamicLegalityCallbackFn &callback = {}) {
1189 markOpRecursivelyLegal<OpT>(callback);
1190 markOpRecursivelyLegal<OpT2, OpTs...>(callback);
1192 template <
typename OpT,
class Callable>
1193 std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
1194 markOpRecursivelyLegal(Callable &&callback) {
1195 markOpRecursivelyLegal<OpT>(
1196 [=](Operation *op) {
return callback(cast<OpT>(op)); });
1200 void setDialectAction(ArrayRef<StringRef> dialectNames,
1201 LegalizationAction action);
1204 template <
typename... Names>
1205 void addLegalDialect(StringRef name, Names... names) {
1206 SmallVector<StringRef, 2> dialectNames({name, names...});
1207 setDialectAction(dialectNames, LegalizationAction::Legal);
1209 template <
typename... Args>
1210 void addLegalDialect() {
1211 SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1212 setDialectAction(dialectNames, LegalizationAction::Legal);
1217 template <
typename... Names>
1218 void addDynamicallyLegalDialect(
const DynamicLegalityCallbackFn &callback,
1219 StringRef name, Names... names) {
1220 SmallVector<StringRef, 2> dialectNames({name, names...});
1221 setDialectAction(dialectNames, LegalizationAction::Dynamic);
1222 setLegalityCallback(dialectNames, callback);
1224 template <
typename... Args>
1225 void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback) {
1226 addDynamicallyLegalDialect(std::move(callback),
1227 Args::getDialectNamespace()...);
1233 void markUnknownOpDynamicallyLegal(
const DynamicLegalityCallbackFn &fn) {
1234 setLegalityCallback(fn);
1239 template <
typename... Names>
1240 void addIllegalDialect(StringRef name, Names... names) {
1241 SmallVector<StringRef, 2> dialectNames({name, names...});
1242 setDialectAction(dialectNames, LegalizationAction::Illegal);
1244 template <
typename... Args>
1245 void addIllegalDialect() {
1246 SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1247 setDialectAction(dialectNames, LegalizationAction::Illegal);
1255 std::optional<LegalizationAction> getOpAction(OperationName op)
const;
1266 std::optional<LegalOpDetails> isLegal(Operation *op)
const;
1271 bool isIllegal(Operation *op)
const;
1275 void setLegalityCallback(OperationName name,
1276 const DynamicLegalityCallbackFn &callback);
1279 void setLegalityCallback(ArrayRef<StringRef> dialects,
1280 const DynamicLegalityCallbackFn &callback);
1283 void setLegalityCallback(
const DynamicLegalityCallbackFn &callback);
1286 struct LegalizationInfo {
1288 LegalizationAction action = LegalizationAction::Illegal;
1291 bool isRecursivelyLegal =
false;
1294 DynamicLegalityCallbackFn legalityFn;
1298 std::optional<LegalizationInfo> getOpInfo(OperationName op)
const;
1302 llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
1306 DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
1310 llvm::StringMap<LegalizationAction> legalDialects;
1313 llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
1316 DynamicLegalityCallbackFn unknownLegalityFn;
1322#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
1329class PDLConversionConfig final
1330 :
public PDLPatternConfigBase<PDLConversionConfig> {
1332 PDLConversionConfig(
const TypeConverter *converter) : converter(converter) {}
1333 ~PDLConversionConfig() final = default;
1337 const TypeConverter *getTypeConverter()
const {
return converter; }
1341 void notifyRewriteBegin(PatternRewriter &rewriter)
final;
1342 void notifyRewriteEnd(PatternRewriter &rewriter)
final;
1346 const TypeConverter *converter;
1358class PDLConversionConfig final {
1360 PDLConversionConfig(
const TypeConverter * ) {}
1370enum class DialectConversionFoldingMode {
1380struct ConversionConfig {
1384 function_ref<void(Diagnostic &)> notifyCallback =
nullptr;
1390 DenseSet<Operation *> *unlegalizedOps =
nullptr;
1396 DenseSet<Operation *> *legalizableOps =
nullptr;
1430 RewriterBase::Listener *listener =
nullptr;
1440 bool buildMaterializations =
true;
1461 bool allowPatternRollback =
true;
1464 DialectConversionFoldingMode foldingMode =
1465 DialectConversionFoldingMode::BeforePatterns;
1471 bool attachDebugMaterializationKind =
false;
1516 const ConversionTarget &
target,
1518 ConversionConfig
config = ConversionConfig());
1522 ConversionConfig
config = ConversionConfig());
1529 const ConversionTarget &
target,
1531 ConversionConfig
config = ConversionConfig());
1532LogicalResult applyFullConversion(
Operation *op,
const ConversionTarget &
target,
1534 ConversionConfig
config = ConversionConfig());
1546 ConversionConfig
config = ConversionConfig());
1550 ConversionConfig
config = ConversionConfig());
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
static void reconcileUnrealizedCasts(const DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref