MLIR  20.0.0git
EmulateUnsupportedFloats.cpp
Go to the documentation of this file.
1 //===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This pass promotes small floats (of some unsupported types T) to a supported
9 // type U by wrapping all float operations on Ts with expansion to and
10 // truncation from U, then operating on U.
11 //===----------------------------------------------------------------------===//
12 
14 
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Location.h"
19 #include "mlir/IR/PatternMatch.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/ErrorHandling.h"
23 #include <optional>
24 
25 namespace mlir::arith {
26 #define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS
27 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
28 } // namespace mlir::arith
29 
30 using namespace mlir;
31 
32 namespace {
33 struct EmulateUnsupportedFloatsPass
34  : arith::impl::ArithEmulateUnsupportedFloatsBase<
35  EmulateUnsupportedFloatsPass> {
36  using arith::impl::ArithEmulateUnsupportedFloatsBase<
37  EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
38 
39  void runOnOperation() override;
40 };
41 
42 struct EmulateFloatPattern final : ConversionPattern {
43  EmulateFloatPattern(TypeConverter &converter, MLIRContext *ctx)
44  : ConversionPattern(converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
45 
46  LogicalResult match(Operation *op) const override;
47  void rewrite(Operation *op, ArrayRef<Value> operands,
48  ConversionPatternRewriter &rewriter) const override;
49 };
50 } // end namespace
51 
52 /// Map strings to float types. This function is here because no one else needs
53 /// it yet, feel free to abstract it out.
54 static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
55  StringRef name) {
56  Builder b(ctx);
58  .Case("f8E5M2", b.getFloat8E5M2Type())
59  .Case("f8E4M3", b.getFloat8E4M3Type())
60  .Case("f8E4M3FN", b.getFloat8E4M3FNType())
61  .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
62  .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
63  .Case("bf16", b.getBF16Type())
64  .Case("f16", b.getF16Type())
65  .Case("f32", b.getF32Type())
66  .Case("f64", b.getF64Type())
67  .Case("f80", b.getF80Type())
68  .Case("f128", b.getF128Type())
69  .Default(std::nullopt);
70 }
71 
72 LogicalResult EmulateFloatPattern::match(Operation *op) const {
73  if (getTypeConverter()->isLegal(op))
74  return failure();
75  // The rewrite doesn't handle cloning regions.
76  if (op->getNumRegions() != 0)
77  return failure();
78  return success();
79 }
80 
82  ConversionPatternRewriter &rewriter) const {
83  Location loc = op->getLoc();
84  const TypeConverter *converter = getTypeConverter();
85  SmallVector<Type> resultTypes;
86  if (failed(converter->convertTypes(op->getResultTypes(), resultTypes))) {
87  // Note to anyone looking for this error message: this is a "can't happen".
88  // If you're seeing it, there's a bug.
89  op->emitOpError("type conversion failed in float emulation");
90  return;
91  }
92  Operation *expandedOp =
93  rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
94  op->getAttrs(), op->getSuccessors(), /*regions=*/{});
95  SmallVector<Value> newResults(expandedOp->getResults());
96  for (auto [res, oldType, newType] : llvm::zip_equal(
97  MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
98  if (oldType != newType) {
99  auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
100  truncFOp.setFastmath(arith::FastMathFlags::contract);
101  res = truncFOp.getResult();
102  }
103  }
104  rewriter.replaceOp(op, newResults);
105 }
106 
108  TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) {
109  converter.addConversion([sourceTypes = SmallVector<Type>(sourceTypes),
110  targetType](Type type) -> std::optional<Type> {
111  if (llvm::is_contained(sourceTypes, type))
112  return targetType;
113  if (auto shaped = dyn_cast<ShapedType>(type))
114  if (llvm::is_contained(sourceTypes, shaped.getElementType()))
115  return shaped.clone(targetType);
116  // All other types legal
117  return type;
118  });
119  converter.addTargetMaterialization(
120  [](OpBuilder &b, Type target, ValueRange input, Location loc) {
121  auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
122  extFOp.setFastmath(arith::FastMathFlags::contract);
123  return extFOp;
124  });
125 }
126 
128  RewritePatternSet &patterns, TypeConverter &converter) {
129  patterns.add<EmulateFloatPattern>(converter, patterns.getContext());
130 }
131 
133  ConversionTarget &target, TypeConverter &converter) {
134  // Don't try to legalize functions and other ops that don't need expansion.
135  target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
136  target.addDynamicallyLegalDialect<arith::ArithDialect>(
137  [&](Operation *op) -> std::optional<bool> {
138  return converter.isLegal(op);
139  });
140  // Manually mark arithmetic-performing vector instructions.
141  target.addDynamicallyLegalOp<
142  vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
143  vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
144  [&](Operation *op) { return converter.isLegal(op); });
145  target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
146  arith::ConstantOp, vector::SplatOp>();
147 }
148 
149 void EmulateUnsupportedFloatsPass::runOnOperation() {
150  MLIRContext *ctx = &getContext();
151  Operation *op = getOperation();
152  SmallVector<Type> sourceTypes;
153  Type targetType;
154 
155  std::optional<FloatType> maybeTargetType = parseFloatType(ctx, targetTypeStr);
156  if (!maybeTargetType) {
157  emitError(UnknownLoc::get(ctx), "could not map target type '" +
158  targetTypeStr +
159  "' to a known floating-point type");
160  return signalPassFailure();
161  }
162  targetType = *maybeTargetType;
163  for (StringRef sourceTypeStr : sourceTypeStrs) {
164  std::optional<FloatType> maybeSourceType =
165  parseFloatType(ctx, sourceTypeStr);
166  if (!maybeSourceType) {
167  emitError(UnknownLoc::get(ctx), "could not map source type '" +
168  sourceTypeStr +
169  "' to a known floating-point type");
170  return signalPassFailure();
171  }
172  sourceTypes.push_back(*maybeSourceType);
173  }
174  if (sourceTypes.empty())
175  (void)emitOptionalWarning(
176  std::nullopt,
177  "no source types specified, float emulation will do nothing");
178 
179  if (llvm::is_contained(sourceTypes, targetType)) {
181  "target type cannot be an unsupported source type");
182  return signalPassFailure();
183  }
184  TypeConverter converter;
186  targetType);
187  RewritePatternSet patterns(ctx);
189  ConversionTarget target(getContext());
191 
192  if (failed(applyPartialConversion(op, target, std::move(patterns))))
193  signalPassFailure();
194 }
static std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
FloatType getFloat8E5M2Type()
Definition: Builders.cpp:37
FloatType getF80Type()
Definition: Builders.cpp:71
FloatType getF128Type()
Definition: Builders.cpp:73
FloatType getF32Type()
Definition: Builders.cpp:67
FloatType getFloat8E4M3Type()
Definition: Builders.cpp:41
FloatType getF16Type()
Definition: Builders.cpp:63
FloatType getBF16Type()
Definition: Builders.cpp:61
FloatType getFloat8E4M3FNType()
Definition: Builders.cpp:45
FloatType getFloat8E4M3FNUZType()
Definition: Builders.cpp:53
FloatType getFloat8E5M2FNUZType()
Definition: Builders.cpp:49
FloatType getF64Type()
Definition: Builders.cpp:69
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
Base class for the conversion patterns.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:210
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
SuccessorRange getSuccessors()
Definition: Operation.h:699
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting an illegal (source) value...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
void populateEmulateUnsupportedFloatsLegality(ConversionTarget &target, TypeConverter &converter)
Set up a dialect conversion to reject arithmetic operations on unsupported float types.
void populateEmulateUnsupportedFloatsPatterns(RewritePatternSet &patterns, TypeConverter &converter)
Add rewrite patterns for converting operations that use illegal float types to ones that use legal on...
void populateEmulateUnsupportedFloatsConversions(TypeConverter &converter, ArrayRef< Type > sourceTypes, Type targetType)
Populate the type conversions needed to emulate the unsupported sourceTypes with destType
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult emitOptionalWarning(std::optional< Location > loc, Args &&...args)
Definition: Diagnostics.h:502
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.