MLIR 23.0.0git
EmulateUnsupportedFloats.cpp
Go to the documentation of this file.
1//===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This pass promotes small floats (of some unsupported types T) to a supported
9// type U by wrapping all float operations on Ts with expansion to and
10// truncation from U, then operating on U.
11//===----------------------------------------------------------------------===//
12
14
19#include "mlir/IR/Location.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/ErrorHandling.h"
24#include <optional>
25
26namespace mlir::arith {
27#define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS
28#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
29} // namespace mlir::arith
30
31using namespace mlir;
32
33namespace {
34struct EmulateUnsupportedFloatsPass
35 : arith::impl::ArithEmulateUnsupportedFloatsBase<
36 EmulateUnsupportedFloatsPass> {
37 using arith::impl::ArithEmulateUnsupportedFloatsBase<
38 EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
39
40 void runOnOperation() override;
41};
42
43struct EmulateFloatPattern final : ConversionPattern {
44 EmulateFloatPattern(const TypeConverter &converter, MLIRContext *ctx)
45 : ConversionPattern::ConversionPattern(
46 converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
47
48 LogicalResult
49 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
50 ConversionPatternRewriter &rewriter) const override;
51};
52} // end namespace
53
54LogicalResult EmulateFloatPattern::matchAndRewrite(
55 Operation *op, ArrayRef<Value> operands,
56 ConversionPatternRewriter &rewriter) const {
57 if (getTypeConverter()->isLegal(op))
58 return failure();
59 // The rewrite doesn't handle cloning regions.
60 if (op->getNumRegions() != 0)
61 return failure();
62
63 Location loc = op->getLoc();
64 const TypeConverter *converter = getTypeConverter();
65 SmallVector<Type> resultTypes;
66 if (failed(converter->convertTypes(op->getResultTypes(), resultTypes))) {
67 // Note to anyone looking for this error message: this is a "can't happen".
68 // If you're seeing it, there's a bug.
69 return op->emitOpError("type conversion failed in float emulation");
70 }
71 Operation *expandedOp =
72 rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
73 op->getAttrs(), op->getSuccessors(), /*regions=*/{});
74 SmallVector<Value> newResults(expandedOp->getResults());
75 for (auto [res, oldType, newType] : llvm::zip_equal(
76 MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
77 if (oldType != newType) {
78 auto truncFOp = arith::TruncFOp::create(rewriter, loc, oldType, res);
79 truncFOp.setFastmath(arith::FastMathFlags::contract);
80 res = truncFOp.getResult();
81 }
82 }
83 rewriter.replaceOp(op, newResults);
84 return success();
85}
86
88 TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) {
89 converter.addConversion([sourceTypes = SmallVector<Type>(sourceTypes),
90 targetType](Type type) -> std::optional<Type> {
91 if (llvm::is_contained(sourceTypes, type))
92 return targetType;
93 if (auto shaped = dyn_cast<ShapedType>(type))
94 if (llvm::is_contained(sourceTypes, shaped.getElementType()))
95 return shaped.clone(targetType);
96 // All other types legal
97 return type;
98 });
99 converter.addTargetMaterialization(
100 [](OpBuilder &b, Type target, ValueRange input, Location loc) {
101 auto extFOp = arith::ExtFOp::create(b, loc, target, input);
102 extFOp.setFastmath(arith::FastMathFlags::contract);
103 return extFOp;
104 });
105}
106
108 RewritePatternSet &patterns, const TypeConverter &converter) {
109 patterns.add<EmulateFloatPattern>(converter, patterns.getContext());
110}
111
113 ConversionTarget &target, const TypeConverter &converter) {
114 // Don't try to legalize functions and other ops that don't need expansion.
115 target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
116 target.addDynamicallyLegalDialect<arith::ArithDialect>(
117 [&](Operation *op) -> std::optional<bool> {
118 return converter.isLegal(op);
119 });
120 // Manually mark arithmetic-performing vector instructions.
121 target.addDynamicallyLegalOp<vector::ContractionOp, vector::ReductionOp,
122 vector::MultiDimReductionOp, vector::FMAOp,
123 vector::OuterProductOp, vector::ScanOp>(
124 [&](Operation *op) { return converter.isLegal(op); });
125 target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
126 arith::ConstantOp, arith::SelectOp, vector::BroadcastOp>();
127}
128
129void EmulateUnsupportedFloatsPass::runOnOperation() {
130 MLIRContext *ctx = &getContext();
131 Operation *op = getOperation();
132 SmallVector<Type> sourceTypes;
133 Type targetType;
134
135 std::optional<FloatType> maybeTargetType =
136 arith::parseFloatType(ctx, targetTypeStr);
137 if (!maybeTargetType) {
138 emitError(UnknownLoc::get(ctx), "could not map target type '" +
139 targetTypeStr +
140 "' to a known floating-point type");
141 return signalPassFailure();
142 }
143 targetType = *maybeTargetType;
144 for (StringRef sourceTypeStr : sourceTypeStrs) {
145 std::optional<FloatType> maybeSourceType =
146 arith::parseFloatType(ctx, sourceTypeStr);
147 if (!maybeSourceType) {
148 emitError(UnknownLoc::get(ctx), "could not map source type '" +
149 sourceTypeStr +
150 "' to a known floating-point type");
151 return signalPassFailure();
152 }
153 sourceTypes.push_back(*maybeSourceType);
154 }
155 if (sourceTypes.empty())
157 std::nullopt,
158 "no source types specified, float emulation will do nothing");
159
160 if (llvm::is_contained(sourceTypes, targetType)) {
161 emitError(UnknownLoc::get(ctx),
162 "target type cannot be an unsupported source type");
163 return signalPassFailure();
164 }
165 TypeConverter converter;
166 arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes,
167 targetType);
168 RewritePatternSet patterns(ctx);
169 arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter);
170 ConversionTarget target(getContext());
171 arith::populateEmulateUnsupportedFloatsLegality(target, converter);
172
173 if (failed(applyPartialConversion(op, target, std::move(patterns))))
174 signalPassFailure();
175}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:520
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:682
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_type_range getResultTypes()
Definition Operation.h:436
SuccessorRange getSuccessors()
Definition Operation.h:711
result_range getResults()
Definition Operation.h:423
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
Definition Utils.cpp:361
void populateEmulateUnsupportedFloatsPatterns(RewritePatternSet &patterns, const TypeConverter &converter)
Add rewrite patterns for converting operations that use illegal float types to ones that use legal on...
void populateEmulateUnsupportedFloatsLegality(ConversionTarget &target, const TypeConverter &converter)
Set up a dialect conversion to reject arithmetic operations on unsupported float types.
void populateEmulateUnsupportedFloatsConversions(TypeConverter &converter, ArrayRef< Type > sourceTypes, Type targetType)
Populate the type conversions needed to emulate the unsupported sourceTypes with destType
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult emitOptionalWarning(std::optional< Location > loc, Args &&...args)