MLIR  20.0.0git
ExtendToSupportedTypes.cpp
Go to the documentation of this file.
1 //===- ExtendToSupportedTypes.cpp - Legalize functions on unsupported floats
2 //----------===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements legalizing math operations on unsupported floating-point
11 // types through arith.extf and arith.truncf.
12 //
13 //===----------------------------------------------------------------------===//
14 
19 #include "mlir/IR/Diagnostics.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "mlir/IR/TypeUtilities.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SetVector.h"
25 
26 namespace mlir::math {
27 #define GEN_PASS_DEF_MATHEXTENDTOSUPPORTEDTYPES
28 #include "mlir/Dialect/Math/Transforms/Passes.h.inc"
29 } // namespace mlir::math
30 
31 using namespace mlir;
32 
33 namespace {
34 struct ExtendToSupportedTypesRewritePattern final : ConversionPattern {
35  ExtendToSupportedTypesRewritePattern(const TypeConverter &converter,
36  MLIRContext *context)
37  : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
38  LogicalResult
39  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
40  ConversionPatternRewriter &rewriter) const override;
41 };
42 
43 struct ExtendToSupportedTypesPass
44  : mlir::math::impl::MathExtendToSupportedTypesBase<
45  ExtendToSupportedTypesPass> {
46  using math::impl::MathExtendToSupportedTypesBase<
47  ExtendToSupportedTypesPass>::MathExtendToSupportedTypesBase;
48 
49  void runOnOperation() override;
50 };
51 } // namespace
52 
54  TypeConverter &typeConverter, const SetVector<Type> &sourceTypes,
55  Type targetType) {
56 
57  typeConverter.addConversion(
58  [](Type type) -> std::optional<Type> { return type; });
59  typeConverter.addConversion(
60  [&sourceTypes, targetType](FloatType type) -> std::optional<Type> {
61  if (!sourceTypes.contains(type))
62  return targetType;
63 
64  return std::nullopt;
65  });
66  typeConverter.addConversion(
67  [&sourceTypes, targetType](ShapedType type) -> std::optional<Type> {
68  if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
69  if (!sourceTypes.contains(elemTy))
70  return type.clone(targetType);
71 
72  return std::nullopt;
73  });
74  typeConverter.addTargetMaterialization(
75  [](OpBuilder &b, Type target, ValueRange input, Location loc) {
76  auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
77  extFOp.setFastmath(arith::FastMathFlags::contract);
78  return extFOp;
79  });
80 }
81 
83  ConversionTarget &target, TypeConverter &typeConverter) {
84  target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
85  if (isa<MathDialect>(op->getDialect()))
86  return typeConverter.isLegal(op);
87  return true;
88  });
89  target.addLegalOp<FmaOp>();
90  target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
91 }
92 
93 LogicalResult ExtendToSupportedTypesRewritePattern::matchAndRewrite(
94  Operation *op, ArrayRef<Value> operands,
95  ConversionPatternRewriter &rewriter) const {
96  Location loc = op->getLoc();
97  const TypeConverter *converter = getTypeConverter();
98  FailureOr<Operation *> legalized =
99  convertOpResultTypes(op, operands, *converter, rewriter);
100  if (failed(legalized))
101  return failure();
102 
103  SmallVector<Value> results = (*legalized)->getResults();
104  for (auto [result, newType, origType] : llvm::zip_equal(
105  results, (*legalized)->getResultTypes(), op->getResultTypes())) {
106  if (newType != origType) {
107  auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
108  truncFOp.setFastmath(arith::FastMathFlags::contract);
109  result = truncFOp.getResult();
110  }
111  }
112  rewriter.replaceOp(op, results);
113  return success();
114 }
115 
117  RewritePatternSet &patterns, const TypeConverter &typeConverter) {
118  patterns.add<ExtendToSupportedTypesRewritePattern>(typeConverter,
119  patterns.getContext());
120 }
121 
122 void ExtendToSupportedTypesPass::runOnOperation() {
123  Operation *op = getOperation();
124  MLIRContext *ctx = &getContext();
125 
126  // Parse target type
127  std::optional<Type> maybeTargetType =
128  arith::parseFloatType(ctx, targetTypeStr);
129  if (!maybeTargetType.has_value()) {
130  emitError(UnknownLoc::get(ctx), "could not map target type '" +
131  targetTypeStr +
132  "' to a known floating-point type");
133  return signalPassFailure();
134  }
135  Type targetType = maybeTargetType.value();
136 
137  // Parse source types
138  llvm::SetVector<Type> sourceTypes;
139  for (const auto &extraTypeStr : extraTypeStrs) {
140  std::optional<FloatType> maybeExtraType =
141  arith::parseFloatType(ctx, extraTypeStr);
142  if (!maybeExtraType.has_value()) {
143  emitError(UnknownLoc::get(ctx), "could not map source type '" +
144  extraTypeStr +
145  "' to a known floating-point type");
146  return signalPassFailure();
147  }
148  sourceTypes.insert(maybeExtraType.value());
149  }
150  // f64 and f32 are implicitly supported
151  Builder b(ctx);
152  sourceTypes.insert(b.getF64Type());
153  sourceTypes.insert(b.getF32Type());
154 
155  TypeConverter typeConverter;
156  math::populateExtendToSupportedTypesTypeConverter(typeConverter, sourceTypes,
157  targetType);
158  ConversionTarget target(*ctx);
162  if (failed(applyPartialConversion(op, target, std::move(patterns))))
163  return signalPassFailure();
164 }
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Base class for the conversion patterns.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:216
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:428
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
Definition: Utils.cpp:361
void populateExtendToSupportedTypesPatterns(RewritePatternSet &patterns, const TypeConverter &typeConverter)
void populateExtendToSupportedTypesConversionTarget(ConversionTarget &target, TypeConverter &typeConverter)
void populateExtendToSupportedTypesTypeConverter(TypeConverter &typeConverter, const SetVector< Type > &sourceTypes, Type targetType)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.