22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/ErrorHandling.h"
27 #define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS
28 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
34 struct EmulateUnsupportedFloatsPass
35 : arith::impl::ArithEmulateUnsupportedFloatsBase<
36 EmulateUnsupportedFloatsPass> {
37 using arith::impl::ArithEmulateUnsupportedFloatsBase<
38 EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
40 void runOnOperation()
override;
47 LogicalResult match(
Operation *op)
const override;
53 LogicalResult EmulateFloatPattern::match(
Operation *op)
const {
54 if (getTypeConverter()->isLegal(op))
70 op->
emitOpError(
"type conversion failed in float emulation");
77 for (
auto [res, oldType, newType] : llvm::zip_equal(
79 if (oldType != newType) {
80 auto truncFOp = rewriter.
create<arith::TruncFOp>(loc, oldType, res);
91 targetType](
Type type) -> std::optional<Type> {
92 if (llvm::is_contained(sourceTypes, type))
94 if (
auto shaped = dyn_cast<ShapedType>(type))
95 if (llvm::is_contained(sourceTypes, shaped.getElementType()))
96 return shaped.clone(targetType);
102 auto extFOp = b.
create<arith::ExtFOp>(loc, target, input);
110 patterns.
add<EmulateFloatPattern>(converter, patterns.
getContext());
118 [&](
Operation *op) -> std::optional<bool> {
123 vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
124 vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
126 target.
addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
127 arith::ConstantOp, vector::SplatOp>();
130 void EmulateUnsupportedFloatsPass::runOnOperation() {
136 std::optional<FloatType> maybeTargetType =
138 if (!maybeTargetType) {
141 "' to a known floating-point type");
142 return signalPassFailure();
144 targetType = *maybeTargetType;
145 for (StringRef sourceTypeStr : sourceTypeStrs) {
146 std::optional<FloatType> maybeSourceType =
148 if (!maybeSourceType) {
151 "' to a known floating-point type");
152 return signalPassFailure();
154 sourceTypes.push_back(*maybeSourceType);
156 if (sourceTypes.empty())
159 "no source types specified, float emulation will do nothing");
161 if (llvm::is_contained(sourceTypes, targetType)) {
163 "target type cannot be an unsupported source type");
164 return signalPassFailure();
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Base class for the conversion patterns.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
SuccessorRange getSuccessors()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting an illegal (source) value...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
void populateEmulateUnsupportedFloatsPatterns(RewritePatternSet &patterns, const TypeConverter &converter)
Add rewrite patterns for converting operations that use illegal float types to ones that use legal on...
void populateEmulateUnsupportedFloatsLegality(ConversionTarget &target, const TypeConverter &converter)
Set up a dialect conversion to reject arithmetic operations on unsupported float types.
void populateEmulateUnsupportedFloatsConversions(TypeConverter &converter, ArrayRef< Type > sourceTypes, Type targetType)
Populate the type conversions needed to emulate the unsupported sourceTypes with destType
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult emitOptionalWarning(std::optional< Location > loc, Args &&...args)
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.