MLIR  19.0.0git
LegalizeToF32.cpp
Go to the documentation of this file.
1 //===- LegalizeToF32.cpp - Legalize functions on small floats ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements legalizing math operations on small floating-point
10 // types through arith.extf and arith.truncf.
11 //
12 //===----------------------------------------------------------------------===//
13 
17 #include "mlir/IR/Diagnostics.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/TypeUtilities.h"
21 #include "llvm/ADT/STLExtras.h"
22 
23 namespace mlir::math {
24 #define GEN_PASS_DEF_MATHLEGALIZETOF32
25 #include "mlir/Dialect/Math/Transforms/Passes.h.inc"
26 } // namespace mlir::math
27 
28 using namespace mlir;
29 namespace {
30 struct LegalizeToF32RewritePattern final : ConversionPattern {
31  LegalizeToF32RewritePattern(TypeConverter &converter, MLIRContext *context)
32  : ConversionPattern(converter, MatchAnyOpTypeTag{}, 1, context) {}
34  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
35  ConversionPatternRewriter &rewriter) const override;
36 };
37 
38 struct LegalizeToF32Pass final
39  : mlir::math::impl::MathLegalizeToF32Base<LegalizeToF32Pass> {
40  void runOnOperation() override;
41 };
42 } // namespace
43 
45  TypeConverter &typeConverter) {
46  typeConverter.addConversion(
47  [](Type type) -> std::optional<Type> { return type; });
48  typeConverter.addConversion([](FloatType type) -> std::optional<Type> {
49  if (type.getWidth() < 32)
50  return Float32Type::get(type.getContext());
51  return std::nullopt;
52  });
53  typeConverter.addConversion([](ShapedType type) -> std::optional<Type> {
54  if (auto elemTy = dyn_cast<FloatType>(type.getElementType()))
55  return type.clone(Float32Type::get(type.getContext()));
56  return std::nullopt;
57  });
58  typeConverter.addTargetMaterialization(
59  [](OpBuilder &b, Type target, ValueRange input, Location loc) {
60  return b.create<arith::ExtFOp>(loc, target, input);
61  });
62 }
63 
65  ConversionTarget &target, TypeConverter &typeConverter) {
66  target.addDynamicallyLegalDialect<MathDialect>(
67  [&typeConverter](Operation *op) -> bool {
68  return typeConverter.isLegal(op);
69  });
70  target.addLegalOp<FmaOp>();
71  target.addLegalOp<arith::ExtFOp, arith::TruncFOp>();
72 }
73 
74 LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
75  Operation *op, ArrayRef<Value> operands,
76  ConversionPatternRewriter &rewriter) const {
77  Location loc = op->getLoc();
78  const TypeConverter *converter = getTypeConverter();
79  FailureOr<Operation *> legalized =
80  convertOpResultTypes(op, operands, *converter, rewriter);
81  if (failed(legalized))
82  return failure();
83 
84  SmallVector<Value> results = (*legalized)->getResults();
85  for (auto [result, newType, origType] : llvm::zip_equal(
86  results, (*legalized)->getResultTypes(), op->getResultTypes())) {
87  if (newType != origType)
88  result = rewriter.create<arith::TruncFOp>(loc, origType, result);
89  }
90  rewriter.replaceOp(op, results);
91  return success();
92 }
93 
95  TypeConverter &typeConverter) {
96  patterns.add<LegalizeToF32RewritePattern>(typeConverter,
97  patterns.getContext());
98 }
99 
100 void LegalizeToF32Pass::runOnOperation() {
101  Operation *op = getOperation();
102  MLIRContext &ctx = getContext();
103 
104  TypeConverter typeConverter;
106  ConversionTarget target(ctx);
107  math::populateLegalizeToF32ConversionTarget(target, typeConverter);
108  RewritePatternSet patterns(&ctx);
109  math::populateLegalizeToF32Patterns(patterns, typeConverter);
110  if (failed(applyPartialConversion(op, target, std::move(patterns))))
111  return signalPassFailure();
112 }
static MLIRContext * getContext(OpFoldResult val)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
Base class for the conversion patterns.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:423
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
void populateLegalizeToF32Patterns(RewritePatternSet &patterns, TypeConverter &typeConverter)
void populateLegalizeToF32TypeConverter(TypeConverter &typeConverter)
void populateLegalizeToF32ConversionTarget(ConversionTarget &target, TypeConverter &typeConverter)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26