MLIR  22.0.0git
EmulateUnsupportedFloats.cpp
Go to the documentation of this file.
1 //===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This pass promotes small floats (of some unsupported types T) to a supported
9 // type U by wrapping all float operations on Ts with expansion to and
10 // truncation from U, then operating on U.
11 //===----------------------------------------------------------------------===//
12 
14 
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Location.h"
20 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/Support/ErrorHandling.h"
24 #include <optional>
25 
26 namespace mlir::arith {
27 #define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS
28 #include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
29 } // namespace mlir::arith
30 
31 using namespace mlir;
32 
33 namespace {
34 struct EmulateUnsupportedFloatsPass
35  : arith::impl::ArithEmulateUnsupportedFloatsBase<
36  EmulateUnsupportedFloatsPass> {
37  using arith::impl::ArithEmulateUnsupportedFloatsBase<
38  EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
39 
40  void runOnOperation() override;
41 };
42 
43 struct EmulateFloatPattern final : ConversionPattern {
44  EmulateFloatPattern(const TypeConverter &converter, MLIRContext *ctx)
46  converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
47 
48  LogicalResult
49  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
50  ConversionPatternRewriter &rewriter) const override;
51 };
52 } // end namespace
53 
54 LogicalResult EmulateFloatPattern::matchAndRewrite(
55  Operation *op, ArrayRef<Value> operands,
56  ConversionPatternRewriter &rewriter) const {
57  if (getTypeConverter()->isLegal(op))
58  return failure();
59  // The rewrite doesn't handle cloning regions.
60  if (op->getNumRegions() != 0)
61  return failure();
62 
63  Location loc = op->getLoc();
64  const TypeConverter *converter = getTypeConverter();
65  SmallVector<Type> resultTypes;
66  if (failed(converter->convertTypes(op->getResultTypes(), resultTypes))) {
67  // Note to anyone looking for this error message: this is a "can't happen".
68  // If you're seeing it, there's a bug.
69  return op->emitOpError("type conversion failed in float emulation");
70  }
71  Operation *expandedOp =
72  rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
73  op->getAttrs(), op->getSuccessors(), /*regions=*/{});
74  SmallVector<Value> newResults(expandedOp->getResults());
75  for (auto [res, oldType, newType] : llvm::zip_equal(
76  MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
77  if (oldType != newType) {
78  auto truncFOp = arith::TruncFOp::create(rewriter, loc, oldType, res);
79  truncFOp.setFastmath(arith::FastMathFlags::contract);
80  res = truncFOp.getResult();
81  }
82  }
83  rewriter.replaceOp(op, newResults);
84  return success();
85 }
86 
88  TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) {
89  converter.addConversion([sourceTypes = SmallVector<Type>(sourceTypes),
90  targetType](Type type) -> std::optional<Type> {
91  if (llvm::is_contained(sourceTypes, type))
92  return targetType;
93  if (auto shaped = dyn_cast<ShapedType>(type))
94  if (llvm::is_contained(sourceTypes, shaped.getElementType()))
95  return shaped.clone(targetType);
96  // All other types legal
97  return type;
98  });
99  converter.addTargetMaterialization(
100  [](OpBuilder &b, Type target, ValueRange input, Location loc) {
101  auto extFOp = arith::ExtFOp::create(b, loc, target, input);
102  extFOp.setFastmath(arith::FastMathFlags::contract);
103  return extFOp;
104  });
105 }
106 
108  RewritePatternSet &patterns, const TypeConverter &converter) {
109  patterns.add<EmulateFloatPattern>(converter, patterns.getContext());
110 }
111 
113  ConversionTarget &target, const TypeConverter &converter) {
114  // Don't try to legalize functions and other ops that don't need expansion.
115  target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
116  target.addDynamicallyLegalDialect<arith::ArithDialect>(
117  [&](Operation *op) -> std::optional<bool> {
118  return converter.isLegal(op);
119  });
120  // Manually mark arithmetic-performing vector instructions.
121  target.addDynamicallyLegalOp<vector::ContractionOp, vector::ReductionOp,
122  vector::MultiDimReductionOp, vector::FMAOp,
123  vector::OuterProductOp, vector::ScanOp>(
124  [&](Operation *op) { return converter.isLegal(op); });
125  target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
126  arith::ConstantOp, vector::SplatOp, vector::BroadcastOp>();
127 }
128 
129 void EmulateUnsupportedFloatsPass::runOnOperation() {
130  MLIRContext *ctx = &getContext();
131  Operation *op = getOperation();
132  SmallVector<Type> sourceTypes;
133  Type targetType;
134 
135  std::optional<FloatType> maybeTargetType =
136  arith::parseFloatType(ctx, targetTypeStr);
137  if (!maybeTargetType) {
138  emitError(UnknownLoc::get(ctx), "could not map target type '" +
139  targetTypeStr +
140  "' to a known floating-point type");
141  return signalPassFailure();
142  }
143  targetType = *maybeTargetType;
144  for (StringRef sourceTypeStr : sourceTypeStrs) {
145  std::optional<FloatType> maybeSourceType =
146  arith::parseFloatType(ctx, sourceTypeStr);
147  if (!maybeSourceType) {
148  emitError(UnknownLoc::get(ctx), "could not map source type '" +
149  sourceTypeStr +
150  "' to a known floating-point type");
151  return signalPassFailure();
152  }
153  sourceTypes.push_back(*maybeSourceType);
154  }
155  if (sourceTypes.empty())
156  (void)emitOptionalWarning(
157  std::nullopt,
158  "no source types specified, float emulation will do nothing");
159 
160  if (llvm::is_contained(sourceTypes, targetType)) {
162  "target type cannot be an unsupported source type");
163  return signalPassFailure();
164  }
165  TypeConverter converter;
167  targetType);
170  ConversionTarget target(getContext());
172 
173  if (failed(applyPartialConversion(op, target, std::move(patterns))))
174  signalPassFailure();
175 }
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Base class for the conversion patterns.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
void addDynamicallyLegalOp(OperationName op, const DynamicLegalityCallbackFn &callback)
Register the given operation as dynamically legal and set the dynamic legalization callback to the on...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:205
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
SuccessorRange getSuccessors()
Definition: Operation.h:703
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Definition: PatternMatch.h:73
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
LogicalResult convertTypes(TypeRange types, SmallVectorImpl< Type > &results) const
Convert the given types, filling 'results' as necessary.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
Definition: Utils.cpp:360
void populateEmulateUnsupportedFloatsPatterns(RewritePatternSet &patterns, const TypeConverter &converter)
Add rewrite patterns for converting operations that use illegal float types to ones that use legal on...
void populateEmulateUnsupportedFloatsLegality(ConversionTarget &target, const TypeConverter &converter)
Set up a dialect conversion to reject arithmetic operations on unsupported float types.
void populateEmulateUnsupportedFloatsConversions(TypeConverter &converter, ArrayRef< Type > sourceTypes, Type targetType)
Populate the type conversions needed to emulate the unsupported sourceTypes with destType
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult emitOptionalWarning(std::optional< Location > loc, Args &&...args)
Definition: Diagnostics.h:503
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.