29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/Sequence.h"
32 #include <type_traits>
59 template <
typename OpTy>
67 auto nanMode = op.getNanMode();
68 if (nanMode == NanPropagationMode::PROPAGATE)
72 Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
73 arith::CmpFPredicate::UNO, lhs, lhs);
74 Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
75 arith::CmpFPredicate::UNO, rhs, rhs);
77 arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN, rhs, result);
78 return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN, lhs,
90 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
91 return math::AbsFOp::create(rewriter, loc, resultTypes, args);
93 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
94 auto zero = arith::ConstantOp::create(rewriter, loc,
96 auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]);
97 return arith::MaxSIOp::create(rewriter, loc, args[0], neg);
101 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
102 return arith::AddFOp::create(rewriter, loc, resultTypes, args);
104 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
105 return arith::AddIOp::create(rewriter, loc, resultTypes, args);
108 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
109 return arith::SubFOp::create(rewriter, loc, resultTypes, args);
111 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
112 return arith::SubIOp::create(rewriter, loc, resultTypes, args);
115 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
116 return arith::DivSIOp::create(rewriter, loc, resultTypes, args);
119 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
121 arith::ConstantOp::create(rewriter, loc,
FloatAttr::get(elementTy, 1));
122 return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
126 if (isa<tosa::MulOp>(op)) {
127 auto shiftVal = cast<tosa::MulOp>(op).getShift();
129 bool shiftIsConstant =
true;
132 shift = shiftElem.
getValues<IntegerAttr>()[0].getInt();
134 shiftIsConstant =
false;
136 if (isa<FloatType>(elementTy)) {
139 "Cannot have shift value for float");
142 return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
146 if (isa<IntegerType>(elementTy)) {
150 if (shift > 0 || !shiftIsConstant) {
157 a = arith::ExtSIOp::create(rewriter, loc, rewriter.
getI32Type(), a);
160 b = arith::ExtSIOp::create(rewriter, loc, rewriter.
getI32Type(), b);
162 auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
164 RoundingMode::SINGLE_ROUND);
166 tosa::ApplyScaleOp::create(rewriter, loc, rewriter.
getI32Type(), a,
167 b, shiftAmount, roundingAttr);
174 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
177 a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a);
179 b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], b);
181 return arith::MulIOp::create(rewriter, loc, resultTypes, a, b);
186 if (isa<tosa::NegateOp>(op)) {
187 auto negate = cast<tosa::NegateOp>(op);
189 int64_t inZp = 0, outZp = 0;
190 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
191 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
192 bool hasInZp = !
failed(maybeInZp);
193 bool hasOutZp = !
failed(maybeOutZp);
199 if (isa<FloatType>(elementTy))
200 return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
202 if (isa<IntegerType>(elementTy)) {
203 if (hasInZp && hasOutZp && !inZp && !outZp) {
204 auto constant = arith::ConstantOp::create(
206 return arith::SubIOp::create(rewriter, loc, resultTypes, constant,
211 Type intermediateType;
214 int intermediateBitWidth = 64;
216 if (hasInZp && hasOutZp) {
218 const int64_t zpAdd = inZp + outZp;
219 const int64_t maxValue =
220 APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
226 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
227 intermediateBitWidth = 16;
228 }
else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
229 intermediateBitWidth = 32;
230 }
else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
231 intermediateBitWidth = 48;
235 zpAddValue = arith::ConstantOp::create(
240 arith::ExtSIOp::create(rewriter, loc, intermediateType, args[1]);
242 arith::ExtSIOp::create(rewriter, loc, intermediateType, args[2]);
244 arith::AddIOp::create(rewriter, loc, intermediateType, arg1, arg2);
250 arith::ExtSIOp::create(rewriter, loc, intermediateType, args[0]);
251 auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext);
255 rewriter, loc, intermediateType,
256 APInt::getSignedMinValue(inputBitWidth).getSExtValue());
258 rewriter, loc, intermediateType,
259 APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
263 return arith::TruncIOp::create(rewriter, loc, elementTy,
clamp);
268 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
269 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
272 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
273 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
276 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
278 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
279 auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr);
280 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes);
284 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
285 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
288 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
289 return arith::ShLIOp::create(rewriter, loc, resultTypes, args);
292 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
293 return arith::ShRUIOp::create(rewriter, loc, resultTypes, args);
296 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
297 auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args);
298 auto round = cast<BoolAttr>(op->
getAttr(
"round")).getValue();
304 auto one = arith::ConstantOp::create(rewriter, loc,
306 auto zero = arith::ConstantOp::create(rewriter, loc,
314 auto shiftValueGreaterThanZero = arith::CmpIOp::create(
315 rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero);
319 arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one);
321 arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract)
323 auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted,
326 arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
328 auto shouldRound = arith::SelectOp::create(
329 rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd, i1zero);
331 arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
332 return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended);
336 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
337 return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]);
341 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
342 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
345 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
346 auto one = arith::ConstantOp::create(rewriter, loc,
348 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one);
352 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
353 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
356 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
357 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
360 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
361 return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args);
364 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
365 return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args);
368 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
369 return mlir::math::LogOp::create(rewriter, loc, resultTypes, args);
372 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
373 return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args);
376 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
377 return mlir::math::SinOp::create(rewriter, loc, resultTypes, args);
380 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
381 return mlir::math::CosOp::create(rewriter, loc, resultTypes, args);
384 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
385 return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args);
388 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
389 return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args);
392 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
393 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT,
396 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
397 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt,
401 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
402 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE,
405 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
406 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
410 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
411 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ,
414 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
415 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
419 if (isa<tosa::SelectOp>(op)) {
421 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
422 return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]);
426 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
427 auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
429 rewriter, args[0], args[1],
max);
432 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
433 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
437 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
438 auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
440 rewriter, args[0], args[1],
min);
443 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
444 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
448 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
449 return math::CeilOp::create(rewriter, loc, resultTypes, args);
452 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
453 return math::FloorOp::create(rewriter, loc, resultTypes, args);
456 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
457 bool losesInfo =
false;
458 APFloat minApf = cast<FloatAttr>(op->
getAttr(
"min_val")).getValue();
459 APFloat maxApf = cast<FloatAttr>(op->
getAttr(
"max_val")).getValue();
460 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
461 APFloat::rmNearestTiesToEven, &losesInfo);
462 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
463 APFloat::rmNearestTiesToEven, &losesInfo);
464 auto min = arith::ConstantOp::create(
465 rewriter, loc, elementTy, rewriter.
getFloatAttr(elementTy, minApf));
466 auto max = arith::ConstantOp::create(
467 rewriter, loc, elementTy, rewriter.
getFloatAttr(elementTy, maxApf));
470 auto clampOp = llvm::cast<tosa::ClampOp>(op);
471 const auto nanMode = clampOp.getNanMode();
474 if (!isa<FloatType>(elementTy))
479 if (nanMode == NanPropagationMode::PROPAGATE)
493 Value isNaN = arith::CmpFOp::create(
494 rewriter, op->
getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
497 return arith::SelectOp::create(rewriter, op->
getLoc(), isNaN,
min, result);
500 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
501 auto intTy = cast<IntegerType>(elementTy);
503 cast<IntegerAttr>(op->
getAttr(
"min_val")).getValue().getSExtValue();
505 cast<IntegerAttr>(op->
getAttr(
"max_val")).getValue().getSExtValue();
509 if (intTy.isUnsignedInteger()) {
510 minRepresentable = 0;
511 if (intTy.getIntOrFloatBitWidth() <= 63) {
513 (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
516 }
else if (intTy.getIntOrFloatBitWidth() <= 64) {
518 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
520 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
531 intTy.getIntOrFloatBitWidth());
533 intTy.getIntOrFloatBitWidth());
535 intTy.isUnsignedInteger());
539 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
541 arith::ConstantOp::create(rewriter, loc,
FloatAttr::get(elementTy, 1));
542 auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
543 auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
544 auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
545 return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
549 if (isa<tosa::CastOp>(op)) {
550 Type srcTy = elementTy;
551 Type dstTy = resultTypes.front();
563 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
564 return arith::ExtFOp::create(rewriter, loc, resultTypes, args,
567 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
568 return arith::TruncFOp::create(rewriter, loc, resultTypes, args,
572 if (srcTy.
isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
573 return arith::UIToFPOp::create(rewriter, loc, resultTypes, args,
576 if (srcTy.
isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
577 return arith::ExtUIOp::create(rewriter, loc, resultTypes, args,
583 auto unrealizedCast =
584 UnrealizedConversionCastOp::create(
588 return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
593 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
594 return arith::SIToFPOp::create(rewriter, loc, resultTypes, args,
598 if (isa<FloatType>(srcTy) && dstTy.
isInteger(1)) {
599 Value zero = arith::ConstantOp::create(rewriter, loc,
601 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE,
605 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
606 auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]);
608 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
612 APFloat::semanticsMaxExponent(fltSemantics)) {
615 auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded);
616 auto posInf = arith::ConstantOp::create(
619 APFloat::getInf(fltSemantics)));
620 auto negInf = arith::ConstantOp::create(
624 APFloat::getInf(fltSemantics,
true)));
625 auto overflow = arith::CmpFOp::create(
626 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf);
627 auto underflow = arith::CmpFOp::create(
628 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf);
629 auto intMin = arith::ConstantOp::create(
634 auto intMax = arith::ConstantOp::create(
640 arith::SelectOp::create(rewriter, loc, overflow, intMax, conv);
641 return arith::SelectOp::create(rewriter, loc, underflow, intMin,
645 auto intMinFP = arith::ConstantOp::create(
653 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
659 auto intMaxFP = arith::ConstantOp::create(
668 return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped);
675 auto intMaxPlusOneFP = arith::ConstantOp::create(
684 auto intMax = arith::ConstantOp::create(
690 arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP);
692 arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP);
693 auto overflow = arith::CmpFOp::create(
694 rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
695 return arith::SelectOp::create(rewriter, loc, overflow, intMax,
701 if (isa<IntegerType>(srcTy) && dstTy.
isInteger(1)) {
704 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne,
708 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
709 return arith::ExtSIOp::create(rewriter, loc, resultTypes, args,
712 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
713 return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]);
718 op,
"unhandled op for linalg body calculation for elementwise op");
729 auto [it, inserted] = indexPool.try_emplace(index);
732 arith::ConstantOp::create(rewriter, loc, rewriter.
getIndexAttr(index));
738 auto indexValue =
createIndex(rewriter, loc, indexPool, index);
739 return tensor::DimOp::create(rewriter, loc, tensor, indexValue).getResult();
745 auto shapedType = dyn_cast<ShapedType>(tensor.
getType());
746 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
747 assert(index >= 0 && index < shapedType.getRank() &&
"index out of bounds");
748 if (shapedType.isDynamicDim(index))
749 return getTensorDim(rewriter, loc, indexPool, tensor, index);
750 return rewriter.
getIndexAttr(shapedType.getDimSize(index));
754 auto isRanked = [](
Value value) {
755 return isa<RankedTensorType>(value.getType());
757 return llvm::all_of(operation->
getOperands(), isRanked) &&
758 llvm::all_of(operation->
getResults(), isRanked);
771 static std::pair<OpFoldResult, Value>
777 for (
auto operand : operands) {
778 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
779 if (ShapedType::isStatic(size) && size > 1)
784 auto operandsWithDynamicDim =
785 llvm::filter_to_vector(operands, [&](
Value operand) {
786 return cast<RankedTensorType>(operand.
getType()).isDynamicDim(dim);
790 if (operandsWithDynamicDim.empty())
797 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
798 if (operandsWithDynamicDim.size() == 1)
799 return {targetSize, operandsWithDynamicDim[0]};
802 for (
size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
804 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
805 targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize);
807 return {targetSize,
nullptr};
815 assert(!operands.empty());
816 auto rank = cast<RankedTensorType>(operands.front().
getType()).getRank();
819 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
820 auto [targetSize, masterOperand] =
822 targetShape.push_back(targetSize);
823 masterOperands.push_back(masterOperand);
825 return {targetShape, masterOperands};
831 Value masterOperand) {
833 auto rankedTensorType = cast<RankedTensorType>(operand.
getType());
834 if (!rankedTensorType.isDynamicDim(dim))
841 if (operand == masterOperand)
845 auto rank = rankedTensorType.getRank();
847 for (
auto index : llvm::seq<int64_t>(0, rank)) {
850 affineExprs.push_back(affineExpr);
852 auto broadcastAffineMap =
858 auto one =
createIndex(rewriter, loc, indexPool, 1);
859 auto runtimeSize =
getTensorDim(rewriter, loc, indexPool, operand, dim);
860 auto broadcastNecessary = arith::CmpIOp::create(
861 rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one);
871 for (
auto index : llvm::seq<int64_t>(0, rank)) {
872 auto size = index == dim ? targetSize
875 outputTensorShape.push_back(size);
877 Value outputTensor = tensor::EmptyOp::create(
878 opBuilder, loc, outputTensorShape, rankedTensorType.getElementType());
882 linalg::GenericOp::create(
883 opBuilder, loc, outputTensor.
getType(), operand, outputTensor,
887 linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
892 auto castResultTensor = rewriter.
createOrFold<tensor::CastOp>(
893 loc, operand.
getType(), resultTensor);
896 scf::YieldOp::create(opBuilder, loc, castResultTensor);
901 scf::YieldOp::create(opBuilder, loc, operand);
905 auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary,
906 emitThenRegion, emitElseRegion);
907 return ifOp.getResult(0);
914 int64_t rank = cast<RankedTensorType>(operand.
getType()).getRank();
915 assert((int64_t)targetShape.size() == rank);
916 assert((int64_t)masterOperands.size() == rank);
917 for (
auto index : llvm::seq<int64_t>(0, rank))
920 targetShape[index], masterOperands[index]);
930 if (operands.size() == 1)
934 bool hasDynamic =
false;
935 for (
auto op : operands) {
936 const auto tType = dyn_cast<RankedTensorType>(op.getType());
937 if (tType && !tType.hasStaticShape()) {
946 return llvm::map_to_vector(operands, [&](
Value operand) {
948 targetShape, masterOperands);
958 auto resultType = cast_or_null<RankedTensorType>(
963 Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape,
964 resultType.getElementType());
969 auto rank = resultType.getRank();
970 auto affineMaps = llvm::map_to_vector(operands, [&](
Value operand) {
971 auto shape = cast<ShapedType>(operand.
getType()).getShape();
977 bool requiresBroadcast =
978 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
979 auto affineExpr = requiresBroadcast
982 affineExprs.push_back(affineExpr);
989 bool encounteredError =
false;
990 auto linalgOp = linalg::GenericOp::create(
991 rewriter, loc, outputTensor.
getType(), operands, outputTensor, affineMaps,
996 {resultType.getElementType()}, rewriter);
998 encounteredError =
true;
1001 linalg::YieldOp::create(opBuilder, loc, opResult);
1003 if (encounteredError)
1005 operation,
"unable to create linalg.generic body for elementwise op");
1008 auto castResult = rewriter.
createOrFold<tensor::CastOp>(
1009 loc, resultType, linalgOp->getResult(0));
1010 rewriter.
replaceOp(operation, castResult);
1017 if (isa<tosa::MulOp>(operation)) {
1021 return operands.take_front(2);
1023 return operands.take_front(3);
1025 if (
auto negate = dyn_cast<tosa::NegateOp>(operation)) {
1026 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
1027 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
1031 return operands.take_front(1);
1036 static LogicalResult
1042 assert(operation->
getNumResults() == 1 &&
"elementwise op expects 1 result");
1044 "elementwise op expects at least 1 operand");
1047 "Unranked tensors not supported");
1051 auto loc = operation->
getLoc();
1053 auto [targetShape, masterOperands] =
1055 auto broadcastOperands =
1057 targetShape, masterOperands);
1059 targetShape, converter);
1066 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1069 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1072 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1075 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1078 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1080 elementTy, APFloat::getLargest(
1081 cast<FloatType>(elementTy).getFloatSemantics(),
false));
1083 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1087 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1089 elementTy, APFloat::getLargest(
1090 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1092 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1096 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1099 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1102 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1104 elementTy, APFloat::getLargest(
1105 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1107 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1121 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1122 return arith::AddFOp::create(rewriter, loc, args);
1125 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1126 return arith::AddIOp::create(rewriter, loc, args);
1129 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1130 return arith::MulFOp::create(rewriter, loc, args);
1133 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1134 return arith::MulIOp::create(rewriter, loc, args);
1137 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1138 return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
1141 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1142 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
1145 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1146 return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
1149 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1150 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
1153 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1154 return arith::AndIOp::create(rewriter, loc, args);
1156 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1157 return arith::OrIOp::create(rewriter, loc, args);
1165 template <
typename OpTy>
1168 auto loc = op->getLoc();
1169 auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1170 auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1171 if (!inputTy || !resultTy)
1174 auto elementTy = resultTy.getElementType();
1175 Value input = op->getOperand(0);
1178 bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
1179 isa<FloatType>(elementTy) &&
1180 cast<FloatType>(elementTy).isBF16();
1185 for (
unsigned i = 0; i < inputTy.getRank(); i++) {
1187 reduceShape.push_back(inputTy.getDimSize(i));
1188 if (inputTy.isDynamicDim(i))
1189 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1194 inputs.push_back(input);
1198 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1204 op,
"No initial value found for reduction operation");
1206 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
1208 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValue},
1211 outputs.push_back(filledTensor);
1213 bool isNanIgnoreMode =
false;
1214 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1215 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1217 if (isa<FloatType>(elementTy) &&
1218 op.getNanMode() == NanPropagationMode::IGNORE) {
1219 isNanIgnoreMode =
true;
1225 auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
1226 auto emptyBoolTensor =
1227 tensor::EmptyOp::create(rewriter, loc, reduceShape,
1228 trueValue.getType(), dynDims)
1230 auto allResultsNaNTensor =
1231 linalg::FillOp::create(rewriter, loc,
ValueRange{trueValue},
1243 inputs.push_back(input);
1244 outputs.push_back(allResultsNaNTensor);
1248 bool didEncounterError =
false;
1249 linalg::LinalgOp linalgOp = linalg::ReduceOp::create(
1250 rewriter, loc, inputs, outputs, axis,
1252 std::array<Value, 2> binaryArgs{
1253 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1256 if (binaryArgs[0].
getType() != accTy)
1257 binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
1263 didEncounterError =
true;
1266 if (isNanIgnoreMode) {
1267 auto inputValue = blockArgs[0];
1268 auto initialValue = blockArgs[2];
1269 auto oldAllResultsNanFlagValue = blockArgs[3];
1272 Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(),
1273 arith::CmpFPredicate::UNO,
1274 inputValue, inputValue);
1276 auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(),
1277 isNaN, initialValue, result);
1280 auto newAllResultsNanFlagValue = arith::AndIOp::create(
1281 nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1282 resultsToYield.push_back(selectOp);
1283 resultsToYield.push_back(newAllResultsNanFlagValue);
1285 resultsToYield.push_back(result);
1287 linalg::YieldOp::create(nestedBuilder, loc, resultsToYield);
1290 if (!didEncounterError)
1292 op,
"unable to create linalg.generic body for reduce op");
1294 if (isNanIgnoreMode) {
1303 APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(),
false));
1304 auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
1305 auto emptyNanTensor =
1306 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1308 auto nanFilledTensor =
1309 linalg::FillOp::create(rewriter, loc,
ValueRange{nanValue},
1315 auto finalEmptyTensor =
1316 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1322 ins.push_back(linalgOp->getOpResult(1));
1323 ins.push_back(nanFilledTensor);
1324 ins.push_back(linalgOp->getResult(0));
1325 outs.push_back(finalEmptyTensor);
1327 linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs);
1328 linalgOp = linalgSelect;
1332 Value reducedRes = linalgOp->getResult(0);
1335 tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
1338 const unsigned reducedRank =
1339 cast<ShapedType>(reducedRes.
getType()).getRank();
1342 linalg::GenericOp::create(
1348 Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc,
1349 elementTy, args[0]);
1350 linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf);
1356 uint64_t expandInputRank = cast<ShapedType>(reducedRes.
getType()).getRank();
1357 reassociationMap.resize(expandInputRank);
1359 for (uint64_t i = 0; i < expandInputRank; i++) {
1360 int32_t dimToPush = i > axis ? i + 1 : i;
1364 if (expandInputRank != 0) {
1365 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1366 reassociationMap[expandedDim].push_back(
1381 template <
typename SrcOp>
1388 matchAndRewrite(SrcOp op, OpAdaptor operands,
1391 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1401 auto inputType = cast<RankedTensorType>(input.
getType());
1402 auto elemType = inputType.getElementType();
1405 return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input,
1412 output.reserve(input.size());
1414 for (
auto v : llvm::map_range(
1415 input, [](int32_t val) {
return static_cast<int8_t
>(val); })) {
1416 output.push_back(v);
1428 static void setupLinalgGenericOpInputAndIndexingMap(
1431 bool isConstant, tosa::RescaleOp op,
Value &constant, int64_t &arg,
1432 bool isShift =
false) {
1434 auto loc = op.getLoc();
1435 auto inputTy = cast<ShapedType>(op.getInput().getType());
1436 unsigned rank = inputTy.getRank();
1442 if (values.size() == 1) {
1443 IntegerAttr intAttr = isShift
1446 constant = arith::ConstantOp::create(rewriter, loc, intAttr);
1451 {
static_cast<int64_t
>(values.size())}, elementType);
1457 genericInputs.push_back(
1458 arith::ConstantOp::create(rewriter, loc, EltAttr));
1466 auto operand = isShift ? op.getShift() : op.getMultiplier();
1467 auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
1468 if (tensorType && tensorType.hasStaticShape() &&
1469 tensorType.getShape()[0] == 1) {
1474 genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
1475 indexingMaps.push_back(broadcastMap);
1477 genericInputs.push_back(operand);
1483 arg = indexingMaps.size() - 1;
1488 FailureOr<int64_t> maybeZp,
Location loc,
1490 bool isOutputZp =
false) {
1493 const uint32_t attrBitwidth =
1494 isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
1501 result = blockArgs[zpArg];
1503 if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
1506 if (zpTy.isUnsignedInteger()) {
1508 UnrealizedConversionCastOp::create(
1513 if (zpTy.isUnsignedInteger()) {
1514 return arith::ExtUIOp::create(builder, loc, extendType, result);
1516 return arith::ExtSIOp::create(builder, loc, extendType, result);
1520 return arith::ConstantOp::create(builder, loc,
1530 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1532 auto loc = op.getLoc();
1533 auto input = op.getInput();
1534 auto inputTy = cast<ShapedType>(op.getInput().getType());
1535 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1536 unsigned rank = inputTy.getRank();
1539 if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND)
1541 op,
"tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1542 "currently supported");
1543 if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32())
1545 op,
"tosa.rescale requires scale32 for double_round to be true");
1547 if (!isa<IntegerType>(inputTy.getElementType()))
1551 for (
int i = 0; i < outputTy.getRank(); i++) {
1552 if (outputTy.isDynamicDim(i)) {
1553 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1558 bool isShiftConstant =
false;
1560 isShiftConstant =
true;
1563 bool isMultiplierConstant =
false;
1565 isMultiplierConstant =
true;
1571 if (isMultiplierConstant && isShiftConstant) {
1573 shiftValues = llvm::to_vector(llvm::map_range(
1574 shiftElems.
getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
1575 return static_cast<int32_t>(attr.getInt());
1577 multiplierValues = llvm::to_vector(
1578 llvm::map_range(multiplierElems.
getValues<IntegerAttr>(),
1579 [](IntegerAttr attr) -> int32_t {
1580 return static_cast<int32_t>(attr.getInt());
1584 for (
int i = 0, s = multiplierValues.size(); i < s; i++) {
1585 if (shiftValues[i] > 63) {
1587 multiplierValues[i] = 0;
1592 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1593 llvm::any_of(shiftValues, [](int32_t v) {
return v > 31; });
1595 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
1597 RoundingMode roundingMode =
1598 doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
1606 Value multiplierConstant;
1607 int64_t multiplierArg = 0;
1608 setupLinalgGenericOpInputAndIndexingMap(
1609 rewriter, multiplierValues, genericInputs, indexingMaps,
1610 isMultiplierConstant, op, multiplierConstant, multiplierArg);
1614 Value shiftConstant;
1615 int64_t shiftArg = 0;
1616 setupLinalgGenericOpInputAndIndexingMap(
1617 rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1618 shiftConstant, shiftArg,
true);
1623 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1624 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1634 genericInputs.push_back(
1635 collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
1636 indexingMaps.push_back(broadcastMap);
1637 iZpArg = indexingMaps.size() - 1;
1641 genericInputs.push_back(
1642 collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
1643 indexingMaps.push_back(broadcastMap);
1644 oZpArg = indexingMaps.size() - 1;
1651 Value emptyTensor = tensor::EmptyOp::create(
1652 rewriter, loc, outputTy.getShape(), outputTy.getElementType(),
1655 auto linalgOp = linalg::GenericOp::create(
1656 rewriter, loc, outputTy, genericInputs,
ValueRange{emptyTensor},
1660 Value value = blockArgs[0];
1663 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1664 auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
1665 nestedLoc, blockArgs, iZpArg);
1667 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1668 auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
1669 nestedLoc, blockArgs, oZpArg,
true);
1671 IntegerType outIntType =
1672 cast<IntegerType>(blockArgs.back().
getType());
1673 unsigned outBitWidth = outIntType.getWidth();
1674 assert(outBitWidth <= 32 &&
"Unexpected output zeropoint bitwidth");
1676 Value multiplier = multiplierConstant ? multiplierConstant
1677 : blockArgs[multiplierArg];
1678 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1681 value = UnrealizedConversionCastOp::create(
1682 nestedBuilder, nestedLoc,
1683 nestedBuilder.getIntegerType(
1689 if (op.getInputUnsigned()) {
1690 value = arith::ExtUIOp::create(nestedBuilder, nestedLoc,
1691 nestedBuilder.getI32Type(), value);
1693 value = arith::ExtSIOp::create(nestedBuilder, nestedLoc,
1694 nestedBuilder.getI32Type(), value);
1699 arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp);
1701 value = tosa::ApplyScaleOp::create(nestedBuilder, loc,
1702 nestedBuilder.getI32Type(), value,
1703 multiplier, shift, roundingMode);
1707 arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp);
1710 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1711 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1714 if (op.getOutputUnsigned()) {
1716 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1719 auto intMinVal = arith::ConstantOp::create(
1720 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin));
1721 auto intMaxVal = arith::ConstantOp::create(
1722 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax));
1725 nestedBuilder,
false);
1727 if (outIntType.getWidth() < 32) {
1728 value = arith::TruncIOp::create(
1729 nestedBuilder, nestedLoc,
1733 if (outIntType.isUnsignedInteger()) {
1734 value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
1738 linalg::YieldOp::create(nestedBuilder, loc, value);
1741 rewriter.
replaceOp(op, linalgOp->getResults());
1753 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1757 auto input = op.getInput();
1758 auto inputTy = cast<RankedTensorType>(input.getType());
1759 auto resultTy = cast<RankedTensorType>(op.getType());
1760 const bool isBilinear = op.getMode() == ResizeMode::BILINEAR;
1762 auto inputH = inputTy.getDimSize(1);
1763 auto inputW = inputTy.getDimSize(2);
1764 auto outputH = resultTy.getDimSize(1);
1765 auto outputW = resultTy.getDimSize(2);
1767 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1769 op,
"tosa.resize is not a pure 1x1->1x1 image operation");
1771 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1772 op.getMode() != ResizeMode::BILINEAR)
1774 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1776 if (inputTy == resultTy) {
1795 inputTy.getElementType());
1796 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input,
1801 if (inputTy.isDynamicDim(0))
1802 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1803 if (inputTy.isDynamicDim(3))
1804 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1807 auto genericTy = collapseTy.clone(resultTy.getElementType());
1809 tensor::EmptyOp::create(builder, genericTy.getShape(),
1810 resultTy.getElementType(), outputDynSize);
1813 utils::IteratorType::parallel);
1815 auto generic = linalg::GenericOp::create(
1819 Value value = args[0];
1821 if (inputTy.getElementType() != resultTy.getElementType()) {
1822 value = arith::ExtSIOp::create(b, loc, resultTy.getElementType(),
1825 if (isBilinear && scale[0] != 0) {
1826 Value scaleY = arith::ConstantOp::create(
1827 b, loc, b.getI32IntegerAttr(scale[0]));
1828 value = arith::MulIOp::create(b, loc, value, scaleY);
1831 if (isBilinear && scale[2] != 0) {
1832 Value scaleX = arith::ConstantOp::create(
1833 b, loc, b.getI32IntegerAttr(scale[2]));
1834 value = arith::MulIOp::create(b, loc, value, scaleX);
1838 linalg::YieldOp::create(b, loc, value);
1842 op, resultTy,
generic.getResults()[0], reassociationMap);
1854 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1858 auto input = op.getInput();
1859 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1860 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1862 if (!inputTy || !resultTy)
1864 "requires ranked input/output types");
1866 auto batch = inputTy.getDimSize(0);
1867 auto channels = inputTy.getDimSize(3);
1868 auto inputH = inputTy.getDimSize(1);
1869 auto inputW = inputTy.getDimSize(2);
1870 auto outputH = resultTy.getDimSize(1);
1871 auto outputW = resultTy.getDimSize(2);
1873 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1875 op,
"tosa.resize has no broadcasting behavior");
1880 resizeShape.push_back(batch);
1881 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1882 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1883 resizeShape.push_back(channels);
1885 auto resizeTy = resultTy.clone(resizeShape);
1887 tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(),
1888 op.getOffset(), op.getBorder(), op.getMode());
1895 reassociationMap.push_back({});
1898 reassociationMap.push_back({});
1903 collapseShape.push_back(outputH);
1905 collapseShape.push_back(outputW);
1906 collapseShape.push_back(channels);
1908 auto collapseTy = resultTy.clone(collapseShape);
1909 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy,
1910 resize, reassociationMap);
1914 if (inputTy.isDynamicDim(0))
1915 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1916 if (inputTy.isDynamicDim(3))
1917 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1920 utils::IteratorType::parallel);
1921 Value empty = tensor::EmptyOp::create(
1922 builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1939 Value value = args[0];
1940 linalg::YieldOp::create(b, loc, value);
1951 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1955 auto input = op.getInput();
1956 auto inputTy = cast<ShapedType>(input.getType());
1957 auto resultTy = cast<ShapedType>(op.getType());
1958 auto resultETy = resultTy.getElementType();
1960 bool floatingPointMode = isa<FloatType>(resultETy);
1961 auto floatTy = resultETy;
1963 auto imageH = inputTy.getShape()[1];
1964 auto imageW = inputTy.getShape()[2];
1966 auto dynamicDimsOr =
1968 if (!dynamicDimsOr.has_value())
1970 op,
"unable to get dynamic dimensions of tosa.resize");
1972 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1973 op.getMode() != ResizeMode::BILINEAR)
1975 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1979 auto emptyTensor = tensor::EmptyOp::create(b, resultTy.getShape(),
1980 resultETy, *dynamicDimsOr);
1981 auto genericOp = linalg::GenericOp::create(
1984 Value resize = genericOp.getResult(0);
1988 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1990 Value batch = linalg::IndexOp::create(b, 0);
1991 Value y = linalg::IndexOp::create(b, 1);
1992 Value x = linalg::IndexOp::create(b, 2);
1993 Value channel = linalg::IndexOp::create(b, 3);
1996 arith::ConstantOp::create(b, b.getZeroAttr(b.getI32Type()));
1997 Value zeroFp = arith::ConstantOp::create(b, b.getZeroAttr(floatTy));
1999 arith::ConstantOp::create(b, b.getI32IntegerAttr(imageH - 1));
2001 arith::ConstantOp::create(b, b.getI32IntegerAttr(imageW - 1));
2003 Value inY = arith::IndexCastOp::create(b, b.getI32Type(), y);
2004 Value inX = arith::IndexCastOp::create(b, b.getI32Type(), x);
2011 op,
"tosa.resize scale/offset/border should have compile time "
2012 "constant values.");
2015 Value yScaleN, yScaleD, xScaleN, xScaleD;
2016 yScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[0]));
2017 yScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[1]));
2018 xScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[2]));
2019 xScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[3]));
2021 Value yOffset, xOffset, yBorder, xBorder;
2022 yOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[0]));
2023 xOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[1]));
2024 yBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[0]));
2025 xBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[1]));
2038 Value val = arith::MulIOp::create(b, in, scaleD);
2039 val = arith::AddIOp::create(b, val, offset);
2040 index = arith::FloorDivSIOp::create(b, val, scaleN);
2044 Value r = arith::RemSIOp::create(b, val, scaleN);
2045 Value rFp = arith::SIToFPOp::create(b, floatTy, r);
2046 Value scaleNfp = arith::UIToFPOp::create(b, floatTy, scaleN);
2047 delta = arith::DivFOp::create(b, rFp, scaleNfp);
2062 Value val = arith::MulIOp::create(b, in, scaleD);
2063 val = arith::AddIOp::create(b, val, offset);
2064 index = arith::DivSIOp::create(b, val, scaleN);
2065 delta = arith::MulIOp::create(b, index, scaleN);
2066 delta = arith::SubIOp::create(b, val, delta);
2069 Value ix, iy, dx, dy;
2070 if (floatingPointMode) {
2071 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
2072 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
2074 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
2075 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
2078 if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) {
2079 auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
2081 auto getNearestIndexAndClamp = [&](
Value val,
Value dval,
Value scale,
2089 if (floatingPointMode) {
2091 arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 0.5f));
2092 pred = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, dval, h);
2094 Value dvalDouble = arith::ShLIOp::create(b, dval, one);
2095 pred = arith::CmpIOp::create(b, arith::CmpIPredicate::sge,
2099 auto offset = arith::SelectOp::create(b, pred, one, zeroI32);
2100 val = arith::AddIOp::create(b, val, offset);
2102 return arith::IndexCastOp::create(b, b.getIndexType(), val);
2105 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
2106 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
2108 Value result = tensor::ExtractOp::create(
2109 b, input,
ValueRange{batch, iy, ix, channel});
2111 linalg::YieldOp::create(b, result);
2114 assert(op.getMode() == ResizeMode::BILINEAR);
2116 auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
2118 auto getClampedIdxs = [&](
Value &val0,
Value &val1,
int size,
Value in,
2121 val1 = arith::AddIOp::create(b, val0, oneVal);
2126 val0 = arith::IndexCastOp::create(b, b.getIndexType(), val0);
2127 val1 = arith::IndexCastOp::create(b, b.getIndexType(), val1);
2135 Value x0, x1, y0, y1;
2136 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
2137 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
2139 Value y0x0 = tensor::ExtractOp::create(
2140 b, input,
ValueRange{batch, y0, x0, channel});
2141 Value y0x1 = tensor::ExtractOp::create(
2142 b, input,
ValueRange{batch, y0, x1, channel});
2143 Value y1x0 = tensor::ExtractOp::create(
2144 b, input,
ValueRange{batch, y1, x0, channel});
2145 Value y1x1 = tensor::ExtractOp::create(
2146 b, input,
ValueRange{batch, y1, x1, channel});
2148 if (floatingPointMode) {
2150 arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 1.0f));
2156 Value oneMinusDelta = arith::SubFOp::create(b, oneVal, delta);
2157 Value mul0 = arith::MulFOp::create(b, val0, oneMinusDelta);
2158 Value mul1 = arith::MulFOp::create(b, val1, delta);
2159 return arith::AddFOp::create(b, mul0, mul1);
2165 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
2170 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
2174 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
2175 linalg::YieldOp::create(b, result);
2178 y0x0 = arith::ExtSIOp::create(b, resultETy, y0x0);
2179 y0x1 = arith::ExtSIOp::create(b, resultETy, y0x1);
2180 y1x0 = arith::ExtSIOp::create(b, resultETy, y1x0);
2181 y1x1 = arith::ExtSIOp::create(b, resultETy, y1x1);
2184 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
2185 dx = arith::ExtSIOp::create(b, resultETy, dx);
2186 dy = arith::ExtSIOp::create(b, resultETy, dy);
2189 Value yScaleNExt = yScaleN;
2190 Value xScaleNExt = xScaleN;
2192 const int64_t scaleBitwidth =
2194 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2195 yScaleNExt = arith::ExtSIOp::create(b, resultETy, yScaleN);
2196 xScaleNExt = arith::ExtSIOp::create(b, resultETy, xScaleN);
2200 Value scale,
int inputSize,
2203 return arith::MulIOp::create(b, val0, scale);
2204 Value weight0 = arith::SubIOp::create(b, scale, weight1);
2205 Value mul0 = arith::MulIOp::create(b, val0, weight0);
2206 Value mul1 = arith::MulIOp::create(b, val1, weight1);
2207 return arith::AddIOp::create(b, mul0, mul1);
2210 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
2211 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
2213 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
2214 linalg::YieldOp::create(b, result);
2227 template <
typename SrcOp>
2232 LogicalResult matchAndRewrite(SrcOp op,
2234 rewriter.
replaceOp(op, op.getOperation()->getOperands());
2239 template <
typename SrcOp>
2244 LogicalResult matchAndRewrite(SrcOp reduceOp,
2254 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2256 auto loc = op.getLoc();
2257 Value input = op.getInput1();
2258 auto inputTy = cast<ShapedType>(input.
getType());
2259 auto resultTy = cast<ShapedType>(op.getType());
2260 auto axis = op.getAxis();
2263 for (
int i = 0; i < inputTy.getRank(); i++) {
2264 if (inputTy.isDynamicDim(i)) {
2265 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2269 Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
2272 auto emptyTensor = tensor::EmptyOp::create(
2273 rewriter, loc, inputTy.getShape(),
2284 for (
unsigned int i = 0; i < inputTy.getRank(); i++) {
2286 linalg::IndexOp::create(rewriter, nestedLoc, i).getResult();
2290 arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one);
2291 index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne,
2295 indices.push_back(index);
2298 auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc,
2300 linalg::YieldOp::create(nestedBuilder, op.getLoc(),
2301 extract.getResult());
2315 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2317 auto loc = op.getLoc();
2318 auto input = op.getInput1();
2319 auto inputTy = cast<ShapedType>(input.
getType());
2320 auto inputShape = inputTy.getShape();
2321 auto resultTy = cast<ShapedType>(op.getType());
2322 auto elementTy = inputTy.getElementType();
2323 int64_t rank = inputTy.getRank();
2326 if (
failed(op.getConstantMultiples(multiples)))
2331 for (
int i = 0; i < rank; i++) {
2332 int64_t dim = multiples[i];
2333 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2334 genericShape.push_back(inputShape[i]);
2338 for (
int i = 0; i < inputTy.getRank(); i++) {
2339 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2340 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2344 auto emptyTensor = tensor::EmptyOp::create(
2345 rewriter, op.getLoc(), genericShape, elementTy, dynDims);
2349 dimExprs.reserve(rank);
2350 for (
unsigned i = 0; i < rank; ++i)
2353 auto readAffineMap =
2360 auto genericOp = linalg::GenericOp::create(
2365 linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin());
2371 op, resultTy, genericOp.getResult(0), shapeValue);
2393 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2395 auto loc = argmaxOp.getLoc();
2396 Value input = argmaxOp.getInput();
2397 auto inputTy = cast<ShapedType>(input.
getType());
2398 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2399 auto inElementTy = inputTy.getElementType();
2400 auto outElementTy = resultTy.getElementType();
2401 int axis = argmaxOp.getAxis();
2404 if (!isa<IntegerType>(outElementTy))
2407 "tosa.arg_max to linalg.* requires integer-like result type");
2410 for (
int i = 0; i < inputTy.getRank(); i++) {
2411 if (inputTy.isDynamicDim(i) && i != axis) {
2412 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2417 auto emptyTensorIdx =
2418 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2419 outElementTy, dynDims)
2421 auto fillValueIdx = arith::ConstantOp::create(
2423 auto filledTensorIdx =
2424 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValueIdx},
2429 auto emptyTensorMax =
2430 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
2433 auto fillValueMaxAttr =
2436 if (!fillValueMaxAttr)
2438 argmaxOp,
"unsupported tosa.argmax element type");
2441 arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
2442 auto filledTensorMax =
2443 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValueMax},
2450 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2451 iteratorTypes[axis] = utils::IteratorType::reduction;
2455 for (
int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2461 bool didEncounterError =
false;
2464 auto linalgOp = linalg::GenericOp::create(
2466 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2469 auto newValue = blockArgs[0];
2470 auto oldIndex = blockArgs[1];
2471 auto oldValue = blockArgs[2];
2473 Value newIndex = arith::IndexCastOp::create(
2474 rewriter, nestedLoc, oldIndex.getType(),
2475 linalg::IndexOp::create(rewriter, loc, axis));
2478 if (isa<FloatType>(inElementTy)) {
2479 if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) {
2482 predicate = arith::CmpFOp::create(rewriter, nestedLoc,
2483 arith::CmpFPredicate::OGT,
2484 newValue, oldValue);
2489 Value gt = arith::CmpFOp::create(rewriter, nestedLoc,
2490 arith::CmpFPredicate::UGT,
2491 newValue, oldValue);
2492 Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc,
2493 arith::CmpFPredicate::ORD,
2494 oldValue, oldValue);
2495 predicate = arith::AndIOp::create(
2496 rewriter, nestedLoc, rewriter.
getI1Type(), gt, oldNonNaN);
2498 }
else if (isa<IntegerType>(inElementTy)) {
2499 predicate = arith::CmpIOp::create(rewriter, nestedLoc,
2500 arith::CmpIPredicate::sgt,
2501 newValue, oldValue);
2503 didEncounterError =
true;
2507 auto resultMax = arith::SelectOp::create(
2508 rewriter, nestedLoc, predicate, newValue, oldValue);
2509 auto resultIndex = arith::SelectOp::create(
2510 rewriter, nestedLoc, predicate, newIndex, oldIndex);
2511 linalg::YieldOp::create(nestedBuilder, nestedLoc,
2515 if (didEncounterError)
2517 argmaxOp,
"unsupported tosa.argmax element type");
2519 rewriter.
replaceOp(argmaxOp, linalgOp.getResult(0));
2528 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2530 auto input = adaptor.getOperands()[0];
2531 auto indices = adaptor.getOperands()[1];
2533 auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2534 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2535 if (!valuesTy || !resultTy)
2538 auto dynamicDims = inferDynamicDimsForGather(
2539 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2541 auto resultElementTy = resultTy.getElementType();
2543 auto loc = op.getLoc();
2545 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2546 resultElementTy, dynamicDims)
2551 resultTy.getRank(), 0,
2552 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2556 auto genericOp = linalg::GenericOp::create(
2561 auto indexValue = args[0];
2562 auto index0 = linalg::IndexOp::create(rewriter, loc, 0);
2563 Value index1 = arith::IndexCastOp::create(
2565 auto index2 = linalg::IndexOp::create(rewriter, loc, 2);
2566 Value extract = tensor::ExtractOp::create(
2567 rewriter, loc, input,
ValueRange{index0, index1, index2});
2568 linalg::YieldOp::create(rewriter, loc, extract);
2570 rewriter.
replaceOp(op, genericOp.getResult(0));
2580 auto addDynamicDimension = [&](
Value source, int64_t dim) {
2582 if (
auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2583 results.push_back(dimValue);
2586 addDynamicDimension(values, 0);
2587 addDynamicDimension(indices, 1);
2588 addDynamicDimension(values, 2);
2600 LogicalResult matchAndRewrite(tosa::TableOp op,
2602 auto loc = op.getLoc();
2603 Value input = op.getInput1();
2605 auto inputTy = cast<ShapedType>(input.
getType());
2606 auto tableTy = cast<ShapedType>(
table.getType());
2607 auto resultTy = cast<ShapedType>(op.getType());
2609 auto inputElementTy = inputTy.getElementType();
2610 auto tableElementTy = tableTy.getElementType();
2611 auto resultElementTy = resultTy.getElementType();
2614 for (
int i = 0; i < resultTy.getRank(); ++i) {
2615 if (inputTy.isDynamicDim(i)) {
2617 tensor::DimOp::create(rewriter, loc, op.getOperand(0), i));
2622 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2623 resultElementTy, dynDims)
2630 auto genericOp = linalg::GenericOp::create(
2633 rewriter.
replaceOp(op, genericOp.getResult(0));
2638 &genericOp.getRegion(), genericOp.getRegion().end(),
2639 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2643 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2644 resultElementTy.isInteger(8)) {
2645 Value index = arith::IndexCastOp::create(
2648 index = arith::AddIOp::create(rewriter, loc, rewriter.
getIndexType(),
2652 linalg::YieldOp::create(rewriter, loc, extract);
2656 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2657 resultElementTy.isInteger(32)) {
2658 Value extend = arith::ExtSIOp::create(
2659 rewriter, loc, rewriter.
getI32Type(), inputValue);
2661 auto offset = arith::ConstantOp::create(
2663 auto seven = arith::ConstantOp::create(rewriter, loc,
2665 auto one = arith::ConstantOp::create(rewriter, loc,
2667 auto b1111111 = arith::ConstantOp::create(
2674 auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset);
2675 Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven);
2677 arith::AndIOp::create(rewriter, loc, extendAdd, b1111111);
2682 Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one);
2684 index = arith::IndexCastOp::create(rewriter, loc,
2686 indexPlusOne = arith::IndexCastOp::create(
2691 Value next = tensor::ExtractOp::create(rewriter, loc,
table,
2695 arith::ExtSIOp::create(rewriter, loc, rewriter.
getI32Type(), base);
2697 arith::ExtSIOp::create(rewriter, loc, rewriter.
getI32Type(), next);
2701 Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven);
2702 Value diff = arith::SubIOp::create(rewriter, loc, next, base);
2703 Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction);
2705 arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled);
2707 linalg::YieldOp::create(rewriter, loc, result);
2714 op,
"unable to create body for tosa.table op");
2721 static bool isRankedTensor(
Type type) {
return isa<RankedTensorType>(type); }
2729 auto divBy2 = builder.
createOrFold<arith::DivUIOp>(loc, value, two);
2730 auto plusOne = builder.
createOrFold<arith::AddIOp>(loc, divBy2, one);
2734 static RankedTensorType
2742 dims[2] = halfPlusOne(builder, loc, dims[2]);
2747 auto elementType = cast<RankedTensorType>(input.
getType()).getElementType();
2752 RankedTensorType type,
2755 tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
2756 auto fillValueAttr = rewriter.
getZeroAttr(type.getElementType());
2757 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
2759 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValue},
2762 return filledTensor;
2766 FloatType type,
Value value) {
2767 auto integerVal = arith::IndexCastUIOp::create(
2769 type.getIntOrFloatBitWidth() > 32 ? builder.
getI64Type()
2773 return arith::UIToFPOp::create(builder, loc, type, integerVal);
2777 FloatType type, int64_t index) {
2778 auto indexVal = linalg::IndexOp::create(builder, loc, index);
2779 return castIndexToFloat(builder, loc, type, indexVal);
2782 template <
typename... Args>
2788 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2790 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2791 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2793 "only supports ranked tensors");
2796 auto loc = rfft2d.getLoc();
2797 auto input = rfft2d.getInputReal();
2799 dyn_cast<FloatType>(cast<ShapedType>(input.
getType()).getElementType());
2802 "only supports float element types");
2806 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2810 utils::IteratorType::parallel, utils::IteratorType::parallel,
2811 utils::IteratorType::parallel, utils::IteratorType::reduction,
2812 utils::IteratorType::reduction};
2817 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2818 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2823 affineDimsExpr(rewriter, 0, 1, 2),
2824 affineDimsExpr(rewriter, 0, 1, 2)},
2828 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input, 1);
2829 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input, 2);
2832 auto twoPiAttr = rewriter.
getFloatAttr(elementType, 6.283185307179586);
2833 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2834 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2835 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2838 Value valReal = args[0];
2839 Value sumReal = args[1];
2840 Value sumImag = args[2];
2843 Value oy = linalg::IndexOp::create(builder, loc, 1);
2844 Value ox = linalg::IndexOp::create(builder, loc, 2);
2845 Value iy = linalg::IndexOp::create(builder, loc, 3);
2846 Value ix = linalg::IndexOp::create(builder, loc, 4);
2851 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2852 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2854 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2855 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2857 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2858 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2860 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2861 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2862 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2863 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2867 auto cosAngle = math::CosOp::create(builder, loc, angle);
2868 auto sinAngle = math::SinOp::create(builder, loc, angle);
2869 auto realComponent =
2870 arith::MulFOp::create(builder, loc, valReal, cosAngle);
2871 auto imagComponent =
2872 arith::MulFOp::create(builder, loc, valReal, sinAngle);
2877 arith::AddFOp::create(builder, loc, sumReal, realComponent);
2879 arith::SubFOp::create(builder, loc, sumImag, imagComponent);
2881 linalg::YieldOp::create(builder, loc,
ValueRange{outReal, outImag});
2885 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2886 indexingMaps, iteratorTypes, buildBody);
2895 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2897 if (!llvm::all_of(fft2d->getOperandTypes(),
2898 RFFT2dConverter::isRankedTensor) ||
2899 !llvm::all_of(fft2d->getResultTypes(),
2900 RFFT2dConverter::isRankedTensor)) {
2905 Value input_real = fft2d.getInputReal();
2906 Value input_imag = fft2d.getInputImag();
2907 BoolAttr inverse = fft2d.getInverseAttr();
2909 auto real_el_ty = cast<FloatType>(
2910 cast<ShapedType>(input_real.
getType()).getElementType());
2911 [[maybe_unused]]
auto imag_el_ty = cast<FloatType>(
2912 cast<ShapedType>(input_imag.
getType()).getElementType());
2914 assert(real_el_ty == imag_el_ty);
2929 utils::IteratorType::parallel, utils::IteratorType::parallel,
2930 utils::IteratorType::parallel, utils::IteratorType::reduction,
2931 utils::IteratorType::reduction};
2936 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2938 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2943 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2944 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2945 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2946 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2950 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 1);
2951 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 2);
2954 auto twoPiAttr = rewriter.
getFloatAttr(real_el_ty, 6.283185307179586);
2955 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2957 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2959 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2962 Value valReal = args[0];
2963 Value valImag = args[1];
2964 Value sumReal = args[2];
2965 Value sumImag = args[3];
2968 Value oy = linalg::IndexOp::create(builder, loc, 1);
2969 Value ox = linalg::IndexOp::create(builder, loc, 2);
2970 Value iy = linalg::IndexOp::create(builder, loc, 3);
2971 Value ix = linalg::IndexOp::create(builder, loc, 4);
2975 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2976 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2978 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2979 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2982 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2984 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2986 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2987 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2989 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2990 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2993 angle = arith::MulFOp::create(
2994 builder, loc, angle,
2995 arith::ConstantOp::create(rewriter, loc,
3001 auto cosAngle = math::CosOp::create(builder, loc, angle);
3002 auto sinAngle = math::SinOp::create(builder, loc, angle);
3004 auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle);
3005 auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle);
3006 auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin);
3008 auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle);
3009 auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle);
3011 auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin);
3016 arith::AddFOp::create(builder, loc, sumReal, realComponent);
3018 arith::AddFOp::create(builder, loc, sumImag, imagComponent);
3020 linalg::YieldOp::create(builder, loc,
ValueRange{outReal, outImag});
3024 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
3025 indexingMaps, iteratorTypes, buildBody);
3046 PointwiseConverter<tosa::AddOp>,
3047 PointwiseConverter<tosa::SubOp>,
3048 PointwiseConverter<tosa::MulOp>,
3049 PointwiseConverter<tosa::IntDivOp>,
3050 PointwiseConverter<tosa::NegateOp>,
3051 PointwiseConverter<tosa::PowOp>,
3052 PointwiseConverter<tosa::ReciprocalOp>,
3053 PointwiseConverter<tosa::RsqrtOp>,
3054 PointwiseConverter<tosa::LogOp>,
3055 PointwiseConverter<tosa::ExpOp>,
3056 PointwiseConverter<tosa::AbsOp>,
3057 PointwiseConverter<tosa::SinOp>,
3058 PointwiseConverter<tosa::CosOp>,
3059 PointwiseConverter<tosa::TanhOp>,
3060 PointwiseConverter<tosa::ErfOp>,
3061 PointwiseConverter<tosa::BitwiseAndOp>,
3062 PointwiseConverter<tosa::BitwiseOrOp>,
3063 PointwiseConverter<tosa::BitwiseNotOp>,
3064 PointwiseConverter<tosa::BitwiseXorOp>,
3065 PointwiseConverter<tosa::LogicalAndOp>,
3066 PointwiseConverter<tosa::LogicalNotOp>,
3067 PointwiseConverter<tosa::LogicalOrOp>,
3068 PointwiseConverter<tosa::LogicalXorOp>,
3069 PointwiseConverter<tosa::CastOp>,
3070 PointwiseConverter<tosa::LogicalLeftShiftOp>,
3071 PointwiseConverter<tosa::LogicalRightShiftOp>,
3072 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
3073 PointwiseConverter<tosa::ClzOp>,
3074 PointwiseConverter<tosa::SelectOp>,
3075 PointwiseConverter<tosa::GreaterOp>,
3076 PointwiseConverter<tosa::GreaterEqualOp>,
3077 PointwiseConverter<tosa::EqualOp>,
3078 PointwiseConverter<tosa::MaximumOp>,
3079 PointwiseConverter<tosa::MinimumOp>,
3080 PointwiseConverter<tosa::CeilOp>,
3081 PointwiseConverter<tosa::FloorOp>,
3082 PointwiseConverter<tosa::ClampOp>,
3083 PointwiseConverter<tosa::SigmoidOp>
3084 >(converter,
patterns->getContext());
3087 IdentityNConverter<tosa::IdentityOp>,
3088 ReduceConverter<tosa::ReduceAllOp>,
3089 ReduceConverter<tosa::ReduceAnyOp>,
3090 ReduceConverter<tosa::ReduceMinOp>,
3091 ReduceConverter<tosa::ReduceMaxOp>,
3092 ReduceConverter<tosa::ReduceSumOp>,
3093 ReduceConverter<tosa::ReduceProductOp>,
3101 TileConverter>(
patterns->getContext());
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static bool operandsAndResultsRanked(Operation *operation)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
BoolAttr getBoolAttr(bool value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
IntegerAttr getI8IntegerAttr(int8_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...