13#ifndef MLIR_TRANSFORMS_DIALECTCONVERSION_H_
14#define MLIR_TRANSFORMS_DIALECTCONVERSION_H_
16#include "mlir/Config/mlir-config.h"
18#include "llvm/ADT/MapVector.h"
19#include "llvm/ADT/StringMap.h"
27struct ConversionConfig;
28class ConversionPatternRewriter;
45 using Base = TypeConverter;
47 virtual ~TypeConverter() =
default;
48 TypeConverter() =
default;
50 TypeConverter(
const TypeConverter &other)
51 : conversions(other.conversions),
52 sourceMaterializations(other.sourceMaterializations),
53 targetMaterializations(other.targetMaterializations),
54 typeAttributeConversions(other.typeAttributeConversions) {}
55 TypeConverter &operator=(
const TypeConverter &other) {
56 conversions = other.conversions;
57 sourceMaterializations = other.sourceMaterializations;
58 targetMaterializations = other.targetMaterializations;
59 typeAttributeConversions = other.typeAttributeConversions;
65 class SignatureConversion {
67 SignatureConversion(
unsigned numOrigInputs)
68 : remappedInputs(numOrigInputs) {}
74 SmallVector<Value, 1> replacementValues;
77 bool replacedWithValues()
const {
return !replacementValues.empty(); }
81 ArrayRef<Type> getConvertedTypes()
const {
return argTypes; }
84 std::optional<InputMapping> getInputMapping(
unsigned input)
const {
85 return remappedInputs[input];
94 void addInputs(
unsigned origInputNo, ArrayRef<Type> types);
98 void addInputs(ArrayRef<Type> types);
102 void remapInput(
unsigned origInputNo, ArrayRef<Value> replacements);
107 void remapInput(
unsigned origInputNo,
unsigned newInputNo,
108 unsigned newInputCount = 1);
111 SmallVector<std::optional<InputMapping>, 4> remappedInputs;
114 SmallVector<Type, 4> argTypes;
119 class AttributeConversionResult {
121 constexpr AttributeConversionResult() : impl() {}
122 AttributeConversionResult(Attribute attr) : impl(attr, resultTag) {}
124 static AttributeConversionResult
result(Attribute attr);
125 static AttributeConversionResult na();
126 static AttributeConversionResult abort();
128 bool hasResult()
const;
130 bool isAbort()
const;
132 Attribute getResult()
const;
135 AttributeConversionResult(Attribute attr,
unsigned tag) : impl(attr, tag) {}
137 llvm::PointerIntPair<Attribute, 2> impl;
140 static constexpr unsigned naTag = 0;
141 static constexpr unsigned resultTag = 1;
142 static constexpr unsigned abortTag = 2;
172 template <
typename FnT,
typename T =
typename llvm::function_traits<
173 std::decay_t<FnT>>::template arg_t<0>>
174 void addConversion(FnT &&callback) {
175 registerConversion(wrapCallback<T>(std::forward<FnT>(callback)));
198 template <
typename FnT,
typename T =
typename llvm::function_traits<
199 std::decay_t<FnT>>::template arg_t<1>>
200 void addSourceMaterialization(FnT &&callback) {
201 sourceMaterializations.emplace_back(
202 wrapSourceMaterialization<T>(std::forward<FnT>(callback)));
222 template <
typename FnT,
typename T =
typename llvm::function_traits<
223 std::decay_t<FnT>>::template arg_t<1>>
224 void addTargetMaterialization(FnT &&callback) {
225 targetMaterializations.emplace_back(
226 wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
250 typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<0>,
252 typename llvm::function_traits<std::decay_t<FnT>>::template arg_t<1>>
253 void addTypeAttributeConversion(FnT &&callback) {
254 registerTypeAttributeConversion(
255 wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
265 LogicalResult convertType(Type t, SmallVectorImpl<Type> &results)
const;
274 LogicalResult convertType(Value v, SmallVectorImpl<Type> &results)
const;
278 Type convertType(Type t)
const;
279 Type convertType(Value v)
const;
284 template <
typename TargetType>
285 TargetType convertType(Type t)
const {
286 return dyn_cast_or_null<TargetType>(convertType(t));
288 template <
typename TargetType>
289 TargetType convertType(Value v)
const {
290 return dyn_cast_or_null<TargetType>(convertType(v));
296 LogicalResult convertTypes(
TypeRange types,
297 SmallVectorImpl<Type> &results)
const;
303 SmallVectorImpl<Type> &results)
const;
307 bool isLegal(Type type)
const;
308 bool isLegal(Value value)
const;
312 return llvm::all_of(range, [
this](Type type) {
return isLegal(type); });
315 return llvm::all_of(range, [
this](Value value) {
return isLegal(value); });
319 bool isLegal(Operation *op)
const;
322 bool isLegal(Region *region)
const;
326 bool isSignatureLegal(FunctionType ty)
const;
331 LogicalResult convertSignatureArg(
unsigned inputNo, Type type,
332 SignatureConversion &
result)
const;
333 LogicalResult convertSignatureArgs(
TypeRange types,
334 SignatureConversion &
result,
335 unsigned origInputOffset = 0)
const;
336 LogicalResult convertSignatureArg(
unsigned inputNo, Value value,
337 SignatureConversion &
result)
const;
338 LogicalResult convertSignatureArgs(
ValueRange values,
339 SignatureConversion &
result,
340 unsigned origInputOffset = 0)
const;
345 std::optional<SignatureConversion> convertBlockSignature(Block *block)
const;
351 Value materializeSourceConversion(OpBuilder &builder, Location loc,
353 Value materializeTargetConversion(OpBuilder &builder, Location loc,
355 Type originalType = {})
const;
356 SmallVector<Value> materializeTargetConversion(OpBuilder &builder,
360 Type originalType = {})
const;
366 std::optional<Attribute> convertTypeAttribute(Type type,
367 Attribute attr)
const;
373 using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
374 PointerUnion<Type, Value>, SmallVectorImpl<Type> &)>;
379 using SourceMaterializationCallbackFn =
380 std::function<Value(OpBuilder &, Type,
ValueRange, Location)>;
385 using TargetMaterializationCallbackFn = std::function<SmallVector<Value>(
389 using TypeAttributeConversionCallbackFn =
390 std::function<AttributeConversionResult(Type, Attribute)>;
396 template <
typename T,
typename FnT>
397 std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
398 wrapCallback(FnT &&callback) {
399 return wrapCallback<T>([callback = std::forward<FnT>(callback)](
400 T typeOrValue, SmallVectorImpl<Type> &results) {
401 if (std::optional<Type> resultOpt = callback(typeOrValue)) {
402 bool wasSuccess =
static_cast<bool>(*resultOpt);
404 results.push_back(*resultOpt);
405 return std::optional<LogicalResult>(
success(wasSuccess));
407 return std::optional<LogicalResult>();
412 template <
typename T,
typename FnT>
413 std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
414 std::is_base_of_v<Type, T>,
415 ConversionCallbackFn>
416 wrapCallback(FnT &&callback)
const {
417 return [callback = std::forward<FnT>(callback)](
418 PointerUnion<Type, Value> typeOrValue,
419 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
421 if (Type t = dyn_cast<Type>(typeOrValue)) {
422 derivedType = dyn_cast<T>(t);
423 }
else if (Value v = dyn_cast<Value>(typeOrValue)) {
424 derivedType = dyn_cast<T>(v.getType());
426 llvm_unreachable(
"unexpected variant");
430 return callback(derivedType, results);
435 template <
typename T,
typename FnT>
436 std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
437 std::is_same_v<T, Value>,
438 ConversionCallbackFn>
439 wrapCallback(FnT &&callback) {
440 contextAwareTypeConversionsIndex = conversions.size();
441 return [callback = std::forward<FnT>(callback)](
442 PointerUnion<Type, Value> typeOrValue,
443 SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
444 if (Type t = dyn_cast<Type>(typeOrValue)) {
447 }
else if (Value v = dyn_cast<Value>(typeOrValue)) {
448 return callback(v, results);
450 llvm_unreachable(
"unexpected variant");
456 void registerConversion(ConversionCallbackFn callback) {
457 conversions.emplace_back(std::move(callback));
458 cachedDirectConversions.clear();
459 cachedMultiConversions.clear();
465 template <
typename T,
typename FnT>
466 SourceMaterializationCallbackFn
467 wrapSourceMaterialization(FnT &&callback)
const {
468 return [callback = std::forward<FnT>(callback)](
469 OpBuilder &builder, Type resultType,
ValueRange inputs,
470 Location loc) -> Value {
471 if (T derivedType = dyn_cast<T>(resultType))
472 return callback(builder, derivedType, inputs, loc);
485 template <
typename T,
typename FnT>
487 std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
488 TargetMaterializationCallbackFn>
489 wrapTargetMaterialization(FnT &&callback)
const {
490 return [callback = std::forward<FnT>(callback)](
492 Location loc, Type originalType) -> SmallVector<Value> {
493 SmallVector<Value>
result;
494 if constexpr (std::is_same<T, TypeRange>::value) {
497 result = callback(builder, resultTypes, inputs, loc, originalType);
498 }
else if constexpr (std::is_assignable<Type, T>::value) {
501 if (resultTypes.size() == 1) {
504 if (T derivedType = dyn_cast<T>(resultTypes.front())) {
509 callback(builder, derivedType, inputs, loc, originalType);
515 static_assert(
sizeof(T) == 0,
"T must be a Type or a TypeRange");
523 template <
typename T,
typename FnT>
525 std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
526 TargetMaterializationCallbackFn>
527 wrapTargetMaterialization(FnT &&callback)
const {
528 return wrapTargetMaterialization<T>(
529 [callback = std::forward<FnT>(callback)](
530 OpBuilder &builder, T resultTypes,
ValueRange inputs, Location loc,
532 return callback(builder, resultTypes, inputs, loc);
540 template <
typename T,
typename A,
typename FnT>
541 TypeAttributeConversionCallbackFn
542 wrapTypeAttributeConversion(FnT &&callback)
const {
543 return [callback = std::forward<FnT>(callback)](
544 Type type, Attribute attr) -> AttributeConversionResult {
545 if (T derivedType = dyn_cast<T>(type)) {
546 if (A derivedAttr = dyn_cast_or_null<A>(attr))
547 return callback(derivedType, derivedAttr);
549 return AttributeConversionResult::na();
555 registerTypeAttributeConversion(TypeAttributeConversionCallbackFn callback) {
556 typeAttributeConversions.emplace_back(std::move(callback));
558 cachedDirectConversions.clear();
559 cachedMultiConversions.clear();
563 LogicalResult convertTypeImpl(PointerUnion<Type, Value> t,
564 SmallVectorImpl<Type> &results)
const;
567 SmallVector<ConversionCallbackFn, 4> conversions;
570 SmallVector<SourceMaterializationCallbackFn, 2> sourceMaterializations;
571 SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
574 SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;
579 mutable DenseMap<Type, Type> cachedDirectConversions;
581 mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
583 mutable llvm::sys::SmartRWMutex<true> cacheMutex;
592 int contextAwareTypeConversionsIndex = -1;
604 using OpAdaptor = ArrayRef<Value>;
605 using OneToNOpAdaptor = ArrayRef<ValueRange>;
611 virtual LogicalResult
612 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
613 ConversionPatternRewriter &rewriter)
const {
614 llvm_unreachable(
"matchAndRewrite is not implemented");
619 virtual LogicalResult
620 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
621 ConversionPatternRewriter &rewriter)
const {
622 return dispatchTo1To1(*
this, op, operands, rewriter);
626 LogicalResult matchAndRewrite(Operation *op,
627 PatternRewriter &rewriter)
const final;
631 const TypeConverter *getTypeConverter()
const {
return typeConverter; }
633 template <
typename ConverterTy>
634 std::enable_if_t<std::is_base_of<TypeConverter, ConverterTy>::value,
636 getTypeConverter()
const {
637 return static_cast<const ConverterTy *
>(typeConverter);
643 using RewritePattern::RewritePattern;
646 template <
typename... Args>
647 ConversionPattern(
const TypeConverter &typeConverter, Args &&...args)
648 : RewritePattern(std::forward<Args>(args)...),
649 typeConverter(&typeConverter) {}
656 FailureOr<SmallVector<Value>>
657 getOneToOneAdaptorOperands(ArrayRef<ValueRange> operands)
const;
663 template <
typename SelfPattern,
typename SourceOp>
664 static LogicalResult dispatchTo1To1(
const SelfPattern &self, SourceOp op,
665 ArrayRef<ValueRange> operands,
666 ConversionPatternRewriter &rewriter);
669 template <
typename SelfPattern,
typename SourceOp>
670 static LogicalResult dispatchTo1To1(
671 const SelfPattern &self, SourceOp op,
672 typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>> adaptor,
673 ConversionPatternRewriter &rewriter);
677 const TypeConverter *typeConverter =
nullptr;
683template <
typename SourceOp>
684class OpConversionPattern :
public ConversionPattern {
688 using Base = OpConversionPattern;
690 using OpAdaptor =
typename SourceOp::Adaptor;
691 using OneToNOpAdaptor =
692 typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
694 OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
695 : ConversionPattern(SourceOp::getOperationName(), benefit, context) {}
696 OpConversionPattern(
const TypeConverter &typeConverter, MLIRContext *context,
697 PatternBenefit benefit = 1)
698 : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit,
704 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
705 ConversionPatternRewriter &rewriter)
const final {
706 auto sourceOp = cast<SourceOp>(op);
707 return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
710 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
711 ConversionPatternRewriter &rewriter)
const final {
712 auto sourceOp = cast<SourceOp>(op);
713 return matchAndRewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp),
719 virtual LogicalResult
720 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
721 ConversionPatternRewriter &rewriter)
const {
722 llvm_unreachable(
"matchAndRewrite is not implemented");
724 virtual LogicalResult
725 matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
726 ConversionPatternRewriter &rewriter)
const {
727 return dispatchTo1To1(*
this, op, adaptor, rewriter);
731 using ConversionPattern::matchAndRewrite;
737template <
typename SourceOp>
738class OpInterfaceConversionPattern :
public ConversionPattern {
742 using Base = OpInterfaceConversionPattern;
744 OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
745 : ConversionPattern(Pattern::MatchInterfaceOpTypeTag(),
746 SourceOp::getInterfaceID(), benefit, context) {}
747 OpInterfaceConversionPattern(
const TypeConverter &typeConverter,
748 MLIRContext *context, PatternBenefit benefit = 1)
749 : ConversionPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(),
750 SourceOp::getInterfaceID(), benefit, context) {}
755 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
756 ConversionPatternRewriter &rewriter)
const final {
757 return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
760 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
761 ConversionPatternRewriter &rewriter)
const final {
762 return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
767 virtual LogicalResult
768 matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
769 ConversionPatternRewriter &rewriter)
const {
770 llvm_unreachable(
"matchAndRewrite is not implemented");
772 virtual LogicalResult
773 matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
774 ConversionPatternRewriter &rewriter)
const {
775 return dispatchTo1To1(*
this, op, operands, rewriter);
779 using ConversionPattern::matchAndRewrite;
785template <
template <
typename>
class TraitType>
786class OpTraitConversionPattern :
public ConversionPattern {
790 using Base = OpTraitConversionPattern;
792 OpTraitConversionPattern(MLIRContext *context, PatternBenefit benefit = 1)
793 : ConversionPattern(Pattern::MatchTraitOpTypeTag(),
794 TypeID::
get<TraitType>(), benefit, context) {}
795 OpTraitConversionPattern(
const TypeConverter &typeConverter,
796 MLIRContext *context, PatternBenefit benefit = 1)
797 : ConversionPattern(typeConverter, Pattern::MatchTraitOpTypeTag(),
798 TypeID::
get<TraitType>(), benefit, context) {}
804FailureOr<Operation *>
806 const TypeConverter &converter,
807 ConversionPatternRewriter &rewriter);
812void populateFunctionOpInterfaceTypeConversionPattern(
816template <
typename FuncOpT>
817void populateFunctionOpInterfaceTypeConversionPattern(
820 populateFunctionOpInterfaceTypeConversionPattern(
821 FuncOpT::getOperationName(),
patterns, converter, benefit);
824void populateAnyFunctionOpInterfaceTypeConversionPattern(
841 ~ConversionPatternRewriter()
override;
844 const ConversionConfig &getConfig()
const;
861 applySignatureConversion(Block *block,
862 TypeConverter::SignatureConversion &conversion,
863 const TypeConverter *converter =
nullptr);
878 FailureOr<Block *> convertRegionTypes(
879 Region *region,
const TypeConverter &converter,
880 TypeConverter::SignatureConversion *entryConversion =
nullptr);
901 void replaceAllUsesWith(Value from,
ValueRange to);
902 void replaceAllUsesWith(Value from, Value to)
override {
909 Value getRemappedValue(Value key);
914 LogicalResult getRemappedValues(
ValueRange keys,
915 SmallVectorImpl<Value> &results);
924 bool canRecoverFromRewriteFailure()
const override {
return true; }
931 void replaceOp(Operation *op,
ValueRange newValues)
override;
938 void replaceOp(Operation *op, Operation *newOp)
override;
942 void replaceOpWithMultiple(Operation *op,
943 SmallVector<SmallVector<Value>> &&newValues);
944 template <
typename RangeT = ValueRange>
945 void replaceOpWithMultiple(Operation *op, ArrayRef<RangeT> newValues) {
946 replaceOpWithMultiple(op,
947 llvm::to_vector_of<SmallVector<Value>>(newValues));
949 template <
typename RangeT>
950 void replaceOpWithMultiple(Operation *op, RangeT &&newValues) {
951 replaceOpWithMultiple(op,
952 ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
958 void eraseOp(Operation *op)
override;
962 void eraseBlock(Block *block)
override;
965 void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
967 using PatternRewriter::inlineBlockBefore;
973 void startOpModification(Operation *op)
override;
976 void finalizeOpModification(Operation *op)
override;
979 void cancelOpModification(Operation *op)
override;
982 detail::ConversionPatternRewriterImpl &getImpl();
991 LogicalResult legalize(Operation *op);
1004 LogicalResult legalize(Region *r);
1008 friend struct OperationConverter;
1013 explicit ConversionPatternRewriter(MLIRContext *ctx,
1014 const ConversionConfig &config,
1015 OperationConverter &converter);
1018 using OpBuilder::setListener;
1020 std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
1023template <
typename SelfPattern,
typename SourceOp>
1025ConversionPattern::dispatchTo1To1(
const SelfPattern &self, SourceOp op,
1027 ConversionPatternRewriter &rewriter) {
1028 FailureOr<SmallVector<Value>> oneToOneOperands =
1029 self.getOneToOneAdaptorOperands(operands);
1030 if (
failed(oneToOneOperands))
1031 return rewriter.notifyMatchFailure(op,
1032 "pattern '" + self.getDebugName() +
1033 "' does not support 1:N conversion");
1034 return self.matchAndRewrite(op, *oneToOneOperands, rewriter);
1037template <
typename SelfPattern,
typename SourceOp>
1038LogicalResult ConversionPattern::dispatchTo1To1(
1039 const SelfPattern &self, SourceOp op,
1041 ConversionPatternRewriter &rewriter) {
1042 FailureOr<SmallVector<Value>> oneToOneOperands =
1043 self.getOneToOneAdaptorOperands(adaptor.getOperands());
1044 if (
failed(oneToOneOperands))
1045 return rewriter.notifyMatchFailure(op,
1046 "pattern '" + self.getDebugName() +
1047 "' does not support 1:N conversion");
1048 return self.matchAndRewrite(
1049 op,
typename SourceOp::Adaptor(*oneToOneOperands, adaptor), rewriter);
1057class ConversionTarget {
1061 enum class LegalizationAction {
1075 struct LegalOpDetails {
1079 bool isRecursivelyLegal =
false;
1084 using DynamicLegalityCallbackFn =
1085 std::function<std::optional<bool>(Operation *)>;
1087 ConversionTarget(MLIRContext &ctx) : ctx(ctx) {}
1088 virtual ~ConversionTarget() =
default;
1095 void setOpAction(OperationName op, LegalizationAction action);
1096 template <
typename OpT>
1097 void setOpAction(LegalizationAction action) {
1098 setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
1102 void addLegalOp(OperationName op) {
1103 setOpAction(op, LegalizationAction::Legal);
1105 template <
typename OpT>
1107 addLegalOp(OperationName(OpT::getOperationName(), &ctx));
1109 template <
typename OpT,
typename OpT2,
typename... OpTs>
1112 addLegalOp<OpT2, OpTs...>();
1117 void addDynamicallyLegalOp(OperationName op,
1118 const DynamicLegalityCallbackFn &callback) {
1119 setOpAction(op, LegalizationAction::Dynamic);
1120 setLegalityCallback(op, callback);
1122 template <
typename OpT>
1123 void addDynamicallyLegalOp(
const DynamicLegalityCallbackFn &callback) {
1124 addDynamicallyLegalOp(OperationName(OpT::getOperationName(), &ctx),
1127 template <
typename OpT,
typename OpT2,
typename... OpTs>
1128 void addDynamicallyLegalOp(
const DynamicLegalityCallbackFn &callback) {
1129 addDynamicallyLegalOp<OpT>(callback);
1130 addDynamicallyLegalOp<OpT2, OpTs...>(callback);
1132 template <
typename OpT,
class Callable>
1133 std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
1134 addDynamicallyLegalOp(Callable &&callback) {
1135 addDynamicallyLegalOp<OpT>(
1136 [=](Operation *op) {
return callback(cast<OpT>(op)); });
1141 void addIllegalOp(OperationName op) {
1142 setOpAction(op, LegalizationAction::Illegal);
1144 template <
typename OpT>
1145 void addIllegalOp() {
1146 addIllegalOp(OperationName(OpT::getOperationName(), &ctx));
1148 template <
typename OpT,
typename OpT2,
typename... OpTs>
1149 void addIllegalOp() {
1150 addIllegalOp<OpT>();
1151 addIllegalOp<OpT2, OpTs...>();
1159 void markOpRecursivelyLegal(OperationName name,
1160 const DynamicLegalityCallbackFn &callback);
1161 template <
typename OpT>
1162 void markOpRecursivelyLegal(
const DynamicLegalityCallbackFn &callback = {}) {
1163 OperationName opName(OpT::getOperationName(), &ctx);
1164 markOpRecursivelyLegal(opName, callback);
1166 template <
typename OpT,
typename OpT2,
typename... OpTs>
1167 void markOpRecursivelyLegal(
const DynamicLegalityCallbackFn &callback = {}) {
1168 markOpRecursivelyLegal<OpT>(callback);
1169 markOpRecursivelyLegal<OpT2, OpTs...>(callback);
1171 template <
typename OpT,
class Callable>
1172 std::enable_if_t<!std::is_invocable_v<Callable, Operation *>>
1173 markOpRecursivelyLegal(Callable &&callback) {
1174 markOpRecursivelyLegal<OpT>(
1175 [=](Operation *op) {
return callback(cast<OpT>(op)); });
1179 void setDialectAction(ArrayRef<StringRef> dialectNames,
1180 LegalizationAction action);
1183 template <
typename... Names>
1184 void addLegalDialect(StringRef name, Names... names) {
1185 SmallVector<StringRef, 2> dialectNames({name, names...});
1186 setDialectAction(dialectNames, LegalizationAction::Legal);
1188 template <
typename... Args>
1189 void addLegalDialect() {
1190 SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1191 setDialectAction(dialectNames, LegalizationAction::Legal);
1196 template <
typename... Names>
1197 void addDynamicallyLegalDialect(
const DynamicLegalityCallbackFn &callback,
1198 StringRef name, Names... names) {
1199 SmallVector<StringRef, 2> dialectNames({name, names...});
1200 setDialectAction(dialectNames, LegalizationAction::Dynamic);
1201 setLegalityCallback(dialectNames, callback);
1203 template <
typename... Args>
1204 void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback) {
1205 addDynamicallyLegalDialect(std::move(callback),
1206 Args::getDialectNamespace()...);
1212 void markUnknownOpDynamicallyLegal(
const DynamicLegalityCallbackFn &fn) {
1213 setLegalityCallback(fn);
1218 template <
typename... Names>
1219 void addIllegalDialect(StringRef name, Names... names) {
1220 SmallVector<StringRef, 2> dialectNames({name, names...});
1221 setDialectAction(dialectNames, LegalizationAction::Illegal);
1223 template <
typename... Args>
1224 void addIllegalDialect() {
1225 SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
1226 setDialectAction(dialectNames, LegalizationAction::Illegal);
1234 std::optional<LegalizationAction> getOpAction(OperationName op)
const;
1245 std::optional<LegalOpDetails> isLegal(Operation *op)
const;
1250 bool isIllegal(Operation *op)
const;
1254 void setLegalityCallback(OperationName name,
1255 const DynamicLegalityCallbackFn &callback);
1258 void setLegalityCallback(ArrayRef<StringRef> dialects,
1259 const DynamicLegalityCallbackFn &callback);
1262 void setLegalityCallback(
const DynamicLegalityCallbackFn &callback);
1265 struct LegalizationInfo {
1267 LegalizationAction action = LegalizationAction::Illegal;
1270 bool isRecursivelyLegal =
false;
1273 DynamicLegalityCallbackFn legalityFn;
1277 std::optional<LegalizationInfo> getOpInfo(OperationName op)
const;
1281 llvm::MapVector<OperationName, LegalizationInfo> legalOperations;
1285 DenseMap<OperationName, DynamicLegalityCallbackFn> opRecursiveLegalityFns;
1289 llvm::StringMap<LegalizationAction> legalDialects;
1292 llvm::StringMap<DynamicLegalityCallbackFn> dialectLegalityFns;
1295 DynamicLegalityCallbackFn unknownLegalityFn;
1301#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
1308class PDLConversionConfig final
1309 :
public PDLPatternConfigBase<PDLConversionConfig> {
1311 PDLConversionConfig(
const TypeConverter *converter) : converter(converter) {}
1312 ~PDLConversionConfig() final = default;
1316 const TypeConverter *getTypeConverter()
const {
return converter; }
1320 void notifyRewriteBegin(PatternRewriter &rewriter)
final;
1321 void notifyRewriteEnd(PatternRewriter &rewriter)
final;
1325 const TypeConverter *converter;
1337class PDLConversionConfig final {
1339 PDLConversionConfig(
const TypeConverter * ) {}
1349enum class DialectConversionFoldingMode {
1359struct ConversionConfig {
1363 function_ref<void(Diagnostic &)> notifyCallback =
nullptr;
1369 DenseSet<Operation *> *unlegalizedOps =
nullptr;
1375 DenseSet<Operation *> *legalizableOps =
nullptr;
1409 RewriterBase::Listener *listener =
nullptr;
1419 bool buildMaterializations =
true;
1440 bool allowPatternRollback =
true;
1443 DialectConversionFoldingMode foldingMode =
1444 DialectConversionFoldingMode::BeforePatterns;
1450 bool attachDebugMaterializationKind =
false;
1495 const ConversionTarget &
target,
1497 ConversionConfig
config = ConversionConfig());
1501 ConversionConfig
config = ConversionConfig());
1508 const ConversionTarget &
target,
1510 ConversionConfig
config = ConversionConfig());
1511LogicalResult applyFullConversion(
Operation *op,
const ConversionTarget &
target,
1513 ConversionConfig
config = ConversionConfig());
1525 ConversionConfig
config = ConversionConfig());
1529 ConversionConfig
config = ConversionConfig());
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
This class represents a frozen set of patterns that can be processed by a pattern applicator.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePattern is the common base class for all DAG to DAG replacements.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
static void reconcileUnrealizedCasts(const DenseMap< UnrealizedConversionCastOp, UnresolvedMaterializationInfo > &castOps, SmallVectorImpl< UnrealizedConversionCastOp > *remainingCastOps)
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::function_ref< Fn > function_ref