31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
35 #include <type_traits>
62 template <
typename OpTy>
70 auto nanMode = op.getNanMode();
71 if (nanMode ==
"PROPAGATE")
76 op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
78 op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
80 rewriter.
create<arith::SelectOp>(op.getLoc(), lhsIsNaN, rhs, result);
81 return rewriter.
create<arith::SelectOp>(op.getLoc(), rhsIsNaN, lhs,
93 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
94 return rewriter.
create<math::AbsFOp>(loc, resultTypes, args);
96 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
97 auto zero = rewriter.
create<arith::ConstantOp>(
99 auto neg = rewriter.
create<arith::SubIOp>(loc, zero, args[0]);
100 return rewriter.
create<arith::MaxSIOp>(loc, args[0], neg);
104 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
105 return rewriter.
create<arith::AddFOp>(loc, resultTypes, args);
107 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
108 return rewriter.
create<arith::AddIOp>(loc, resultTypes, args);
111 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
112 return rewriter.
create<arith::SubFOp>(loc, resultTypes, args);
114 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
115 return rewriter.
create<arith::SubIOp>(loc, resultTypes, args);
118 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
119 return rewriter.
create<arith::DivSIOp>(loc, resultTypes, args);
122 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
125 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, args[0]);
129 if (isa<tosa::MulOp>(op)) {
130 auto shiftVal = cast<tosa::MulOp>(op).getShift();
137 int32_t shift = shiftElem.
getValues<IntegerAttr>()[0].getInt();
139 if (isa<FloatType>(elementTy)) {
142 "Cannot have shift value for float");
145 return rewriter.
create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
148 if (isa<IntegerType>(elementTy)) {
154 rewriter.
create<arith::ConstantIntOp>(loc, shift, 8);
161 auto result = rewriter.
create<tosa::ApplyScaleOp>(
165 if (elementTy.isInteger(32))
168 return rewriter.
create<arith::TruncIOp>(loc, elementTy, result);
173 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
176 a = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], a);
178 b = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], b);
180 return rewriter.
create<arith::MulIOp>(loc, resultTypes, a, b);
185 if (isa<tosa::NegateOp>(op)) {
186 auto negate = cast<tosa::NegateOp>(op);
188 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
189 if (failed(maybeInZp)) {
191 op,
"input1 zero point cannot be statically determined");
195 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
196 if (failed(maybeOutZp)) {
198 op,
"output zero point cannot be statically determined");
202 int64_t inZp = *maybeInZp;
203 int64_t outZp = *maybeOutZp;
205 if (isa<FloatType>(elementTy))
206 return rewriter.
create<arith::NegFOp>(loc, resultTypes, args[0]);
208 if (isa<IntegerType>(elementTy)) {
209 if (!inZp && !outZp) {
210 auto constant = rewriter.
create<arith::ConstantOp>(
212 return rewriter.
create<arith::SubIOp>(loc, resultTypes, constant,
217 const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
218 const int64_t zpAdd = inZp + outZp;
219 const int64_t maxValue =
220 APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
226 int intermediateBitWidth = 64;
227 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
228 intermediateBitWidth = 16;
229 }
else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
230 intermediateBitWidth = 32;
231 }
else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
232 intermediateBitWidth = 48;
236 Value zpAddValue = rewriter.
create<arith::ConstantOp>(
242 rewriter.
create<arith::ExtSIOp>(loc, intermediateType, args[0]);
243 auto sub = rewriter.
create<arith::SubIOp>(loc, zpAddValue, ext);
247 loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
250 loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
255 return rewriter.
create<arith::TruncIOp>(loc, elementTy,
clamp);
260 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
261 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
264 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
265 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
268 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
270 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
271 auto allOnes = rewriter.
create<arith::ConstantOp>(loc, allOnesAttr);
272 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
276 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
277 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
280 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
281 return rewriter.
create<arith::ShLIOp>(loc, resultTypes, args);
284 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
285 return rewriter.
create<arith::ShRUIOp>(loc, resultTypes, args);
288 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
289 auto result = rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args);
290 auto round = cast<BoolAttr>(op->
getAttr(
"round")).getValue();
304 auto shiftValueGreaterThanZero = rewriter.
create<arith::CmpIOp>(
305 loc, arith::CmpIPredicate::sgt, args[1], zero);
309 rewriter.
create<arith::SubIOp>(loc, resultTypes, args[1], one);
311 rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
314 rewriter.
create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
316 rewriter.
create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
318 auto shouldRound = rewriter.
create<arith::AndIOp>(
319 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
321 rewriter.
create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
322 return rewriter.
create<arith::AddIOp>(loc, resultTypes, result, extended);
326 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
327 return rewriter.
create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
331 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
332 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
335 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
336 auto one = rewriter.
create<arith::ConstantOp>(
338 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], one);
342 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
343 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
346 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
347 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
350 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
351 return rewriter.
create<mlir::math::PowFOp>(loc, resultTypes, args);
354 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
355 return rewriter.
create<mlir::math::RsqrtOp>(loc, resultTypes, args);
358 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
359 return rewriter.
create<mlir::math::LogOp>(loc, resultTypes, args);
362 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
363 return rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, args);
366 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
367 return rewriter.
create<mlir::math::SinOp>(loc, resultTypes, args);
370 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
371 return rewriter.
create<mlir::math::CosOp>(loc, resultTypes, args);
374 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
375 return rewriter.
create<mlir::math::TanhOp>(loc, resultTypes, args);
378 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
379 return rewriter.
create<mlir::math::ErfOp>(loc, resultTypes, args);
382 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
383 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
386 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
387 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
391 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
392 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
395 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
396 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
400 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
401 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
404 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
405 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
409 if (isa<tosa::SelectOp>(op)) {
411 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
412 return rewriter.
create<arith::SelectOp>(loc, args[0], args[1], args[2]);
416 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
417 auto max = rewriter.
create<arith::MaximumFOp>(loc, args[0], args[1]);
419 rewriter, args[0], args[1],
max);
422 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
423 return rewriter.
create<arith::MaxSIOp>(loc, args[0], args[1]);
427 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
428 auto min = rewriter.
create<arith::MinimumFOp>(loc, args[0], args[1]);
430 rewriter, args[0], args[1],
min);
433 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
434 return rewriter.
create<arith::MinSIOp>(loc, args[0], args[1]);
438 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
439 return rewriter.
create<math::CeilOp>(loc, resultTypes, args);
442 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
443 return rewriter.
create<math::FloorOp>(loc, resultTypes, args);
446 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
447 bool losesInfo =
false;
448 APFloat minApf = cast<FloatAttr>(op->
getAttr(
"min_val")).getValue();
449 APFloat maxApf = cast<FloatAttr>(op->
getAttr(
"max_val")).getValue();
450 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
451 APFloat::rmNearestTiesToEven, &losesInfo);
452 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
453 APFloat::rmNearestTiesToEven, &losesInfo);
454 auto min = rewriter.
create<arith::ConstantOp>(
455 loc, elementTy, rewriter.
getFloatAttr(elementTy, minApf));
456 auto max = rewriter.
create<arith::ConstantOp>(
457 loc, elementTy, rewriter.
getFloatAttr(elementTy, maxApf));
460 auto clampOp = llvm::cast<tosa::ClampOp>(op);
461 const auto nanMode = clampOp.getNanMode();
464 if (!isa<FloatType>(elementTy))
469 if (nanMode ==
"PROPAGATE")
484 op->
getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
487 return rewriter.
create<arith::SelectOp>(op->
getLoc(), isNaN,
min, result);
490 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
491 auto intTy = cast<IntegerType>(elementTy);
493 cast<IntegerAttr>(op->
getAttr(
"min_val")).getValue().getSExtValue();
495 cast<IntegerAttr>(op->
getAttr(
"max_val")).getValue().getSExtValue();
499 if (intTy.isUnsignedInteger()) {
500 minRepresentable = 0;
501 if (intTy.getIntOrFloatBitWidth() <= 63) {
503 (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
506 }
else if (intTy.getIntOrFloatBitWidth() <= 64) {
508 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
510 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
520 auto minVal = rewriter.
create<arith::ConstantIntOp>(
521 loc,
min, intTy.getIntOrFloatBitWidth());
522 auto maxVal = rewriter.
create<arith::ConstantIntOp>(
523 loc,
max, intTy.getIntOrFloatBitWidth());
525 intTy.isUnsignedInteger());
529 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
532 auto negate = rewriter.
create<arith::NegFOp>(loc, resultTypes, args[0]);
533 auto exp = rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, negate);
534 auto added = rewriter.
create<arith::AddFOp>(loc, resultTypes, exp, one);
535 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, added);
539 if (isa<tosa::CastOp>(op)) {
540 Type srcTy = elementTy;
541 Type dstTy = resultTypes.front();
553 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
554 return rewriter.
create<arith::ExtFOp>(loc, resultTypes, args,
557 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
558 return rewriter.
create<arith::TruncFOp>(loc, resultTypes, args,
562 if (srcTy.
isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
563 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes, args,
566 if (srcTy.
isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
567 return rewriter.
create<arith::ExtUIOp>(loc, resultTypes, args,
573 auto unrealizedCast =
575 .
create<UnrealizedConversionCastOp>(
579 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes[0],
584 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
585 return rewriter.
create<arith::SIToFPOp>(loc, resultTypes, args,
589 if (isa<FloatType>(srcTy) && dstTy.
isInteger(1)) {
592 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
596 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
597 auto rounded = rewriter.
create<math::RoundEvenOp>(loc, args[0]);
599 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
603 APFloat::semanticsMaxExponent(fltSemantics)) {
606 auto conv = rewriter.
create<arith::FPToSIOp>(loc, dstTy, rounded);
607 auto posInf = rewriter.
create<arith::ConstantOp>(
609 APFloat::getInf(fltSemantics)));
610 auto negInf = rewriter.
create<arith::ConstantOp>(
613 APFloat::getInf(fltSemantics,
true)));
614 auto overflow = rewriter.
create<arith::CmpFOp>(
615 loc, arith::CmpFPredicate::UEQ, rounded, posInf);
616 auto underflow = rewriter.
create<arith::CmpFOp>(
617 loc, arith::CmpFPredicate::UEQ, rounded, negInf);
618 auto intMin = rewriter.
create<arith::ConstantOp>(
621 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
622 auto intMax = rewriter.
create<arith::ConstantOp>(
625 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
627 rewriter.
create<arith::SelectOp>(loc, overflow, intMax, conv);
628 return rewriter.
create<arith::SelectOp>(loc, underflow, intMin,
632 auto intMinFP = rewriter.
create<arith::ConstantOp>(
639 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
645 auto intMaxFP = rewriter.
create<arith::ConstantOp>(
653 return rewriter.
create<arith::FPToSIOp>(loc, dstTy, clamped);
660 auto intMaxPlusOneFP = rewriter.
create<arith::ConstantOp>(
668 auto intMax = rewriter.
create<arith::ConstantOp>(
673 rewriter.
create<arith::MaximumFOp>(loc, rounded, intMinFP);
675 rewriter.
create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
676 auto overflow = rewriter.
create<arith::CmpFOp>(
677 loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
678 return rewriter.
create<arith::SelectOp>(loc, overflow, intMax,
684 if (isa<IntegerType>(srcTy) && dstTy.
isInteger(1)) {
685 Value zero = rewriter.
create<arith::ConstantIntOp>(
687 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
691 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
692 return rewriter.
create<arith::ExtSIOp>(loc, resultTypes, args,
695 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
696 return rewriter.
create<arith::TruncIOp>(loc, dstTy, args[0]);
701 op,
"unhandled op for linalg body calculation for elementwise op");
712 auto [it, inserted] = indexPool.try_emplace(index);
721 auto indexValue =
createIndex(rewriter, loc, indexPool, index);
722 return rewriter.
create<tensor::DimOp>(loc, tensor, indexValue).getResult();
728 auto shapedType = dyn_cast<ShapedType>(tensor.
getType());
729 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
730 assert(index >= 0 && index < shapedType.getRank() &&
"index out of bounds");
731 if (shapedType.isDynamicDim(index))
732 return getTensorDim(rewriter, loc, indexPool, tensor, index);
733 return rewriter.
getIndexAttr(shapedType.getDimSize(index));
737 auto isRanked = [](
Value value) {
738 return isa<RankedTensorType>(value.getType());
740 return llvm::all_of(operation->
getOperands(), isRanked) &&
741 llvm::all_of(operation->
getResults(), isRanked);
754 static std::pair<OpFoldResult, Value>
760 for (
auto operand : operands) {
761 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
762 if (!ShapedType::isDynamic(size) && size > 1)
767 auto operandsWithDynamicDim =
768 llvm::filter_to_vector(operands, [&](
Value operand) {
769 return cast<RankedTensorType>(operand.
getType()).isDynamicDim(dim);
773 if (operandsWithDynamicDim.empty())
780 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
781 if (operandsWithDynamicDim.size() == 1)
782 return {targetSize, operandsWithDynamicDim[0]};
785 for (
size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
787 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
788 targetSize = rewriter.
create<arith::MaxUIOp>(loc, targetSize, nextSize);
790 return {targetSize,
nullptr};
798 assert(!operands.empty());
799 auto rank = cast<RankedTensorType>(operands.front().
getType()).getRank();
802 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
803 auto [targetSize, masterOperand] =
805 targetShape.push_back(targetSize);
806 masterOperands.push_back(masterOperand);
808 return {targetShape, masterOperands};
814 Value masterOperand) {
816 auto rankedTensorType = cast<RankedTensorType>(operand.
getType());
817 if (!rankedTensorType.isDynamicDim(dim))
824 if (operand == masterOperand)
828 auto rank = rankedTensorType.getRank();
830 for (
auto index : llvm::seq<int64_t>(0, rank)) {
833 affineExprs.push_back(affineExpr);
835 auto broadcastAffineMap =
841 auto one =
createIndex(rewriter, loc, indexPool, 1);
842 auto runtimeSize =
getTensorDim(rewriter, loc, indexPool, operand, dim);
843 auto broadcastNecessary = rewriter.
create<arith::CmpIOp>(
844 loc, arith::CmpIPredicate::eq, runtimeSize, one);
854 for (
auto index : llvm::seq<int64_t>(0, rank)) {
855 auto size = index == dim ? targetSize
858 outputTensorShape.push_back(size);
860 Value outputTensor = opBuilder.
create<tensor::EmptyOp>(
861 loc, outputTensorShape, rankedTensorType.getElementType());
866 .
create<linalg::GenericOp>(
867 loc, outputTensor.
getType(), operand, outputTensor, affineMaps,
871 opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
876 auto castResultTensor = rewriter.
createOrFold<tensor::CastOp>(
877 loc, operand.
getType(), resultTensor);
880 opBuilder.create<scf::YieldOp>(loc, castResultTensor);
885 opBuilder.
create<scf::YieldOp>(loc, operand);
889 auto ifOp = rewriter.
create<scf::IfOp>(loc, broadcastNecessary,
890 emitThenRegion, emitElseRegion);
898 int64_t rank = cast<RankedTensorType>(operand.
getType()).getRank();
899 assert((int64_t)targetShape.size() == rank);
900 assert((int64_t)masterOperands.size() == rank);
901 for (
auto index : llvm::seq<int64_t>(0, rank))
904 targetShape[index], masterOperands[index]);
914 if (operands.size() == 1)
918 return llvm::map_to_vector(operands, [&](
Value operand) {
920 targetShape, masterOperands);
930 auto resultType = cast_or_null<RankedTensorType>(
935 Value outputTensor = rewriter.
create<tensor::EmptyOp>(
936 loc, targetShape, resultType.getElementType());
941 auto rank = resultType.getRank();
942 auto affineMaps = llvm::map_to_vector(operands, [&](
Value operand) {
943 auto shape = cast<ShapedType>(operand.
getType()).getShape();
949 bool requiresBroadcast =
950 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
951 auto affineExpr = requiresBroadcast
954 affineExprs.push_back(affineExpr);
961 bool encounteredError =
false;
962 auto linalgOp = rewriter.
create<linalg::GenericOp>(
963 loc, outputTensor.
getType(), operands, outputTensor, affineMaps,
968 {resultType.getElementType()}, rewriter);
970 encounteredError =
true;
973 opBuilder.create<linalg::YieldOp>(loc, opResult);
975 if (encounteredError)
977 operation,
"unable to create linalg.generic body for elementwise op");
980 auto castResult = rewriter.
createOrFold<tensor::CastOp>(
981 loc, resultType, linalgOp->getResult(0));
982 rewriter.
replaceOp(operation, castResult);
989 if (isa<tosa::MulOp>(operation))
990 return operands.take_front(2);
992 if (isa<tosa::NegateOp>(operation))
993 return operands.take_front(1);
1003 assert(operation->
getNumResults() == 1 &&
"elementwise op expects 1 result");
1005 "elementwise op expects at least 1 operand");
1008 "Unranked tensors not supported");
1012 auto loc = operation->
getLoc();
1014 auto [targetShape, masterOperands] =
1016 auto broadcastOperands =
1018 targetShape, masterOperands);
1020 targetShape, converter);
1027 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1030 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1033 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1036 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1039 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1041 elementTy, APFloat::getLargest(
1042 cast<FloatType>(elementTy).getFloatSemantics(),
false));
1044 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1048 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1050 elementTy, APFloat::getLargest(
1051 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1053 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1057 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1060 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1063 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1065 elementTy, APFloat::getLargest(
1066 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1068 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1082 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1083 return rewriter.
create<arith::AddFOp>(loc, args);
1086 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1087 return rewriter.
create<arith::AddIOp>(loc, args);
1090 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1091 return rewriter.
create<arith::MulFOp>(loc, args);
1094 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1095 return rewriter.
create<arith::MulIOp>(loc, args);
1098 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1099 return rewriter.
create<arith::MinimumFOp>(loc, args[0], args[1]);
1102 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1103 return rewriter.
create<arith::MinSIOp>(loc, args[0], args[1]);
1106 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1107 return rewriter.
create<arith::MaximumFOp>(loc, args[0], args[1]);
1110 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1111 return rewriter.
create<arith::MaxSIOp>(loc, args[0], args[1]);
1114 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1115 return rewriter.
create<arith::AndIOp>(loc, args);
1117 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1118 return rewriter.
create<arith::OrIOp>(loc, args);
1126 template <
typename OpTy>
1129 auto loc = op->getLoc();
1130 auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1131 auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1132 if (!inputTy || !resultTy)
1135 auto elementTy = resultTy.getElementType();
1136 Value input = op->getOperand(0);
1140 for (
unsigned i = 0; i < inputTy.getRank(); i++) {
1142 reduceShape.push_back(inputTy.getDimSize(i));
1143 if (inputTy.isDynamicDim(i))
1144 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1149 inputs.push_back(input);
1154 .
create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1161 op,
"No initial value found for reduction operation");
1163 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
1164 auto filledTensor = rewriter
1168 outputs.push_back(filledTensor);
1170 bool isNanIgnoreMode =
false;
1171 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1172 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1174 if (isa<FloatType>(elementTy) && op.getNanMode() ==
"IGNORE") {
1175 isNanIgnoreMode =
true;
1181 auto trueValue = rewriter.
create<arith::ConstantOp>(loc, trueAttr);
1182 auto emptyBoolTensor =
1184 .
create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
1187 auto allResultsNaNTensor =
1201 inputs.push_back(input);
1202 outputs.push_back(allResultsNaNTensor);
1206 bool didEncounterError =
false;
1207 linalg::LinalgOp linalgOp = rewriter.
create<linalg::ReduceOp>(
1208 loc, inputs, outputs, axis,
1210 std::array<Value, 2> binaryArgs{
1211 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1213 op, binaryArgs, elementTy, rewriter);
1215 didEncounterError =
true;
1218 if (isNanIgnoreMode) {
1219 auto inputValue = blockArgs[0];
1220 auto initialValue = blockArgs[2];
1221 auto oldAllResultsNanFlagValue = blockArgs[3];
1224 Value isNaN = nestedBuilder.create<arith::CmpFOp>(
1225 op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);
1227 auto selectOp = nestedBuilder.create<arith::SelectOp>(
1228 op->getLoc(), isNaN, initialValue, result);
1231 auto newAllResultsNanFlagValue = nestedBuilder.create<arith::AndIOp>(
1232 op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1233 resultsToYield.push_back(selectOp);
1234 resultsToYield.push_back(newAllResultsNanFlagValue);
1236 resultsToYield.push_back(result);
1238 nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
1241 if (!didEncounterError)
1243 op,
"unable to create linalg.generic body for reduce op");
1245 if (isNanIgnoreMode) {
1254 APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(),
false));
1255 auto nanValue = rewriter.
create<arith::ConstantOp>(loc, nanValueAttr);
1256 auto emptyNanTensor =
1258 .
create<tensor::EmptyOp>(loc, reduceShape,
1259 resultTy.getElementType(), dynDims)
1261 auto nanFilledTensor =
1269 auto finalEmptyTensor =
1271 .
create<tensor::EmptyOp>(loc, reduceShape,
1272 resultTy.getElementType(), dynDims)
1278 ins.push_back(linalgOp->getOpResult(1));
1279 ins.push_back(nanFilledTensor);
1280 ins.push_back(linalgOp->getResult(0));
1281 outs.push_back(finalEmptyTensor);
1283 rewriter.
create<linalg::SelectOp>(op->getLoc(), ins, outs);
1284 linalgOp = linalgSelect;
1288 uint64_t expandInputRank =
1289 cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
1290 reassociationMap.resize(expandInputRank);
1292 for (uint64_t i = 0; i < expandInputRank; i++) {
1293 int32_t dimToPush = i > axis ? i + 1 : i;
1297 if (expandInputRank != 0) {
1298 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1299 reassociationMap[expandedDim].push_back(
1308 op, resultTy, linalgOp->getResults()[0], reassociationMap);
1314 template <
typename SrcOp>
1321 matchAndRewrite(SrcOp op, OpAdaptor operands,
1324 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1332 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1334 auto loc = op.getLoc();
1335 auto input = op.getInput();
1336 auto inputTy = cast<ShapedType>(op.getInput().getType());
1337 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1338 unsigned rank = inputTy.getRank();
1341 if (op.getRoundingMode() ==
"INEXACT_ROUND")
1342 return rewriter.notifyMatchFailure(
1343 op,
"tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1344 "currently supported");
1345 if (op.getRoundingMode() ==
"DOUBLE_ROUND" && !op.getScale32())
1346 return rewriter.notifyMatchFailure(
1347 op,
"tosa.rescale requires scale32 for double_round to be true");
1349 if (!isa<IntegerType>(inputTy.getElementType()))
1350 return rewriter.notifyMatchFailure(op,
"only support integer type");
1353 for (
int i = 0; i < outputTy.getRank(); i++) {
1354 if (outputTy.isDynamicDim(i)) {
1355 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1362 return rewriter.notifyMatchFailure(
1363 op,
"tosa.rescale requires constant shift input values");
1367 return rewriter.notifyMatchFailure(
1368 op,
"tosa.rescale requires constant multiplier input values");
1371 llvm::to_vector(shiftElems.
getValues<int8_t>());
1374 llvm::map_range(multiplierElems.
getValues<IntegerAttr>(),
1375 [](IntegerAttr attr) -> int32_t {
1376 return static_cast<int32_t>(attr.getInt());
1380 for (
int i = 0, s = multiplierValues.size(); i < s; i++) {
1381 if (shiftValues[i] > 63) {
1383 multiplierValues[i] = 0;
1391 op.getRoundingMode() ==
"DOUBLE_ROUND" &&
1392 llvm::any_of(shiftValues, [](int32_t v) {
return v > 31; });
1393 StringAttr roundingMode = doubleRound
1394 ? rewriter.getStringAttr(
"DOUBLE_ROUND")
1395 : rewriter.getStringAttr(
"SINGLE_ROUND");
1398 rewriter.getMultiDimIdentityMap(rank)};
1403 Value multiplierConstant;
1404 int64_t multiplierArg = 0;
1405 if (multiplierValues.size() == 1) {
1406 multiplierConstant = rewriter.create<arith::ConstantOp>(
1407 loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1410 rewriter.getAffineDimExpr(rank - 1)};
1411 auto multiplierType =
1413 rewriter.getI32Type());
1414 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1419 rewriter.getContext()));
1421 multiplierArg = indexingMaps.size() - 1;
1426 Value shiftConstant;
1427 int64_t shiftArg = 0;
1428 if (shiftValues.size() == 1) {
1429 shiftConstant = rewriter.create<arith::ConstantOp>(
1430 loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1433 rewriter.getAffineDimExpr(rank - 1)};
1436 rewriter.getIntegerType(8));
1437 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1441 rewriter.getContext()));
1442 shiftArg = indexingMaps.size() - 1;
1446 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1449 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1450 loc, outputTy.getShape(), outputTy.getElementType(),
1453 auto linalgOp = rewriter.create<linalg::GenericOp>(
1454 loc, outputTy, genericInputs,
ValueRange{emptyTensor}, indexingMaps,
1458 Value value = blockArgs[0];
1461 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1462 if (failed(maybeIZp)) {
1463 (void)rewriter.notifyMatchFailure(
1464 op,
"input zero point cannot be statically determined");
1470 const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
1471 auto inputZp = nestedBuilder.create<arith::ConstantOp>(
1475 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1476 if (failed(maybeOZp)) {
1477 (void)rewriter.notifyMatchFailure(
1478 op,
"output zero point cannot be statically determined");
1482 IntegerType outIntType =
1483 cast<IntegerType>(blockArgs.back().getType());
1484 unsigned outBitWidth = outIntType.getWidth();
1485 const int32_t outAttrBitwidth = 32;
1486 assert(outBitWidth <= 32 &&
"Unexpected output zeropoint bitwidth");
1487 auto outputZp = nestedBuilder.create<arith::ConstantOp>(
1491 Value multiplier = multiplierConstant ? multiplierConstant
1492 : blockArgs[multiplierArg];
1493 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1496 if (op.getInputUnsigned()) {
1497 value = nestedBuilder.create<arith::ExtUIOp>(
1498 nestedLoc, nestedBuilder.getI32Type(), value);
1500 value = nestedBuilder.create<arith::ExtSIOp>(
1501 nestedLoc, nestedBuilder.getI32Type(), value);
1506 nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1508 value = nestedBuilder.create<tosa::ApplyScaleOp>(
1509 loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1514 nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1517 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1518 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1521 if (op.getOutputUnsigned()) {
1523 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1526 auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1527 loc, nestedBuilder.getI32IntegerAttr(intMin));
1528 auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1529 loc, nestedBuilder.getI32IntegerAttr(intMax));
1532 nestedBuilder,
false);
1534 if (outIntType.getWidth() < 32) {
1535 value = nestedBuilder.create<arith::TruncIOp>(
1536 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1540 nestedBuilder.create<linalg::YieldOp>(loc, value);
1543 rewriter.replaceOp(op, linalgOp->getResults());
1555 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1559 auto input = op.getInput();
1560 auto inputTy = cast<RankedTensorType>(input.getType());
1561 auto resultTy = cast<RankedTensorType>(op.getType());
1562 const bool isBilinear = op.getMode() ==
"BILINEAR";
1564 auto inputH = inputTy.getDimSize(1);
1565 auto inputW = inputTy.getDimSize(2);
1566 auto outputH = resultTy.getDimSize(1);
1567 auto outputW = resultTy.getDimSize(2);
1569 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1570 return rewriter.notifyMatchFailure(
1571 op,
"tosa.resize is not a pure 1x1->1x1 image operation");
1574 if (op.getMode() !=
"NEAREST_NEIGHBOR" && op.getMode() !=
"BILINEAR")
1575 return rewriter.notifyMatchFailure(
1576 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1578 if (inputTy == resultTy) {
1579 rewriter.replaceOp(op, input);
1590 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1591 reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1592 reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1593 reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1597 inputTy.getElementType());
1598 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1603 if (inputTy.isDynamicDim(0))
1604 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1605 if (inputTy.isDynamicDim(3))
1606 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1609 auto genericTy = collapseTy.clone(resultTy.getElementType());
1610 Value empty = builder.create<tensor::EmptyOp>(
1611 genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1612 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1614 utils::IteratorType::parallel);
1616 auto generic = builder.create<linalg::GenericOp>(
1620 Value value = args[0];
1622 if (inputTy.getElementType() != resultTy.getElementType()) {
1624 b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1626 if (isBilinear && scale[0] != 0) {
1627 Value scaleY = b.create<arith::ConstantOp>(
1628 loc, b.getI32IntegerAttr(scale[0]));
1629 value = b.create<arith::MulIOp>(loc, value, scaleY);
1632 if (isBilinear && scale[2] != 0) {
1633 Value scaleX = b.create<arith::ConstantOp>(
1634 loc, b.getI32IntegerAttr(scale[2]));
1635 value = b.create<arith::MulIOp>(loc, value, scaleX);
1639 b.create<linalg::YieldOp>(loc, value);
1642 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1643 op, resultTy,
generic.getResults()[0], reassociationMap);
1655 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1659 auto input = op.getInput();
1660 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1661 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1663 if (!inputTy || !resultTy)
1664 return rewriter.notifyMatchFailure(op,
1665 "requires ranked input/output types");
1667 auto batch = inputTy.getDimSize(0);
1668 auto channels = inputTy.getDimSize(3);
1669 auto inputH = inputTy.getDimSize(1);
1670 auto inputW = inputTy.getDimSize(2);
1671 auto outputH = resultTy.getDimSize(1);
1672 auto outputW = resultTy.getDimSize(2);
1674 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1675 return rewriter.notifyMatchFailure(
1676 op,
"tosa.resize has no broadcasting behavior");
1681 resizeShape.push_back(batch);
1682 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1683 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1684 resizeShape.push_back(channels);
1686 auto resizeTy = resultTy.clone(resizeShape);
1687 auto resize = builder.create<tosa::ResizeOp>(resizeTy, input, op.getScale(),
1688 op.getOffset(), op.getBorder(),
1693 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1694 reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1696 reassociationMap.push_back({});
1697 reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1699 reassociationMap.push_back({});
1700 reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1704 collapseShape.push_back(outputH);
1706 collapseShape.push_back(outputW);
1707 collapseShape.push_back(channels);
1709 auto collapseTy = resultTy.clone(collapseShape);
1710 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1715 if (inputTy.isDynamicDim(0))
1716 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1717 if (inputTy.isDynamicDim(3))
1718 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1721 utils::IteratorType::parallel);
1722 Value empty = builder.create<tensor::EmptyOp>(
1723 resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1727 inputExprs.push_back(rewriter.getAffineDimExpr(1));
1729 inputExprs.push_back(rewriter.getAffineDimExpr(2));
1730 inputExprs.push_back(rewriter.getAffineDimExpr(3));
1733 inputExprs, rewriter.getContext());
1735 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1736 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1740 Value value = args[0];
1741 b.create<linalg::YieldOp>(loc, value);
1752 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1756 auto input = op.getInput();
1757 auto inputTy = cast<ShapedType>(input.getType());
1758 auto resultTy = cast<ShapedType>(op.getType());
1759 auto resultETy = resultTy.getElementType();
1761 bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1762 auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1764 auto imageH = inputTy.getShape()[1];
1765 auto imageW = inputTy.getShape()[2];
1767 auto dynamicDimsOr =
1769 if (!dynamicDimsOr.has_value())
1770 return rewriter.notifyMatchFailure(
1771 op,
"unable to get dynamic dimensions of tosa.resize");
1773 if (op.getMode() !=
"NEAREST_NEIGHBOR" && op.getMode() !=
"BILINEAR")
1774 return rewriter.notifyMatchFailure(
1775 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1778 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1779 auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1781 auto genericOp = b.create<linalg::GenericOp>(
1784 Value resize = genericOp.getResult(0);
1788 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1790 Value batch = b.create<linalg::IndexOp>(0);
1791 Value y = b.create<linalg::IndexOp>(1);
1792 Value x = b.create<linalg::IndexOp>(2);
1793 Value channel = b.create<linalg::IndexOp>(3);
1796 b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1797 Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1798 Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1799 Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1801 Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1802 Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1808 return rewriter.notifyMatchFailure(
1809 op,
"tosa.resize scale/offset/border should have compile time "
1810 "constant values.");
1813 Value yScaleN, yScaleD, xScaleN, xScaleD;
1814 yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1815 yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1816 xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1817 xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1819 Value yOffset, xOffset, yBorder, xBorder;
1820 yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1821 xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1822 yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1823 xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1836 Value val = b.create<arith::MulIOp>(in, scaleD);
1837 val = b.create<arith::AddIOp>(val, offset);
1838 index = b.create<arith::FloorDivSIOp>(val, scaleN);
1842 Value r = b.create<arith::RemSIOp>(val, scaleN);
1843 Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1844 Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1845 delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1860 Value val = b.create<arith::MulIOp>(in, scaleD);
1861 val = b.create<arith::AddIOp>(val, offset);
1862 index = b.create<arith::DivSIOp>(val, scaleN);
1863 delta = b.create<arith::MulIOp>(index, scaleN);
1864 delta = b.create<arith::SubIOp>(val, delta);
1867 Value ix, iy, dx, dy;
1868 if (floatingPointMode) {
1869 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1870 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1872 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1873 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1876 if (op.getMode() ==
"NEAREST_NEIGHBOR") {
1877 auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1879 auto getNearestIndexAndClamp = [&](
Value val,
Value dval,
Value scale,
1883 return b.create<arith::ConstantIndexOp>(0);
1887 if (floatingPointMode) {
1888 auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1889 pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1891 Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1892 pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1896 auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1897 val = b.create<arith::AddIOp>(val, offset);
1899 return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1902 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1903 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1905 Value result = b.create<tensor::ExtractOp>(
1908 b.create<linalg::YieldOp>(result);
1911 assert(op.getMode() ==
"BILINEAR");
1913 auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1915 auto getClampedIdxs = [&](
Value &val0,
Value &val1,
int size,
Value in,
1918 val1 = b.create<arith::AddIOp>(val0, oneVal);
1923 val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1924 val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1932 Value x0, x1, y0, y1;
1933 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1934 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1936 Value y0x0 = b.create<tensor::ExtractOp>(
1938 Value y0x1 = b.create<tensor::ExtractOp>(
1940 Value y1x0 = b.create<tensor::ExtractOp>(
1942 Value y1x1 = b.create<tensor::ExtractOp>(
1945 if (floatingPointMode) {
1947 b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1953 Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1954 Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1955 Value mul1 = b.create<arith::MulFOp>(val1, delta);
1956 return b.create<arith::AddFOp>(mul0, mul1);
1962 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1967 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1971 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1972 b.create<linalg::YieldOp>(result);
1975 y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1976 y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1977 y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1978 y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1981 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1982 dx = b.create<arith::ExtSIOp>(resultETy, dx);
1983 dy = b.create<arith::ExtSIOp>(resultETy, dy);
1986 Value yScaleNExt = yScaleN;
1987 Value xScaleNExt = xScaleN;
1989 const int64_t scaleBitwidth =
1991 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
1992 yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
1993 xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
1997 Value scale,
int inputSize,
2000 return b.create<arith::MulIOp>(val0, scale);
2001 Value weight0 = b.create<arith::SubIOp>(scale, weight1);
2002 Value mul0 = b.create<arith::MulIOp>(val0, weight0);
2003 Value mul1 = b.create<arith::MulIOp>(val1, weight1);
2004 return b.create<arith::AddIOp>(mul0, mul1);
2007 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
2008 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
2010 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
2011 b.create<linalg::YieldOp>(result);
2016 rewriter.replaceOp(op, resize);
2024 template <
typename SrcOp>
2029 LogicalResult matchAndRewrite(SrcOp op,
2031 rewriter.replaceOp(op, op.getOperation()->getOperands());
2036 template <
typename SrcOp>
2041 LogicalResult matchAndRewrite(SrcOp reduceOp,
2051 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2053 auto loc = op.getLoc();
2054 Value input = op.getInput1();
2055 auto inputTy = cast<ShapedType>(input.
getType());
2056 auto resultTy = cast<ShapedType>(op.getType());
2057 auto axis = op.getAxis();
2060 for (
int i = 0; i < inputTy.getRank(); i++) {
2061 if (inputTy.isDynamicDim(i)) {
2062 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2066 Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
2069 auto emptyTensor = rewriter
2070 .create<tensor::EmptyOp>(loc, inputTy.getShape(),
2071 inputTy.getElementType(),
2075 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2077 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2082 for (
unsigned int i = 0; i < inputTy.getRank(); i++) {
2084 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
2086 auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
2088 rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
2089 index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
2093 indices.push_back(index);
2096 auto extract = nestedBuilder.create<tensor::ExtractOp>(
2097 nestedLoc, input, indices);
2098 nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
2099 extract.getResult());
2113 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2115 auto loc = op.getLoc();
2116 auto input = op.getInput1();
2117 auto inputTy = cast<ShapedType>(input.getType());
2118 auto inputShape = inputTy.getShape();
2119 auto resultTy = cast<ShapedType>(op.getType());
2120 auto elementTy = inputTy.getElementType();
2121 int64_t rank = inputTy.getRank();
2124 if (failed(op.getConstantMultiples(multiples)))
2129 for (
int i = 0; i < rank; i++) {
2130 int64_t dim = multiples[i];
2131 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2132 genericShape.push_back(inputShape[i]);
2136 for (
int i = 0; i < inputTy.getRank(); i++) {
2137 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2138 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
2142 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
2143 op.getLoc(), genericShape, elementTy, dynDims);
2147 dimExprs.reserve(rank);
2148 for (
unsigned i = 0; i < rank; ++i)
2151 auto readAffineMap =
2158 auto genericOp = rewriter.
create<linalg::GenericOp>(
2163 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
2169 op, resultTy, genericOp.getResult(0), shapeValue);
2191 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2193 auto loc = argmaxOp.getLoc();
2194 Value input = argmaxOp.getInput();
2195 auto inputTy = cast<ShapedType>(input.
getType());
2196 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2197 auto inElementTy = inputTy.getElementType();
2198 auto outElementTy = resultTy.getElementType();
2199 int axis = argmaxOp.getAxis();
2202 if (!isa<IntegerType>(outElementTy))
2205 "tosa.arg_max to linalg.* requires integer-like result type");
2208 for (
int i = 0; i < inputTy.getRank(); i++) {
2209 if (inputTy.isDynamicDim(i) && i != axis) {
2210 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
2215 auto emptyTensorIdx = rewriter
2216 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2217 outElementTy, dynDims)
2219 auto fillValueIdx = rewriter.
create<arith::ConstantOp>(
2221 auto filledTensorIdx =
2228 auto emptyTensorMax = rewriter
2229 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2230 inElementTy, dynDims)
2232 auto fillValueMaxAttr =
2235 if (!fillValueMaxAttr)
2237 argmaxOp,
"unsupported tosa.argmax element type");
2240 rewriter.
create<arith::ConstantOp>(loc, fillValueMaxAttr);
2241 auto filledTensorMax =
2250 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2251 iteratorTypes[axis] = utils::IteratorType::reduction;
2255 for (
int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2261 bool didEncounterError =
false;
2264 auto linalgOp = rewriter.
create<linalg::GenericOp>(
2266 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2269 auto newValue = blockArgs[0];
2270 auto oldIndex = blockArgs[1];
2271 auto oldValue = blockArgs[2];
2273 Value newIndex = rewriter.
create<arith::IndexCastOp>(
2274 nestedLoc, oldIndex.getType(),
2275 rewriter.
create<linalg::IndexOp>(loc, axis));
2278 if (isa<FloatType>(inElementTy)) {
2279 if (argmaxOp.getNanMode() ==
"IGNORE") {
2282 predicate = rewriter.
create<arith::CmpFOp>(
2283 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2289 nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue);
2291 nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue);
2292 predicate = rewriter.
create<arith::AndIOp>(
2293 nestedLoc, rewriter.
getI1Type(), gt, oldNonNaN);
2295 }
else if (isa<IntegerType>(inElementTy)) {
2296 predicate = rewriter.
create<arith::CmpIOp>(
2297 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2299 didEncounterError =
true;
2303 auto resultMax = rewriter.
create<arith::SelectOp>(
2304 nestedLoc, predicate, newValue, oldValue);
2305 auto resultIndex = rewriter.
create<arith::SelectOp>(
2306 nestedLoc, predicate, newIndex, oldIndex);
2307 nestedBuilder.
create<linalg::YieldOp>(
2308 nestedLoc,
ValueRange({resultIndex, resultMax}));
2311 if (didEncounterError)
2313 argmaxOp,
"unsupported tosa.argmax element type");
2315 rewriter.
replaceOp(argmaxOp, linalgOp.getResult(0));
2324 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2326 auto input = adaptor.getOperands()[0];
2327 auto indices = adaptor.getOperands()[1];
2329 auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2330 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2331 if (!valuesTy || !resultTy)
2334 auto dynamicDims = inferDynamicDimsForGather(
2335 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2337 auto resultElementTy = resultTy.getElementType();
2339 auto loc = op.getLoc();
2342 .
create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2348 resultTy.getRank(), 0,
2349 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2353 auto genericOp = rewriter.
create<linalg::GenericOp>(
2358 auto indexValue = args[0];
2359 auto index0 = rewriter.
create<linalg::IndexOp>(loc, 0);
2360 Value index1 = rewriter.
create<arith::IndexCastOp>(
2362 auto index2 = rewriter.
create<linalg::IndexOp>(loc, 2);
2363 Value extract = rewriter.
create<tensor::ExtractOp>(
2364 loc, input,
ValueRange{index0, index1, index2});
2365 rewriter.
create<linalg::YieldOp>(loc, extract);
2367 rewriter.
replaceOp(op, genericOp.getResult(0));
2377 auto addDynamicDimension = [&](
Value source, int64_t dim) {
2379 if (
auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2380 results.push_back(dimValue);
2383 addDynamicDimension(values, 0);
2384 addDynamicDimension(indices, 1);
2385 addDynamicDimension(values, 2);
2397 LogicalResult matchAndRewrite(tosa::TableOp op,
2399 auto loc = op.getLoc();
2400 Value input = op.getInput1();
2402 auto inputTy = cast<ShapedType>(input.
getType());
2403 auto tableTy = cast<ShapedType>(
table.getType());
2404 auto resultTy = cast<ShapedType>(op.getType());
2406 auto inputElementTy = inputTy.getElementType();
2407 auto tableElementTy = tableTy.getElementType();
2408 auto resultElementTy = resultTy.getElementType();
2411 for (
int i = 0; i < resultTy.getRank(); ++i) {
2412 if (inputTy.isDynamicDim(i)) {
2414 rewriter.
create<tensor::DimOp>(loc, op.getOperand(0), i));
2418 auto emptyTensor = rewriter
2419 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2420 resultElementTy, dynDims)
2427 auto genericOp = rewriter.
create<linalg::GenericOp>(
2430 rewriter.
replaceOp(op, genericOp.getResult(0));
2435 &genericOp.getRegion(), genericOp.getRegion().end(),
2436 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2440 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2441 resultElementTy.isInteger(8)) {
2442 Value index = rewriter.
create<arith::IndexCastOp>(
2444 Value offset = rewriter.
create<arith::ConstantIndexOp>(loc, 128);
2449 rewriter.
create<linalg::YieldOp>(loc, extract);
2453 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2454 resultElementTy.isInteger(32)) {
2458 auto offset = rewriter.
create<arith::ConstantOp>(
2460 auto seven = rewriter.
create<arith::ConstantOp>(
2462 auto one = rewriter.
create<arith::ConstantOp>(
2464 auto b1111111 = rewriter.
create<arith::ConstantOp>(
2471 auto extendAdd = rewriter.
create<arith::AddIOp>(loc, extend, offset);
2472 Value index = rewriter.
create<arith::ShRUIOp>(loc, extendAdd, seven);
2474 rewriter.
create<arith::AndIOp>(loc, extendAdd, b1111111);
2479 Value indexPlusOne = rewriter.
create<arith::AddIOp>(loc, index, one);
2481 index = rewriter.
create<arith::IndexCastOp>(
2483 indexPlusOne = rewriter.
create<arith::IndexCastOp>(
2498 Value baseScaled = rewriter.
create<arith::ShLIOp>(loc, base, seven);
2499 Value diff = rewriter.
create<arith::SubIOp>(loc, next, base);
2500 Value diffScaled = rewriter.
create<arith::MulIOp>(loc, diff, fraction);
2502 rewriter.
create<arith::AddIOp>(loc, baseScaled, diffScaled);
2504 rewriter.
create<linalg::YieldOp>(loc, result);
2511 op,
"unable to create body for tosa.table op");
2518 static bool isRankedTensor(
Type type) {
return isa<RankedTensorType>(type); }
2522 auto one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2523 auto two = builder.
create<arith::ConstantIndexOp>(loc, 2);
2526 auto divBy2 = builder.
createOrFold<arith::DivUIOp>(loc, value, two);
2527 auto plusOne = builder.
createOrFold<arith::AddIOp>(loc, divBy2, one);
2531 static RankedTensorType
2539 dims[2] = halfPlusOne(builder, loc, dims[2]);
2544 auto elementType = cast<RankedTensorType>(input.
getType()).getElementType();
2549 RankedTensorType type,
2552 rewriter.
create<tensor::EmptyOp>(loc, type, dynamicSizes);
2553 auto fillValueAttr = rewriter.
getZeroAttr(type.getElementType());
2554 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
2555 auto filledTensor = rewriter
2559 return filledTensor;
2563 FloatType type,
Value value) {
2564 auto integerVal = builder.
create<arith::IndexCastUIOp>(
2566 type.getIntOrFloatBitWidth() > 32 ? builder.
getI64Type()
2570 return builder.
create<arith::UIToFPOp>(loc, type, integerVal);
2574 FloatType type, int64_t index) {
2575 auto indexVal = builder.
create<linalg::IndexOp>(loc, index);
2576 return castIndexToFloat(builder, loc, type, indexVal);
2579 template <
typename... Args>
2585 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2587 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2588 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2590 "only supports ranked tensors");
2593 auto loc = rfft2d.getLoc();
2594 auto input = rfft2d.getInputReal();
2596 dyn_cast<FloatType>(cast<ShapedType>(input.
getType()).getElementType());
2599 "only supports float element types");
2603 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2607 utils::IteratorType::parallel, utils::IteratorType::parallel,
2608 utils::IteratorType::parallel, utils::IteratorType::reduction,
2609 utils::IteratorType::reduction};
2614 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2615 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2620 affineDimsExpr(rewriter, 0, 1, 2),
2621 affineDimsExpr(rewriter, 0, 1, 2)},
2625 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input, 1);
2626 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input, 2);
2629 auto twoPiAttr = rewriter.
getFloatAttr(elementType, 6.283185307179586);
2630 auto twoPi = rewriter.
create<arith::ConstantOp>(loc, twoPiAttr);
2631 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2632 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2635 Value valReal = args[0];
2636 Value sumReal = args[1];
2637 Value sumImag = args[2];
2640 Value oy = builder.
create<linalg::IndexOp>(loc, 1);
2641 Value ox = builder.
create<linalg::IndexOp>(loc, 2);
2642 Value iy = builder.
create<linalg::IndexOp>(loc, 3);
2643 Value ix = builder.
create<linalg::IndexOp>(loc, 4);
2648 auto iyXoy = builder.
create<index::MulOp>(loc, iy, oy);
2649 auto ixXox = builder.
create<index::MulOp>(loc, ix, ox);
2651 auto iyRem = builder.
create<index::RemUOp>(loc, iyXoy, dimH);
2652 auto ixRem = builder.
create<index::RemUOp>(loc, ixXox, dimW);
2654 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2655 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2657 auto yComponent = builder.
create<arith::DivFOp>(loc, iyRemFloat, constH);
2658 auto xComponent = builder.
create<arith::DivFOp>(loc, ixRemFloat, constW);
2659 auto sumXY = builder.
create<arith::AddFOp>(loc, yComponent, xComponent);
2660 auto angle = builder.
create<arith::MulFOp>(loc, twoPi, sumXY);
2664 auto cosAngle = builder.
create<math::CosOp>(loc, angle);
2665 auto sinAngle = builder.
create<math::SinOp>(loc, angle);
2666 auto realComponent =
2667 builder.
create<arith::MulFOp>(loc, valReal, cosAngle);
2668 auto imagComponent =
2669 builder.
create<arith::MulFOp>(loc, valReal, sinAngle);
2673 auto outReal = builder.
create<arith::AddFOp>(loc, sumReal, realComponent);
2674 auto outImag = builder.
create<arith::SubFOp>(loc, sumImag, imagComponent);
2680 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2681 indexingMaps, iteratorTypes, buildBody);
2690 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2692 if (!llvm::all_of(fft2d->getOperandTypes(),
2693 RFFT2dConverter::isRankedTensor) ||
2694 !llvm::all_of(fft2d->getResultTypes(),
2695 RFFT2dConverter::isRankedTensor)) {
2700 Value input_real = fft2d.getInputReal();
2701 Value input_imag = fft2d.getInputImag();
2702 BoolAttr inverse = fft2d.getInverseAttr();
2704 auto real_el_ty = cast<FloatType>(
2705 cast<ShapedType>(input_real.
getType()).getElementType());
2706 [[maybe_unused]]
auto imag_el_ty = cast<FloatType>(
2707 cast<ShapedType>(input_imag.
getType()).getElementType());
2709 assert(real_el_ty == imag_el_ty);
2724 utils::IteratorType::parallel, utils::IteratorType::parallel,
2725 utils::IteratorType::parallel, utils::IteratorType::reduction,
2726 utils::IteratorType::reduction};
2731 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2733 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2738 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2739 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2740 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2741 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2745 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 1);
2746 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 2);
2749 auto twoPiAttr = rewriter.
getFloatAttr(real_el_ty, 6.283185307179586);
2750 auto twoPi = rewriter.
create<arith::ConstantOp>(loc, twoPiAttr);
2752 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2754 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2757 Value valReal = args[0];
2758 Value valImag = args[1];
2759 Value sumReal = args[2];
2760 Value sumImag = args[3];
2763 Value oy = builder.
create<linalg::IndexOp>(loc, 1);
2764 Value ox = builder.
create<linalg::IndexOp>(loc, 2);
2765 Value iy = builder.
create<linalg::IndexOp>(loc, 3);
2766 Value ix = builder.
create<linalg::IndexOp>(loc, 4);
2770 auto iyXoy = builder.
create<index::MulOp>(loc, iy, oy);
2771 auto ixXox = builder.
create<index::MulOp>(loc, ix, ox);
2773 auto iyRem = builder.
create<index::RemUOp>(loc, iyXoy, dimH);
2774 auto ixRem = builder.
create<index::RemUOp>(loc, ixXox, dimW);
2777 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2779 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2781 auto yComponent = builder.
create<arith::DivFOp>(loc, iyRemFloat, constH);
2782 auto xComponent = builder.
create<arith::DivFOp>(loc, ixRemFloat, constW);
2784 auto sumXY = builder.
create<arith::AddFOp>(loc, yComponent, xComponent);
2785 auto angle = builder.
create<arith::MulFOp>(loc, twoPi, sumXY);
2788 angle = builder.
create<arith::MulFOp>(
2790 rewriter.
create<arith::ConstantOp>(
2796 auto cosAngle = builder.
create<math::CosOp>(loc, angle);
2797 auto sinAngle = builder.
create<math::SinOp>(loc, angle);
2799 auto rcos = builder.
create<arith::MulFOp>(loc, valReal, cosAngle);
2800 auto rsin = builder.
create<arith::MulFOp>(loc, valImag, sinAngle);
2801 auto realComponent = builder.
create<arith::AddFOp>(loc, rcos, rsin);
2803 auto icos = builder.
create<arith::MulFOp>(loc, valImag, cosAngle);
2804 auto isin = builder.
create<arith::MulFOp>(loc, valReal, sinAngle);
2806 auto imagComponent = builder.
create<arith::SubFOp>(loc, icos, isin);
2810 auto outReal = builder.
create<arith::AddFOp>(loc, sumReal, realComponent);
2811 auto outImag = builder.
create<arith::AddFOp>(loc, sumImag, imagComponent);
2817 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2818 indexingMaps, iteratorTypes, buildBody);
2839 PointwiseConverter<tosa::AddOp>,
2840 PointwiseConverter<tosa::SubOp>,
2841 PointwiseConverter<tosa::MulOp>,
2842 PointwiseConverter<tosa::IntDivOp>,
2843 PointwiseConverter<tosa::NegateOp>,
2844 PointwiseConverter<tosa::PowOp>,
2845 PointwiseConverter<tosa::ReciprocalOp>,
2846 PointwiseConverter<tosa::RsqrtOp>,
2847 PointwiseConverter<tosa::LogOp>,
2848 PointwiseConverter<tosa::ExpOp>,
2849 PointwiseConverter<tosa::AbsOp>,
2850 PointwiseConverter<tosa::SinOp>,
2851 PointwiseConverter<tosa::CosOp>,
2852 PointwiseConverter<tosa::TanhOp>,
2853 PointwiseConverter<tosa::ErfOp>,
2854 PointwiseConverter<tosa::BitwiseAndOp>,
2855 PointwiseConverter<tosa::BitwiseOrOp>,
2856 PointwiseConverter<tosa::BitwiseNotOp>,
2857 PointwiseConverter<tosa::BitwiseXorOp>,
2858 PointwiseConverter<tosa::LogicalAndOp>,
2859 PointwiseConverter<tosa::LogicalNotOp>,
2860 PointwiseConverter<tosa::LogicalOrOp>,
2861 PointwiseConverter<tosa::LogicalXorOp>,
2862 PointwiseConverter<tosa::CastOp>,
2863 PointwiseConverter<tosa::LogicalLeftShiftOp>,
2864 PointwiseConverter<tosa::LogicalRightShiftOp>,
2865 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2866 PointwiseConverter<tosa::ClzOp>,
2867 PointwiseConverter<tosa::SelectOp>,
2868 PointwiseConverter<tosa::GreaterOp>,
2869 PointwiseConverter<tosa::GreaterEqualOp>,
2870 PointwiseConverter<tosa::EqualOp>,
2871 PointwiseConverter<tosa::MaximumOp>,
2872 PointwiseConverter<tosa::MinimumOp>,
2873 PointwiseConverter<tosa::CeilOp>,
2874 PointwiseConverter<tosa::FloorOp>,
2875 PointwiseConverter<tosa::ClampOp>,
2876 PointwiseConverter<tosa::SigmoidOp>
2877 >(converter,
patterns->getContext());
2880 IdentityNConverter<tosa::IdentityOp>,
2881 ReduceConverter<tosa::ReduceAllOp>,
2882 ReduceConverter<tosa::ReduceAnyOp>,
2883 ReduceConverter<tosa::ReduceMinOp>,
2884 ReduceConverter<tosa::ReduceMaxOp>,
2885 ReduceConverter<tosa::ReduceSumOp>,
2886 ReduceConverter<tosa::ReduceProductOp>,
2894 TileConverter>(
patterns->getContext());
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static bool operandsAndResultsRanked(Operation *operation)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...