MLIR 23.0.0git
ExtendToSupportedTypes.cpp
Go to the documentation of this file.
1//===- ExtendToSupportedTypes.cpp - Legalize functions on unsupported floats
2//----------===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements legalizing math operations on unsupported floating-point
11// types through arith.extf and arith.truncf.
12//
13//===----------------------------------------------------------------------===//
14
19#include "mlir/IR/Diagnostics.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SetVector.h"
25
26namespace mlir::math {
27#define GEN_PASS_DEF_MATHEXTENDTOSUPPORTEDTYPES
28#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
29} // namespace mlir::math
30
31using namespace mlir;
32
33namespace {
34struct ExtendToSupportedTypesRewritePattern final : ConversionPattern {
35 ExtendToSupportedTypesRewritePattern(const TypeConverter &converter,
36 MLIRContext *context)
37 : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
38 LogicalResult
39 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
40 ConversionPatternRewriter &rewriter) const override;
41};
42
43struct ExtendToSupportedTypesPass
44 : mlir::math::impl::MathExtendToSupportedTypesBase<
45 ExtendToSupportedTypesPass> {
46 using math::impl::MathExtendToSupportedTypesBase<
47 ExtendToSupportedTypesPass>::MathExtendToSupportedTypesBase;
48
49 void runOnOperation() override;
50};
51} // namespace
52
54 TypeConverter &typeConverter, const SetVector<Type> &sourceTypes,
55 Type targetType) {
56
57 typeConverter.addConversion(
58 [](Type type) -> std::optional<Type> { return type; });
59 typeConverter.addConversion(
60 [&sourceTypes, targetType](FloatType type) -> std::optional<Type> {
61 if (!sourceTypes.contains(type))
62 return targetType;
63
64 return std::nullopt;
65 });
66 typeConverter.addConversion(
67 [&sourceTypes, targetType](ShapedType type) -> std::optional<Type> {
68 if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
69 if (!sourceTypes.contains(elemTy))
70 return type.clone(targetType);
71
72 return std::nullopt;
73 });
74 typeConverter.addTargetMaterialization(
75 [](OpBuilder &b, Type target, ValueRange input, Location loc) {
76 auto extFOp = arith::ExtFOp::create(b, loc, target, input);
77 extFOp.setFastmath(arith::FastMathFlags::contract);
78 return extFOp;
79 });
80}
81
83 ConversionTarget &target, TypeConverter &typeConverter) {
84 target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
85 if (isa<MathDialect>(op->getDialect()))
86 return typeConverter.isLegal(op);
87 return true;
88 });
89 target.addLegalOp<FmaOp>();
90 target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
91}
92
93LogicalResult ExtendToSupportedTypesRewritePattern::matchAndRewrite(
94 Operation *op, ArrayRef<Value> operands,
95 ConversionPatternRewriter &rewriter) const {
96 Location loc = op->getLoc();
97 const TypeConverter *converter = getTypeConverter();
98 FailureOr<Operation *> legalized =
99 convertOpResultTypes(op, operands, *converter, rewriter);
100 if (failed(legalized))
101 return failure();
102
103 SmallVector<Value> results = (*legalized)->getResults();
104 for (auto [result, newType, origType] : llvm::zip_equal(
105 results, (*legalized)->getResultTypes(), op->getResultTypes())) {
106 if (newType != origType) {
107 auto truncFOp = arith::TruncFOp::create(rewriter, loc, origType, result);
108 truncFOp.setFastmath(arith::FastMathFlags::contract);
109 result = truncFOp.getResult();
110 }
111 }
112 rewriter.replaceOp(op, results);
113 return success();
114}
115
117 RewritePatternSet &patterns, const TypeConverter &typeConverter) {
118 patterns.add<ExtendToSupportedTypesRewritePattern>(typeConverter,
119 patterns.getContext());
120}
121
122void ExtendToSupportedTypesPass::runOnOperation() {
123 Operation *op = getOperation();
124 MLIRContext *ctx = &getContext();
125
126 // Parse target type
127 std::optional<Type> maybeTargetType =
128 arith::parseFloatType(ctx, targetTypeStr);
129 if (!maybeTargetType.has_value()) {
130 emitError(UnknownLoc::get(ctx), "could not map target type '" +
131 targetTypeStr +
132 "' to a known floating-point type");
133 return signalPassFailure();
134 }
135 Type targetType = maybeTargetType.value();
136
137 // Parse source types
138 llvm::SetVector<Type> sourceTypes;
139 for (const auto &extraTypeStr : extraTypeStrs) {
140 std::optional<FloatType> maybeExtraType =
141 arith::parseFloatType(ctx, extraTypeStr);
142 if (!maybeExtraType.has_value()) {
143 emitError(UnknownLoc::get(ctx), "could not map source type '" +
144 extraTypeStr +
145 "' to a known floating-point type");
146 return signalPassFailure();
147 }
148 sourceTypes.insert(maybeExtraType.value());
149 }
150 // f64 and f32 are implicitly supported
151 Builder b(ctx);
152 sourceTypes.insert(b.getF64Type());
153 sourceTypes.insert(b.getF32Type());
154
155 TypeConverter typeConverter;
156 math::populateExtendToSupportedTypesTypeConverter(typeConverter, sourceTypes,
157 targetType);
158 ConversionTarget target(*ctx);
160 RewritePatternSet patterns(ctx);
161 math::populateExtendToSupportedTypesPatterns(patterns, typeConverter);
162 if (failed(applyPartialConversion(op, target, std::move(patterns))))
163 return signalPassFailure();
164}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
result_type_range getResultTypes()
Definition Operation.h:436
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
Definition Utils.cpp:361
void populateExtendToSupportedTypesPatterns(RewritePatternSet &patterns, const TypeConverter &typeConverter)
void populateExtendToSupportedTypesConversionTarget(ConversionTarget &target, TypeConverter &typeConverter)
void populateExtendToSupportedTypesTypeConverter(TypeConverter &typeConverter, const SetVector< Type > &sourceTypes, Type targetType)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123