MLIR  21.0.0git
EmulateUnsupportedFloats.cpp
Go to the documentation of this file.
1 //===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This pass promotes small floats (of some unsupported types T) to a supported
9 // type U by wrapping all float operations on Ts with expansion to and
10 // truncation from U, then operating on U.
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Location.h"
20 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/ErrorHandling.h"
24 #include <optional>
25 
26 namespace mlir::arith {
27 #define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS
28 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
29 } // namespace mlir::arith
30 
31 using namespace mlir;
32 
33 namespace {
34 struct EmulateUnsupportedFloatsPass
35  : arith::impl::ArithEmulateUnsupportedFloatsBase<
36  EmulateUnsupportedFloatsPass> {
37  using arith::impl::ArithEmulateUnsupportedFloatsBase<
38  EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
39 
40  void runOnOperation() override;
41 };
42 
43 struct EmulateFloatPattern final : ConversionPattern {
44  EmulateFloatPattern(const TypeConverter &converter, MLIRContext *ctx)
46  converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
47 
48  LogicalResult
49  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
50  ConversionPatternRewriter &rewriter) const override;
51 };
52 } // end namespace
53 
54 LogicalResult EmulateFloatPattern::matchAndRewrite(
55  Operation *op, ArrayRef<Value> operands,
56  ConversionPatternRewriter &rewriter) const {
57  if (getTypeConverter()->isLegal(op))
58  return failure();
59  // The rewrite doesn't handle cloning regions.
60  if (op->getNumRegions() != 0)
61  return failure();
62 
63  Location loc = op->getLoc();
64  const TypeConverter *converter = getTypeConverter();
65  SmallVector<Type> resultTypes;
66  if (failed(converter->convertTypes(op->getResultTypes(), resultTypes))) {
67  // Note to anyone looking for this error message: this is a "can't happen".
68  // If you're seeing it, there's a bug.
69  return op->emitOpError("type conversion failed in float emulation");
70  }
71  Operation *expandedOp =
72  rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
73  op->getAttrs(), op->getSuccessors(), /*regions=*/{});
74  SmallVector<Value> newResults(expandedOp->getResults());
75  for (auto [res, oldType, newType] : llvm::zip_equal(
76  MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
77  if (oldType != newType) {
78  auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
79  truncFOp.setFastmath(arith::FastMathFlags::contract);
80  res = truncFOp.getResult();
81  }
82  }
83  rewriter.replaceOp(op, newResults);
84  return success();
85 }
86 
88  TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) {
89  converter.addConversion([sourceTypes = SmallVector<Type>(sourceTypes),
90  targetType](Type type) -> std::optional<Type> {
91  if (llvm::is_contained(sourceTypes, type))
92  return targetType;
93  if (auto shaped = dyn_cast<ShapedType>(type))
94  if (llvm::is_contained(sourceTypes, shaped.getElementType()))
95  return shaped.clone(targetType);
96  // All other types legal
97  return type;
98  });
99  converter.addTargetMaterialization(
100  [](OpBuilder &b, Type target, ValueRange input, Location loc) {
101  auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
102  extFOp.setFastmath(arith::FastMathFlags::contract);
103  return extFOp;
104  });
105 }
106 
108  RewritePatternSet &patterns, const TypeConverter &converter) {
109  patterns.add<EmulateFloatPattern>(converter, patterns.getContext());
110 }
111 
113  ConversionTarget &target, const TypeConverter &converter) {
114  // Don't try to legalize functions and other ops that don't need expansion.
115  target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
116  target.addDynamicallyLegalDialect<arith::ArithDialect>(
117  [&](Operation *op) -> std::optional<bool> {
118  return converter.isLegal(op);
119  });
120  // Manually mark arithmetic-performing vector instructions.
121  target.addDynamicallyLegalOp<
122  vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
123  vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
124  [&](Operation *op) { return converter.isLegal(op); });
125  target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
126  arith::ConstantOp, vector::SplatOp>();
127 }
128 
129 void EmulateUnsupportedFloatsPass::runOnOperation() {
130  MLIRContext *ctx = &getContext();
131  Operation *op = getOperation();
132  SmallVector<Type> sourceTypes;
133  Type targetType;
134 
135  std::optional<FloatType> maybeTargetType =
136  arith::parseFloatType(ctx, targetTypeStr);
137  if (!maybeTargetType) {
138  emitError(UnknownLoc::get(ctx), "could not map target type '" +
139  targetTypeStr +
140  "' to a known floating-point type");
141  return signalPassFailure();
142  }
143  targetType = *maybeTargetType;
144  for (StringRef sourceTypeStr : sourceTypeStrs) {
145  std::optional<FloatType> maybeSourceType =
146  arith::parseFloatType(ctx, sourceTypeStr);
147  if (!maybeSourceType) {
148  emitError(UnknownLoc::get(ctx), "could not map source type '" +
149  sourceTypeStr +
150  "' to a known floating-point type");
151  return signalPassFailure();
152  }
153  sourceTypes.push_back(*maybeSourceType);
154  }
155  if (sourceTypes.empty())
156  (void)emitOptionalWarning(
157  std::nullopt,
158  "no source types specified, float emulation will do nothing");
159 
160  if (llvm::is_contained(sourceTypes, targetType)) {
162  "target type cannot be an unsupported source type");
163  return signalPassFailure();
164  }
165  TypeConverter converter;
167  targetType);
170  ConversionTarget target(getContext());
172 
173  if (failed(applyPartialConversion(op, target, std::move(patterns))))
174  signalPassFailure();
175 }
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Base class for the conversion patterns.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
SuccessorRange getSuccessors()
Definition: Operation.h:704
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given set of types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
Definition: Utils.cpp:361
void populateEmulateUnsupportedFloatsPatterns(RewritePatternSet &patterns, const TypeConverter &converter)
Add rewrite patterns for converting operations that use illegal float types to ones that use legal on...
void populateEmulateUnsupportedFloatsLegality(ConversionTarget &target, const TypeConverter &converter)
Set up a dialect conversion to reject arithmetic operations on unsupported float types.
void populateEmulateUnsupportedFloatsConversions(TypeConverter &converter, ArrayRef< Type > sourceTypes, Type targetType)
Populate the type conversions needed to emulate the unsupported sourceTypes with destType
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult emitOptionalWarning(std::optional< Location > loc, Args &&...args)
Definition: Diagnostics.h:503
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.