1 //===- EmulateUnsupportedFloats.cpp - Promote small floats --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This pass promotes small floats (of some unsupported types T) to a supported
9 // type U by wrapping all float operations on Ts with expansion to and
10 // truncation from U, then operating on U.
11 //===----------------------------------------------------------------------===//
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Location.h"
19 #include "mlir/IR/PatternMatch.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/ErrorHandling.h"
23 #include <optional>
25 namespace mlir::arith {
27 #include "mlir/Dialect/Arith/Transforms/"
28 } // namespace mlir::arith
30 using namespace mlir;
32 namespace {
33 struct EmulateUnsupportedFloatsPass
34  : arith::impl::ArithEmulateUnsupportedFloatsBase<
35  EmulateUnsupportedFloatsPass> {
36  using arith::impl::ArithEmulateUnsupportedFloatsBase<
37  EmulateUnsupportedFloatsPass>::ArithEmulateUnsupportedFloatsBase;
39  void runOnOperation() override;
40 };
42 struct EmulateFloatPattern final : ConversionPattern {
43  EmulateFloatPattern(TypeConverter &converter, MLIRContext *ctx)
44  : ConversionPattern(converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
46  LogicalResult match(Operation *op) const override;
47  void rewrite(Operation *op, ArrayRef<Value> operands,
48  ConversionPatternRewriter &rewriter) const override;
49 };
50 } // end namespace
52 /// Map strings to float types. This function is here because no one else needs
53 /// it yet, feel free to abstract it out.
54 static std::optional<FloatType> parseFloatType(MLIRContext *ctx,
55  StringRef name) {
56  Builder b(ctx);
58  .Case("f8E5M2", b.getFloat8E5M2Type())
59  .Case("f8E4M3FN", b.getFloat8E4M3FNType())
60  .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
61  .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
62  .Case("bf16", b.getBF16Type())
63  .Case("f16", b.getF16Type())
64  .Case("f32", b.getF32Type())
65  .Case("f64", b.getF64Type())
66  .Case("f80", b.getF80Type())
67  .Case("f128", b.getF128Type())
68  .Default(std::nullopt);
69 }
71 LogicalResult EmulateFloatPattern::match(Operation *op) const {
72  if (getTypeConverter()->isLegal(op))
73  return failure();
74  // The rewrite doesn't handle cloning regions.
75  if (op->getNumRegions() != 0)
76  return failure();
77  return success();
78 }
81  ConversionPatternRewriter &rewriter) const {
82  Location loc = op->getLoc();
83  const TypeConverter *converter = getTypeConverter();
84  SmallVector<Type> resultTypes;
85  if (failed(converter->convertTypes(op->getResultTypes(), resultTypes))) {
86  // Note to anyone looking for this error message: this is a "can't happen".
87  // If you're seeing it, there's a bug.
88  op->emitOpError("type conversion failed in float emulation");
89  return;
90  }
91  Operation *expandedOp =
92  rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
93  op->getAttrs(), op->getSuccessors(), /*regions=*/{});
94  SmallVector<Value> newResults(expandedOp->getResults());
95  for (auto [res, oldType, newType] : llvm::zip_equal(
96  MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
97  if (oldType != newType)
98  res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
99  }
100  rewriter.replaceOp(op, newResults);
101 }
104  TypeConverter &converter, ArrayRef<Type> sourceTypes, Type targetType) {
105  converter.addConversion([sourceTypes = SmallVector<Type>(sourceTypes),
106  targetType](Type type) -> std::optional<Type> {
107  if (llvm::is_contained(sourceTypes, type))
108  return targetType;
109  if (auto shaped = dyn_cast<ShapedType>(type))
110  if (llvm::is_contained(sourceTypes, shaped.getElementType()))
111  return shaped.clone(targetType);
112  // All other types legal
113  return type;
114  });
115  converter.addTargetMaterialization(
116  [](OpBuilder &b, Type target, ValueRange input, Location loc) {
117  return b.create<arith::ExtFOp>(loc, target, input);
118  });
119 }
122  RewritePatternSet &patterns, TypeConverter &converter) {
123  patterns.add<EmulateFloatPattern>(converter, patterns.getContext());
124 }
127  ConversionTarget &target, TypeConverter &converter) {
128  // Don't try to legalize functions and other ops that don't need expansion.
129  target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
130  target.addDynamicallyLegalDialect<arith::ArithDialect>(
131  [&](Operation *op) -> std::optional<bool> {
132  return converter.isLegal(op);
133  });
134  // Manually mark arithmetic-performing vector instructions.
135  target.addDynamicallyLegalOp<
136  vector::ContractionOp, vector::ReductionOp, vector::MultiDimReductionOp,
137  vector::FMAOp, vector::OuterProductOp, vector::MatmulOp, vector::ScanOp>(
138  [&](Operation *op) { return converter.isLegal(op); });
139  target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
140  arith::ConstantOp, vector::SplatOp>();
141 }
143 void EmulateUnsupportedFloatsPass::runOnOperation() {
144  MLIRContext *ctx = &getContext();
145  Operation *op = getOperation();
146  SmallVector<Type> sourceTypes;
147  Type targetType;
149  std::optional<FloatType> maybeTargetType = parseFloatType(ctx, targetTypeStr);
150  if (!maybeTargetType) {
151  emitError(UnknownLoc::get(ctx), "could not map target type '" +
152  targetTypeStr +
153  "' to a known floating-point type");
154  return signalPassFailure();
155  }
156  targetType = *maybeTargetType;
157  for (StringRef sourceTypeStr : sourceTypeStrs) {
158  std::optional<FloatType> maybeSourceType =
159  parseFloatType(ctx, sourceTypeStr);
160  if (!maybeSourceType) {
161  emitError(UnknownLoc::get(ctx), "could not map source type '" +
162  sourceTypeStr +
163  "' to a known floating-point type");
164  return signalPassFailure();
165  }
166  sourceTypes.push_back(*maybeSourceType);
167  }
168  if (sourceTypes.empty())
169  (void)emitOptionalWarning(
170  std::nullopt,
171  "no source types specified, float emulation will do nothing");
173  if (llvm::is_contained(sourceTypes, targetType)) {
175  "target type cannot be an unsupported source type");
176  return signalPassFailure();
177  }
178  TypeConverter converter;
180  targetType);
181  RewritePatternSet patterns(ctx);
183  ConversionTarget target(getContext());
186  if (failed(applyPartialConversion(op, target, std::move(patterns))))
187  signalPassFailure();
188 }
