23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SetVector.h"
27#define GEN_PASS_DEF_MATHEXTENDTOSUPPORTEDTYPES
28#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
35 ExtendToSupportedTypesRewritePattern(
const TypeConverter &converter,
37 : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
39 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
40 ConversionPatternRewriter &rewriter)
const override;
43struct ExtendToSupportedTypesPass
44 : mlir::math::impl::MathExtendToSupportedTypesBase<
45 ExtendToSupportedTypesPass> {
46 using math::impl::MathExtendToSupportedTypesBase<
47 ExtendToSupportedTypesPass>::MathExtendToSupportedTypesBase;
49 void runOnOperation()
override;
57 typeConverter.addConversion(
58 [](
Type type) -> std::optional<Type> {
return type; });
59 typeConverter.addConversion(
60 [&sourceTypes, targetType](FloatType type) -> std::optional<Type> {
61 if (!sourceTypes.contains(type))
66 typeConverter.addConversion(
67 [&sourceTypes, targetType](ShapedType type) -> std::optional<Type> {
68 if (
auto elemTy = dyn_cast<FloatType>(type.getElementType()))
69 if (!sourceTypes.contains(elemTy))
70 return type.clone(targetType);
74 typeConverter.addTargetMaterialization(
76 auto extFOp = arith::ExtFOp::create(
b, loc,
target, input);
77 extFOp.setFastmath(arith::FastMathFlags::contract);
84 target.markUnknownOpDynamicallyLegal([&typeConverter](
Operation *op) ->
bool {
86 return typeConverter.isLegal(op);
89 target.addLegalOp<FmaOp>();
90 target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
93LogicalResult ExtendToSupportedTypesRewritePattern::matchAndRewrite(
95 ConversionPatternRewriter &rewriter)
const {
98 FailureOr<Operation *> legalized =
99 convertOpResultTypes(op, operands, *converter, rewriter);
100 if (failed(legalized))
104 for (
auto [
result, newType, origType] : llvm::zip_equal(
106 if (newType != origType) {
107 auto truncFOp = arith::TruncFOp::create(rewriter, loc, origType,
result);
108 truncFOp.setFastmath(arith::FastMathFlags::contract);
109 result = truncFOp.getResult();
112 rewriter.replaceOp(op, results);
118 patterns.add<ExtendToSupportedTypesRewritePattern>(typeConverter,
122void ExtendToSupportedTypesPass::runOnOperation() {
127 std::optional<Type> maybeTargetType =
129 if (!maybeTargetType.has_value()) {
130 emitError(UnknownLoc::get(ctx),
"could not map target type '" +
132 "' to a known floating-point type");
133 return signalPassFailure();
135 Type targetType = maybeTargetType.value();
139 for (
const auto &extraTypeStr : extraTypeStrs) {
140 std::optional<FloatType> maybeExtraType =
142 if (!maybeExtraType.has_value()) {
143 emitError(UnknownLoc::get(ctx),
"could not map source type '" +
145 "' to a known floating-point type");
146 return signalPassFailure();
148 sourceTypes.insert(maybeExtraType.value());
152 sourceTypes.insert(
b.getF64Type());
153 sourceTypes.insert(
b.getF32Type());
155 TypeConverter typeConverter;
158 ConversionTarget
target(*ctx);
163 return signalPassFailure();
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
void populateExtendToSupportedTypesPatterns(RewritePatternSet &patterns, const TypeConverter &typeConverter)
void populateExtendToSupportedTypesConversionTarget(ConversionTarget &target, TypeConverter &typeConverter)
void populateExtendToSupportedTypesTypeConverter(TypeConverter &typeConverter, const SetVector< Type > &sourceTypes, Type targetType)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
const FrozenRewritePatternSet & patterns