22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/ErrorHandling.h"
27#define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS
28#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
34struct EmulateUnsupportedFloatsPass
36 EmulateUnsupportedFloatsPass> {
38 EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
46 converter,
Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
50 ConversionPatternRewriter &rewriter)
const override;
54LogicalResult EmulateFloatPattern::matchAndRewrite(
56 ConversionPatternRewriter &rewriter)
const {
57 if (getTypeConverter()->isLegal(op))
66 if (failed(converter->convertTypes(op->
getResultTypes(), resultTypes))) {
69 return op->
emitOpError(
"type conversion failed in float emulation");
75 for (
auto [res, oldType, newType] : llvm::zip_equal(
77 if (oldType != newType) {
78 auto truncFOp = arith::TruncFOp::create(rewriter, loc, oldType, res);
79 truncFOp.setFastmath(arith::FastMathFlags::contract);
80 res = truncFOp.getResult();
83 rewriter.replaceOp(op, newResults);
90 targetType](
Type type) -> std::optional<Type> {
91 if (llvm::is_contained(sourceTypes, type))
93 if (
auto shaped = dyn_cast<ShapedType>(type))
94 if (llvm::is_contained(sourceTypes, shaped.getElementType()))
95 return shaped.clone(targetType);
99 converter.addTargetMaterialization(
101 auto extFOp = arith::ExtFOp::create(
b, loc,
target, input);
102 extFOp.setFastmath(arith::FastMathFlags::contract);
115 target.markUnknownOpDynamicallyLegal([](
Operation *op) {
return true; });
116 target.addDynamicallyLegalDialect<arith::ArithDialect>(
117 [&](
Operation *op) -> std::optional<bool> {
118 return converter.isLegal(op);
121 target.addDynamicallyLegalOp<vector::ContractionOp, vector::ReductionOp,
122 vector::MultiDimReductionOp, vector::FMAOp,
123 vector::OuterProductOp, vector::ScanOp>(
124 [&](
Operation *op) {
return converter.isLegal(op); });
125 target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
126 arith::ConstantOp, arith::SelectOp, vector::BroadcastOp>();
129void EmulateUnsupportedFloatsPass::runOnOperation() {
135 std::optional<FloatType> maybeTargetType =
137 if (!maybeTargetType) {
138 emitError(UnknownLoc::get(ctx),
"could not map target type '" +
140 "' to a known floating-point type");
141 return signalPassFailure();
143 targetType = *maybeTargetType;
144 for (StringRef sourceTypeStr : sourceTypeStrs) {
145 std::optional<FloatType> maybeSourceType =
147 if (!maybeSourceType) {
148 emitError(UnknownLoc::get(ctx),
"could not map source type '" +
150 "' to a known floating-point type");
151 return signalPassFailure();
153 sourceTypes.push_back(*maybeSourceType);
155 if (sourceTypes.empty())
158 "no source types specified, float emulation will do nothing");
160 if (llvm::is_contained(sourceTypes, targetType)) {
162 "target type cannot be an unsupported source type");
163 return signalPassFailure();
165 TypeConverter converter;
166 arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes,
169 arith::populateEmulateUnsupportedFloatsPatterns(
patterns, converter);
171 arith::populateEmulateUnsupportedFloatsLegality(
target, converter);
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
SuccessorRange getSuccessors()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
void populateEmulateUnsupportedFloatsPatterns(RewritePatternSet &patterns, const TypeConverter &converter)
Add rewrite patterns for converting operations that use illegal float types to ones that use legal on...
void populateEmulateUnsupportedFloatsLegality(ConversionTarget &target, const TypeConverter &converter)
Set up a dialect conversion to reject arithmetic operations on unsupported float types.
void populateEmulateUnsupportedFloatsConversions(TypeConverter &converter, ArrayRef< Type > sourceTypes, Type targetType)
Populate the type conversions needed to emulate the unsupported sourceTypes with destType
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
LogicalResult emitOptionalWarning(std::optional< Location > loc, Args &&...args)