MLIR  20.0.0git
EmulateUnsupportedFloats.cpp
Go to the documentation of this file.
1 //===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This pass promotes small floats (of some unsupported types T) to a supported
9 // type U by wrapping all float operations on Ts with expansion to and
10 // truncation from U, then operating on U.
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Location.h"
20 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/ErrorHandling.h"
24 #include <optional>
25 
26 namespace mlir::arith {
27 #define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS
28 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
29 } // namespace mlir::arith
30 
31 using namespace mlir;
32 
33 namespace {
34 struct EmulateUnsupportedFloatsPass
35  : arith::impl::ArithEmulateUnsupportedFloatsBase<
36  EmulateUnsupportedFloatsPass> {
37  using arith::impl::ArithEmulateUnsupportedFloatsBase<
38  EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
39 
40  void runOnOperation() override;
41 };
42 
43 struct EmulateFloatPattern final : ConversionPattern {
44  EmulateFloatPattern(const TypeConverter &converter, MLIRContext *ctx)
45  : ConversionPattern(converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
46 
47  LogicalResult match(Operation *op) const override;
48  void rewrite(Operation *op, ArrayRef<Value> operands,
49  ConversionPatternRewriter &rewriter) const override;
50 };
51 } // end namespace
52 
53 LogicalResult EmulateFloatPattern::match(Operation *op) const {
54  if (getTypeConverter()->isLegal(op))
55  return failure();
56  // The rewrite doesn't handle cloning regions.
57  if (op->getNumRegions() != 0)
58  return failure();
59  return success();
60 }
61 
63  ConversionPatternRewriter &rewriter) const {
64  Location loc = op->getLoc();
65  const TypeConverter *converter = getTypeConverter();
66  SmallVector<Type> resultTypes;
67  if (failed(converter->convertTypes(op->getResultTypes(), resultTypes))) {
68  // Note to anyone looking for this error message: this is a "can't happen".
69  // If you're seeing it, there's a bug.
70  op->emitOpError("type conversion failed in float emulation");
71  return;
72  }
73  Operation *expandedOp =
74  rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
75  op->getAttrs(), op->getSuccessors(), /*regions=*/{});
76  SmallVector<Value> newResults(expandedOp->getResults());
77  for (auto [res, oldType, newType] : llvm::zip_equal(
78  MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
79  if (oldType != newType) {
80  auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
81  truncFOp.setFastmath(arith::FastMathFlags::contract);
82  res = truncFOp.getResult();
83  }
84  }
85  rewriter.replaceOp(op, newResults);
86 }
87 
89  TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) {
90  converter.addConversion([sourceTypes = SmallVector<Type>(sourceTypes),
91  targetType](Type type) -> std::optional<Type> {
92  if (llvm::is_contained(sourceTypes, type))
93  return targetType;
94  if (auto shaped = dyn_cast<ShapedType>(type))
95  if (llvm::is_contained(sourceTypes, shaped.getElementType()))
96  return shaped.clone(targetType);
97  // All other types legal
98  return type;
99  });
100  converter.addTargetMaterialization(
101  [](OpBuilder &b, Type target, ValueRange input, Location loc) {
102  auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
103  extFOp.setFastmath(arith::FastMathFlags::contract);
104  return extFOp;
105  });
106 }
107 
109  RewritePatternSet &patterns, const TypeConverter &converter) {
110  patterns.add<EmulateFloatPattern>(converter, patterns.getContext());
111 }
112 
114  ConversionTarget &target, const TypeConverter &converter) {
115  // Don't try to legalize functions and other ops that don't need expansion.
116  target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
117  target.addDynamicallyLegalDialect<arith::ArithDialect>(
118  [&](Operation *op) -> std::optional<bool> {
119  return converter.isLegal(op);
120  });
121  // Manually mark arithmetic-performing vector instructions.
122  target.addDynamicallyLegalOp<
123  vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
124  vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
125  [&](Operation *op) { return converter.isLegal(op); });
126  target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
127  arith::ConstantOp, vector::SplatOp>();
128 }
129 
130 void EmulateUnsupportedFloatsPass::runOnOperation() {
131  MLIRContext *ctx = &getContext();
132  Operation *op = getOperation();
133  SmallVector<Type> sourceTypes;
134  Type targetType;
135 
136  std::optional<FloatType> maybeTargetType =
137  arith::parseFloatType(ctx, targetTypeStr);
138  if (!maybeTargetType) {
139  emitError(UnknownLoc::get(ctx), "could not map target type '" +
140  targetTypeStr +
141  "' to a known floating-point type");
142  return signalPassFailure();
143  }
144  targetType = *maybeTargetType;
145  for (StringRef sourceTypeStr : sourceTypeStrs) {
146  std::optional<FloatType> maybeSourceType =
147  arith::parseFloatType(ctx, sourceTypeStr);
148  if (!maybeSourceType) {
149  emitError(UnknownLoc::get(ctx), "could not map source type '" +
150  sourceTypeStr +
151  "' to a known floating-point type");
152  return signalPassFailure();
153  }
154  sourceTypes.push_back(*maybeSourceType);
155  }
156  if (sourceTypes.empty())
157  (void)emitOptionalWarning(
158  std::nullopt,
159  "no source types specified, float emulation will do nothing");
160 
161  if (llvm::is_contained(sourceTypes, targetType)) {
163  "target type cannot be an unsupported source type");
164  return signalPassFailure();
165  }
166  TypeConverter converter;
168  targetType);
171  ConversionTarget target(getContext());
173 
174  if (failed(applyPartialConversion(op, target, std::move(patterns))))
175  signalPassFailure();
176 }
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Base class for the conversion patterns.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:216
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
SuccessorRange getSuccessors()
Definition: Operation.h:704
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
Definition: Utils.cpp:361
void populateEmulateUnsupportedFloatsPatterns(RewritePatternSet &patterns, const TypeConverter &converter)
Add rewrite patterns for converting operations that use illegal float types to ones that use legal on...
void populateEmulateUnsupportedFloatsLegality(ConversionTarget &target, const TypeConverter &converter)
Set up a dialect conversion to reject arithmetic operations on unsupported float types.
void populateEmulateUnsupportedFloatsConversions(TypeConverter &converter, ArrayRef< Type > sourceTypes, Type targetType)
Populate the type conversions needed to emulate the unsupported sourceTypes with destType
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult emitOptionalWarning(std::optional< Location > loc, Args &&...args)
Definition: Diagnostics.h:503
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.