21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/ErrorHandling.h"
26 #define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS
27 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
33 struct EmulateUnsupportedFloatsPass
34 : arith::impl::ArithEmulateUnsupportedFloatsBase<
35 EmulateUnsupportedFloatsPass> {
36 using arith::impl::ArithEmulateUnsupportedFloatsBase<
37 EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
39 void runOnOperation()
override;
46 LogicalResult match(
Operation *op)
const override;
69 .Default(std::nullopt);
72 LogicalResult EmulateFloatPattern::match(
Operation *op)
const {
73 if (getTypeConverter()->isLegal(op))
89 op->
emitOpError(
"type conversion failed in float emulation");
96 for (
auto [res, oldType, newType] : llvm::zip_equal(
98 if (oldType != newType) {
99 auto truncFOp = rewriter.
create<arith::TruncFOp>(loc, oldType, res);
110 targetType](
Type type) -> std::optional<Type> {
111 if (llvm::is_contained(sourceTypes, type))
113 if (
auto shaped = dyn_cast<ShapedType>(type))
114 if (llvm::is_contained(sourceTypes, shaped.getElementType()))
115 return shaped.clone(targetType);
121 auto extFOp = b.
create<arith::ExtFOp>(loc, target, input);
129 patterns.
add<EmulateFloatPattern>(converter, patterns.
getContext());
137 [&](
Operation *op) -> std::optional<bool> {
142 vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
143 vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
145 target.
addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
146 arith::ConstantOp, vector::SplatOp>();
149 void EmulateUnsupportedFloatsPass::runOnOperation() {
155 std::optional<FloatType> maybeTargetType =
parseFloatType(ctx, targetTypeStr);
156 if (!maybeTargetType) {
159 "' to a known floating-point type");
160 return signalPassFailure();
162 targetType = *maybeTargetType;
163 for (StringRef sourceTypeStr : sourceTypeStrs) {
164 std::optional<FloatType> maybeSourceType =
166 if (!maybeSourceType) {
169 "' to a known floating-point type");
170 return signalPassFailure();
172 sourceTypes.push_back(*maybeSourceType);
174 if (sourceTypes.empty())
177 "no source types specified, float emulation will do nothing");
179 if (llvm::is_contained(sourceTypes, targetType)) {
181 "target type cannot be an unsupported source type");
182 return signalPassFailure();
static std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
This class is a general helper class for creating context-global objects like types,...
FloatType getFloat8E5M2Type()
FloatType getFloat8E4M3Type()
FloatType getFloat8E4M3FNType()
FloatType getFloat8E4M3FNUZType()
FloatType getFloat8E5M2FNUZType()
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
Base class for the conversion patterns.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
SuccessorRange getSuccessors()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting an illegal (source) value...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
void populateEmulateUnsupportedFloatsLegality(ConversionTarget &target, TypeConverter &converter)
Set up a dialect conversion to reject arithmetic operations on unsupported float types.
void populateEmulateUnsupportedFloatsPatterns(RewritePatternSet &patterns, TypeConverter &converter)
Add rewrite patterns for converting operations that use illegal float types to ones that use legal on...
void populateEmulateUnsupportedFloatsConversions(TypeConverter &converter, ArrayRef< Type > sourceTypes, Type targetType)
Populate the type conversions needed to emulate the unsupported sourceTypes with destType
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult emitOptionalWarning(std::optional< Location > loc, Args &&...args)
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.