16#include "llvm/ADT/APFloat.h"
29#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
30#define GEN_PASS_DEF_TOSANARROWF64TOF32PASS
31#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
41enum class TosaNarrowKind { Int64ToInt32, Float64ToFloat32 };
47template <TosaNarrowKind Kind>
48bool isSourceInteger(IntegerType type) {
49 if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
50 return type.isInteger(64);
54template <TosaNarrowKind Kind>
55bool isSourceFloat(FloatType type) {
56 if constexpr (Kind == TosaNarrowKind::Float64ToFloat32)
61template <TosaNarrowKind Kind>
62Type convertInteger(IntegerType type) {
63 if (!isSourceInteger<Kind>(type))
65 if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
66 return IntegerType::get(type.getContext(), 32);
70template <TosaNarrowKind Kind>
71Type convertFloat(FloatType type) {
72 if (!isSourceFloat<Kind>(type))
74 if constexpr (Kind == TosaNarrowKind::Float64ToFloat32)
75 return Float32Type::get(type.getContext());
79template <TosaNarrowKind Kind>
80bool isSourceElement(
Type type) {
81 if (
auto intTy = dyn_cast<IntegerType>(type))
82 return isSourceInteger<Kind>(intTy);
83 if (
auto floatTy = dyn_cast<FloatType>(type))
84 return isSourceFloat<Kind>(floatTy);
88template <TosaNarrowKind Kind>
90 if (
auto intTy = dyn_cast<IntegerType>(type))
91 return convertInteger<Kind>(intTy);
92 if (
auto floatTy = dyn_cast<FloatType>(type))
93 return convertFloat<Kind>(floatTy);
97template <TosaNarrowKind Kind>
98bool typeNeedsConversion(
Type type) {
99 if (
auto shaped = dyn_cast<ShapedType>(type))
100 return isSourceElement<Kind>(shaped.getElementType());
101 return isSourceElement<Kind>(type);
104FailureOr<APInt> convertIntegerConstant(IntegerType targetType,
106 bool allowLossyConversion) {
107 const unsigned targetWidth = targetType.getWidth();
108 if (!allowLossyConversion && !value.isSignedIntN(targetWidth))
111 if (allowLossyConversion)
112 return value.truncSSat(targetWidth);
113 return value.sextOrTrunc(targetWidth);
116FailureOr<APFloat> convertFloatConstant(FloatType targetType,
117 const APFloat &value,
118 bool allowLossyConversion) {
119 APFloat converted(value);
120 bool losesInfo =
false;
121 converted.convert(targetType.getFloatSemantics(),
122 APFloat::rmNearestTiesToEven, &losesInfo);
123 if (!allowLossyConversion && losesInfo)
130template <TosaNarrowKind Kind>
131FailureOr<Attribute> tryConvertScalarAttribute(
Attribute attribute,
132 bool allowLossyConversion) {
133 if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
134 if (
const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
135 if (
const auto intType = dyn_cast<IntegerType>(intAttr.getType());
136 intType && isSourceInteger<Kind>(intType)) {
137 const auto convertedType =
138 cast<IntegerType>(convertInteger<Kind>(intType));
139 FailureOr<APInt> convertedValue = convertIntegerConstant(
140 convertedType, intAttr.getValue(), allowLossyConversion);
141 if (
failed(convertedValue))
143 return IntegerAttr::get(convertedType, convertedValue.value());
146 }
else if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
147 if (
const auto floatAttr = dyn_cast<FloatAttr>(attribute)) {
148 if (
const auto floatType = dyn_cast<FloatType>(floatAttr.getType());
149 floatType && isSourceFloat<Kind>(floatType)) {
150 const auto convertedType =
151 cast<FloatType>(convertFloat<Kind>(floatType));
152 FailureOr<APFloat> convertedValue = convertFloatConstant(
153 convertedType, floatAttr.getValue(), allowLossyConversion);
154 if (
failed(convertedValue))
156 return FloatAttr::get(convertedType, convertedValue.value());
164template <TosaNarrowKind Kind>
168 bool allowLossyConversion) {
169 if constexpr (Kind != TosaNarrowKind::Int64ToInt32)
172 const auto oldElementType = dyn_cast<IntegerType>(type.getElementType());
173 if (!oldElementType || !isSourceInteger<Kind>(oldElementType))
177 dyn_cast_or_null<ShapedType>(typeConverter.convertType(type));
181 const auto newElementType = dyn_cast<IntegerType>(newType.getElementType());
185 if (!allowLossyConversion) {
186 for (APInt value : attr.getValues<APInt>())
187 if (
failed(convertIntegerConstant(newElementType, value,
193 attr.
mapValues(newElementType, [&](
const APInt &value) -> APInt {
194 return convertIntegerConstant(newElementType, value,
198 return convertedAttr;
201template <TosaNarrowKind Kind>
205 bool allowLossyConversion) {
206 if constexpr (Kind != TosaNarrowKind::Float64ToFloat32)
209 const auto oldElementType = dyn_cast<FloatType>(type.getElementType());
210 if (!oldElementType || !isSourceFloat<Kind>(oldElementType))
214 dyn_cast_or_null<ShapedType>(typeConverter.convertType(type));
218 const auto newElementType = dyn_cast<FloatType>(newType.getElementType());
222 if (!allowLossyConversion) {
223 for (APFloat value : attr.getValues<APFloat>())
224 if (
failed(convertFloatConstant(newElementType, value,
230 attr.
mapValues(newElementType, [&](
const APFloat &value) -> APInt {
231 APFloat converted = convertFloatConstant(newElementType, value,
236 return converted.bitcastToAPInt();
238 return convertedAttr;
241template <TosaNarrowKind Kind,
typename AttrT>
243convertAttributeWithTypeConverter(AttrT attr,
Type type,
245 if (!typeNeedsConversion<Kind>(type))
248 const std::optional<Attribute> convertedAttribute =
249 typeConverter->convertTypeAttribute(type, attr);
250 if (!convertedAttribute)
253 return convertedAttribute.value();
258template <TosaNarrowKind Kind>
260verifyCastDoesNotLosePrecision(
Operation *op, ShapedType inputType,
261 ShapedType resultType,
262 ConversionPatternRewriter &rewriter) {
263 if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
264 const auto elementInputIntType =
265 dyn_cast<IntegerType>(inputType.getElementType());
266 const auto elementResultIntType =
267 dyn_cast<IntegerType>(resultType.getElementType());
268 if (elementInputIntType && elementResultIntType &&
269 elementInputIntType.getWidth() > elementResultIntType.getWidth())
270 return rewriter.notifyMatchFailure(
271 op,
"Narrowing cast may lead to data loss.");
272 }
else if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
273 const auto elementInputFloatType =
274 dyn_cast<FloatType>(inputType.getElementType());
275 const auto elementResultFloatType =
276 dyn_cast<FloatType>(resultType.getElementType());
277 if (elementInputFloatType && elementResultFloatType &&
278 elementInputFloatType.getIntOrFloatBitWidth() >
279 elementResultFloatType.getIntOrFloatBitWidth())
280 return rewriter.notifyMatchFailure(
281 op,
"Narrowing cast may lead to data loss.");
293template <TosaNarrowKind Kind>
295 ConversionPatternRewriter &rewriter,
297 bool allowLossyConversion) {
307 const Attribute attribute = namedAttribute.getValue();
309 if (isa<IntegerAttr>(attribute) || isa<FloatAttr>(attribute)) {
310 FailureOr<Attribute> convertedAttr =
311 tryConvertScalarAttribute<Kind>(attribute, allowLossyConversion);
312 if (
failed(convertedAttr))
313 return rewriter.notifyMatchFailure(
314 op,
"Scalar attribute narrowing would lose precision; enable "
315 "aggressive rewrite to override.");
316 state.addAttribute(namedAttribute.getName(), convertedAttr.value());
320 if (
const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
321 FailureOr<Attribute> convertedAttr =
322 convertAttributeWithTypeConverter<Kind>(typeAttr, typeAttr.getValue(),
324 if (
failed(convertedAttr))
325 return rewriter.notifyMatchFailure(op,
326 "Failed to convert type attribute.");
327 state.addAttribute(namedAttribute.getName(), convertedAttr.value());
331 if (
const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
332 FailureOr<Attribute> convertedAttr =
333 convertAttributeWithTypeConverter<Kind>(
334 denseElementsAttr, denseElementsAttr.getType(), typeConverter);
335 if (
failed(convertedAttr))
336 return rewriter.notifyMatchFailure(
337 op,
"Failed to convert dense elements attribute without precision "
338 "loss; enable aggressive rewrite to override.");
339 state.addAttribute(namedAttribute.getName(), convertedAttr.value());
343 state.addAttribute(namedAttribute.getName(), attribute);
347 Region *newRegion = state.addRegion();
348 rewriter.inlineRegionBefore(region, *newRegion, newRegion->
begin());
349 if (
failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
353 Operation *newOp = rewriter.create(state);
358template <TosaNarrowKind Kind>
361 ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context,
362 bool allowLossyConversion)
363 : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context),
364 allowLossyConversion(allowLossyConversion) {}
367 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
368 ConversionPatternRewriter &rewriter)
const final {
369 if (!isa<tosa::TosaOp>(op))
370 return rewriter.notifyMatchFailure(
372 "Support for operations other than TOSA has not been implemented.");
374 return convertGenericOp<Kind>(op, operands, rewriter, typeConverter,
375 allowLossyConversion);
379 const bool allowLossyConversion;
382template <
typename OpTy, TosaNarrowKind Kind>
383class ConvertTypedOp :
public OpConversionPattern<OpTy> {
385 ConvertTypedOp(TypeConverter &typeConverter, MLIRContext *context)
386 : OpConversionPattern<OpTy>(typeConverter, context) {}
389 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
390 ConversionPatternRewriter &rewriter)
const final {
391 return convertGenericOp<Kind>(op, adaptor.getOperands(), rewriter,
392 this->getTypeConverter(),
402template <TosaNarrowKind Kind>
403class ConvertCastOpWithBoundsChecking
404 :
public OpConversionPattern<tosa::CastOp> {
405 using OpConversionPattern<tosa::CastOp>::OpConversionPattern;
408 matchAndRewrite(tosa::CastOp op,
typename tosa::CastOp::Adaptor adaptor,
409 ConversionPatternRewriter &rewriter)
const final {
410 const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
412 if (!inputType || !resultType)
415 const TypeConverter *typeConverter = this->getTypeConverter();
416 if (
failed(verifyCastDoesNotLosePrecision<Kind>(op, inputType, resultType,
420 rewriter.replaceOpWithNewOp<tosa::CastOp>(
421 op, typeConverter->convertType(resultType), adaptor.getInput());
427class ConvertArgMaxOpWithBoundsChecking
428 :
public OpConversionPattern<tosa::ArgMaxOp> {
429 using OpConversionPattern::OpConversionPattern;
432 matchAndRewrite(tosa::ArgMaxOp op,
typename tosa::ArgMaxOp::Adaptor adaptor,
433 ConversionPatternRewriter &rewriter)
const final {
434 const int32_t axis = op.getAxis();
435 const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
436 if (!inputType || !inputType.isStaticDim(axis))
437 return rewriter.notifyMatchFailure(
438 op,
"Requires a static axis dimension for bounds checking.");
439 const int64_t axisDim = inputType.getDimSize(axis);
440 if (axisDim >= std::numeric_limits<int32_t>::max())
441 return rewriter.notifyMatchFailure(
442 op,
"Axis dimension is too large to narrow safely.");
444 const Type resultType = op.getOutput().getType();
445 const Type newResultType =
446 this->getTypeConverter()->convertType(resultType);
447 rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType,
448 adaptor.getInput(), axis);
453template <TosaNarrowKind Kind>
454class ConvertClampOpWithBoundsChecking
455 :
public OpConversionPattern<tosa::ClampOp> {
456 static_assert(
Kind == TosaNarrowKind::Int64ToInt32,
457 "Clamp bounds checking only supported for integer narrowing");
458 using OpConversionPattern<tosa::ClampOp>::OpConversionPattern;
461 matchAndRewrite(tosa::ClampOp op,
typename tosa::ClampOp::Adaptor adaptor,
462 ConversionPatternRewriter &rewriter)
const final {
463 auto minAttr = dyn_cast<IntegerAttr>(op.getMinValAttr());
464 auto maxAttr = dyn_cast<IntegerAttr>(op.getMaxValAttr());
465 if (!minAttr || !maxAttr)
466 return rewriter.notifyMatchFailure(
467 op,
"Clamp attributes must be integer constants.");
469 const int64_t
min = minAttr.getInt();
470 const int64_t
max = maxAttr.getInt();
471 if (
min < std::numeric_limits<int32_t>::min() ||
472 max > std::numeric_limits<int32_t>::max())
473 return rewriter.notifyMatchFailure(
474 op,
"Clamp bounds exceed int32 range. Narrowing may lose data.");
476 const Type resultType = op.getOutput().getType();
477 const Type newResultType =
478 this->getTypeConverter()->convertType(resultType);
479 const auto newResultShaped = dyn_cast<ShapedType>(newResultType);
480 if (!newResultShaped)
482 const auto newElementType =
483 dyn_cast<IntegerType>(newResultShaped.getElementType());
487 const IntegerAttr newMinAttr = IntegerAttr::get(newElementType,
min);
488 const IntegerAttr newMaxAttr = IntegerAttr::get(newElementType,
max);
490 rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, newResultType,
491 adaptor.getInput(), newMinAttr,
492 newMaxAttr, op.getNanModeAttr());
499template <TosaNarrowKind Kind>
500LogicalResult runTosaNarrowing(
Operation *op,
bool aggressiveRewrite,
501 bool convertFunctionBoundaries) {
503 const bool allowLossyConversion = aggressiveRewrite;
506 typeConverter.addConversion([](
Type type) ->
Type {
return type; });
508 typeConverter.addConversion(
509 [](IntegerType type) ->
Type {
return convertInteger<Kind>(type); });
510 typeConverter.addConversion(
511 [](FloatType type) ->
Type {
return convertFloat<Kind>(type); });
512 typeConverter.addConversion([&typeConverter](RankedTensorType type) ->
Type {
513 Type elementType = type.getElementType();
514 if (!isSourceElement<Kind>(elementType))
516 Type converted = typeConverter.convertType(elementType);
517 if (!converted || converted == elementType)
519 return RankedTensorType::get(type.getShape(), converted,
522 typeConverter.addConversion(
523 [&typeConverter](UnrankedTensorType type) ->
Type {
524 Type elementType = type.getElementType();
525 if (!isSourceElement<Kind>(elementType))
527 Type converted = typeConverter.convertType(elementType);
528 if (!converted || converted == elementType)
530 return UnrankedTensorType::get(converted);
533 const auto materializeCast = [](
OpBuilder &builder,
Type resultType,
535 if (inputs.size() != 1)
537 return tosa::CastOp::create(builder, loc, resultType, inputs.front());
539 typeConverter.addSourceMaterialization(materializeCast);
540 typeConverter.addTargetMaterialization(materializeCast);
542 if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
543 typeConverter.addTypeAttributeConversion(
544 [allowLossyConversion](IntegerType , IntegerAttr attribute)
545 -> TypeConverter::AttributeConversionResult {
546 FailureOr<Attribute> converted =
547 tryConvertScalarAttribute<Kind>(attribute, allowLossyConversion);
549 return TypeConverter::AttributeConversionResult::abort();
550 return TypeConverter::AttributeConversionResult::result(
553 typeConverter.addTypeAttributeConversion(
554 [&typeConverter, allowLossyConversion](ShapedType type,
556 -> TypeConverter::AttributeConversionResult {
557 FailureOr<Attribute> converted = convertDenseIntElementsAttr<Kind>(
558 type, attr, typeConverter, allowLossyConversion);
560 return TypeConverter::AttributeConversionResult::abort();
561 return TypeConverter::AttributeConversionResult::result(
564 }
else if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
565 typeConverter.addTypeAttributeConversion(
566 [allowLossyConversion](FloatType , FloatAttr attribute)
567 -> TypeConverter::AttributeConversionResult {
568 FailureOr<Attribute> converted =
569 tryConvertScalarAttribute<Kind>(attribute, allowLossyConversion);
571 return TypeConverter::AttributeConversionResult::abort();
572 return TypeConverter::AttributeConversionResult::result(
575 typeConverter.addTypeAttributeConversion(
576 [&typeConverter, allowLossyConversion](ShapedType type,
578 -> TypeConverter::AttributeConversionResult {
579 FailureOr<Attribute> converted = convertDenseFPElementsAttr<Kind>(
580 type, attr, typeConverter, allowLossyConversion);
582 return TypeConverter::AttributeConversionResult::abort();
583 return TypeConverter::AttributeConversionResult::result(
589 target.addDynamicallyLegalDialect<tosa::TosaDialect>(
594 if (convertFunctionBoundaries) {
595 target.addDynamicallyLegalOp<func::FuncOp>(
596 [&typeConverter](func::FuncOp op) {
597 return typeConverter.isSignatureLegal(op.getFunctionType()) &&
598 typeConverter.isLegal(&op.getBody());
600 target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) {
601 const FunctionType funcType =
606 target.addDynamicallyLegalOp<func::FuncOp>(
607 [](func::FuncOp) {
return true; });
608 target.addDynamicallyLegalOp<func::ReturnOp>(
609 [](func::ReturnOp) {
return true; });
613 if (convertFunctionBoundaries) {
614 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
618 if (aggressiveRewrite) {
619 patterns.add<ConvertGenericOp<Kind>>(typeConverter, context,
620 allowLossyConversion);
622 if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
623 patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
624 patterns.add<ConvertClampOpWithBoundsChecking<Kind>>(typeConverter,
627 patterns.add<ConvertTypedOp<tosa::ConstOp, Kind>>(typeConverter, context);
628 patterns.add<ConvertTypedOp<tosa::ConcatOp, Kind>>(typeConverter, context);
629 patterns.add<ConvertTypedOp<tosa::PadOp, Kind>>(typeConverter, context);
630 patterns.add<ConvertTypedOp<tosa::ReshapeOp, Kind>>(typeConverter, context);
631 patterns.add<ConvertTypedOp<tosa::ReverseOp, Kind>>(typeConverter, context);
632 patterns.add<ConvertTypedOp<tosa::SliceOp, Kind>>(typeConverter, context);
633 patterns.add<ConvertTypedOp<tosa::TileOp, Kind>>(typeConverter, context);
634 patterns.add<ConvertTypedOp<tosa::TransposeOp, Kind>>(typeConverter,
636 patterns.add<ConvertTypedOp<tosa::IdentityOp, Kind>>(typeConverter,
638 patterns.add<ConvertCastOpWithBoundsChecking<Kind>>(typeConverter, context);
639 patterns.add<ConvertTypedOp<tosa::IfOp, Kind>>(typeConverter, context);
640 patterns.add<ConvertTypedOp<tosa::WhileOp, Kind>>(typeConverter, context);
641 patterns.add<ConvertTypedOp<tosa::YieldOp, Kind>>(typeConverter, context);
653struct TosaNarrowI64ToI32
654 :
public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> {
655 using Base = tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32>;
657 TosaNarrowI64ToI32() =
default;
659 explicit TosaNarrowI64ToI32(
const TosaNarrowI64ToI32PassOptions &
options) {
660 this->aggressiveRewrite =
options.aggressiveRewrite;
661 this->convertFunctionBoundaries =
options.convertFunctionBoundaries;
664 void runOnOperation()
override {
665 if (
failed(runTosaNarrowing<TosaNarrowKind::Int64ToInt32>(
666 getOperation(), this->aggressiveRewrite,
667 this->convertFunctionBoundaries)))
672struct TosaNarrowF64ToF32
673 :
public tosa::impl::TosaNarrowF64ToF32PassBase<TosaNarrowF64ToF32> {
674 using Base = tosa::impl::TosaNarrowF64ToF32PassBase<TosaNarrowF64ToF32>;
676 TosaNarrowF64ToF32() =
default;
678 explicit TosaNarrowF64ToF32(
const TosaNarrowF64ToF32PassOptions &
options) {
679 this->aggressiveRewrite =
options.aggressiveRewrite;
680 this->convertFunctionBoundaries =
options.convertFunctionBoundaries;
683 void runOnOperation()
override {
684 if (
failed(runTosaNarrowing<TosaNarrowKind::Float64ToFloat32>(
685 getOperation(), this->aggressiveRewrite,
686 this->convertFunctionBoundaries)))
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
An attribute that represents a reference to a dense float vector or tensor object.
DenseElementsAttr mapValues(Type newElementType, function_ref< APInt(const APFloat &)> mapping) const
Generates a new DenseElementsAttr by mapping each value attribute, and constructing the DenseElements...
An attribute that represents a reference to a dense integer vector or tensor object.
DenseElementsAttr mapValues(Type newElementType, function_ref< APInt(const APInt &)> mapping) const
Generates a new DenseElementsAttr by mapping each value attribute, and constructing the DenseElements...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
This class helps build Operations.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
result_type_range getResultTypes()
SuccessorRange getSuccessors()
result_range getResults()
MLIRContext * getContext()
Return the context this operation is associated with.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.