MLIR 22.0.0git
EmulateUnsupportedFloats.cpp
Go to the documentation of this file.
1//===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This pass promotes small floats (of some unsupported types T) to a supported
9// type U by wrapping all float operations on Ts with expansion to and
10// truncation from U, then operating on U.
11//===----------------------------------------------------------------------===//
12
14
19#include "mlir/IR/Location.h"
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/Support/ErrorHandling.h"
24#include <optional>
25
26namespace mlir::arith {
27#define GEN_PASS_DEF_ARITHEMULATEUNSUPPORTEDFLOATS
28#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
29} // namespace mlir::arith
30
31using namespace mlir;
32
33namespace {
34struct EmulateUnsupportedFloatsPass
36 EmulateUnsupportedFloatsPass> {
38 EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
40 void runOnOperation() override;
41};
42
43struct EmulateFloatPattern final : ConversionPattern {
44 EmulateFloatPattern(const TypeConverter &converter, MLIRContext *ctx)
46 converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
48 LogicalResult
49 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
50 ConversionPatternRewriter &rewriter) const override;
51};
52} // end namespace
53
54LogicalResult EmulateFloatPattern::matchAndRewrite(
55 Operation *op, ArrayRef<Value> operands,
56 ConversionPatternRewriter &rewriter) const {
57 if (getTypeConverter()->isLegal(op))
58 return failure();
59 // The rewrite doesn't handle cloning regions.
60 if (op->getNumRegions() != 0)
61 return failure();
63 Location loc = op->getLoc();
64 const TypeConverter *converter = getTypeConverter();
65 SmallVector<Type> resultTypes;
66 if (failed(converter->convertTypes(op->getResultTypes(), resultTypes))) {
67 // Note to anyone looking for this error message: this is a "can't happen".
68 // If you're seeing it, there's a bug.
69 return op->emitOpError("type conversion failed in float emulation");
70 }
71 Operation *expandedOp =
72 rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
73 op->getAttrs(), op->getSuccessors(), /*regions=*/{});
74 SmallVector<Value> newResults(expandedOp->getResults());
75 for (auto [res, oldType, newType] : llvm::zip_equal(
76 MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
77 if (oldType != newType) {
78 auto truncFOp = arith::TruncFOp::create(rewriter, loc, oldType, res);
79 truncFOp.setFastmath(arith::FastMathFlags::contract);
80 res = truncFOp.getResult();
81 }
82 }
83 rewriter.replaceOp(op, newResults);
88 TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) {
89 converter.addConversion([sourceTypes = SmallVector<Type>(sourceTypes),
90 targetType](Type type) -> std::optional<Type> {
91 if (llvm::is_contained(sourceTypes, type))
92 return targetType;
93 if (auto shaped = dyn_cast<ShapedType>(type))
94 if (llvm::is_contained(sourceTypes, shaped.getElementType()))
95 return shaped.clone(targetType);
96 // All other types legal
97 return type;
98 });
99 converter.addTargetMaterialization(
100 [](OpBuilder &b, Type target, ValueRange input, Location loc) {
101 auto extFOp = arith::ExtFOp::create(b, loc, target, input);
102 extFOp.setFastmath(arith::FastMathFlags::contract);
103 return extFOp;
104 });
105}
106
108 RewritePatternSet &patterns, const TypeConverter &converter) {
109 patterns.add<EmulateFloatPattern>(converter, patterns.getContext());
110}
111
113 ConversionTarget &target, const TypeConverter &converter) {
114 // Don't try to legalize functions and other ops that don't need expansion.
115 target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
116 target.addDynamicallyLegalDialect<arith::ArithDialect>(
117 [&](Operation *op) -> std::optional<bool> {
118 return converter.isLegal(op);
119 });
120 // Manually mark arithmetic-performing vector instructions.
121 target.addDynamicallyLegalOp<vector::ContractionOp, vector::ReductionOp,
122 vector::MultiDimReductionOp, vector::FMAOp,
123 vector::OuterProductOp, vector::ScanOp>(
124 [&](Operation *op) { return converter.isLegal(op); });
125 target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
126 arith::ConstantOp, arith::SelectOp, vector::BroadcastOp>();
127}
128
129void EmulateUnsupportedFloatsPass::runOnOperation() {
130 MLIRContext *ctx = &getContext();
131 Operation *op = getOperation();
132 SmallVector<Type> sourceTypes;
133 Type targetType;
134
135 std::optional<FloatType> maybeTargetType =
136 arith::parseFloatType(ctx, targetTypeStr);
137 if (!maybeTargetType) {
138 emitError(UnknownLoc::get(ctx), "could not map target type '" +
139 targetTypeStr +
140 "' to a known floating-point type");
141 return signalPassFailure();
142 }
143 targetType = *maybeTargetType;
144 for (StringRef sourceTypeStr : sourceTypeStrs) {
145 std::optional<FloatType> maybeSourceType =
146 arith::parseFloatType(ctx, sourceTypeStr);
147 if (!maybeSourceType) {
148 emitError(UnknownLoc::get(ctx), "could not map source type '" +
149 sourceTypeStr +
150 "' to a known floating-point type");
151 return signalPassFailure();
152 }
153 sourceTypes.push_back(*maybeSourceType);
154 }
155 if (sourceTypes.empty())
157 std::nullopt,
158 "no source types specified, float emulation will do nothing");
159
160 if (llvm::is_contained(sourceTypes, targetType)) {
161 emitError(UnknownLoc::get(ctx),
162 "target type cannot be an unsupported source type");
163 return signalPassFailure();
164 }
165 TypeConverter converter;
166 arith::populateEmulateUnsupportedFloatsConversions(converter, sourceTypes,
167 targetType);
168 RewritePatternSet patterns(ctx);
169 arith::populateEmulateUnsupportedFloatsPatterns(patterns, converter);
170 ConversionTarget target(getContext());
171 arith::populateEmulateUnsupportedFloatsLegality(target, converter);
172
173 if (failed(applyPartialConversion(op, target, std::move(patterns))))
174 signalPassFailure();
175}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_type_range getResultTypes()
Definition Operation.h:428
SuccessorRange getSuccessors()
Definition Operation.h:703
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
virtual void runOnOperation()=0
The polymorphic API that runs the pass over the currently held operation.
This class contains all of the data related to a pattern, but does not contain any methods or logic f...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
Definition Utils.cpp:360
void populateEmulateUnsupportedFloatsPatterns(RewritePatternSet &patterns, const TypeConverter &converter)
Add rewrite patterns for converting operations that use illegal float types to ones that use legal on...
void populateEmulateUnsupportedFloatsLegality(ConversionTarget &target, const TypeConverter &converter)
Set up a dialect conversion to reject arithmetic operations on unsupported float types.
void populateEmulateUnsupportedFloatsConversions(TypeConverter &converter, ArrayRef< Type > sourceTypes, Type targetType)
Populate the type conversions needed to emulate the unsupported sourceTypes with destType
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
const FrozenRewritePatternSet & patterns
LogicalResult emitOptionalWarning(std::optional< Location > loc, Args &&...args)