MLIR 23.0.0git
ExtendToSupportedTypes.cpp
Go to the documentation of this file.
1//===- ExtendToSupportedTypes.cpp - Legalize functions on unsupported floats
2//----------===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements legalizing math operations on unsupported floating-point
11// types through arith.extf and arith.truncf.
12//
13//===----------------------------------------------------------------------===//
14
19#include "mlir/IR/Diagnostics.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SetVector.h"
25
26namespace mlir::math {
27#define GEN_PASS_DEF_MATHEXTENDTOSUPPORTEDTYPES
28#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
29} // namespace mlir::math
30
31using namespace mlir;
32
33namespace {
34struct ExtendToSupportedTypesRewritePattern final : ConversionPattern {
35 ExtendToSupportedTypesRewritePattern(const TypeConverter &converter,
36 MLIRContext *context)
37 : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
38 LogicalResult
39 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
40 ConversionPatternRewriter &rewriter) const override;
41};
42
43struct ExtendToSupportedTypesPass
45 ExtendToSupportedTypesPass> {
46 using math::impl::MathExtendToSupportedTypesBase<
47 ExtendToSupportedTypesPass>::MathExtendToSupportedTypesBase;
48
49 void runOnOperation() override;
50};
51} // namespace
52
54 TypeConverter &typeConverter, const SetVector<Type> &sourceTypes,
55 Type targetType) {
56
57 typeConverter.addConversion(
58 [](Type type) -> std::optional<Type> { return type; });
59 typeConverter.addConversion(
60 [&sourceTypes, targetType](FloatType type) -> std::optional<Type> {
61 if (!sourceTypes.contains(type))
62 return targetType;
63
64 return std::nullopt;
65 });
66 typeConverter.addConversion(
67 [&sourceTypes, targetType](ShapedType type) -> std::optional<Type> {
68 if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
69 if (!sourceTypes.contains(elemTy))
70 return type.clone(targetType);
71
72 return std::nullopt;
73 });
74 typeConverter.addTargetMaterialization(
75 [](OpBuilder &b, Type target, ValueRange input, Location loc) {
76 auto extFOp = arith::ExtFOp::create(b, loc, target, input);
77 extFOp.setFastmath(arith::FastMathFlags::contract);
78 return extFOp;
79 });
80}
81
83 ConversionTarget &target, TypeConverter &typeConverter) {
84 target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
85 if (isa<MathDialect>(op->getDialect()))
86 return typeConverter.isLegal(op);
87 return true;
88 });
89 target.addLegalOp<FmaOp>();
90 target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
91}
92
93LogicalResult ExtendToSupportedTypesRewritePattern::matchAndRewrite(
94 Operation *op, ArrayRef<Value> operands,
95 ConversionPatternRewriter &rewriter) const {
96 Location loc = op->getLoc();
97 const TypeConverter *converter = getTypeConverter();
98 FailureOr<Operation *> legalized =
99 convertOpResultTypes(op, operands, *converter, rewriter);
100 if (failed(legalized))
101 return failure();
102
103 SmallVector<Value> results = (*legalized)->getResults();
104 for (auto [result, newType, origType] : llvm::zip_equal(
105 results, (*legalized)->getResultTypes(), op->getResultTypes())) {
106 if (newType != origType) {
107 auto truncFOp = arith::TruncFOp::create(rewriter, loc, origType, result);
108 truncFOp.setFastmath(arith::FastMathFlags::contract);
109 result = truncFOp.getResult();
110 }
111 }
112 rewriter.replaceOp(op, results);
113 return success();
114}
115
117 RewritePatternSet &patterns, const TypeConverter &typeConverter) {
118 patterns.add<ExtendToSupportedTypesRewritePattern>(typeConverter,
119 patterns.getContext());
120}
121
122void ExtendToSupportedTypesPass::runOnOperation() {
123 Operation *op = getOperation();
124 MLIRContext *ctx = &getContext();
125
126 // Parse target type
127 FloatType targetType = arith::parseFloatType(ctx, targetTypeStr);
128 if (!targetType) {
129 emitError(UnknownLoc::get(ctx), "could not map target type '" +
130 targetTypeStr +
131 "' to a known floating-point type");
133 }
135 // Parse source types
137 for (const auto &extraTypeStr : extraTypeStrs) {
138 FloatType extraType = arith::parseFloatType(ctx, extraTypeStr);
139 if (!extraType) {
140 emitError(UnknownLoc::get(ctx), "could not map source type '" +
141 extraTypeStr +
142 "' to a known floating-point type");
143 return signalPassFailure();
145 sourceTypes.insert(extraType);
146 }
147 // f64 and f32 are implicitly supported
148 Builder b(ctx);
149 sourceTypes.insert(b.getF64Type());
150 sourceTypes.insert(b.getF32Type());
151
152 TypeConverter typeConverter;
153 math::populateExtendToSupportedTypesTypeConverter(typeConverter, sourceTypes,
154 targetType);
157 RewritePatternSet patterns(ctx);
159 if (failed(applyPartialConversion(op, target, std::move(patterns))))
160 return signalPassFailure();
161}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:241
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
result_type_range getResultTypes()
Definition Operation.h:457
void signalPassFailure()
Signal that some invariant was broken when running.
Definition Pass.h:226
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
::mlir::Pass::ListOption< std::string > extraTypeStrs
FloatType parseFloatType(MLIRContext *ctx, StringRef name)
Definition Utils.cpp:362
void populateExtendToSupportedTypesPatterns(RewritePatternSet &patterns, const TypeConverter &typeConverter)
void populateExtendToSupportedTypesConversionTarget(ConversionTarget &target, TypeConverter &typeConverter)
void populateExtendToSupportedTypesTypeConverter(TypeConverter &typeConverter, const SetVector< Type > &sourceTypes, Type targetType)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125