31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
35 #include <type_traits>
62 template <
typename OpTy>
70 auto nanMode = op.getNanMode();
71 if (nanMode ==
"PROPAGATE")
76 op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
78 op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
80 rewriter.
create<arith::SelectOp>(op.getLoc(), lhsIsNaN, rhs, result);
81 return rewriter.
create<arith::SelectOp>(op.getLoc(), rhsIsNaN, lhs,
93 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
94 return rewriter.
create<math::AbsFOp>(loc, resultTypes, args);
96 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
97 auto zero = rewriter.
create<arith::ConstantOp>(
99 auto neg = rewriter.
create<arith::SubIOp>(loc, zero, args[0]);
100 return rewriter.
create<arith::MaxSIOp>(loc, args[0], neg);
104 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
105 return rewriter.
create<arith::AddFOp>(loc, resultTypes, args);
107 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
108 return rewriter.
create<arith::AddIOp>(loc, resultTypes, args);
111 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
112 return rewriter.
create<arith::SubFOp>(loc, resultTypes, args);
114 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
115 return rewriter.
create<arith::SubIOp>(loc, resultTypes, args);
118 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
119 return rewriter.
create<arith::DivSIOp>(loc, resultTypes, args);
122 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
125 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, args[0]);
129 if (isa<tosa::MulOp>(op)) {
130 auto shiftVal = cast<tosa::MulOp>(op).getShift();
137 int32_t shift = shiftElem.
getValues<IntegerAttr>()[0].getInt();
139 if (isa<FloatType>(elementTy)) {
142 "Cannot have shift value for float");
145 return rewriter.
create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
148 if (isa<IntegerType>(elementTy)) {
154 rewriter.
create<arith::ConstantIntOp>(loc, shift, 8);
161 auto result = rewriter.
create<tosa::ApplyScaleOp>(
165 if (elementTy.isInteger(32))
168 return rewriter.
create<arith::TruncIOp>(loc, elementTy, result);
173 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
176 a = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], a);
178 b = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], b);
180 return rewriter.
create<arith::MulIOp>(loc, resultTypes, a, b);
185 if (isa<tosa::NegateOp>(op)) {
186 auto negate = cast<tosa::NegateOp>(op);
188 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
189 if (failed(maybeInZp)) {
191 op,
"input1 zero point cannot be statically determined");
195 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
196 if (failed(maybeOutZp)) {
198 op,
"output zero point cannot be statically determined");
202 int64_t inZp = *maybeInZp;
203 int64_t outZp = *maybeOutZp;
205 if (isa<FloatType>(elementTy))
206 return rewriter.
create<arith::NegFOp>(loc, resultTypes, args[0]);
208 if (isa<IntegerType>(elementTy)) {
209 if (!inZp && !outZp) {
210 auto constant = rewriter.
create<arith::ConstantOp>(
212 return rewriter.
create<arith::SubIOp>(loc, resultTypes, constant,
217 const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
218 const int64_t zpAdd = inZp + outZp;
219 const int64_t maxValue =
220 APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
226 int intermediateBitWidth = 64;
227 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
228 intermediateBitWidth = 16;
229 }
else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
230 intermediateBitWidth = 32;
231 }
else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
232 intermediateBitWidth = 48;
236 Value zpAddValue = rewriter.
create<arith::ConstantOp>(
242 rewriter.
create<arith::ExtSIOp>(loc, intermediateType, args[0]);
243 auto sub = rewriter.
create<arith::SubIOp>(loc, zpAddValue, ext);
247 loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
250 loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
255 return rewriter.
create<arith::TruncIOp>(loc, elementTy,
clamp);
260 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
261 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
264 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
265 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
268 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
270 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
271 auto allOnes = rewriter.
create<arith::ConstantOp>(loc, allOnesAttr);
272 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
276 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
277 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
280 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
281 return rewriter.
create<arith::ShLIOp>(loc, resultTypes, args);
284 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
285 return rewriter.
create<arith::ShRUIOp>(loc, resultTypes, args);
288 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
289 auto result = rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args);
290 auto round = cast<BoolAttr>(op->
getAttr(
"round")).getValue();
304 auto shiftValueGreaterThanZero = rewriter.
create<arith::CmpIOp>(
305 loc, arith::CmpIPredicate::sgt, args[1], zero);
309 rewriter.
create<arith::SubIOp>(loc, resultTypes, args[1], one);
311 rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
314 rewriter.
create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
316 rewriter.
create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
318 auto shouldRound = rewriter.
create<arith::AndIOp>(
319 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
321 rewriter.
create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
322 return rewriter.
create<arith::AddIOp>(loc, resultTypes, result, extended);
326 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
327 return rewriter.
create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
331 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
332 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
335 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
336 auto one = rewriter.
create<arith::ConstantOp>(
338 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], one);
342 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
343 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
346 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
347 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
350 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
351 return rewriter.
create<mlir::math::PowFOp>(loc, resultTypes, args);
354 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
355 return rewriter.
create<mlir::math::RsqrtOp>(loc, resultTypes, args);
358 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
359 return rewriter.
create<mlir::math::LogOp>(loc, resultTypes, args);
362 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
363 return rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, args);
366 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
367 return rewriter.
create<mlir::math::SinOp>(loc, resultTypes, args);
370 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
371 return rewriter.
create<mlir::math::CosOp>(loc, resultTypes, args);
374 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
375 return rewriter.
create<mlir::math::TanhOp>(loc, resultTypes, args);
378 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
379 return rewriter.
create<mlir::math::ErfOp>(loc, resultTypes, args);
382 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
383 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
386 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
387 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
391 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
392 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
395 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
396 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
400 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
401 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
404 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
405 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
409 if (isa<tosa::SelectOp>(op)) {
411 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
412 return rewriter.
create<arith::SelectOp>(loc, args[0], args[1], args[2]);
416 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
417 auto max = rewriter.
create<arith::MaximumFOp>(loc, args[0], args[1]);
419 rewriter, args[0], args[1],
max);
422 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
423 return rewriter.
create<arith::MaxSIOp>(loc, args[0], args[1]);
427 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
428 auto min = rewriter.
create<arith::MinimumFOp>(loc, args[0], args[1]);
430 rewriter, args[0], args[1],
min);
433 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
434 return rewriter.
create<arith::MinSIOp>(loc, args[0], args[1]);
438 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
439 return rewriter.
create<math::CeilOp>(loc, resultTypes, args);
442 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
443 return rewriter.
create<math::FloorOp>(loc, resultTypes, args);
446 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
447 bool losesInfo =
false;
448 APFloat minApf = cast<FloatAttr>(op->
getAttr(
"min_val")).getValue();
449 APFloat maxApf = cast<FloatAttr>(op->
getAttr(
"max_val")).getValue();
450 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
451 APFloat::rmNearestTiesToEven, &losesInfo);
452 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
453 APFloat::rmNearestTiesToEven, &losesInfo);
454 auto min = rewriter.
create<arith::ConstantOp>(
455 loc, elementTy, rewriter.
getFloatAttr(elementTy, minApf));
456 auto max = rewriter.
create<arith::ConstantOp>(
457 loc, elementTy, rewriter.
getFloatAttr(elementTy, maxApf));
460 auto clampOp = llvm::cast<tosa::ClampOp>(op);
461 const auto nanMode = clampOp.getNanMode();
464 if (!isa<FloatType>(elementTy))
469 if (nanMode ==
"PROPAGATE")
484 op->
getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
487 return rewriter.
create<arith::SelectOp>(op->
getLoc(), isNaN,
min, result);
490 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
491 auto intTy = cast<IntegerType>(elementTy);
493 cast<IntegerAttr>(op->
getAttr(
"min_val")).getValue().getSExtValue();
495 cast<IntegerAttr>(op->
getAttr(
"max_val")).getValue().getSExtValue();
499 if (intTy.isUnsignedInteger()) {
500 minRepresentable = 0;
501 if (intTy.getIntOrFloatBitWidth() <= 63) {
503 (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
506 }
else if (intTy.getIntOrFloatBitWidth() <= 64) {
508 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
510 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
520 auto minVal = rewriter.
create<arith::ConstantIntOp>(
521 loc,
min, intTy.getIntOrFloatBitWidth());
522 auto maxVal = rewriter.
create<arith::ConstantIntOp>(
523 loc,
max, intTy.getIntOrFloatBitWidth());
525 intTy.isUnsignedInteger());
529 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
532 auto negate = rewriter.
create<arith::NegFOp>(loc, resultTypes, args[0]);
533 auto exp = rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, negate);
534 auto added = rewriter.
create<arith::AddFOp>(loc, resultTypes, exp, one);
535 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, added);
539 if (isa<tosa::CastOp>(op)) {
540 Type srcTy = elementTy;
541 Type dstTy = resultTypes.front();
553 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
554 return rewriter.
create<arith::ExtFOp>(loc, resultTypes, args,
557 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
558 return rewriter.
create<arith::TruncFOp>(loc, resultTypes, args,
562 if (srcTy.
isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
563 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes, args,
566 if (srcTy.
isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
567 return rewriter.
create<arith::ExtUIOp>(loc, resultTypes, args,
573 auto unrealizedCast =
575 .
create<UnrealizedConversionCastOp>(
579 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes[0],
584 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
585 return rewriter.
create<arith::SIToFPOp>(loc, resultTypes, args,
589 if (isa<FloatType>(srcTy) && dstTy.
isInteger(1)) {
592 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
596 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
597 auto rounded = rewriter.
create<math::RoundEvenOp>(loc, args[0]);
599 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
603 APFloat::semanticsMaxExponent(fltSemantics)) {
606 auto conv = rewriter.
create<arith::FPToSIOp>(loc, dstTy, rounded);
607 auto posInf = rewriter.
create<arith::ConstantOp>(
609 APFloat::getInf(fltSemantics)));
610 auto negInf = rewriter.
create<arith::ConstantOp>(
613 APFloat::getInf(fltSemantics,
true)));
614 auto overflow = rewriter.
create<arith::CmpFOp>(
615 loc, arith::CmpFPredicate::UEQ, rounded, posInf);
616 auto underflow = rewriter.
create<arith::CmpFOp>(
617 loc, arith::CmpFPredicate::UEQ, rounded, negInf);
618 auto intMin = rewriter.
create<arith::ConstantOp>(
621 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
622 auto intMax = rewriter.
create<arith::ConstantOp>(
625 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
627 rewriter.
create<arith::SelectOp>(loc, overflow, intMax, conv);
628 return rewriter.
create<arith::SelectOp>(loc, underflow, intMin,
632 auto intMinFP = rewriter.
create<arith::ConstantOp>(
639 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
645 auto intMaxFP = rewriter.
create<arith::ConstantOp>(
653 return rewriter.
create<arith::FPToSIOp>(loc, dstTy, clamped);
660 auto intMaxPlusOneFP = rewriter.
create<arith::ConstantOp>(
668 auto intMax = rewriter.
create<arith::ConstantOp>(
673 rewriter.
create<arith::MaximumFOp>(loc, rounded, intMinFP);
675 rewriter.
create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
676 auto overflow = rewriter.
create<arith::CmpFOp>(
677 loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
678 return rewriter.
create<arith::SelectOp>(loc, overflow, intMax,
684 if (isa<IntegerType>(srcTy) && dstTy.
isInteger(1)) {
685 Value zero = rewriter.
create<arith::ConstantIntOp>(
687 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
691 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
692 return rewriter.
create<arith::ExtSIOp>(loc, resultTypes, args,
695 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
696 return rewriter.
create<arith::TruncIOp>(loc, dstTy, args[0]);
701 op,
"unhandled op for linalg body calculation for elementwise op");
712 auto [it, inserted] = indexPool.try_emplace(index);
721 auto indexValue =
createIndex(rewriter, loc, indexPool, index);
722 return rewriter.
create<tensor::DimOp>(loc, tensor, indexValue).getResult();
728 auto shapedType = dyn_cast<ShapedType>(tensor.
getType());
729 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
730 assert(index >= 0 && index < shapedType.getRank() &&
"index out of bounds");
731 if (shapedType.isDynamicDim(index))
732 return getTensorDim(rewriter, loc, indexPool, tensor, index);
733 return rewriter.
getIndexAttr(shapedType.getDimSize(index));
737 auto isRanked = [](
Value value) {
738 return isa<RankedTensorType>(value.getType());
740 return llvm::all_of(operation->
getOperands(), isRanked) &&
741 llvm::all_of(operation->
getResults(), isRanked);
754 static std::pair<OpFoldResult, Value>
760 for (
auto operand : operands) {
761 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
762 if (!ShapedType::isDynamic(size) && size > 1)
767 auto operandsWithDynamicDim =
768 llvm::filter_to_vector(operands, [&](
Value operand) {
769 return cast<RankedTensorType>(operand.
getType()).isDynamicDim(dim);
773 if (operandsWithDynamicDim.empty())
780 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
781 if (operandsWithDynamicDim.size() == 1)
782 return {targetSize, operandsWithDynamicDim[0]};
785 for (
size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
787 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
788 targetSize = rewriter.
create<arith::MaxUIOp>(loc, targetSize, nextSize);
790 return {targetSize,
nullptr};
798 assert(!operands.empty());
799 auto rank = cast<RankedTensorType>(operands.front().
getType()).getRank();
802 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
803 auto [targetSize, masterOperand] =
805 targetShape.push_back(targetSize);
806 masterOperands.push_back(masterOperand);
808 return {targetShape, masterOperands};
814 Value masterOperand) {
816 auto rankedTensorType = cast<RankedTensorType>(operand.
getType());
817 if (!rankedTensorType.isDynamicDim(dim))
824 if (operand == masterOperand)
828 auto rank = rankedTensorType.getRank();
830 for (
auto index : llvm::seq<int64_t>(0, rank)) {
833 affineExprs.push_back(affineExpr);
835 auto broadcastAffineMap =
841 auto one =
createIndex(rewriter, loc, indexPool, 1);
842 auto runtimeSize =
getTensorDim(rewriter, loc, indexPool, operand, dim);
843 auto broadcastNecessary = rewriter.
create<arith::CmpIOp>(
844 loc, arith::CmpIPredicate::eq, runtimeSize, one);
854 for (
auto index : llvm::seq<int64_t>(0, rank)) {
855 auto size = index == dim ? targetSize
858 outputTensorShape.push_back(size);
860 Value outputTensor = opBuilder.
create<tensor::EmptyOp>(
861 loc, outputTensorShape, rankedTensorType.getElementType());
866 .
create<linalg::GenericOp>(
867 loc, outputTensor.
getType(), operand, outputTensor, affineMaps,
871 opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
876 auto castResultTensor = rewriter.
createOrFold<tensor::CastOp>(
877 loc, operand.
getType(), resultTensor);
880 opBuilder.create<scf::YieldOp>(loc, castResultTensor);
885 opBuilder.
create<scf::YieldOp>(loc, operand);
889 auto ifOp = rewriter.
create<scf::IfOp>(loc, broadcastNecessary,
890 emitThenRegion, emitElseRegion);
898 int64_t rank = cast<RankedTensorType>(operand.
getType()).getRank();
899 assert((int64_t)targetShape.size() == rank);
900 assert((int64_t)masterOperands.size() == rank);
901 for (
auto index : llvm::seq<int64_t>(0, rank))
904 targetShape[index], masterOperands[index]);
914 if (operands.size() == 1)
918 return llvm::map_to_vector(operands, [&](
Value operand) {
920 targetShape, masterOperands);
930 auto resultType = cast_or_null<RankedTensorType>(
935 Value outputTensor = rewriter.
create<tensor::EmptyOp>(
936 loc, targetShape, resultType.getElementType());
941 auto rank = resultType.getRank();
942 auto affineMaps = llvm::map_to_vector(operands, [&](
Value operand) {
943 auto shape = cast<ShapedType>(operand.
getType()).getShape();
949 bool requiresBroadcast =
950 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
951 auto affineExpr = requiresBroadcast
954 affineExprs.push_back(affineExpr);
961 bool encounteredError =
false;
962 auto linalgOp = rewriter.
create<linalg::GenericOp>(
963 loc, outputTensor.
getType(), operands, outputTensor, affineMaps,
968 {resultType.getElementType()}, rewriter);
970 encounteredError =
true;
973 opBuilder.create<linalg::YieldOp>(loc, opResult);
975 if (encounteredError)
977 operation,
"unable to create linalg.generic body for elementwise op");
980 auto castResult = rewriter.
createOrFold<tensor::CastOp>(
981 loc, resultType, linalgOp->getResult(0));
982 rewriter.
replaceOp(operation, castResult);
989 if (isa<tosa::MulOp>(operation))
990 return operands.take_front(2);
992 if (isa<tosa::NegateOp>(operation))
993 return operands.take_front(1);
1003 assert(operation->
getNumResults() == 1 &&
"elementwise op expects 1 result");
1005 "elementwise op expects at least 1 operand");
1008 "Unranked tensors not supported");
1012 auto loc = operation->
getLoc();
1014 auto [targetShape, masterOperands] =
1016 auto broadcastOperands =
1018 targetShape, masterOperands);
1020 targetShape, converter);
1027 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1030 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1033 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1036 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1039 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1041 elementTy, APFloat::getLargest(
1042 cast<FloatType>(elementTy).getFloatSemantics(),
false));
1044 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1048 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1050 elementTy, APFloat::getLargest(
1051 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1053 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1057 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1060 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1063 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1065 elementTy, APFloat::getLargest(
1066 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1068 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1082 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1083 return rewriter.
create<arith::AddFOp>(loc, args);
1086 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1087 return rewriter.
create<arith::AddIOp>(loc, args);
1090 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1091 return rewriter.
create<arith::MulFOp>(loc, args);
1094 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1095 return rewriter.
create<arith::MulIOp>(loc, args);
1098 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1099 return rewriter.
create<arith::MinimumFOp>(loc, args[0], args[1]);
1102 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1103 return rewriter.
create<arith::MinSIOp>(loc, args[0], args[1]);
1106 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1107 return rewriter.
create<arith::MaximumFOp>(loc, args[0], args[1]);
1110 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1111 return rewriter.
create<arith::MaxSIOp>(loc, args[0], args[1]);
1114 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1115 return rewriter.
create<arith::AndIOp>(loc, args);
1117 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1118 return rewriter.
create<arith::OrIOp>(loc, args);
1126 template <
typename OpTy>
1129 auto loc = op->getLoc();
1130 auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1131 auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1132 if (!inputTy || !resultTy)
1135 auto elementTy = resultTy.getElementType();
1136 Value input = op->getOperand(0);
1140 for (
unsigned i = 0; i < inputTy.getRank(); i++) {
1142 reduceShape.push_back(inputTy.getDimSize(i));
1143 if (inputTy.isDynamicDim(i))
1144 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1149 inputs.push_back(input);
1154 .
create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1161 op,
"No initial value found for reduction operation");
1163 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
1164 auto filledTensor = rewriter
1168 outputs.push_back(filledTensor);
1170 bool isNanIgnoreMode =
false;
1171 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1172 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1174 if (isa<FloatType>(elementTy) && op.getNanMode() ==
"IGNORE") {
1175 isNanIgnoreMode =
true;
1181 auto trueValue = rewriter.
create<arith::ConstantOp>(loc, trueAttr);
1182 auto emptyBoolTensor =
1184 .
create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
1187 auto allResultsNaNTensor =
1201 inputs.push_back(input);
1202 outputs.push_back(allResultsNaNTensor);
1206 bool didEncounterError =
false;
1207 linalg::LinalgOp linalgOp = rewriter.
create<linalg::ReduceOp>(
1208 loc, inputs, outputs, axis,
1210 std::array<Value, 2> binaryArgs{
1211 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1213 op, binaryArgs, elementTy, rewriter);
1215 didEncounterError =
true;
1218 if (isNanIgnoreMode) {
1219 auto inputValue = blockArgs[0];
1220 auto initialValue = blockArgs[2];
1221 auto oldAllResultsNanFlagValue = blockArgs[3];
1224 Value isNaN = nestedBuilder.create<arith::CmpFOp>(
1225 op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);
1227 auto selectOp = nestedBuilder.create<arith::SelectOp>(
1228 op->getLoc(), isNaN, initialValue, result);
1231 auto newAllResultsNanFlagValue = nestedBuilder.create<arith::AndIOp>(
1232 op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1233 resultsToYield.push_back(selectOp);
1234 resultsToYield.push_back(newAllResultsNanFlagValue);
1236 resultsToYield.push_back(result);
1238 nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
1241 if (!didEncounterError)
1243 op,
"unable to create linalg.generic body for reduce op");
1245 if (isNanIgnoreMode) {
1254 APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(),
false));
1255 auto nanValue = rewriter.
create<arith::ConstantOp>(loc, nanValueAttr);
1256 auto emptyNanTensor =
1258 .
create<tensor::EmptyOp>(loc, reduceShape,
1259 resultTy.getElementType(), dynDims)
1261 auto nanFilledTensor =
1269 auto finalEmptyTensor =
1271 .
create<tensor::EmptyOp>(loc, reduceShape,
1272 resultTy.getElementType(), dynDims)
1278 ins.push_back(linalgOp->getOpResult(1));
1279 ins.push_back(nanFilledTensor);
1280 ins.push_back(linalgOp->getResult(0));
1281 outs.push_back(finalEmptyTensor);
1283 rewriter.
create<linalg::SelectOp>(op->getLoc(), ins, outs);
1284 linalgOp = linalgSelect;
1288 uint64_t expandInputRank =
1289 cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
1290 reassociationMap.resize(expandInputRank);
1292 for (uint64_t i = 0; i < expandInputRank; i++) {
1293 int32_t dimToPush = i > axis ? i + 1 : i;
1297 if (expandInputRank != 0) {
1298 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1299 reassociationMap[expandedDim].push_back(
1308 op, resultTy, linalgOp->getResults()[0], reassociationMap);
1314 template <
typename SrcOp>
1321 matchAndRewrite(SrcOp op, OpAdaptor operands,
1324 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1332 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1334 auto loc = op.getLoc();
1335 auto input = op.getInput();
1336 auto inputTy = cast<ShapedType>(op.getInput().getType());
1337 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1338 unsigned rank = inputTy.getRank();
1341 if (op.getRoundingMode() ==
"INEXACT_ROUND")
1342 return rewriter.notifyMatchFailure(
1343 op,
"tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1344 "currently supported");
1345 if (op.getRoundingMode() ==
"DOUBLE_ROUND" && !op.getScale32())
1346 return rewriter.notifyMatchFailure(
1347 op,
"tosa.rescale requires scale32 for double_round to be true");
1349 if (!isa<IntegerType>(inputTy.getElementType()))
1350 return rewriter.notifyMatchFailure(op,
"only support integer type");
1353 for (
int i = 0; i < outputTy.getRank(); i++) {
1354 if (outputTy.isDynamicDim(i)) {
1355 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1362 return rewriter.notifyMatchFailure(
1363 op,
"tosa.rescale requires constant shift input values");
1367 return rewriter.notifyMatchFailure(
1368 op,
"tosa.rescale requires constant multiplier input values");
1371 llvm::to_vector(shiftElems.
getValues<int8_t>());
1374 llvm::map_range(multiplierElems.
getValues<IntegerAttr>(),
1375 [](IntegerAttr attr) -> int32_t {
1376 return static_cast<int32_t>(attr.getInt());
1380 for (
int i = 0, s = multiplierValues.size(); i < s; i++) {
1381 if (shiftValues[i] > 63) {
1383 multiplierValues[i] = 0;
1391 op.getRoundingMode() ==
"DOUBLE_ROUND" &&
1392 llvm::any_of(shiftValues, [](int32_t v) {
return v > 31; });
1393 StringAttr roundingMode = doubleRound
1394 ? rewriter.getStringAttr(
"DOUBLE_ROUND")
1395 : rewriter.getStringAttr(
"SINGLE_ROUND");
1398 rewriter.getMultiDimIdentityMap(rank)};
1403 Value multiplierConstant;
1404 int64_t multiplierArg = 0;
1405 if (multiplierValues.size() == 1) {
1406 multiplierConstant = rewriter.create<arith::ConstantOp>(
1407 loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1410 rewriter.getAffineDimExpr(rank - 1)};
1411 auto multiplierType =
1413 rewriter.getI32Type());
1414 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1419 rewriter.getContext()));
1421 multiplierArg = indexingMaps.size() - 1;
1426 Value shiftConstant;
1427 int64_t shiftArg = 0;
1428 if (shiftValues.size() == 1) {
1429 shiftConstant = rewriter.create<arith::ConstantOp>(
1430 loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1433 rewriter.getAffineDimExpr(rank - 1)};
1436 rewriter.getIntegerType(8));
1437 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1441 rewriter.getContext()));
1442 shiftArg = indexingMaps.size() - 1;
1446 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1449 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1450 loc, outputTy.getShape(), outputTy.getElementType(),
1453 auto linalgOp = rewriter.create<linalg::GenericOp>(
1454 loc, outputTy, genericInputs,
ValueRange{emptyTensor}, indexingMaps,
1458 Value value = blockArgs[0];
1461 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1462 if (failed(maybeIZp)) {
1463 (void)rewriter.notifyMatchFailure(
1464 op,
"input zero point cannot be statically determined");
1470 const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
1471 auto inputZp = nestedBuilder.create<arith::ConstantOp>(
1475 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1476 if (failed(maybeOZp)) {
1477 (void)rewriter.notifyMatchFailure(
1478 op,
"output zero point cannot be statically determined");
1482 IntegerType outIntType =
1483 cast<IntegerType>(blockArgs.back().getType());
1484 unsigned outBitWidth = outIntType.getWidth();
1485 const int32_t outAttrBitwidth = 32;
1486 assert(outBitWidth <= 32 &&
"Unexpected output zeropoint bitwidth");
1487 auto outputZp = nestedBuilder.create<arith::ConstantOp>(
1491 Value multiplier = multiplierConstant ? multiplierConstant
1492 : blockArgs[multiplierArg];
1493 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1496 value = nestedBuilder
1497 .create<UnrealizedConversionCastOp>(
1499 nestedBuilder.getIntegerType(
1505 if (op.getInputUnsigned()) {
1506 value = nestedBuilder.create<arith::ExtUIOp>(
1507 nestedLoc, nestedBuilder.getI32Type(), value);
1509 value = nestedBuilder.create<arith::ExtSIOp>(
1510 nestedLoc, nestedBuilder.getI32Type(), value);
1515 nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1517 value = nestedBuilder.create<tosa::ApplyScaleOp>(
1518 loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1523 nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1526 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1527 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1530 if (op.getOutputUnsigned()) {
1532 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1535 auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1536 loc, nestedBuilder.getI32IntegerAttr(intMin));
1537 auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1538 loc, nestedBuilder.getI32IntegerAttr(intMax));
1541 nestedBuilder,
false);
1543 if (outIntType.getWidth() < 32) {
1544 value = nestedBuilder.create<arith::TruncIOp>(
1545 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1549 if (outIntType.isUnsignedInteger()) {
1550 value = nestedBuilder
1551 .create<UnrealizedConversionCastOp>(nestedLoc,
1555 nestedBuilder.create<linalg::YieldOp>(loc, value);
1558 rewriter.replaceOp(op, linalgOp->getResults());
1570 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1574 auto input = op.getInput();
1575 auto inputTy = cast<RankedTensorType>(input.getType());
1576 auto resultTy = cast<RankedTensorType>(op.getType());
1577 const bool isBilinear = op.getMode() ==
"BILINEAR";
1579 auto inputH = inputTy.getDimSize(1);
1580 auto inputW = inputTy.getDimSize(2);
1581 auto outputH = resultTy.getDimSize(1);
1582 auto outputW = resultTy.getDimSize(2);
1584 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1585 return rewriter.notifyMatchFailure(
1586 op,
"tosa.resize is not a pure 1x1->1x1 image operation");
1589 if (op.getMode() !=
"NEAREST_NEIGHBOR" && op.getMode() !=
"BILINEAR")
1590 return rewriter.notifyMatchFailure(
1591 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1593 if (inputTy == resultTy) {
1594 rewriter.replaceOp(op, input);
1605 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1606 reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1607 reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1608 reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1612 inputTy.getElementType());
1613 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1618 if (inputTy.isDynamicDim(0))
1619 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1620 if (inputTy.isDynamicDim(3))
1621 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1624 auto genericTy = collapseTy.clone(resultTy.getElementType());
1625 Value empty = builder.create<tensor::EmptyOp>(
1626 genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1627 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1629 utils::IteratorType::parallel);
1631 auto generic = builder.create<linalg::GenericOp>(
1635 Value value = args[0];
1637 if (inputTy.getElementType() != resultTy.getElementType()) {
1639 b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1641 if (isBilinear && scale[0] != 0) {
1642 Value scaleY = b.create<arith::ConstantOp>(
1643 loc, b.getI32IntegerAttr(scale[0]));
1644 value = b.create<arith::MulIOp>(loc, value, scaleY);
1647 if (isBilinear && scale[2] != 0) {
1648 Value scaleX = b.create<arith::ConstantOp>(
1649 loc, b.getI32IntegerAttr(scale[2]));
1650 value = b.create<arith::MulIOp>(loc, value, scaleX);
1654 b.create<linalg::YieldOp>(loc, value);
1657 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1658 op, resultTy,
generic.getResults()[0], reassociationMap);
1670 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1674 auto input = op.getInput();
1675 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1676 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1678 if (!inputTy || !resultTy)
1679 return rewriter.notifyMatchFailure(op,
1680 "requires ranked input/output types");
1682 auto batch = inputTy.getDimSize(0);
1683 auto channels = inputTy.getDimSize(3);
1684 auto inputH = inputTy.getDimSize(1);
1685 auto inputW = inputTy.getDimSize(2);
1686 auto outputH = resultTy.getDimSize(1);
1687 auto outputW = resultTy.getDimSize(2);
1689 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1690 return rewriter.notifyMatchFailure(
1691 op,
"tosa.resize has no broadcasting behavior");
1696 resizeShape.push_back(batch);
1697 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1698 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1699 resizeShape.push_back(channels);
1701 auto resizeTy = resultTy.clone(resizeShape);
1702 auto resize = builder.create<tosa::ResizeOp>(resizeTy, input, op.getScale(),
1703 op.getOffset(), op.getBorder(),
1708 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1709 reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1711 reassociationMap.push_back({});
1712 reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1714 reassociationMap.push_back({});
1715 reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1719 collapseShape.push_back(outputH);
1721 collapseShape.push_back(outputW);
1722 collapseShape.push_back(channels);
1724 auto collapseTy = resultTy.clone(collapseShape);
1725 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1730 if (inputTy.isDynamicDim(0))
1731 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1732 if (inputTy.isDynamicDim(3))
1733 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1736 utils::IteratorType::parallel);
1737 Value empty = builder.create<tensor::EmptyOp>(
1738 resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1742 inputExprs.push_back(rewriter.getAffineDimExpr(1));
1744 inputExprs.push_back(rewriter.getAffineDimExpr(2));
1745 inputExprs.push_back(rewriter.getAffineDimExpr(3));
1748 inputExprs, rewriter.getContext());
1750 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1751 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1755 Value value = args[0];
1756 b.create<linalg::YieldOp>(loc, value);
1767 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1771 auto input = op.getInput();
1772 auto inputTy = cast<ShapedType>(input.getType());
1773 auto resultTy = cast<ShapedType>(op.getType());
1774 auto resultETy = resultTy.getElementType();
1776 bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1777 auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1779 auto imageH = inputTy.getShape()[1];
1780 auto imageW = inputTy.getShape()[2];
1782 auto dynamicDimsOr =
1784 if (!dynamicDimsOr.has_value())
1785 return rewriter.notifyMatchFailure(
1786 op,
"unable to get dynamic dimensions of tosa.resize");
1788 if (op.getMode() !=
"NEAREST_NEIGHBOR" && op.getMode() !=
"BILINEAR")
1789 return rewriter.notifyMatchFailure(
1790 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1793 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1794 auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1796 auto genericOp = b.create<linalg::GenericOp>(
1799 Value resize = genericOp.getResult(0);
1803 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1805 Value batch = b.create<linalg::IndexOp>(0);
1806 Value y = b.create<linalg::IndexOp>(1);
1807 Value x = b.create<linalg::IndexOp>(2);
1808 Value channel = b.create<linalg::IndexOp>(3);
1811 b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1812 Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1813 Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1814 Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1816 Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1817 Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1823 return rewriter.notifyMatchFailure(
1824 op,
"tosa.resize scale/offset/border should have compile time "
1825 "constant values.");
1828 Value yScaleN, yScaleD, xScaleN, xScaleD;
1829 yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1830 yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1831 xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1832 xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1834 Value yOffset, xOffset, yBorder, xBorder;
1835 yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1836 xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1837 yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1838 xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1851 Value val = b.create<arith::MulIOp>(in, scaleD);
1852 val = b.create<arith::AddIOp>(val, offset);
1853 index = b.create<arith::FloorDivSIOp>(val, scaleN);
1857 Value r = b.create<arith::RemSIOp>(val, scaleN);
1858 Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1859 Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1860 delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1875 Value val = b.create<arith::MulIOp>(in, scaleD);
1876 val = b.create<arith::AddIOp>(val, offset);
1877 index = b.create<arith::DivSIOp>(val, scaleN);
1878 delta = b.create<arith::MulIOp>(index, scaleN);
1879 delta = b.create<arith::SubIOp>(val, delta);
1882 Value ix, iy, dx, dy;
1883 if (floatingPointMode) {
1884 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1885 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1887 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1888 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1891 if (op.getMode() ==
"NEAREST_NEIGHBOR") {
1892 auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1894 auto getNearestIndexAndClamp = [&](
Value val,
Value dval,
Value scale,
1898 return b.create<arith::ConstantIndexOp>(0);
1902 if (floatingPointMode) {
1903 auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1904 pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1906 Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1907 pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1911 auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1912 val = b.create<arith::AddIOp>(val, offset);
1914 return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1917 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1918 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1920 Value result = b.create<tensor::ExtractOp>(
1923 b.create<linalg::YieldOp>(result);
1926 assert(op.getMode() ==
"BILINEAR");
1928 auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1930 auto getClampedIdxs = [&](
Value &val0,
Value &val1,
int size,
Value in,
1933 val1 = b.create<arith::AddIOp>(val0, oneVal);
1938 val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1939 val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1947 Value x0, x1, y0, y1;
1948 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1949 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1951 Value y0x0 = b.create<tensor::ExtractOp>(
1953 Value y0x1 = b.create<tensor::ExtractOp>(
1955 Value y1x0 = b.create<tensor::ExtractOp>(
1957 Value y1x1 = b.create<tensor::ExtractOp>(
1960 if (floatingPointMode) {
1962 b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1968 Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1969 Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1970 Value mul1 = b.create<arith::MulFOp>(val1, delta);
1971 return b.create<arith::AddFOp>(mul0, mul1);
1977 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1982 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1986 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1987 b.create<linalg::YieldOp>(result);
1990 y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1991 y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1992 y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1993 y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1996 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1997 dx = b.create<arith::ExtSIOp>(resultETy, dx);
1998 dy = b.create<arith::ExtSIOp>(resultETy, dy);
2001 Value yScaleNExt = yScaleN;
2002 Value xScaleNExt = xScaleN;
2004 const int64_t scaleBitwidth =
2006 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2007 yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
2008 xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
2012 Value scale,
int inputSize,
2015 return b.create<arith::MulIOp>(val0, scale);
2016 Value weight0 = b.create<arith::SubIOp>(scale, weight1);
2017 Value mul0 = b.create<arith::MulIOp>(val0, weight0);
2018 Value mul1 = b.create<arith::MulIOp>(val1, weight1);
2019 return b.create<arith::AddIOp>(mul0, mul1);
2022 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
2023 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
2025 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
2026 b.create<linalg::YieldOp>(result);
2031 rewriter.replaceOp(op, resize);
2039 template <
typename SrcOp>
2044 LogicalResult matchAndRewrite(SrcOp op,
2046 rewriter.replaceOp(op, op.getOperation()->getOperands());
2051 template <
typename SrcOp>
2056 LogicalResult matchAndRewrite(SrcOp reduceOp,
2066 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2068 auto loc = op.getLoc();
2069 Value input = op.getInput1();
2070 auto inputTy = cast<ShapedType>(input.
getType());
2071 auto resultTy = cast<ShapedType>(op.getType());
2072 auto axis = op.getAxis();
2075 for (
int i = 0; i < inputTy.getRank(); i++) {
2076 if (inputTy.isDynamicDim(i)) {
2077 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2081 Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
2084 auto emptyTensor = rewriter
2085 .create<tensor::EmptyOp>(loc, inputTy.getShape(),
2086 inputTy.getElementType(),
2090 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2092 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2097 for (
unsigned int i = 0; i < inputTy.getRank(); i++) {
2099 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
2101 auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
2103 rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
2104 index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
2108 indices.push_back(index);
2111 auto extract = nestedBuilder.create<tensor::ExtractOp>(
2112 nestedLoc, input, indices);
2113 nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
2114 extract.getResult());
2128 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2130 auto loc = op.getLoc();
2131 auto input = op.getInput1();
2132 auto inputTy = cast<ShapedType>(input.getType());
2133 auto inputShape = inputTy.getShape();
2134 auto resultTy = cast<ShapedType>(op.getType());
2135 auto elementTy = inputTy.getElementType();
2136 int64_t rank = inputTy.getRank();
2139 if (failed(op.getConstantMultiples(multiples)))
2144 for (
int i = 0; i < rank; i++) {
2145 int64_t dim = multiples[i];
2146 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2147 genericShape.push_back(inputShape[i]);
2151 for (
int i = 0; i < inputTy.getRank(); i++) {
2152 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2153 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
2157 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
2158 op.getLoc(), genericShape, elementTy, dynDims);
2162 dimExprs.reserve(rank);
2163 for (
unsigned i = 0; i < rank; ++i)
2166 auto readAffineMap =
2173 auto genericOp = rewriter.
create<linalg::GenericOp>(
2178 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
2184 op, resultTy, genericOp.getResult(0), shapeValue);
2206 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2208 auto loc = argmaxOp.getLoc();
2209 Value input = argmaxOp.getInput();
2210 auto inputTy = cast<ShapedType>(input.
getType());
2211 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2212 auto inElementTy = inputTy.getElementType();
2213 auto outElementTy = resultTy.getElementType();
2214 int axis = argmaxOp.getAxis();
2217 if (!isa<IntegerType>(outElementTy))
2220 "tosa.arg_max to linalg.* requires integer-like result type");
2223 for (
int i = 0; i < inputTy.getRank(); i++) {
2224 if (inputTy.isDynamicDim(i) && i != axis) {
2225 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
2230 auto emptyTensorIdx = rewriter
2231 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2232 outElementTy, dynDims)
2234 auto fillValueIdx = rewriter.
create<arith::ConstantOp>(
2236 auto filledTensorIdx =
2243 auto emptyTensorMax = rewriter
2244 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2245 inElementTy, dynDims)
2247 auto fillValueMaxAttr =
2250 if (!fillValueMaxAttr)
2252 argmaxOp,
"unsupported tosa.argmax element type");
2255 rewriter.
create<arith::ConstantOp>(loc, fillValueMaxAttr);
2256 auto filledTensorMax =
2265 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2266 iteratorTypes[axis] = utils::IteratorType::reduction;
2270 for (
int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2276 bool didEncounterError =
false;
2279 auto linalgOp = rewriter.
create<linalg::GenericOp>(
2281 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2284 auto newValue = blockArgs[0];
2285 auto oldIndex = blockArgs[1];
2286 auto oldValue = blockArgs[2];
2288 Value newIndex = rewriter.
create<arith::IndexCastOp>(
2289 nestedLoc, oldIndex.getType(),
2290 rewriter.
create<linalg::IndexOp>(loc, axis));
2293 if (isa<FloatType>(inElementTy)) {
2294 if (argmaxOp.getNanMode() ==
"IGNORE") {
2297 predicate = rewriter.
create<arith::CmpFOp>(
2298 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2304 nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue);
2306 nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue);
2307 predicate = rewriter.
create<arith::AndIOp>(
2308 nestedLoc, rewriter.
getI1Type(), gt, oldNonNaN);
2310 }
else if (isa<IntegerType>(inElementTy)) {
2311 predicate = rewriter.
create<arith::CmpIOp>(
2312 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2314 didEncounterError =
true;
2318 auto resultMax = rewriter.
create<arith::SelectOp>(
2319 nestedLoc, predicate, newValue, oldValue);
2320 auto resultIndex = rewriter.
create<arith::SelectOp>(
2321 nestedLoc, predicate, newIndex, oldIndex);
2322 nestedBuilder.
create<linalg::YieldOp>(
2323 nestedLoc,
ValueRange({resultIndex, resultMax}));
2326 if (didEncounterError)
2328 argmaxOp,
"unsupported tosa.argmax element type");
2330 rewriter.
replaceOp(argmaxOp, linalgOp.getResult(0));
2339 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2341 auto input = adaptor.getOperands()[0];
2342 auto indices = adaptor.getOperands()[1];
2344 auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2345 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2346 if (!valuesTy || !resultTy)
2349 auto dynamicDims = inferDynamicDimsForGather(
2350 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2352 auto resultElementTy = resultTy.getElementType();
2354 auto loc = op.getLoc();
2357 .
create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2363 resultTy.getRank(), 0,
2364 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2368 auto genericOp = rewriter.
create<linalg::GenericOp>(
2373 auto indexValue = args[0];
2374 auto index0 = rewriter.
create<linalg::IndexOp>(loc, 0);
2375 Value index1 = rewriter.
create<arith::IndexCastOp>(
2377 auto index2 = rewriter.
create<linalg::IndexOp>(loc, 2);
2378 Value extract = rewriter.
create<tensor::ExtractOp>(
2379 loc, input,
ValueRange{index0, index1, index2});
2380 rewriter.
create<linalg::YieldOp>(loc, extract);
2382 rewriter.
replaceOp(op, genericOp.getResult(0));
2392 auto addDynamicDimension = [&](
Value source, int64_t dim) {
2394 if (
auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2395 results.push_back(dimValue);
2398 addDynamicDimension(values, 0);
2399 addDynamicDimension(indices, 1);
2400 addDynamicDimension(values, 2);
2412 LogicalResult matchAndRewrite(tosa::TableOp op,
2414 auto loc = op.getLoc();
2415 Value input = op.getInput1();
2417 auto inputTy = cast<ShapedType>(input.
getType());
2418 auto tableTy = cast<ShapedType>(
table.getType());
2419 auto resultTy = cast<ShapedType>(op.getType());
2421 auto inputElementTy = inputTy.getElementType();
2422 auto tableElementTy = tableTy.getElementType();
2423 auto resultElementTy = resultTy.getElementType();
2426 for (
int i = 0; i < resultTy.getRank(); ++i) {
2427 if (inputTy.isDynamicDim(i)) {
2429 rewriter.
create<tensor::DimOp>(loc, op.getOperand(0), i));
2433 auto emptyTensor = rewriter
2434 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2435 resultElementTy, dynDims)
2442 auto genericOp = rewriter.
create<linalg::GenericOp>(
2445 rewriter.
replaceOp(op, genericOp.getResult(0));
2450 &genericOp.getRegion(), genericOp.getRegion().end(),
2451 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2455 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2456 resultElementTy.isInteger(8)) {
2457 Value index = rewriter.
create<arith::IndexCastOp>(
2459 Value offset = rewriter.
create<arith::ConstantIndexOp>(loc, 128);
2464 rewriter.
create<linalg::YieldOp>(loc, extract);
2468 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2469 resultElementTy.isInteger(32)) {
2473 auto offset = rewriter.
create<arith::ConstantOp>(
2475 auto seven = rewriter.
create<arith::ConstantOp>(
2477 auto one = rewriter.
create<arith::ConstantOp>(
2479 auto b1111111 = rewriter.
create<arith::ConstantOp>(
2486 auto extendAdd = rewriter.
create<arith::AddIOp>(loc, extend, offset);
2487 Value index = rewriter.
create<arith::ShRUIOp>(loc, extendAdd, seven);
2489 rewriter.
create<arith::AndIOp>(loc, extendAdd, b1111111);
2494 Value indexPlusOne = rewriter.
create<arith::AddIOp>(loc, index, one);
2496 index = rewriter.
create<arith::IndexCastOp>(
2498 indexPlusOne = rewriter.
create<arith::IndexCastOp>(
2513 Value baseScaled = rewriter.
create<arith::ShLIOp>(loc, base, seven);
2514 Value diff = rewriter.
create<arith::SubIOp>(loc, next, base);
2515 Value diffScaled = rewriter.
create<arith::MulIOp>(loc, diff, fraction);
2517 rewriter.
create<arith::AddIOp>(loc, baseScaled, diffScaled);
2519 rewriter.
create<linalg::YieldOp>(loc, result);
2526 op,
"unable to create body for tosa.table op");
2533 static bool isRankedTensor(
Type type) {
return isa<RankedTensorType>(type); }
2537 auto one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2538 auto two = builder.
create<arith::ConstantIndexOp>(loc, 2);
2541 auto divBy2 = builder.
createOrFold<arith::DivUIOp>(loc, value, two);
2542 auto plusOne = builder.
createOrFold<arith::AddIOp>(loc, divBy2, one);
2546 static RankedTensorType
2554 dims[2] = halfPlusOne(builder, loc, dims[2]);
2559 auto elementType = cast<RankedTensorType>(input.
getType()).getElementType();
2564 RankedTensorType type,
2567 rewriter.
create<tensor::EmptyOp>(loc, type, dynamicSizes);
2568 auto fillValueAttr = rewriter.
getZeroAttr(type.getElementType());
2569 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
2570 auto filledTensor = rewriter
2574 return filledTensor;
2578 FloatType type,
Value value) {
2579 auto integerVal = builder.
create<arith::IndexCastUIOp>(
2581 type.getIntOrFloatBitWidth() > 32 ? builder.
getI64Type()
2585 return builder.
create<arith::UIToFPOp>(loc, type, integerVal);
2589 FloatType type, int64_t index) {
2590 auto indexVal = builder.
create<linalg::IndexOp>(loc, index);
2591 return castIndexToFloat(builder, loc, type, indexVal);
2594 template <
typename... Args>
2600 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2602 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2603 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2605 "only supports ranked tensors");
2608 auto loc = rfft2d.getLoc();
2609 auto input = rfft2d.getInputReal();
2611 dyn_cast<FloatType>(cast<ShapedType>(input.
getType()).getElementType());
2614 "only supports float element types");
2618 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2622 utils::IteratorType::parallel, utils::IteratorType::parallel,
2623 utils::IteratorType::parallel, utils::IteratorType::reduction,
2624 utils::IteratorType::reduction};
2629 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2630 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2635 affineDimsExpr(rewriter, 0, 1, 2),
2636 affineDimsExpr(rewriter, 0, 1, 2)},
2640 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input, 1);
2641 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input, 2);
2644 auto twoPiAttr = rewriter.
getFloatAttr(elementType, 6.283185307179586);
2645 auto twoPi = rewriter.
create<arith::ConstantOp>(loc, twoPiAttr);
2646 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2647 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2650 Value valReal = args[0];
2651 Value sumReal = args[1];
2652 Value sumImag = args[2];
2655 Value oy = builder.
create<linalg::IndexOp>(loc, 1);
2656 Value ox = builder.
create<linalg::IndexOp>(loc, 2);
2657 Value iy = builder.
create<linalg::IndexOp>(loc, 3);
2658 Value ix = builder.
create<linalg::IndexOp>(loc, 4);
2663 auto iyXoy = builder.
create<index::MulOp>(loc, iy, oy);
2664 auto ixXox = builder.
create<index::MulOp>(loc, ix, ox);
2666 auto iyRem = builder.
create<index::RemUOp>(loc, iyXoy, dimH);
2667 auto ixRem = builder.
create<index::RemUOp>(loc, ixXox, dimW);
2669 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2670 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2672 auto yComponent = builder.
create<arith::DivFOp>(loc, iyRemFloat, constH);
2673 auto xComponent = builder.
create<arith::DivFOp>(loc, ixRemFloat, constW);
2674 auto sumXY = builder.
create<arith::AddFOp>(loc, yComponent, xComponent);
2675 auto angle = builder.
create<arith::MulFOp>(loc, twoPi, sumXY);
2679 auto cosAngle = builder.
create<math::CosOp>(loc, angle);
2680 auto sinAngle = builder.
create<math::SinOp>(loc, angle);
2681 auto realComponent =
2682 builder.
create<arith::MulFOp>(loc, valReal, cosAngle);
2683 auto imagComponent =
2684 builder.
create<arith::MulFOp>(loc, valReal, sinAngle);
2688 auto outReal = builder.
create<arith::AddFOp>(loc, sumReal, realComponent);
2689 auto outImag = builder.
create<arith::SubFOp>(loc, sumImag, imagComponent);
2695 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2696 indexingMaps, iteratorTypes, buildBody);
2705 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2707 if (!llvm::all_of(fft2d->getOperandTypes(),
2708 RFFT2dConverter::isRankedTensor) ||
2709 !llvm::all_of(fft2d->getResultTypes(),
2710 RFFT2dConverter::isRankedTensor)) {
2715 Value input_real = fft2d.getInputReal();
2716 Value input_imag = fft2d.getInputImag();
2717 BoolAttr inverse = fft2d.getInverseAttr();
2719 auto real_el_ty = cast<FloatType>(
2720 cast<ShapedType>(input_real.
getType()).getElementType());
2721 [[maybe_unused]]
auto imag_el_ty = cast<FloatType>(
2722 cast<ShapedType>(input_imag.
getType()).getElementType());
2724 assert(real_el_ty == imag_el_ty);
2739 utils::IteratorType::parallel, utils::IteratorType::parallel,
2740 utils::IteratorType::parallel, utils::IteratorType::reduction,
2741 utils::IteratorType::reduction};
2746 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2748 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2753 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2754 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2755 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2756 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2760 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 1);
2761 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 2);
2764 auto twoPiAttr = rewriter.
getFloatAttr(real_el_ty, 6.283185307179586);
2765 auto twoPi = rewriter.
create<arith::ConstantOp>(loc, twoPiAttr);
2767 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2769 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2772 Value valReal = args[0];
2773 Value valImag = args[1];
2774 Value sumReal = args[2];
2775 Value sumImag = args[3];
2778 Value oy = builder.
create<linalg::IndexOp>(loc, 1);
2779 Value ox = builder.
create<linalg::IndexOp>(loc, 2);
2780 Value iy = builder.
create<linalg::IndexOp>(loc, 3);
2781 Value ix = builder.
create<linalg::IndexOp>(loc, 4);
2785 auto iyXoy = builder.
create<index::MulOp>(loc, iy, oy);
2786 auto ixXox = builder.
create<index::MulOp>(loc, ix, ox);
2788 auto iyRem = builder.
create<index::RemUOp>(loc, iyXoy, dimH);
2789 auto ixRem = builder.
create<index::RemUOp>(loc, ixXox, dimW);
2792 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2794 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2796 auto yComponent = builder.
create<arith::DivFOp>(loc, iyRemFloat, constH);
2797 auto xComponent = builder.
create<arith::DivFOp>(loc, ixRemFloat, constW);
2799 auto sumXY = builder.
create<arith::AddFOp>(loc, yComponent, xComponent);
2800 auto angle = builder.
create<arith::MulFOp>(loc, twoPi, sumXY);
2803 angle = builder.
create<arith::MulFOp>(
2805 rewriter.
create<arith::ConstantOp>(
2811 auto cosAngle = builder.
create<math::CosOp>(loc, angle);
2812 auto sinAngle = builder.
create<math::SinOp>(loc, angle);
2814 auto rcos = builder.
create<arith::MulFOp>(loc, valReal, cosAngle);
2815 auto rsin = builder.
create<arith::MulFOp>(loc, valImag, sinAngle);
2816 auto realComponent = builder.
create<arith::AddFOp>(loc, rcos, rsin);
2818 auto icos = builder.
create<arith::MulFOp>(loc, valImag, cosAngle);
2819 auto isin = builder.
create<arith::MulFOp>(loc, valReal, sinAngle);
2821 auto imagComponent = builder.
create<arith::SubFOp>(loc, icos, isin);
2825 auto outReal = builder.
create<arith::AddFOp>(loc, sumReal, realComponent);
2826 auto outImag = builder.
create<arith::AddFOp>(loc, sumImag, imagComponent);
2832 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2833 indexingMaps, iteratorTypes, buildBody);
2854 PointwiseConverter<tosa::AddOp>,
2855 PointwiseConverter<tosa::SubOp>,
2856 PointwiseConverter<tosa::MulOp>,
2857 PointwiseConverter<tosa::IntDivOp>,
2858 PointwiseConverter<tosa::NegateOp>,
2859 PointwiseConverter<tosa::PowOp>,
2860 PointwiseConverter<tosa::ReciprocalOp>,
2861 PointwiseConverter<tosa::RsqrtOp>,
2862 PointwiseConverter<tosa::LogOp>,
2863 PointwiseConverter<tosa::ExpOp>,
2864 PointwiseConverter<tosa::AbsOp>,
2865 PointwiseConverter<tosa::SinOp>,
2866 PointwiseConverter<tosa::CosOp>,
2867 PointwiseConverter<tosa::TanhOp>,
2868 PointwiseConverter<tosa::ErfOp>,
2869 PointwiseConverter<tosa::BitwiseAndOp>,
2870 PointwiseConverter<tosa::BitwiseOrOp>,
2871 PointwiseConverter<tosa::BitwiseNotOp>,
2872 PointwiseConverter<tosa::BitwiseXorOp>,
2873 PointwiseConverter<tosa::LogicalAndOp>,
2874 PointwiseConverter<tosa::LogicalNotOp>,
2875 PointwiseConverter<tosa::LogicalOrOp>,
2876 PointwiseConverter<tosa::LogicalXorOp>,
2877 PointwiseConverter<tosa::CastOp>,
2878 PointwiseConverter<tosa::LogicalLeftShiftOp>,
2879 PointwiseConverter<tosa::LogicalRightShiftOp>,
2880 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2881 PointwiseConverter<tosa::ClzOp>,
2882 PointwiseConverter<tosa::SelectOp>,
2883 PointwiseConverter<tosa::GreaterOp>,
2884 PointwiseConverter<tosa::GreaterEqualOp>,
2885 PointwiseConverter<tosa::EqualOp>,
2886 PointwiseConverter<tosa::MaximumOp>,
2887 PointwiseConverter<tosa::MinimumOp>,
2888 PointwiseConverter<tosa::CeilOp>,
2889 PointwiseConverter<tosa::FloorOp>,
2890 PointwiseConverter<tosa::ClampOp>,
2891 PointwiseConverter<tosa::SigmoidOp>
2892 >(converter,
patterns->getContext());
2895 IdentityNConverter<tosa::IdentityOp>,
2896 ReduceConverter<tosa::ReduceAllOp>,
2897 ReduceConverter<tosa::ReduceAnyOp>,
2898 ReduceConverter<tosa::ReduceMinOp>,
2899 ReduceConverter<tosa::ReduceMaxOp>,
2900 ReduceConverter<tosa::ReduceSumOp>,
2901 ReduceConverter<tosa::ReduceProductOp>,
2909 TileConverter>(
patterns->getContext());
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static bool operandsAndResultsRanked(Operation *operation)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...