MLIR 23.0.0git
TosaNarrowTypes.cpp
Go to the documentation of this file.
1//===- TosaNarrowTypes.cpp ------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the TOSA narrowing passes that rewrite tensor element
10// types to narrower equivalents (i64 -> i32, f64 -> f32, ...).
11//
12//===----------------------------------------------------------------------===//
13
15
16#include "llvm/ADT/APFloat.h"
17
18#include <algorithm>
19#include <limits>
20#include <type_traits>
21
28#include "mlir/IR/Verifier.h"
29#include "mlir/Pass/Pass.h"
30
31namespace mlir {
32namespace tosa {
33#define GEN_PASS_DEF_TOSANARROWI64TOI32PASS
34#define GEN_PASS_DEF_TOSANARROWF64TOF32PASS
35#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
36} // namespace tosa
37} // namespace mlir
38
39using namespace mlir;
40using namespace mlir::tosa;
41
42namespace {
43
44// Narrowing mode for this pass.
45enum class TosaNarrowKind { Int64ToInt32, Float64ToFloat32 };
46
47// ---------------------------------------------------------------------------
48// Shared helpers
49// ---------------------------------------------------------------------------
50
51template <TosaNarrowKind Kind>
52bool isSourceInteger(IntegerType type) {
53 if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
54 return type.isInteger(64);
55 return false;
56}
57
58template <TosaNarrowKind Kind>
59bool isSourceFloat(FloatType type) {
60 if constexpr (Kind == TosaNarrowKind::Float64ToFloat32)
61 return type.isF64();
62 return false;
63}
64
65template <TosaNarrowKind Kind>
66Type convertInteger(IntegerType type) {
67 if (!isSourceInteger<Kind>(type))
68 return type;
69 if constexpr (Kind == TosaNarrowKind::Int64ToInt32)
70 return IntegerType::get(type.getContext(), 32);
71 return type;
72}
73
74template <TosaNarrowKind Kind>
75Type convertFloat(FloatType type) {
76 if (!isSourceFloat<Kind>(type))
77 return type;
78 if constexpr (Kind == TosaNarrowKind::Float64ToFloat32)
79 return Float32Type::get(type.getContext());
80 return type;
81}
82
83template <TosaNarrowKind Kind>
84bool isSourceElement(Type type) {
85 if (auto intTy = dyn_cast<IntegerType>(type))
86 return isSourceInteger<Kind>(intTy);
87 if (auto floatTy = dyn_cast<FloatType>(type))
88 return isSourceFloat<Kind>(floatTy);
89 return false;
90}
91
92template <TosaNarrowKind Kind>
93Type convertElement(Type type) {
94 if (auto intTy = dyn_cast<IntegerType>(type))
95 return convertInteger<Kind>(intTy);
96 if (auto floatTy = dyn_cast<FloatType>(type))
97 return convertFloat<Kind>(floatTy);
98 return type;
99}
100
101template <TosaNarrowKind Kind>
102bool typeNeedsConversion(Type type) {
103 if (auto shaped = dyn_cast<ShapedType>(type))
104 return isSourceElement<Kind>(shaped.getElementType());
105 return isSourceElement<Kind>(type);
106}
107
108FailureOr<APInt> convertIntegerConstant(IntegerType targetType,
109 const APInt &value,
110 bool allowLossyConversion) {
111 const unsigned targetWidth = targetType.getWidth();
112 if (!allowLossyConversion && !value.isSignedIntN(targetWidth))
113 return failure();
114
115 if (allowLossyConversion)
116 return value.truncSSat(targetWidth);
117 return value.sextOrTrunc(targetWidth);
118}
119
120FailureOr<APFloat> convertFloatConstant(FloatType targetType,
121 const APFloat &value,
122 bool allowLossyConversion) {
123 APFloat converted(value);
124 bool losesInfo = false;
125 converted.convert(targetType.getFloatSemantics(),
126 APFloat::rmNearestTiesToEven, &losesInfo);
127 if (!allowLossyConversion && losesInfo)
128 return failure();
129 return converted;
130}
131
132// Narrows scalar constant attributes so they keep matching the converted
133// element types.
134template <TosaNarrowKind Kind>
135FailureOr<Attribute> tryConvertScalarAttribute(Attribute attribute,
136 bool allowLossyConversion) {
137 if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
138 if (const auto intAttr = dyn_cast<IntegerAttr>(attribute)) {
139 if (const auto intType = dyn_cast<IntegerType>(intAttr.getType());
140 intType && isSourceInteger<Kind>(intType)) {
141 const auto convertedType =
142 cast<IntegerType>(convertInteger<Kind>(intType));
143 FailureOr<APInt> convertedValue = convertIntegerConstant(
144 convertedType, intAttr.getValue(), allowLossyConversion);
145 if (failed(convertedValue))
146 return failure();
147 return IntegerAttr::get(convertedType, convertedValue.value());
148 }
149 }
150 } else if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
151 if (const auto floatAttr = dyn_cast<FloatAttr>(attribute)) {
152 if (const auto floatType = dyn_cast<FloatType>(floatAttr.getType());
153 floatType && isSourceFloat<Kind>(floatType)) {
154 const auto convertedType =
155 cast<FloatType>(convertFloat<Kind>(floatType));
156 FailureOr<APFloat> convertedValue = convertFloatConstant(
157 convertedType, floatAttr.getValue(), allowLossyConversion);
158 if (failed(convertedValue))
159 return failure();
160 return FloatAttr::get(convertedType, convertedValue.value());
161 }
162 }
163 }
164
165 return attribute;
166}
167
168template <TosaNarrowKind Kind>
169FailureOr<Attribute>
170convertDenseIntElementsAttr(ShapedType type, DenseIntElementsAttr attr,
171 const TypeConverter &typeConverter,
172 bool allowLossyConversion) {
173 if constexpr (Kind != TosaNarrowKind::Int64ToInt32)
174 return attr;
175
176 const auto oldElementType = dyn_cast<IntegerType>(type.getElementType());
177 if (!oldElementType || !isSourceInteger<Kind>(oldElementType))
178 return attr;
179
180 const auto newType =
181 dyn_cast_or_null<ShapedType>(typeConverter.convertType(type));
182 if (!newType)
183 return failure();
184
185 const auto newElementType = dyn_cast<IntegerType>(newType.getElementType());
186 if (!newElementType)
187 return failure();
188
189 if (!allowLossyConversion) {
190 for (APInt value : attr.getValues<APInt>())
191 if (failed(convertIntegerConstant(newElementType, value,
192 /*allowLossyConversion=*/false)))
193 return failure();
194 }
195
196 Attribute convertedAttr =
197 attr.mapValues(newElementType, [&](const APInt &value) -> APInt {
198 return convertIntegerConstant(newElementType, value,
199 /*allowLossyConversion=*/true)
200 .value();
201 });
202 return convertedAttr;
203}
204
205template <TosaNarrowKind Kind>
206FailureOr<Attribute>
207convertDenseFPElementsAttr(ShapedType type, DenseFPElementsAttr attr,
208 const TypeConverter &typeConverter,
209 bool allowLossyConversion) {
210 if constexpr (Kind != TosaNarrowKind::Float64ToFloat32)
211 return attr;
212
213 const auto oldElementType = dyn_cast<FloatType>(type.getElementType());
214 if (!oldElementType || !isSourceFloat<Kind>(oldElementType))
215 return attr;
216
217 const auto newType =
218 dyn_cast_or_null<ShapedType>(typeConverter.convertType(type));
219 if (!newType)
220 return failure();
221
222 const auto newElementType = dyn_cast<FloatType>(newType.getElementType());
223 if (!newElementType)
224 return failure();
225
226 if (!allowLossyConversion) {
227 for (APFloat value : attr.getValues<APFloat>())
228 if (failed(convertFloatConstant(newElementType, value,
229 /*allowLossyConversion=*/false)))
230 return failure();
231 }
232
233 Attribute convertedAttr =
234 attr.mapValues(newElementType, [&](const APFloat &value) -> APInt {
235 APFloat converted = convertFloatConstant(newElementType, value,
236 /*allowLossyConversion=*/true)
237 .value();
238 // DenseFPElementsAttr stores each float as raw bits, so emit the APInt
239 // representation that MLIR expects in the underlying buffer.
240 return converted.bitcastToAPInt();
241 });
242 return convertedAttr;
243}
244
245template <TosaNarrowKind Kind>
246FailureOr<Attribute> convertDenseResourceElementsAttr(
247 ShapedType type, DenseResourceElementsAttr attr,
248 const TypeConverter &typeConverter, bool allowLossyConversion) {
249 static_assert(Kind == TosaNarrowKind::Int64ToInt32 ||
250 Kind == TosaNarrowKind::Float64ToFloat32);
251 using From =
252 std::conditional_t<Kind == TosaNarrowKind::Int64ToInt32, int64_t, double>;
253 using To =
254 std::conditional_t<Kind == TosaNarrowKind::Int64ToInt32, int32_t, float>;
255
256 if (Kind == TosaNarrowKind::Int64ToInt32 &&
257 !isa<DenseI64ResourceElementsAttr>(attr)) {
258 return attr;
259 }
260
261 if (Kind == TosaNarrowKind::Float64ToFloat32 &&
262 !isa<DenseF64ResourceElementsAttr>(attr)) {
263 return attr;
264 }
265
266 auto narrow = [](From value) {
267 if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
268 value = std::clamp<From>(value, std::numeric_limits<To>::min(),
269 std::numeric_limits<To>::max());
270 }
271
272 return static_cast<To>(value);
273 };
274
275 const auto newType =
276 dyn_cast_or_null<ShapedType>(typeConverter.convertType(type));
277 if (!newType) {
278 return failure();
279 }
280
281 const std::optional<ArrayRef<From>> values =
283 if (!values) {
284 return failure();
285 }
286
287 SmallVector<To> newValues;
288 newValues.reserve(values->size());
289 for (From value : *values) {
290 const To convertedValue = narrow(value);
291 if (!allowLossyConversion && convertedValue != value) {
292 return failure();
293 }
294
295 newValues.push_back(convertedValue);
296 }
297
299 ArrayRef<To>(newValues.data(), newValues.size()));
300
301 auto resourceManager =
303 resourceManager.getBlobManager().update(attr.getRawHandle().getKey(),
304 std::move(blob));
305
306 return DenseResourceElementsAttr::get(newType, attr.getRawHandle());
307}
308
309template <TosaNarrowKind Kind, typename AttrT>
310FailureOr<Attribute>
311convertAttributeWithTypeConverter(AttrT attr, Type type,
312 const TypeConverter *typeConverter) {
313 if (!typeNeedsConversion<Kind>(type))
314 return attr;
315
316 const std::optional<Attribute> convertedAttribute =
317 typeConverter->convertTypeAttribute(type, attr);
318 if (!convertedAttribute)
319 return failure();
320
321 return convertedAttribute.value();
322}
323
324// Rejects cast rewrites that would lose precision (unless aggressive mode is
325// enabled).
326template <TosaNarrowKind Kind>
327LogicalResult
328verifyCastDoesNotLosePrecision(Operation *op, ShapedType inputType,
329 ShapedType resultType,
330 ConversionPatternRewriter &rewriter) {
331 if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
332 const auto elementInputIntType =
333 dyn_cast<IntegerType>(inputType.getElementType());
334 const auto elementResultIntType =
335 dyn_cast<IntegerType>(resultType.getElementType());
336 if (elementInputIntType && elementResultIntType &&
337 elementInputIntType.getWidth() > elementResultIntType.getWidth())
338 return rewriter.notifyMatchFailure(
339 op, "Narrowing cast may lead to data loss.");
340 } else if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
341 const auto elementInputFloatType =
342 dyn_cast<FloatType>(inputType.getElementType());
343 const auto elementResultFloatType =
344 dyn_cast<FloatType>(resultType.getElementType());
345 if (elementInputFloatType && elementResultFloatType &&
346 elementInputFloatType.getIntOrFloatBitWidth() >
347 elementResultFloatType.getIntOrFloatBitWidth())
348 return rewriter.notifyMatchFailure(
349 op, "Narrowing cast may lead to data loss.");
350 }
351
352 return success();
353}
354
355// ---------------------------------------------------------------------------
356// Conversion patterns
357// ---------------------------------------------------------------------------
358
359// Applies the narrowing TypeConverter to a single TOSA op, including its
360// attributes and nested regions.
361template <TosaNarrowKind Kind>
362LogicalResult convertGenericOp(Operation *op, ValueRange operands,
363 ConversionPatternRewriter &rewriter,
364 const TypeConverter *typeConverter,
365 bool allowLossyConversion) {
366 SmallVector<Type, 4> newResults;
367 if (failed(typeConverter->convertTypes(op->getResultTypes(), newResults)))
368 return failure();
369
370 OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
371 newResults, {}, op->getSuccessors());
372
373 // Keep attribute payloads consistent with the converted element types.
374 for (const NamedAttribute &namedAttribute : op->getAttrs()) {
375 const Attribute attribute = namedAttribute.getValue();
376
377 if (isa<IntegerAttr>(attribute) || isa<FloatAttr>(attribute)) {
378 FailureOr<Attribute> convertedAttr =
379 tryConvertScalarAttribute<Kind>(attribute, allowLossyConversion);
380 if (failed(convertedAttr))
381 return rewriter.notifyMatchFailure(
382 op, "Scalar attribute narrowing would lose precision; enable "
383 "aggressive rewrite to override.");
384 state.addAttribute(namedAttribute.getName(), convertedAttr.value());
385 continue;
386 }
387
388 if (const auto typeAttr = dyn_cast<TypeAttr>(attribute)) {
389 FailureOr<Attribute> convertedAttr =
390 convertAttributeWithTypeConverter<Kind>(typeAttr, typeAttr.getValue(),
391 typeConverter);
392 if (failed(convertedAttr))
393 return rewriter.notifyMatchFailure(op,
394 "Failed to convert type attribute.");
395 state.addAttribute(namedAttribute.getName(), convertedAttr.value());
396 continue;
397 }
398
399 if (const auto denseElementsAttr = dyn_cast<DenseElementsAttr>(attribute)) {
400 FailureOr<Attribute> convertedAttr =
401 convertAttributeWithTypeConverter<Kind>(
402 denseElementsAttr, denseElementsAttr.getType(), typeConverter);
403 if (failed(convertedAttr))
404 return rewriter.notifyMatchFailure(
405 op, "Failed to convert dense elements attribute without precision "
406 "loss; enable aggressive rewrite to override.");
407 state.addAttribute(namedAttribute.getName(), convertedAttr.value());
408 continue;
409 }
410
411 if (const auto denseResourceElementsAttr =
412 dyn_cast<DenseResourceElementsAttr>(attribute)) {
413 FailureOr<Attribute> convertedAttr =
414 convertAttributeWithTypeConverter<Kind>(
415 denseResourceElementsAttr, denseResourceElementsAttr.getType(),
416 typeConverter);
417 if (failed(convertedAttr))
418 return rewriter.notifyMatchFailure(
419 op, "Failed to convert dense resource elements attribute without "
420 "precision loss; enable aggressive rewrite to override.");
421 state.addAttribute(namedAttribute.getName(), convertedAttr.value());
422 continue;
423 }
424
425 state.addAttribute(namedAttribute.getName(), attribute);
426 }
427
428 for (Region &region : op->getRegions()) {
429 Region *newRegion = state.addRegion();
430 rewriter.inlineRegionBefore(region, *newRegion, newRegion->begin());
431 if (failed(rewriter.convertRegionTypes(newRegion, *typeConverter)))
432 return failure();
433 }
434
435 Operation *newOp = rewriter.create(state);
436 rewriter.replaceOp(op, newOp->getResults());
437 return success();
438}
439
440template <TosaNarrowKind Kind>
441class ConvertGenericOp : public ConversionPattern {
442public:
443 ConvertGenericOp(TypeConverter &typeConverter, MLIRContext *context,
444 bool allowLossyConversion)
445 : ConversionPattern(typeConverter, MatchAnyOpTypeTag{}, 0, context),
446 allowLossyConversion(allowLossyConversion) {}
447
448 LogicalResult
449 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
450 ConversionPatternRewriter &rewriter) const final {
451 if (!isa<tosa::TosaOp>(op))
452 return rewriter.notifyMatchFailure(
453 op,
454 "Support for operations other than TOSA has not been implemented.");
455
456 return convertGenericOp<Kind>(op, operands, rewriter, typeConverter,
457 allowLossyConversion);
458 }
459
460private:
461 const bool allowLossyConversion;
462};
463
464template <typename OpTy, TosaNarrowKind Kind>
465class ConvertTypedOp : public OpConversionPattern<OpTy> {
466public:
467 ConvertTypedOp(TypeConverter &typeConverter, MLIRContext *context)
468 : OpConversionPattern<OpTy>(typeConverter, context) {}
469
470 LogicalResult
471 matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
472 ConversionPatternRewriter &rewriter) const final {
473 return convertGenericOp<Kind>(op, adaptor.getOperands(), rewriter,
474 this->getTypeConverter(),
475 /*allowLossyConversion=*/false);
476 }
477};
478
479// ---------------------------------------------------------------------------
480// Kind-specific helpers and patterns
481// ---------------------------------------------------------------------------
482
483// Casts get extra checking so we only narrow when it is probably safe.
484template <TosaNarrowKind Kind>
485class ConvertCastOpWithBoundsChecking
486 : public OpConversionPattern<tosa::CastOp> {
487 using OpConversionPattern<tosa::CastOp>::OpConversionPattern;
488
489 LogicalResult
490 matchAndRewrite(tosa::CastOp op, typename tosa::CastOp::Adaptor adaptor,
491 ConversionPatternRewriter &rewriter) const final {
492 const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
493 const auto resultType = dyn_cast<ShapedType>(op.getResult().getType());
494 if (!inputType || !resultType)
495 return failure();
496
497 const TypeConverter *typeConverter = this->getTypeConverter();
498 if (failed(verifyCastDoesNotLosePrecision<Kind>(op, inputType, resultType,
499 rewriter)))
500 return failure();
501
502 rewriter.replaceOpWithNewOp<tosa::CastOp>(
503 op, typeConverter->convertType(resultType), adaptor.getInput());
504 return success();
505 }
506};
507
508// ArgMax indices must fit the axis dimension, so we guard the integer rewrite.
509class ConvertArgMaxOpWithBoundsChecking
510 : public OpConversionPattern<tosa::ArgMaxOp> {
511 using OpConversionPattern::OpConversionPattern;
512
513 LogicalResult
514 matchAndRewrite(tosa::ArgMaxOp op, typename tosa::ArgMaxOp::Adaptor adaptor,
515 ConversionPatternRewriter &rewriter) const final {
516 const int32_t axis = op.getAxis();
517 const auto inputType = dyn_cast<ShapedType>(adaptor.getInput().getType());
518 if (!inputType || !inputType.isStaticDim(axis))
519 return rewriter.notifyMatchFailure(
520 op, "Requires a static axis dimension for bounds checking.");
521 const int64_t axisDim = inputType.getDimSize(axis);
522 if (axisDim >= std::numeric_limits<int32_t>::max())
523 return rewriter.notifyMatchFailure(
524 op, "Axis dimension is too large to narrow safely.");
525
526 const Type resultType = op.getOutput().getType();
527 const Type newResultType =
528 this->getTypeConverter()->convertType(resultType);
529 rewriter.replaceOpWithNewOp<tosa::ArgMaxOp>(op, newResultType,
530 adaptor.getInput(), axis);
531 return success();
532 }
533};
534
535template <TosaNarrowKind Kind>
536class ConvertClampOpWithBoundsChecking
537 : public OpConversionPattern<tosa::ClampOp> {
538 static_assert(Kind == TosaNarrowKind::Int64ToInt32,
539 "Clamp bounds checking only supported for integer narrowing");
540 using OpConversionPattern<tosa::ClampOp>::OpConversionPattern;
541
542 LogicalResult
543 matchAndRewrite(tosa::ClampOp op, typename tosa::ClampOp::Adaptor adaptor,
544 ConversionPatternRewriter &rewriter) const final {
545 auto minAttr = dyn_cast<IntegerAttr>(op.getMinValAttr());
546 auto maxAttr = dyn_cast<IntegerAttr>(op.getMaxValAttr());
547 if (!minAttr || !maxAttr)
548 return rewriter.notifyMatchFailure(
549 op, "Clamp attributes must be integer constants.");
550
551 const int64_t min = minAttr.getInt();
552 const int64_t max = maxAttr.getInt();
553 if (min < std::numeric_limits<int32_t>::min() ||
554 max > std::numeric_limits<int32_t>::max())
555 return rewriter.notifyMatchFailure(
556 op, "Clamp bounds exceed int32 range. Narrowing may lose data.");
557
558 const Type resultType = op.getOutput().getType();
559 const Type newResultType =
560 this->getTypeConverter()->convertType(resultType);
561 const auto newResultShaped = dyn_cast<ShapedType>(newResultType);
562 if (!newResultShaped)
563 return failure();
564 const auto newElementType =
565 dyn_cast<IntegerType>(newResultShaped.getElementType());
566 if (!newElementType)
567 return failure();
568
569 const IntegerAttr newMinAttr = IntegerAttr::get(newElementType, min);
570 const IntegerAttr newMaxAttr = IntegerAttr::get(newElementType, max);
571
572 rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, newResultType,
573 adaptor.getInput(), newMinAttr,
574 newMaxAttr, op.getNanModeAttr());
575 return success();
576 }
577};
578
579// Shared implementation for both narrowing passes; the mode decides which
580// element types and attribute payloads participate.
581template <TosaNarrowKind Kind>
582LogicalResult runTosaNarrowing(Operation *op, bool aggressiveRewrite,
583 bool convertFunctionBoundaries) {
584 MLIRContext *context = op->getContext();
585 const bool allowLossyConversion = aggressiveRewrite;
586
587 TypeConverter typeConverter;
588 typeConverter.addConversion([](Type type) -> Type { return type; });
589
590 typeConverter.addConversion(
591 [](IntegerType type) -> Type { return convertInteger<Kind>(type); });
592 typeConverter.addConversion(
593 [](FloatType type) -> Type { return convertFloat<Kind>(type); });
594 typeConverter.addConversion([&typeConverter](RankedTensorType type) -> Type {
595 Type elementType = type.getElementType();
596 if (!isSourceElement<Kind>(elementType))
597 return type;
598 Type converted = typeConverter.convertType(elementType);
599 if (!converted || converted == elementType)
600 return type;
601 return RankedTensorType::get(type.getShape(), converted,
602 type.getEncoding());
603 });
604 typeConverter.addConversion(
605 [&typeConverter](UnrankedTensorType type) -> Type {
606 Type elementType = type.getElementType();
607 if (!isSourceElement<Kind>(elementType))
608 return type;
609 Type converted = typeConverter.convertType(elementType);
610 if (!converted || converted == elementType)
611 return type;
612 return UnrankedTensorType::get(converted);
613 });
614
615 const auto materializeCast = [](OpBuilder &builder, Type resultType,
616 ValueRange inputs, Location loc) -> Value {
617 if (inputs.size() != 1)
618 return Value();
619 return tosa::CastOp::create(builder, loc, resultType, inputs.front());
620 };
621 typeConverter.addSourceMaterialization(materializeCast);
622 typeConverter.addTargetMaterialization(materializeCast);
623
624 typeConverter.addTypeAttributeConversion(
625 [&typeConverter, allowLossyConversion](ShapedType type,
626 DenseResourceElementsAttr attr)
627 -> TypeConverter::AttributeConversionResult {
628 FailureOr<Attribute> converted = convertDenseResourceElementsAttr<Kind>(
629 type, attr, typeConverter, allowLossyConversion);
630 if (failed(converted))
631 return TypeConverter::AttributeConversionResult::abort();
632 return TypeConverter::AttributeConversionResult::result(
633 converted.value());
634 });
635
636 if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
637 typeConverter.addTypeAttributeConversion(
638 [allowLossyConversion](IntegerType /*type*/, IntegerAttr attribute)
639 -> TypeConverter::AttributeConversionResult {
640 FailureOr<Attribute> converted =
641 tryConvertScalarAttribute<Kind>(attribute, allowLossyConversion);
642 if (failed(converted))
643 return TypeConverter::AttributeConversionResult::abort();
644 return TypeConverter::AttributeConversionResult::result(
645 converted.value());
646 });
647 typeConverter.addTypeAttributeConversion(
648 [&typeConverter, allowLossyConversion](ShapedType type,
650 -> TypeConverter::AttributeConversionResult {
651 FailureOr<Attribute> converted = convertDenseIntElementsAttr<Kind>(
652 type, attr, typeConverter, allowLossyConversion);
653 if (failed(converted))
654 return TypeConverter::AttributeConversionResult::abort();
655 return TypeConverter::AttributeConversionResult::result(
656 converted.value());
657 });
658 } else if constexpr (Kind == TosaNarrowKind::Float64ToFloat32) {
659 typeConverter.addTypeAttributeConversion(
660 [allowLossyConversion](FloatType /*type*/, FloatAttr attribute)
661 -> TypeConverter::AttributeConversionResult {
662 FailureOr<Attribute> converted =
663 tryConvertScalarAttribute<Kind>(attribute, allowLossyConversion);
664 if (failed(converted))
665 return TypeConverter::AttributeConversionResult::abort();
666 return TypeConverter::AttributeConversionResult::result(
667 converted.value());
668 });
669 typeConverter.addTypeAttributeConversion(
670 [&typeConverter, allowLossyConversion](ShapedType type,
672 -> TypeConverter::AttributeConversionResult {
673 FailureOr<Attribute> converted = convertDenseFPElementsAttr<Kind>(
674 type, attr, typeConverter, allowLossyConversion);
675 if (failed(converted))
676 return TypeConverter::AttributeConversionResult::abort();
677 return TypeConverter::AttributeConversionResult::result(
678 converted.value());
679 });
680 }
681
682 ConversionTarget target(*context);
683 target.addDynamicallyLegalDialect<tosa::TosaDialect>(
684 [&typeConverter](Operation *op) {
685 return typeConverter.isLegal(op->getResultTypes()) &&
686 typeConverter.isLegal(op->getOperandTypes());
687 });
688 if (convertFunctionBoundaries) {
689 target.addDynamicallyLegalOp<func::FuncOp>(
690 [&typeConverter](func::FuncOp op) {
691 return typeConverter.isSignatureLegal(op.getFunctionType()) &&
692 typeConverter.isLegal(&op.getBody());
693 });
694 target.addDynamicallyLegalOp<func::ReturnOp>([](func::ReturnOp op) {
695 const FunctionType funcType =
696 op->getParentOfType<func::FuncOp>().getFunctionType();
697 return llvm::equal(op.getOperandTypes(), funcType.getResults());
698 });
699 } else {
700 target.addDynamicallyLegalOp<func::FuncOp>(
701 [](func::FuncOp) { return true; });
702 target.addDynamicallyLegalOp<func::ReturnOp>(
703 [](func::ReturnOp) { return true; });
704 }
705
706 RewritePatternSet patterns(context);
707 if (convertFunctionBoundaries) {
708 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
709 patterns, typeConverter);
710 populateReturnOpTypeConversionPattern(patterns, typeConverter);
711 }
712 if (aggressiveRewrite) {
713 patterns.add<ConvertGenericOp<Kind>>(typeConverter, context,
714 allowLossyConversion);
715 } else {
716 if constexpr (Kind == TosaNarrowKind::Int64ToInt32) {
717 patterns.add<ConvertArgMaxOpWithBoundsChecking>(typeConverter, context);
718 patterns.add<ConvertClampOpWithBoundsChecking<Kind>>(typeConverter,
719 context);
720 }
721 patterns.add<ConvertTypedOp<tosa::ConstOp, Kind>>(typeConverter, context);
722 patterns.add<ConvertTypedOp<tosa::ConcatOp, Kind>>(typeConverter, context);
723 patterns.add<ConvertTypedOp<tosa::PadOp, Kind>>(typeConverter, context);
724 patterns.add<ConvertTypedOp<tosa::ReshapeOp, Kind>>(typeConverter, context);
725 patterns.add<ConvertTypedOp<tosa::ReverseOp, Kind>>(typeConverter, context);
726 patterns.add<ConvertTypedOp<tosa::SliceOp, Kind>>(typeConverter, context);
727 patterns.add<ConvertTypedOp<tosa::TileOp, Kind>>(typeConverter, context);
728 patterns.add<ConvertTypedOp<tosa::TransposeOp, Kind>>(typeConverter,
729 context);
730 patterns.add<ConvertTypedOp<tosa::IdentityOp, Kind>>(typeConverter,
731 context);
732 patterns.add<ConvertCastOpWithBoundsChecking<Kind>>(typeConverter, context);
733 patterns.add<ConvertTypedOp<tosa::IfOp, Kind>>(typeConverter, context);
734 patterns.add<ConvertTypedOp<tosa::WhileOp, Kind>>(typeConverter, context);
735 patterns.add<ConvertTypedOp<tosa::YieldOp, Kind>>(typeConverter, context);
736 }
737
738 if (failed(applyFullConversion(op, target, std::move(patterns))))
739 return failure();
740 return success();
741}
742
743// ---------------------------------------------------------------------------
744// Pass adapters that forward to the shared implementation
745// ---------------------------------------------------------------------------
746
747struct TosaNarrowI64ToI32
748 : public tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32> {
749 using Base = tosa::impl::TosaNarrowI64ToI32PassBase<TosaNarrowI64ToI32>;
750
751 TosaNarrowI64ToI32() = default;
752
753 explicit TosaNarrowI64ToI32(const TosaNarrowI64ToI32PassOptions &options) {
754 this->aggressiveRewrite = options.aggressiveRewrite;
755 this->convertFunctionBoundaries = options.convertFunctionBoundaries;
756 }
757
758 void runOnOperation() override {
759 if (failed(runTosaNarrowing<TosaNarrowKind::Int64ToInt32>(
760 getOperation(), this->aggressiveRewrite,
761 this->convertFunctionBoundaries)))
762 signalPassFailure();
763 }
764};
765
766struct TosaNarrowF64ToF32
767 : public tosa::impl::TosaNarrowF64ToF32PassBase<TosaNarrowF64ToF32> {
768 using Base = tosa::impl::TosaNarrowF64ToF32PassBase<TosaNarrowF64ToF32>;
769
770 TosaNarrowF64ToF32() = default;
771
772 explicit TosaNarrowF64ToF32(const TosaNarrowF64ToF32PassOptions &options) {
773 this->aggressiveRewrite = options.aggressiveRewrite;
774 this->convertFunctionBoundaries = options.convertFunctionBoundaries;
775 }
776
777 void runOnOperation() override {
778 if (failed(runTosaNarrowing<TosaNarrowKind::Float64ToFloat32>(
779 getOperation(), this->aggressiveRewrite,
780 this->convertFunctionBoundaries)))
781 signalPassFailure();
782 }
783};
784
785} // namespace
return success()
static llvm::Constant * convertDenseResourceElementsAttr(Location loc, DenseResourceElementsAttr denseResourceAttr, llvm::Type *llvmType, const ModuleTranslation &moduleTranslation)
Convert a dense resource elements attribute to an LLVM IR constant using its raw data storage if poss...
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This class represents a processed binary blob of data.
Definition AsmState.h:91
Attributes are known-constant values of operations.
Definition Attributes.h:25
An attribute that represents a reference to a dense float vector or tensor object.
DenseElementsAttr mapValues(Type newElementType, function_ref< APInt(const APFloat &)> mapping) const
Generates a new DenseElementsAttr by mapping each value attribute, and constructing the DenseElements...
An attribute that represents a reference to a dense integer vector or tensor object.
DenseElementsAttr mapValues(Type newElementType, function_ref< APInt(const APInt &)> mapping) const
Generates a new DenseElementsAttr by mapping each value attribute, and constructing the DenseElements...
static AsmResourceBlob allocateAndCopyInferAlign(ArrayRef< T > data, bool dataIsMutable=true)
Definition AsmState.h:212
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
This class helps build Operations.
Definition Builders.h:209
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:541
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:259
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:426
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:706
result_type_range getResultTypes()
Definition Operation.h:457
SuccessorRange getSuccessors()
Definition Operation.h:732
result_range getResults()
Definition Operation.h:444
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:237
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
iterator begin()
Definition Region.h:55
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Kind
An enumeration of the kinds of predicates.
Definition Predicate.h:44
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
std::optional< ArrayRef< T > > tryGetDenseResourceValues(ElementsAttr attr)
Include the generated interface declarations.
void populateReturnOpTypeConversionPattern(RewritePatternSet &patterns, const TypeConverter &converter, PatternBenefit benefit=1)
Add a pattern to the given pattern list to rewrite return ops to use operands that have been legalize...
static ManagerInterface & getManagerInterface(MLIRContext *ctx)
This represents an operation in an abstracted form, suitable for use with the builder APIs.