MLIR 22.0.0git
TosaNarrowTypes.cpp
Go to the documentation of this file.
1//===- TosaNarrowTypes.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the TOSA narrowing passes that rewrite tensor element
10// types to narrower equivalents (i64 -> i32, f64 -> f32, ...).
11//
12//===----------------------------------------------------------------------===//
13
15
16#include "llvm/ADT/APFloat.h"
17
18#include <limits>
19#include <type_traits>
20
24#include "mlir/IR/Verifier.h"
25#include "mlir/Pass/Pass.h"
26
27namespace mlir {
28namespace tosa {
29#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
30#define GEN_PASS_DEF_TOSANARROWF64TOF32PASS
31#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
32} // namespace tosa
33} // namespace mlir
34
35using namespace mlir;
36using namespace mlir::tosa;
37
38namespace {
39
40// Narrowing mode for this pass.
41enum class TosaNarrowKind { Int64ToInt32, Float64ToFloat32 };
42
43// ---------------------------------------------------------------------------
44// Shared helpers
45// ---------------------------------------------------------------------------
46
47template <TosaNarrowKind Kind>
48bool isSourceInteger(IntegerType type) {
49 if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
50 return type.isInteger(64);
51 return false;
52}
53
54template <TosaNarrowKind Kind>
55bool isSourceFloat(FloatType type) {
56 if constexpr (Kind == TosaNarrowKind::Float64ToFloat32)
57 return type.isF64();
58 return false;
59}
60
61template <TosaNarrowKind Kind>
62Type convertInteger(IntegerType type) {
63 if (!isSourceInteger<Kind>(type))
64 return type;
65 if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
66 return IntegerType::get(type.getContext(), 32);
67 return type;
68}
69
70template <TosaNarrowKind Kind>
71Type convertFloat(FloatType type) {
72 if (!isSourceFloat<Kind>(type))
73 return type;
74 if constexpr (Kind == TosaNarrowKind::Float64ToFloat32)
75 return Float32Type::get(type.getContext());
76 return type;
77}
78
79template <TosaNarrowKind Kind>
80bool isSourceElement(Type type) {
81 if (auto intTy = dyn_cast<IntegerType>(type))
82 return isSourceInteger<Kind>(intTy);
83 if (auto floatTy = dyn_cast<FloatType>(type))
84 return isSourceFloat<Kind>(floatTy);
85 return false;
86}
87
88template <TosaNarrowKind Kind>
89Type convertElement(Type type) {
90 if (auto intTy = dyn_cast<IntegerType>(type))
91 return convertInteger<Kind>(intTy);
92 if (auto floatTy = dyn_cast<FloatType>(type))
93 return convertFloat<Kind>(floatTy);
94 return type;
95}
96
97template <TosaNarrowKind Kind>
98bool typeNeedsConversion(Type type) {
99 if (auto shaped = dyn_cast<ShapedType>(type))
100 return isSourceElement<Kind>(shaped.getElementType());
101 return isSourceElement<Kind>(type);
102}
103
104FailureOr<APInt> convertIntegerConstant(IntegerType targetType,
105 const APInt &value,
106 bool allowLossyConversion) {
107 const unsigned targetWidth = targetType.getWidth();
108 if (!allowLossyConversion && !value.isSignedIntN(targetWidth))
109 return failure();
110
111 if (allowLossyConversion)
112 return value.truncSSat(targetWidth);
113 return value.sextOrTrunc(targetWidth);
114}
115
116FailureOr<APFloat> convertFloatConstant(FloatType targetType,
117 const APFloat &value,
118 bool allowLossyConversion) {
119 APFloat converted(value);
120 bool losesInfo = false;
121 converted.convert(targetType.getFloatSemantics(),
122 APFloat::rmNearestTiesToEven, &losesInfo);
123 if (!allowLossyConversion && losesInfo)
124 return failure();
125 return converted;
126}
127
128// Narrows scalar constant attributes so they keep matching the converted
129// element types.
130template <TosaNarrowKind Kind>
131FailureOr<Attribute> tryConvertScalarAttribute(Attribute attribute,
132 bool allowLossyConversion) {
133 if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
134 if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
135 if (const auto intType = dyn_cast<IntegerType>(intAttr.getType());
136 intType && isSourceInteger<Kind>(intType)) {
137 const auto convertedType =
138 cast<IntegerType>(convertInteger<Kind>(intType));
139 FailureOr<APInt> convertedValue = convertIntegerConstant(
140 convertedType, intAttr.getValue(), allowLossyConversion);
141 if (failed(convertedValue))
142 return failure();
143 return IntegerAttr::get(convertedType, convertedValue.value());
144 }
145 }
146 } else if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
147 if (const auto floatAttr = dyn_cast<FloatAttr>(attribute)) {
148 if (const auto floatType = dyn_cast<FloatType>(floatAttr.getType());
149 floatType && isSourceFloat<Kind>(floatType)) {
150 const auto convertedType =
151 cast<FloatType>(convertFloat<Kind>(floatType));
152 FailureOr<APFloat> convertedValue = convertFloatConstant(
153 convertedType, floatAttr.getValue(), allowLossyConversion);
154 if (failed(convertedValue))
155 return failure();
156 return FloatAttr::get(convertedType, convertedValue.value());
157 }
158 }
159 }
160
161 return attribute;
162}
163
164template <TosaNarrowKind Kind>
165FailureOr<Attribute>
166convertDenseIntElementsAttr(ShapedType type, DenseIntElementsAttr attr,
167 const TypeConverter &typeConverter,
168 bool allowLossyConversion) {
169 if constexpr (Kind != TosaNarrowKind::Int64ToInt32)
170 return attr;
171
172 const auto oldElementType = dyn_cast<IntegerType>(type.getElementType());
173 if (!oldElementType || !isSourceInteger<Kind>(oldElementType))
174 return attr;
175
176 const auto newType =
177 dyn_cast_or_null<ShapedType>(typeConverter.convertType(type));
178 if (!newType)
179 return failure();
180
181 const auto newElementType = dyn_cast<IntegerType>(newType.getElementType());
182 if (!newElementType)
183 return failure();
184
185 if (!allowLossyConversion) {
186 for (APInt value : attr.getValues<APInt>())
187 if (failed(convertIntegerConstant(newElementType, value,
188 /*allowLossyConversion=*/false)))
189 return failure();
190 }
191
192 Attribute convertedAttr =
193 attr.mapValues(newElementType, [&](const APInt &value) -> APInt {
194 return convertIntegerConstant(newElementType, value,
195 /*allowLossyConversion=*/true)
196 .value();
197 });
198 return convertedAttr;
199}
200
201template <TosaNarrowKind Kind>
202FailureOr<Attribute>
203convertDenseFPElementsAttr(ShapedType type, DenseFPElementsAttr attr,
204 const TypeConverter &typeConverter,
205 bool allowLossyConversion) {
206 if constexpr (Kind != TosaNarrowKind::Float64ToFloat32)
207 return attr;
208
209 const auto oldElementType = dyn_cast<FloatType>(type.getElementType());
210 if (!oldElementType || !isSourceFloat<Kind>(oldElementType))
211 return attr;
212
213 const auto newType =
214 dyn_cast_or_null<ShapedType>(typeConverter.convertType(type));
215 if (!newType)
216 return failure();
217
218 const auto newElementType = dyn_cast<FloatType>(newType.getElementType());
219 if (!newElementType)
220 return failure();
221
222 if (!allowLossyConversion) {
223 for (APFloat value : attr.getValues<APFloat>())
224 if (failed(convertFloatConstant(newElementType, value,
225 /*allowLossyConversion=*/false)))
226 return failure();
227 }
228
229 Attribute convertedAttr =
230 attr.mapValues(newElementType, [&](const APFloat &value) -> APInt {
231 APFloat converted = convertFloatConstant(newElementType, value,
232 /*allowLossyConversion=*/true)
233 .value();
234 // DenseFPElementsAttr stores each float as raw bits, so emit the APInt
235 // representation that MLIR expects in the underlying buffer.
236 return converted.bitcastToAPInt();
237 });
238 return convertedAttr;
239}
240
241template <TosaNarrowKind Kind, typename AttrT>
242FailureOr<Attribute>
243convertAttributeWithTypeConverter(AttrT attr, Type type,
244 const TypeConverter *typeConverter) {
245 if (!typeNeedsConversion<Kind>(type))
246 return attr;
247
248 const std::optional<Attribute> convertedAttribute =
249 typeConverter->convertTypeAttribute(type, attr);
250 if (!convertedAttribute)
251 return failure();
252
253 return convertedAttribute.value();
254}
255
256// Rejects cast rewrites that would lose precision (unless aggressive mode is
257// enabled).
258template <TosaNarrowKind Kind>
259LogicalResult
260verifyCastDoesNotLosePrecision(Operation *op, ShapedType inputType,
261 ShapedType resultType,
262 ConversionPatternRewriter &rewriter) {
263 if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
264 const auto elementInputIntType =
265 dyn_cast<IntegerType>(inputType.getElementType());
266 const auto elementResultIntType =
267 dyn_cast<IntegerType>(resultType.getElementType());
268 if (elementInputIntType && elementResultIntType &&
269 elementInputIntType.getWidth() > elementResultIntType.getWidth())
270 return rewriter.notifyMatchFailure(
271 op, "Narrowing cast may lead to data loss.");
272 } else if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
273 const auto elementInputFloatType =
274 dyn_cast<FloatType>(inputType.getElementType());
275 const auto elementResultFloatType =
276 dyn_cast<FloatType>(resultType.getElementType());
277 if (elementInputFloatType && elementResultFloatType &&
278 elementInputFloatType.getIntOrFloatBitWidth() >
279 elementResultFloatType.getIntOrFloatBitWidth())
280 return rewriter.notifyMatchFailure(
281 op, "Narrowing cast may lead to data loss.");
282 }
283
284 return success();
285}
286
287// ---------------------------------------------------------------------------
288// Conversion patterns
289// ---------------------------------------------------------------------------
290
291// Applies the narrowing TypeConverter to a single TOSA op, including its
292// attributes and nested regions.
293template <TosaNarrowKind Kind>
294LogicalResult convertGenericOp(Operation *op, ValueRange operands,
295 ConversionPatternRewriter &rewriter,
296 const TypeConverter *typeConverter,
297 bool allowLossyConversion) {
298 SmallVector<Type, 4> newResults;
299 if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
300 return failure();
301
302 OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
303 newResults, {}, op->getSuccessors());
304
305 // Keep attribute payloads consistent with the converted element types.
306 for (const NamedAttribute &namedAttribute : op->getAttrs()) {
307 const Attribute attribute = namedAttribute.getValue();
308
309 if (isa<IntegerAttr>(attribute) || isa<FloatAttr>(attribute)) {
310 FailureOr<Attribute> convertedAttr =
311 tryConvertScalarAttribute<Kind>(attribute, allowLossyConversion);
312 if (failed(convertedAttr))
313 return rewriter.notifyMatchFailure(
314 op, "Scalar attribute narrowing would lose precision; enable "
315 "aggressive rewrite to override.");
316 state.addAttribute(namedAttribute.getName(), convertedAttr.value());
317 continue;
318 }
319
320 if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
321 FailureOr<Attribute> convertedAttr =
322 convertAttributeWithTypeConverter<Kind>(typeAttr, typeAttr.getValue(),
323 typeConverter);
324 if (failed(convertedAttr))
325 return rewriter.notifyMatchFailure(op,
326 "Failed to convert type attribute.");
327 state.addAttribute(namedAttribute.getName(), convertedAttr.value());
328 continue;
329 }
330
331 if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
332 FailureOr<Attribute> convertedAttr =
333 convertAttributeWithTypeConverter<Kind>(
334 denseElementsAttr, denseElementsAttr.getType(), typeConverter);
335 if (failed(convertedAttr))
336 return rewriter.notifyMatchFailure(
337 op, "Failed to convert dense elements attribute without precision "
338 "loss; enable aggressive rewrite to override.");
339 state.addAttribute(namedAttribute.getName(), convertedAttr.value());
340 continue;
341 }
342
343 state.addAttribute(namedAttribute.getName(), attribute);
344 }
345
346 for (Region &region : op->getRegions()) {
347 Region *newRegion = state.addRegion();
348 rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
349 if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
350 return failure();
351 }
352
353 Operation *newOp = rewriter.create(state);
354 rewriter.replaceOp(op, newOp->getResults());
355 return success();
356}
357
358template <TosaNarrowKind Kind>
359class ConvertGenericOp : public ConversionPattern {
360public:
361 ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context,
362 bool allowLossyConversion)
363 : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context),
364 allowLossyConversion(allowLossyConversion) {}
365
366 LogicalResult
367 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
368 ConversionPatternRewriter &rewriter) const final {
369 if (!isa<tosa::TosaOp>(op))
370 return rewriter.notifyMatchFailure(
371 op,
372 "Support for operations other than TOSA has not been implemented.");
373
374 return convertGenericOp<Kind>(op, operands, rewriter, typeConverter,
375 allowLossyConversion);
376 }
377
378private:
379 const bool allowLossyConversion;
380};
381
382template <typename OpTy, TosaNarrowKind Kind>
383class ConvertTypedOp : public OpConversionPattern<OpTy> {
384public:
385 ConvertTypedOp(TypeConverter &typeConverter, MLIRContext *context)
386 : OpConversionPattern<OpTy>(typeConverter, context) {}
387
388 LogicalResult
389 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
390 ConversionPatternRewriter &rewriter) const final {
391 return convertGenericOp<Kind>(op, adaptor.getOperands(), rewriter,
392 this->getTypeConverter(),
393 /*allowLossyConversion=*/false);
394 }
395};
396
397// ---------------------------------------------------------------------------
398// Kind-specific helpers and patterns
399// ---------------------------------------------------------------------------
400
401// Casts get extra checking so we only narrow when it is probably safe.
402template <TosaNarrowKind Kind>
403class ConvertCastOpWithBoundsChecking
404 : public OpConversionPattern<tosa::CastOp> {
405 using OpConversionPattern<tosa::CastOp>::OpConversionPattern;
406
407 LogicalResult
408 matchAndRewrite(tosa::CastOp op, typename tosa::CastOp::Adaptor adaptor,
409 ConversionPatternRewriter &rewriter) const final {
410 const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
411 const auto resultType = dyn_cast<ShapedType>(op.getResult().getType());
412 if (!inputType || !resultType)
413 return failure();
414
415 const TypeConverter *typeConverter = this->getTypeConverter();
416 if (failed(verifyCastDoesNotLosePrecision<Kind>(op, inputType, resultType,
417 rewriter)))
418 return failure();
419
420 rewriter.replaceOpWithNewOp<tosa::CastOp>(
421 op, typeConverter->convertType(resultType), adaptor.getInput());
422 return success();
423 }
424};
425
426// ArgMax indices must fit the axis dimension, so we guard the integer rewrite.
427class ConvertArgMaxOpWithBoundsChecking
428 : public OpConversionPattern<tosa::ArgMaxOp> {
429 using OpConversionPattern::OpConversionPattern;
430
431 LogicalResult
432 matchAndRewrite(tosa::ArgMaxOp op, typename tosa::ArgMaxOp::Adaptor adaptor,
433 ConversionPatternRewriter &rewriter) const final {
434 const int32_t axis = op.getAxis();
435 const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
436 if (!inputType || !inputType.isStaticDim(axis))
437 return rewriter.notifyMatchFailure(
438 op, "Requires a static axis dimension for bounds checking.");
439 const int64_t axisDim = inputType.getDimSize(axis);
440 if (axisDim >= std::numeric_limits<int32_t>::max())
441 return rewriter.notifyMatchFailure(
442 op, "Axis dimension is too large to narrow safely.");
443
444 const Type resultType = op.getOutput().getType();
445 const Type newResultType =
446 this->getTypeConverter()->convertType(resultType);
447 rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType,
448 adaptor.getInput(), axis);
449 return success();
450 }
451};
452
453template <TosaNarrowKind Kind>
454class ConvertClampOpWithBoundsChecking
455 : public OpConversionPattern<tosa::ClampOp> {
456 static_assert(Kind == TosaNarrowKind::Int64ToInt32,
457 "Clamp bounds checking only supported for integer narrowing");
458 using OpConversionPattern<tosa::ClampOp>::OpConversionPattern;
459
460 LogicalResult
461 matchAndRewrite(tosa::ClampOp op, typename tosa::ClampOp::Adaptor adaptor,
462 ConversionPatternRewriter &rewriter) const final {
463 auto minAttr = dyn_cast<IntegerAttr>(op.getMinValAttr());
464 auto maxAttr = dyn_cast<IntegerAttr>(op.getMaxValAttr());
465 if (!minAttr || !maxAttr)
466 return rewriter.notifyMatchFailure(
467 op, "Clamp attributes must be integer constants.");
468
469 const int64_t min = minAttr.getInt();
470 const int64_t max = maxAttr.getInt();
471 if (min < std::numeric_limits<int32_t>::min() ||
472 max > std::numeric_limits<int32_t>::max())
473 return rewriter.notifyMatchFailure(
474 op, "Clamp bounds exceed int32 range. Narrowing may lose data.");
475
476 const Type resultType = op.getOutput().getType();
477 const Type newResultType =
478 this->getTypeConverter()->convertType(resultType);
479 const auto newResultShaped = dyn_cast<ShapedType>(newResultType);
480 if (!newResultShaped)
481 return failure();
482 const auto newElementType =
483 dyn_cast<IntegerType>(newResultShaped.getElementType());
484 if (!newElementType)
485 return failure();
486
487 const IntegerAttr newMinAttr = IntegerAttr::get(newElementType, min);
488 const IntegerAttr newMaxAttr = IntegerAttr::get(newElementType, max);
489
490 rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, newResultType,
491 adaptor.getInput(), newMinAttr,
492 newMaxAttr, op.getNanModeAttr());
493 return success();
494 }
495};
496
497// Shared implementation for both narrowing passes; the mode decides which
498// element types and attribute payloads participate.
499template <TosaNarrowKind Kind>
500LogicalResult runTosaNarrowing(Operation *op, bool aggressiveRewrite,
501 bool convertFunctionBoundaries) {
502 MLIRContext *context = op->getContext();
503 const bool allowLossyConversion = aggressiveRewrite;
504
505 TypeConverter typeConverter;
506 typeConverter.addConversion([](Type type) -> Type { return type; });
507
508 typeConverter.addConversion(
509 [](IntegerType type) -> Type { return convertInteger<Kind>(type); });
510 typeConverter.addConversion(
511 [](FloatType type) -> Type { return convertFloat<Kind>(type); });
512 typeConverter.addConversion([&typeConverter](RankedTensorType type) -> Type {
513 Type elementType = type.getElementType();
514 if (!isSourceElement<Kind>(elementType))
515 return type;
516 Type converted = typeConverter.convertType(elementType);
517 if (!converted || converted == elementType)
518 return type;
519 return RankedTensorType::get(type.getShape(), converted,
520 type.getEncoding());
521 });
522 typeConverter.addConversion(
523 [&typeConverter](UnrankedTensorType type) -> Type {
524 Type elementType = type.getElementType();
525 if (!isSourceElement<Kind>(elementType))
526 return type;
527 Type converted = typeConverter.convertType(elementType);
528 if (!converted || converted == elementType)
529 return type;
530 return UnrankedTensorType::get(converted);
531 });
532
533 const auto materializeCast = [](OpBuilder &builder, Type resultType,
534 ValueRange inputs, Location loc) -> Value {
535 if (inputs.size() != 1)
536 return Value();
537 return tosa::CastOp::create(builder, loc, resultType, inputs.front());
538 };
539 typeConverter.addSourceMaterialization(materializeCast);
540 typeConverter.addTargetMaterialization(materializeCast);
541
542 if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
543 typeConverter.addTypeAttributeConversion(
544 [allowLossyConversion](IntegerType /*type*/, IntegerAttr attribute)
545 -> TypeConverter::AttributeConversionResult {
546 FailureOr<Attribute> converted =
547 tryConvertScalarAttribute<Kind>(attribute, allowLossyConversion);
548 if (failed(converted))
549 return TypeConverter::AttributeConversionResult::abort();
550 return TypeConverter::AttributeConversionResult::result(
551 converted.value());
552 });
553 typeConverter.addTypeAttributeConversion(
554 [&typeConverter, allowLossyConversion](ShapedType type,
556 -> TypeConverter::AttributeConversionResult {
557 FailureOr<Attribute> converted = convertDenseIntElementsAttr<Kind>(
558 type, attr, typeConverter, allowLossyConversion);
559 if (failed(converted))
560 return TypeConverter::AttributeConversionResult::abort();
561 return TypeConverter::AttributeConversionResult::result(
562 converted.value());
563 });
564 } else if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
565 typeConverter.addTypeAttributeConversion(
566 [allowLossyConversion](FloatType /*type*/, FloatAttr attribute)
567 -> TypeConverter::AttributeConversionResult {
568 FailureOr<Attribute> converted =
569 tryConvertScalarAttribute<Kind>(attribute, allowLossyConversion);
570 if (failed(converted))
571 return TypeConverter::AttributeConversionResult::abort();
572 return TypeConverter::AttributeConversionResult::result(
573 converted.value());
574 });
575 typeConverter.addTypeAttributeConversion(
576 [&typeConverter, allowLossyConversion](ShapedType type,
578 -> TypeConverter::AttributeConversionResult {
579 FailureOr<Attribute> converted = convertDenseFPElementsAttr<Kind>(
580 type, attr, typeConverter, allowLossyConversion);
581 if (failed(converted))
582 return TypeConverter::AttributeConversionResult::abort();
583 return TypeConverter::AttributeConversionResult::result(
584 converted.value());
585 });
586 }
587
588 ConversionTarget target(*context);
589 target.addDynamicallyLegalDialect<tosa::TosaDialect>(
590 [&typeConverter](Operation *op) {
591 return typeConverter.isLegal(op->getResultTypes()) &&
592 typeConverter.isLegal(op->getOperandTypes());
593 });
594 if (convertFunctionBoundaries) {
595 target.addDynamicallyLegalOp<func::FuncOp>(
596 [&typeConverter](func::FuncOp op) {
597 return typeConverter.isSignatureLegal(op.getFunctionType()) &&
598 typeConverter.isLegal(&op.getBody());
599 });
600 target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) {
601 const FunctionType funcType =
602 op->getParentOfType<func::FuncOp>().getFunctionType();
603 return llvm::equal(op.getOperandTypes(), funcType.getResults());
604 });
605 } else {
606 target.addDynamicallyLegalOp<func::FuncOp>(
607 [](func::FuncOp) { return true; });
608 target.addDynamicallyLegalOp<func::ReturnOp>(
609 [](func::ReturnOp) { return true; });
610 }
611
613 if (convertFunctionBoundaries) {
614 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
615 patterns, typeConverter);
617 }
618 if (aggressiveRewrite) {
619 patterns.add<ConvertGenericOp<Kind>>(typeConverter, context,
620 allowLossyConversion);
621 } else {
622 if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
623 patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
624 patterns.add<ConvertClampOpWithBoundsChecking<Kind>>(typeConverter,
625 context);
626 }
627 patterns.add<ConvertTypedOp<tosa::ConstOp, Kind>>(typeConverter, context);
628 patterns.add<ConvertTypedOp<tosa::ConcatOp, Kind>>(typeConverter, context);
629 patterns.add<ConvertTypedOp<tosa::PadOp, Kind>>(typeConverter, context);
630 patterns.add<ConvertTypedOp<tosa::ReshapeOp, Kind>>(typeConverter, context);
631 patterns.add<ConvertTypedOp<tosa::ReverseOp, Kind>>(typeConverter, context);
632 patterns.add<ConvertTypedOp<tosa::SliceOp, Kind>>(typeConverter, context);
633 patterns.add<ConvertTypedOp<tosa::TileOp, Kind>>(typeConverter, context);
634 patterns.add<ConvertTypedOp<tosa::TransposeOp, Kind>>(typeConverter,
635 context);
636 patterns.add<ConvertTypedOp<tosa::IdentityOp, Kind>>(typeConverter,
637 context);
638 patterns.add<ConvertCastOpWithBoundsChecking<Kind>>(typeConverter, context);
639 patterns.add<ConvertTypedOp<tosa::IfOp, Kind>>(typeConverter, context);
640 patterns.add<ConvertTypedOp<tosa::WhileOp, Kind>>(typeConverter, context);
641 patterns.add<ConvertTypedOp<tosa::YieldOp, Kind>>(typeConverter, context);
642 }
643
644 if (failed(applyFullConversion(op, target, std::move(patterns))))
645 return failure();
646 return success();
647}
648
649// ---------------------------------------------------------------------------
650// Pass adapters that forward to the shared implementation
651// ---------------------------------------------------------------------------
652
653struct TosaNarrowI64ToI32
654 : public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> {
655 using Base = tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32>;
656
657 TosaNarrowI64ToI32() = default;
658
659 explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options) {
660 this->aggressiveRewrite = options.aggressiveRewrite;
661 this->convertFunctionBoundaries = options.convertFunctionBoundaries;
662 }
663
664 void runOnOperation() override {
665 if (failed(runTosaNarrowing<TosaNarrowKind::Int64ToInt32>(
666 getOperation(), this->aggressiveRewrite,
667 this->convertFunctionBoundaries)))
668 signalPassFailure();
669 }
670};
671
672struct TosaNarrowF64ToF32
673 : public tosa::impl::TosaNarrowF64ToF32PassBase<TosaNarrowF64ToF32> {
674 using Base = tosa::impl::TosaNarrowF64ToF32PassBase<TosaNarrowF64ToF32>;
675
676 TosaNarrowF64ToF32() = default;
677
678 explicit TosaNarrowF64ToF32(const TosaNarrowF64ToF32PassOptions &options) {
679 this->aggressiveRewrite = options.aggressiveRewrite;
680 this->convertFunctionBoundaries = options.convertFunctionBoundaries;
681 }
682
683 void runOnOperation() override {
684 if (failed(runTosaNarrowing<TosaNarrowKind::Float64ToFloat32>(
685 getOperation(), this->aggressiveRewrite,
686 this->convertFunctionBoundaries)))
687 signalPassFailure();
688 }
689};
690
691} // namespace
return success()
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition Attributes.h:25
An attribute that represents a reference to a dense float vector or tensor object.
DenseElementsAttr mapValues(Type newElementType, function_ref< APInt(const APFloat &)> mapping) const
Generates a new DenseElementsAttr by mapping each value attribute, and constructing the DenseElements...
An attribute that represents a reference to a dense integer vector or tensor object.
DenseElementsAttr mapValues(Type newElementType, function_ref< APInt(const APInt &)> mapping) const
Generates a new DenseElementsAttr by mapping each value attribute, and constructing the DenseElements...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
This class helps build Operations.
Definition Builders.h:207
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:397
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
result_type_range getResultTypes()
Definition Operation.h:428
SuccessorRange getSuccessors()
Definition Operation.h:703
result_range getResults()
Definition Operation.h:415
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
iterator begin()
Definition Region.h:55
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Kind
An enumeration of the kinds of predicates.
Definition Predicate.h:44
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.