MLIR  19.0.0git
Go to the documentation of this file.
1 //===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This pass promotes small floats (of some unsupported types T) to a supported
9 // type U by wrapping all float operations on Ts with expansion to and
10 // truncation from U, then operating on U.
11 //===----------------------------------------------------------------------===//
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Location.h"
19 #include "mlir/IR/PatternMatch.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/ErrorHandling.h"
23 #include <optional>
25 namespace mlir::arith {
27 #include "mlir/Dialect/Arith/Transforms/"
28 } // namespace mlir::arith
30 using namespace mlir;
32 namespace {
33 struct EmulateUnsupportedFloatsPass
34  : arith::impl::ArithEmulateUnsupportedFloatsBase<
35  EmulateUnsupportedFloatsPass> {
36  using arith::impl::ArithEmulateUnsupportedFloatsBase<
37  EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
39  void runOnOperation() override;
40 };
42 struct EmulateFloatPattern final : ConversionPattern {
43  EmulateFloatPattern(TypeConverter &converter, MLIRContext *ctx)
44  : ConversionPattern(converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
46  LogicalResult match(Operation *op) const override;
47  void rewrite(Operation *op, ArrayRef<Value> operands,
48  ConversionPatternRewriter &rewriter) const override;
49 };
50 } // end namespace
52 /// Map strings to float types. This function is here because no one else needs
53 /// it yet, feel free to abstract it out.
54 static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
55  StringRef name) {
56  Builder b(ctx);
58  .Case("f8E5M2", b.getFloat8E5M2Type())
59  .Case("f8E4M3FN", b.getFloat8E4M3FNType())
60  .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
61  .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
62  .Case("bf16", b.getBF16Type())
63  .Case("f16", b.getF16Type())
64  .Case("f32", b.getF32Type())
65  .Case("f64", b.getF64Type())
66  .Case("f80", b.getF80Type())
67  .Case("f128", b.getF128Type())
68  .Default(std::nullopt);
69 }
71 LogicalResult EmulateFloatPattern::match(Operation *op) const {
72  if (getTypeConverter()->isLegal(op))
73  return failure();
74  // The rewrite doesn't handle cloning regions.
75  if (op->getNumRegions() != 0)
76  return failure();
77  return success();
78 }
81  ConversionPatternRewriter &rewriter) const {
82  Location loc = op->getLoc();
83  const TypeConverter *converter = getTypeConverter();
84  SmallVector<Type> resultTypes;
85  if (failed(converter->convertTypes(op->getResultTypes(), resultTypes))) {
86  // Note to anyone looking for this error message: this is a "can't happen".
87  // If you're seeing it, there's a bug.
88  op->emitOpError("type conversion failed in float emulation");
89  return;
90  }
91  Operation *expandedOp =
92  rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
93  op->getAttrs(), op->getSuccessors(), /*regions=*/{});
94  SmallVector<Value> newResults(expandedOp->getResults());
95  for (auto [res, oldType, newType] : llvm::zip_equal(
96  MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
97  if (oldType != newType)
98  res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
99  }
100  rewriter.replaceOp(op, newResults);
101 }
104  TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) {
105  converter.addConversion([sourceTypes = SmallVector<Type>(sourceTypes),
106  targetType](Type type) -> std::optional<Type> {
107  if (llvm::is_contained(sourceTypes, type))
108  return targetType;
109  if (auto shaped = dyn_cast<ShapedType>(type))
110  if (llvm::is_contained(sourceTypes, shaped.getElementType()))
111  return shaped.clone(targetType);
112  // All other types legal
113  return type;
114  });
115  converter.addTargetMaterialization(
116  [](OpBuilder &b, Type target, ValueRange input, Location loc) {
117  return b.create<arith::ExtFOp>(loc, target, input);
118  });
119 }
122  RewritePatternSet &patterns, TypeConverter &converter) {
123  patterns.add<EmulateFloatPattern>(converter, patterns.getContext());
124 }
127  ConversionTarget &target, TypeConverter &converter) {
128  // Don't try to legalize functions and other ops that don't need expansion.
129  target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
130  target.addDynamicallyLegalDialect<arith::ArithDialect>(
131  [&](Operation *op) -> std::optional<bool> {
132  return converter.isLegal(op);
133  });
134  // Manually mark arithmetic-performing vector instructions.
135  target.addDynamicallyLegalOp<
136  vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
137  vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
138  [&](Operation *op) { return converter.isLegal(op); });
139  target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
140  arith::ConstantOp, vector::SplatOp>();
141 }
143 void EmulateUnsupportedFloatsPass::runOnOperation() {
144  MLIRContext *ctx = &getContext();
145  Operation *op = getOperation();
146  SmallVector<Type> sourceTypes;
147  Type targetType;
149  std::optional<FloatType> maybeTargetType = parseFloatType(ctx, targetTypeStr);
150  if (!maybeTargetType) {
151  emitError(UnknownLoc::get(ctx), "could not map target type '" +
152  targetTypeStr +
153  "' to a known floating-point type");
154  return signalPassFailure();
155  }
156  targetType = *maybeTargetType;
157  for (StringRef sourceTypeStr : sourceTypeStrs) {
158  std::optional<FloatType> maybeSourceType =
159  parseFloatType(ctx, sourceTypeStr);
160  if (!maybeSourceType) {
161  emitError(UnknownLoc::get(ctx), "could not map source type '" +
162  sourceTypeStr +
163  "' to a known floating-point type");
164  return signalPassFailure();
165  }
166  sourceTypes.push_back(*maybeSourceType);
167  }
168  if (sourceTypes.empty())
169  (void)emitOptionalWarning(
170  std::nullopt,
171  "no source types specified, float emulation will do nothing");
173  if (llvm::is_contained(sourceTypes, targetType)) {
175  "target type cannot be an unsupported source type");
176  return signalPassFailure();
177  }
178  TypeConverter converter;
180  targetType);
181  RewritePatternSet patterns(ctx);
183  ConversionTarget target(getContext());
186  if (failed(applyPartialConversion(op, target, std::move(patterns))))
187  signalPassFailure();
188 }
static std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
static MLIRContext * getContext(OpFoldResult val)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
FloatType getFloat8E5M2Type()
Definition: Builders.cpp:37
FloatType getF80Type()
Definition: Builders.cpp:67
FloatType getF128Type()
Definition: Builders.cpp:69
FloatType getF32Type()
Definition: Builders.cpp:63
FloatType getF16Type()
Definition: Builders.cpp:59
FloatType getBF16Type()
Definition: Builders.cpp:57
FloatType getFloat8E4M3FNType()
Definition: Builders.cpp:41
FloatType getFloat8E4M3FNUZType()
Definition: Builders.cpp:49
FloatType getFloat8E5M2FNUZType()
Definition: Builders.cpp:45
FloatType getF64Type()
Definition: Builders.cpp:65
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
Base class for the conversion patterns.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
SuccessorRange getSuccessors()
Definition: Operation.h:699
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
void populateEmulateUnsupportedFloatsLegality(ConversionTarget &target, TypeConverter &converter)
Set up a dialect conversion to reject arithmetic operations on unsupported float types.
void populateEmulateUnsupportedFloatsPatterns(RewritePatternSet &patterns, TypeConverter &converter)
Add rewrite patterns for converting operations that use illegal float types to ones that use legal on...
void populateEmulateUnsupportedFloatsConversions(TypeConverter &converter, ArrayRef< Type > sourceTypes, Type targetType)
Populate the type conversions needed to emulate the unsupported sourceTypes with destType
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult emitOptionalWarning(std::optional< Location > loc, Args &&...args)
Definition: Diagnostics.h:497
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26