21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/ErrorHandling.h"
26 #define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS
27 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
33 struct EmulateUnsupportedFloatsPass
34 : arith::impl::ArithEmulateUnsupportedFloatsBase<
35 EmulateUnsupportedFloatsPass> {
36 using arith::impl::ArithEmulateUnsupportedFloatsBase<
37 EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
39 void runOnOperation()
override;
68 .Default(std::nullopt);
72 if (getTypeConverter()->isLegal(op))
88 op->
emitOpError(
"type conversion failed in float emulation");
95 for (
auto [res, oldType, newType] : llvm::zip_equal(
97 if (oldType != newType)
98 res = rewriter.
create<arith::TruncFOp>(loc, oldType, res);
106 targetType](
Type type) -> std::optional<Type> {
107 if (llvm::is_contained(sourceTypes, type))
109 if (
auto shaped = type.
dyn_cast<ShapedType>())
110 if (llvm::is_contained(sourceTypes, shaped.getElementType()))
111 return shaped.clone(targetType);
117 return b.
create<arith::ExtFOp>(loc, target, input);
123 patterns.
add<EmulateFloatPattern>(converter, patterns.
getContext());
131 [&](
Operation *op) -> std::optional<bool> {
136 vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
137 vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
139 target.
addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
140 arith::ConstantOp, vector::SplatOp>();
143 void EmulateUnsupportedFloatsPass::runOnOperation() {
149 std::optional<FloatType> maybeTargetType =
parseFloatType(ctx, targetTypeStr);
150 if (!maybeTargetType) {
153 "' to a known floating-point type");
154 return signalPassFailure();
156 targetType = *maybeTargetType;
157 for (StringRef sourceTypeStr : sourceTypeStrs) {
158 std::optional<FloatType> maybeSourceType =
160 if (!maybeSourceType) {
163 "' to a known floating-point type");
164 return signalPassFailure();
166 sourceTypes.push_back(*maybeSourceType);
168 if (sourceTypes.empty())
171 "no source types specified, float emulation will do nothing");
173 if (llvm::is_contained(sourceTypes, targetType)) {
175 "target type cannot be an unsupported source type");
176 return signalPassFailure();
static std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
static MLIRContext * getContext(OpFoldResult val)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
This class is a general helper class for creating context-global objects like types,...
FloatType getFloat8E5M2Type()
FloatType getFloat8E4M3FNType()
FloatType getFloat8E4M3FNUZType()
FloatType getFloat8E5M2FNUZType()
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
Base class for the conversion patterns.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
SuccessorRange getSuccessors()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
void populateEmulateUnsupportedFloatsLegality(ConversionTarget &target, TypeConverter &converter)
Set up a dialect conversion to reject arithmetic operations on unsupported float types.
void populateEmulateUnsupportedFloatsPatterns(RewritePatternSet &patterns, TypeConverter &converter)
Add rewrite patterns for converting operations that use illegal float types to ones that use legal on...
void populateEmulateUnsupportedFloatsConversions(TypeConverter &converter, ArrayRef< Type > sourceTypes, Type targetType)
Populate the type conversions needed to emulate the unsupported sourceTypes with destType
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, DenseSet< Operation * > *unconvertedOps=nullptr)
Below we define several entry points for operation conversion.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult emitOptionalWarning(std::optional< Location > loc, Args &&...args)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.