MLIR 22.0.0git
ExtendToSupportedTypes.cpp
Go to the documentation of this file.
1//===- ExtendToSupportedTypes.cpp - Legalize functions on unsupported floats
2//----------===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements legalizing math operations on unsupported floating-point
11// types through arith.extf and arith.truncf.
12//
13//===----------------------------------------------------------------------===//
14
19#include "mlir/IR/Diagnostics.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SetVector.h"
25
26namespace mlir::math {
27#define GEN_PASS_DEF_MATHEXTENDTOSUPPORTEDTYPES
28#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
29} // namespace mlir::math
30
31using namespace mlir;
32
33namespace {
34struct ExtendToSupportedTypesRewritePattern final : ConversionPattern {
35 ExtendToSupportedTypesRewritePattern(const TypeConverter &converter,
36 MLIRContext *context)
37 : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
38 LogicalResult
39 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
40 ConversionPatternRewriter &rewriter) const override;
41};
42
43struct ExtendToSupportedTypesPass
45 ExtendToSupportedTypesPass> {
46 using math::impl::MathExtendToSupportedTypesBase<
47 ExtendToSupportedTypesPass>::MathExtendToSupportedTypesBase;
48
49 void runOnOperation() override;
50};
51} // namespace
52
54 TypeConverter &typeConverter, const SetVector<Type> &sourceTypes,
55 Type targetType) {
56
57 typeConverter.addConversion(
58 [](Type type) -> std::optional<Type> { return type; });
59 typeConverter.addConversion(
60 [&sourceTypes, targetType](FloatType type) -> std::optional<Type> {
61 if (!sourceTypes.contains(type))
62 return targetType;
63
64 return std::nullopt;
65 });
66 typeConverter.addConversion(
67 [&sourceTypes, targetType](ShapedType type) -> std::optional<Type> {
68 if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
69 if (!sourceTypes.contains(elemTy))
70 return type.clone(targetType);
71
72 return std::nullopt;
73 });
74 typeConverter.addTargetMaterialization(
75 [](OpBuilder &b, Type target, ValueRange input, Location loc) {
76 auto extFOp = arith::ExtFOp::create(b, loc, target, input);
77 extFOp.setFastmath(arith::FastMathFlags::contract);
78 return extFOp;
79 });
80}
81
83 ConversionTarget &target, TypeConverter &typeConverter) {
84 target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
85 if (isa<MathDialect>(op->getDialect()))
86 return typeConverter.isLegal(op);
87 return true;
88 });
89 target.addLegalOp<FmaOp>();
90 target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
91}
92
93LogicalResult ExtendToSupportedTypesRewritePattern::matchAndRewrite(
94 Operation *op, ArrayRef<Value> operands,
95 ConversionPatternRewriter &rewriter) const {
96 Location loc = op->getLoc();
97 const TypeConverter *converter = getTypeConverter();
98 FailureOr<Operation *> legalized =
99 convertOpResultTypes(op, operands, *converter, rewriter);
100 if (failed(legalized))
101 return failure();
102
103 SmallVector<Value> results = (*legalized)->getResults();
104 for (auto [result, newType, origType] : llvm::zip_equal(
105 results, (*legalized)->getResultTypes(), op->getResultTypes())) {
106 if (newType != origType) {
107 auto truncFOp = arith::TruncFOp::create(rewriter, loc, origType, result);
108 truncFOp.setFastmath(arith::FastMathFlags::contract);
109 result = truncFOp.getResult();
110 }
111 }
112 rewriter.replaceOp(op, results);
113 return success();
114}
115
117 RewritePatternSet &patterns, const TypeConverter &typeConverter) {
118 patterns.add<ExtendToSupportedTypesRewritePattern>(typeConverter,
119 patterns.getContext());
120}
121
122void ExtendToSupportedTypesPass::runOnOperation() {
123 Operation *op = getOperation();
124 MLIRContext *ctx = &getContext();
125
126 // Parse target type
127 std::optional<Type> maybeTargetType =
128 arith::parseFloatType(ctx, targetTypeStr);
129 if (!maybeTargetType.has_value()) {
130 emitError(UnknownLoc::get(ctx), "could not map target type '" +
131 targetTypeStr +
132 "' to a known floating-point type");
133 return signalPassFailure();
135 Type targetType = maybeTargetType.value();
137 // Parse source types
139 for (const auto &extraTypeStr : extraTypeStrs) {
140 std::optional<FloatType> maybeExtraType =
141 arith::parseFloatType(ctx, extraTypeStr);
142 if (!maybeExtraType.has_value()) {
143 emitError(UnknownLoc::get(ctx), "could not map source type '" +
144 extraTypeStr +
145 "' to a known floating-point type");
146 return signalPassFailure();
148 sourceTypes.insert(maybeExtraType.value());
150 // f64 and f32 are implicitly supported
151 Builder b(ctx);
152 sourceTypes.insert(b.getF64Type());
153 sourceTypes.insert(b.getF32Type());
154
155 TypeConverter typeConverter;
156 math::populateExtendToSupportedTypesTypeConverter(typeConverter, sourceTypes,
157 targetType);
162 if (failed(applyPartialConversion(op, target, std::move(patterns))))
164}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
result_type_range getResultTypes()
Definition Operation.h:428
void signalPassFailure()
Signal that some invariant was broken when running.
Definition Pass.h:225
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
::mlir::Pass::ListOption< std::string > extraTypeStrs
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
Definition Utils.cpp:360
void populateExtendToSupportedTypesPatterns(RewritePatternSet &patterns, const TypeConverter &typeConverter)
void populateExtendToSupportedTypesConversionTarget(ConversionTarget &target, TypeConverter &typeConverter)
void populateExtendToSupportedTypesTypeConverter(TypeConverter &typeConverter, const SetVector< Type > &sourceTypes, Type targetType)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
const FrozenRewritePatternSet & patterns