23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SetVector.h"
27 #define GEN_PASS_DEF_MATHEXTENDTOSUPPORTEDTYPES
28 #include "mlir/Dialect/Math/Transforms/Passes.h.inc"
35 ExtendToSupportedTypesRewritePattern(
const TypeConverter &converter,
43 struct ExtendToSupportedTypesPass
44 : mlir::math::impl::MathExtendToSupportedTypesBase<
45 ExtendToSupportedTypesPass> {
46 using math::impl::MathExtendToSupportedTypesBase<
47 ExtendToSupportedTypesPass>::MathExtendToSupportedTypesBase;
49 void runOnOperation()
override;
58 [](
Type type) -> std::optional<Type> {
return type; });
60 [&sourceTypes, targetType](
FloatType type) -> std::optional<Type> {
61 if (!sourceTypes.contains(type))
67 [&sourceTypes, targetType](ShapedType type) -> std::optional<Type> {
68 if (
auto elemTy = dyn_cast<FloatType>(type.getElementType()))
69 if (!sourceTypes.contains(elemTy))
70 return type.clone(targetType);
76 auto extFOp = b.
create<arith::ExtFOp>(loc, target, input);
86 return typeConverter.
isLegal(op);
90 target.
addLegalOp<arith::ExtFOp, arith::TruncFOp>();
93 LogicalResult ExtendToSupportedTypesRewritePattern::matchAndRewrite(
98 FailureOr<Operation *> legalized =
100 if (failed(legalized))
104 for (
auto [result, newType, origType] : llvm::zip_equal(
106 if (newType != origType) {
107 auto truncFOp = rewriter.
create<arith::TruncFOp>(loc, origType, result);
118 patterns.
add<ExtendToSupportedTypesRewritePattern>(typeConverter,
122 void ExtendToSupportedTypesPass::runOnOperation() {
127 std::optional<Type> maybeTargetType =
129 if (!maybeTargetType.has_value()) {
132 "' to a known floating-point type");
133 return signalPassFailure();
135 Type targetType = maybeTargetType.value();
139 for (
const auto &extraTypeStr : extraTypeStrs) {
140 std::optional<FloatType> maybeExtraType =
142 if (!maybeExtraType.has_value()) {
145 "' to a known floating-point type");
146 return signalPassFailure();
148 sourceTypes.insert(maybeExtraType.value());
152 sourceTypes.insert(b.getF64Type());
153 sourceTypes.insert(b.getF32Type());
163 return signalPassFailure();
static MLIRContext * getContext(OpFoldResult val)
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
This class is a general helper class for creating context-global objects like types,...
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Base class for the conversion patterns.
This class describes a specific conversion target.
void addLegalOp(OperationName op)
Register the given operations as legal.
void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn)
Register unknown operations as dynamically legal.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
result_type_range getResultTypes()
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void addConversion(FnT &&callback)
Register a conversion function.
bool isLegal(Type type) const
Return true if the given type is legal for this type converter, i.e.
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting an illegal (source) value...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
void populateExtendToSupportedTypesPatterns(RewritePatternSet &patterns, const TypeConverter &typeConverter)
void populateExtendToSupportedTypesConversionTarget(ConversionTarget &target, TypeConverter &typeConverter)
void populateExtendToSupportedTypesTypeConverter(TypeConverter &typeConverter, const SetVector< Type > &sourceTypes, Type targetType)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
FailureOr< Operation * > convertOpResultTypes(Operation *op, ValueRange operands, const TypeConverter &converter, ConversionPatternRewriter &rewriter)
Generic utility to convert op result types according to type converter without knowing exact op type.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.