29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/Sequence.h"
32 #include <type_traits>
59 template <
typename OpTy>
67 auto nanMode = op.getNanMode();
68 if (nanMode == NanPropagationMode::PROPAGATE)
72 Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
73 arith::CmpFPredicate::UNO, lhs, lhs);
74 Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
75 arith::CmpFPredicate::UNO, rhs, rhs);
77 arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN, rhs, result);
78 return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN, lhs,
90 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
91 return math::AbsFOp::create(rewriter, loc, resultTypes, args);
93 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
94 auto zero = arith::ConstantOp::create(rewriter, loc,
96 auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]);
97 return arith::MaxSIOp::create(rewriter, loc, args[0], neg);
101 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
102 return arith::AddFOp::create(rewriter, loc, resultTypes, args);
104 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
105 return arith::AddIOp::create(rewriter, loc, resultTypes, args);
108 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
109 return arith::SubFOp::create(rewriter, loc, resultTypes, args);
111 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
112 return arith::SubIOp::create(rewriter, loc, resultTypes, args);
115 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
116 return arith::DivSIOp::create(rewriter, loc, resultTypes, args);
119 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
121 arith::ConstantOp::create(rewriter, loc,
FloatAttr::get(elementTy, 1));
122 return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
126 if (isa<tosa::MulOp>(op)) {
127 auto shiftVal = cast<tosa::MulOp>(op).getShift();
129 bool shiftIsConstant =
true;
132 shift = shiftElem.
getValues<IntegerAttr>()[0].getInt();
134 shiftIsConstant =
false;
136 if (isa<FloatType>(elementTy)) {
139 "Cannot have shift value for float");
142 return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
146 if (isa<IntegerType>(elementTy)) {
150 if (shift > 0 || !shiftIsConstant) {
157 a = arith::ExtSIOp::create(rewriter, loc, rewriter.
getI32Type(), a);
160 b = arith::ExtSIOp::create(rewriter, loc, rewriter.
getI32Type(), b);
162 auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
164 RoundingMode::SINGLE_ROUND);
166 tosa::ApplyScaleOp::create(rewriter, loc, rewriter.
getI32Type(), a,
167 b, shiftAmount, roundingAttr);
174 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
177 a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a);
179 b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], b);
181 return arith::MulIOp::create(rewriter, loc, resultTypes, a, b);
186 if (isa<tosa::NegateOp>(op)) {
187 auto negate = cast<tosa::NegateOp>(op);
189 int64_t inZp = 0, outZp = 0;
190 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
191 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
192 bool hasInZp = !
failed(maybeInZp);
193 bool hasOutZp = !
failed(maybeOutZp);
199 if (isa<FloatType>(elementTy))
200 return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
202 if (isa<IntegerType>(elementTy)) {
203 if (hasInZp && hasOutZp && !inZp && !outZp) {
204 auto constant = arith::ConstantOp::create(
206 return arith::SubIOp::create(rewriter, loc, resultTypes, constant,
211 Type intermediateType;
214 int intermediateBitWidth = 64;
216 if (hasInZp && hasOutZp) {
218 const int64_t zpAdd = inZp + outZp;
219 const int64_t maxValue =
220 APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
226 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
227 intermediateBitWidth = 16;
228 }
else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
229 intermediateBitWidth = 32;
230 }
else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
231 intermediateBitWidth = 48;
235 zpAddValue = rewriter.
create<arith::ConstantOp>(
240 rewriter.
create<arith::ExtSIOp>(loc, intermediateType, args[1]);
242 rewriter.
create<arith::ExtSIOp>(loc, intermediateType, args[2]);
244 rewriter.
create<arith::AddIOp>(loc, intermediateType, arg1, arg2);
250 arith::ExtSIOp::create(rewriter, loc, intermediateType, args[0]);
251 auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext);
255 rewriter, loc, intermediateType,
256 APInt::getSignedMinValue(inputBitWidth).getSExtValue());
258 rewriter, loc, intermediateType,
259 APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
263 return arith::TruncIOp::create(rewriter, loc, elementTy,
clamp);
268 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
269 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
272 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
273 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
276 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
278 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
279 auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr);
280 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes);
284 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
285 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
288 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
289 return arith::ShLIOp::create(rewriter, loc, resultTypes, args);
292 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
293 return arith::ShRUIOp::create(rewriter, loc, resultTypes, args);
296 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
297 auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args);
298 auto round = cast<BoolAttr>(op->
getAttr(
"round")).getValue();
304 auto one = arith::ConstantOp::create(rewriter, loc,
306 auto zero = arith::ConstantOp::create(rewriter, loc,
314 auto shiftValueGreaterThanZero = arith::CmpIOp::create(
315 rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero);
319 arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one);
321 arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract)
323 auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted,
326 arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
328 auto shouldRound = arith::SelectOp::create(
329 rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd, i1zero);
331 arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
332 return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended);
336 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
337 return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]);
341 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
342 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
345 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
346 auto one = arith::ConstantOp::create(rewriter, loc,
348 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one);
352 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
353 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
356 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
357 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
360 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
361 return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args);
364 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
365 return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args);
368 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
369 return mlir::math::LogOp::create(rewriter, loc, resultTypes, args);
372 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
373 return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args);
376 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
377 return mlir::math::SinOp::create(rewriter, loc, resultTypes, args);
380 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
381 return mlir::math::CosOp::create(rewriter, loc, resultTypes, args);
384 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
385 return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args);
388 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
389 return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args);
392 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
393 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT,
396 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
397 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt,
401 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
402 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE,
405 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
406 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
410 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
411 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ,
414 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
415 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
419 if (isa<tosa::SelectOp>(op)) {
421 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
422 return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]);
426 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
427 auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
429 rewriter, args[0], args[1],
max);
432 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
433 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
437 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
438 auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
440 rewriter, args[0], args[1],
min);
443 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
444 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
448 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
449 return math::CeilOp::create(rewriter, loc, resultTypes, args);
452 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
453 return math::FloorOp::create(rewriter, loc, resultTypes, args);
456 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
457 bool losesInfo =
false;
458 APFloat minApf = cast<FloatAttr>(op->
getAttr(
"min_val")).getValue();
459 APFloat maxApf = cast<FloatAttr>(op->
getAttr(
"max_val")).getValue();
460 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
461 APFloat::rmNearestTiesToEven, &losesInfo);
462 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
463 APFloat::rmNearestTiesToEven, &losesInfo);
464 auto min = arith::ConstantOp::create(
465 rewriter, loc, elementTy, rewriter.
getFloatAttr(elementTy, minApf));
466 auto max = arith::ConstantOp::create(
467 rewriter, loc, elementTy, rewriter.
getFloatAttr(elementTy, maxApf));
470 auto clampOp = llvm::cast<tosa::ClampOp>(op);
471 const auto nanMode = clampOp.getNanMode();
474 if (!isa<FloatType>(elementTy))
479 if (nanMode == NanPropagationMode::PROPAGATE)
493 Value isNaN = arith::CmpFOp::create(
494 rewriter, op->
getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
497 return arith::SelectOp::create(rewriter, op->
getLoc(), isNaN,
min, result);
500 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
501 auto intTy = cast<IntegerType>(elementTy);
503 cast<IntegerAttr>(op->
getAttr(
"min_val")).getValue().getSExtValue();
505 cast<IntegerAttr>(op->
getAttr(
"max_val")).getValue().getSExtValue();
509 if (intTy.isUnsignedInteger()) {
510 minRepresentable = 0;
511 if (intTy.getIntOrFloatBitWidth() <= 63) {
513 (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
516 }
else if (intTy.getIntOrFloatBitWidth() <= 64) {
518 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
520 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
531 intTy.getIntOrFloatBitWidth());
533 intTy.getIntOrFloatBitWidth());
535 intTy.isUnsignedInteger());
539 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
541 arith::ConstantOp::create(rewriter, loc,
FloatAttr::get(elementTy, 1));
542 auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
543 auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
544 auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
545 return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
549 if (isa<tosa::CastOp>(op)) {
550 Type srcTy = elementTy;
551 Type dstTy = resultTypes.front();
563 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
564 return arith::ExtFOp::create(rewriter, loc, resultTypes, args,
567 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
568 return arith::TruncFOp::create(rewriter, loc, resultTypes, args,
572 if (srcTy.
isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
573 return arith::UIToFPOp::create(rewriter, loc, resultTypes, args,
576 if (srcTy.
isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
577 return arith::ExtUIOp::create(rewriter, loc, resultTypes, args,
583 auto unrealizedCast =
584 UnrealizedConversionCastOp::create(
588 return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
593 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
594 return arith::SIToFPOp::create(rewriter, loc, resultTypes, args,
598 if (isa<FloatType>(srcTy) && dstTy.
isInteger(1)) {
599 Value zero = arith::ConstantOp::create(rewriter, loc,
601 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE,
605 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
606 auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]);
608 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
612 APFloat::semanticsMaxExponent(fltSemantics)) {
615 auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded);
616 auto posInf = arith::ConstantOp::create(
619 APFloat::getInf(fltSemantics)));
620 auto negInf = arith::ConstantOp::create(
624 APFloat::getInf(fltSemantics,
true)));
625 auto overflow = arith::CmpFOp::create(
626 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf);
627 auto underflow = arith::CmpFOp::create(
628 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf);
629 auto intMin = arith::ConstantOp::create(
634 auto intMax = arith::ConstantOp::create(
640 arith::SelectOp::create(rewriter, loc, overflow, intMax, conv);
641 return arith::SelectOp::create(rewriter, loc, underflow, intMin,
645 auto intMinFP = arith::ConstantOp::create(
653 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
659 auto intMaxFP = arith::ConstantOp::create(
668 return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped);
675 auto intMaxPlusOneFP = arith::ConstantOp::create(
684 auto intMax = arith::ConstantOp::create(
690 arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP);
692 arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP);
693 auto overflow = arith::CmpFOp::create(
694 rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
695 return arith::SelectOp::create(rewriter, loc, overflow, intMax,
701 if (isa<IntegerType>(srcTy) && dstTy.
isInteger(1)) {
704 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne,
708 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
709 return arith::ExtSIOp::create(rewriter, loc, resultTypes, args,
712 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
713 return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]);
718 op,
"unhandled op for linalg body calculation for elementwise op");
729 auto [it, inserted] = indexPool.try_emplace(index);
732 arith::ConstantOp::create(rewriter, loc, rewriter.
getIndexAttr(index));
738 auto indexValue =
createIndex(rewriter, loc, indexPool, index);
739 return tensor::DimOp::create(rewriter, loc, tensor, indexValue).getResult();
745 auto shapedType = dyn_cast<ShapedType>(tensor.
getType());
746 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
747 assert(index >= 0 && index < shapedType.getRank() &&
"index out of bounds");
748 if (shapedType.isDynamicDim(index))
749 return getTensorDim(rewriter, loc, indexPool, tensor, index);
750 return rewriter.
getIndexAttr(shapedType.getDimSize(index));
754 auto isRanked = [](
Value value) {
755 return isa<RankedTensorType>(value.getType());
757 return llvm::all_of(operation->
getOperands(), isRanked) &&
758 llvm::all_of(operation->
getResults(), isRanked);
771 static std::pair<OpFoldResult, Value>
777 for (
auto operand : operands) {
778 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
779 if (ShapedType::isStatic(size) && size > 1)
784 auto operandsWithDynamicDim =
785 llvm::filter_to_vector(operands, [&](
Value operand) {
786 return cast<RankedTensorType>(operand.
getType()).isDynamicDim(dim);
790 if (operandsWithDynamicDim.empty())
797 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
798 if (operandsWithDynamicDim.size() == 1)
799 return {targetSize, operandsWithDynamicDim[0]};
802 for (
size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
804 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
805 targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize);
807 return {targetSize,
nullptr};
815 assert(!operands.empty());
816 auto rank = cast<RankedTensorType>(operands.front().
getType()).getRank();
819 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
820 auto [targetSize, masterOperand] =
822 targetShape.push_back(targetSize);
823 masterOperands.push_back(masterOperand);
825 return {targetShape, masterOperands};
831 Value masterOperand) {
833 auto rankedTensorType = cast<RankedTensorType>(operand.
getType());
834 if (!rankedTensorType.isDynamicDim(dim))
841 if (operand == masterOperand)
845 auto rank = rankedTensorType.getRank();
847 for (
auto index : llvm::seq<int64_t>(0, rank)) {
850 affineExprs.push_back(affineExpr);
852 auto broadcastAffineMap =
858 auto one =
createIndex(rewriter, loc, indexPool, 1);
859 auto runtimeSize =
getTensorDim(rewriter, loc, indexPool, operand, dim);
860 auto broadcastNecessary = arith::CmpIOp::create(
861 rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one);
871 for (
auto index : llvm::seq<int64_t>(0, rank)) {
872 auto size = index == dim ? targetSize
875 outputTensorShape.push_back(size);
877 Value outputTensor = tensor::EmptyOp::create(
878 opBuilder, loc, outputTensorShape, rankedTensorType.getElementType());
882 linalg::GenericOp::create(
883 opBuilder, loc, outputTensor.
getType(), operand, outputTensor,
887 linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
892 auto castResultTensor = rewriter.
createOrFold<tensor::CastOp>(
893 loc, operand.
getType(), resultTensor);
896 scf::YieldOp::create(opBuilder, loc, castResultTensor);
901 scf::YieldOp::create(opBuilder, loc, operand);
905 auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary,
906 emitThenRegion, emitElseRegion);
907 return ifOp.getResult(0);
914 int64_t rank = cast<RankedTensorType>(operand.
getType()).getRank();
915 assert((int64_t)targetShape.size() == rank);
916 assert((int64_t)masterOperands.size() == rank);
917 for (
auto index : llvm::seq<int64_t>(0, rank))
920 targetShape[index], masterOperands[index]);
930 if (operands.size() == 1)
934 bool hasDynamic =
false;
935 for (
auto op : operands) {
936 const auto tType = dyn_cast<RankedTensorType>(op.getType());
937 if (tType && !tType.hasStaticShape()) {
946 return llvm::map_to_vector(operands, [&](
Value operand) {
948 targetShape, masterOperands);
958 auto resultType = cast_or_null<RankedTensorType>(
963 Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape,
964 resultType.getElementType());
969 auto rank = resultType.getRank();
970 auto affineMaps = llvm::map_to_vector(operands, [&](
Value operand) {
971 auto shape = cast<ShapedType>(operand.
getType()).getShape();
977 bool requiresBroadcast =
978 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
979 auto affineExpr = requiresBroadcast
982 affineExprs.push_back(affineExpr);
989 bool encounteredError =
false;
990 auto linalgOp = linalg::GenericOp::create(
991 rewriter, loc, outputTensor.
getType(), operands, outputTensor, affineMaps,
996 {resultType.getElementType()}, rewriter);
998 encounteredError =
true;
1001 linalg::YieldOp::create(opBuilder, loc, opResult);
1003 if (encounteredError)
1005 operation,
"unable to create linalg.generic body for elementwise op");
1008 auto castResult = rewriter.
createOrFold<tensor::CastOp>(
1009 loc, resultType, linalgOp->getResult(0));
1010 rewriter.
replaceOp(operation, castResult);
1017 if (isa<tosa::MulOp>(operation)) {
1021 return operands.take_front(2);
1023 return operands.take_front(3);
1025 if (
auto negate = dyn_cast<tosa::NegateOp>(operation)) {
1026 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
1027 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
1031 return operands.take_front(1);
1036 static LogicalResult
1042 assert(operation->
getNumResults() == 1 &&
"elementwise op expects 1 result");
1044 "elementwise op expects at least 1 operand");
1047 "Unranked tensors not supported");
1051 auto loc = operation->
getLoc();
1053 auto [targetShape, masterOperands] =
1055 auto broadcastOperands =
1057 targetShape, masterOperands);
1059 targetShape, converter);
1066 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1069 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1072 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1075 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1078 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1080 elementTy, APFloat::getLargest(
1081 cast<FloatType>(elementTy).getFloatSemantics(),
false));
1083 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1087 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1089 elementTy, APFloat::getLargest(
1090 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1092 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1096 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1099 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1102 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1104 elementTy, APFloat::getLargest(
1105 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1107 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1121 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1122 return arith::AddFOp::create(rewriter, loc, args);
1125 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1126 return arith::AddIOp::create(rewriter, loc, args);
1129 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1130 return arith::MulFOp::create(rewriter, loc, args);
1133 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1134 return arith::MulIOp::create(rewriter, loc, args);
1137 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1138 return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
1141 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1142 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
1145 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1146 return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
1149 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1150 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
1153 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1154 return arith::AndIOp::create(rewriter, loc, args);
1156 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1157 return arith::OrIOp::create(rewriter, loc, args);
1165 template <
typename OpTy>
1168 auto loc = op->getLoc();
1169 auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1170 auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1171 if (!inputTy || !resultTy)
1174 auto elementTy = resultTy.getElementType();
1175 Value input = op->getOperand(0);
1178 bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
1179 isa<FloatType>(elementTy) &&
1180 cast<FloatType>(elementTy).isBF16();
1185 for (
unsigned i = 0; i < inputTy.getRank(); i++) {
1187 reduceShape.push_back(inputTy.getDimSize(i));
1188 if (inputTy.isDynamicDim(i))
1189 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1194 inputs.push_back(input);
1198 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1204 op,
"No initial value found for reduction operation");
1206 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
1208 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValue},
1211 outputs.push_back(filledTensor);
1213 bool isNanIgnoreMode =
false;
1214 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1215 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1217 if (isa<FloatType>(elementTy) &&
1218 op.getNanMode() == NanPropagationMode::IGNORE) {
1219 isNanIgnoreMode =
true;
1225 auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
1226 auto emptyBoolTensor =
1227 tensor::EmptyOp::create(rewriter, loc, reduceShape,
1228 trueValue.getType(), dynDims)
1230 auto allResultsNaNTensor =
1231 linalg::FillOp::create(rewriter, loc,
ValueRange{trueValue},
1243 inputs.push_back(input);
1244 outputs.push_back(allResultsNaNTensor);
1248 bool didEncounterError =
false;
1249 linalg::LinalgOp linalgOp = linalg::ReduceOp::create(
1250 rewriter, loc, inputs, outputs, axis,
1252 std::array<Value, 2> binaryArgs{
1253 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1256 if (binaryArgs[0].
getType() != accTy)
1257 binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
1263 didEncounterError =
true;
1266 if (isNanIgnoreMode) {
1267 auto inputValue = blockArgs[0];
1268 auto initialValue = blockArgs[2];
1269 auto oldAllResultsNanFlagValue = blockArgs[3];
1272 Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(),
1273 arith::CmpFPredicate::UNO,
1274 inputValue, inputValue);
1276 auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(),
1277 isNaN, initialValue, result);
1280 auto newAllResultsNanFlagValue = arith::AndIOp::create(
1281 nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1282 resultsToYield.push_back(selectOp);
1283 resultsToYield.push_back(newAllResultsNanFlagValue);
1285 resultsToYield.push_back(result);
1287 linalg::YieldOp::create(nestedBuilder, loc, resultsToYield);
1290 if (!didEncounterError)
1292 op,
"unable to create linalg.generic body for reduce op");
1294 if (isNanIgnoreMode) {
1303 APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(),
false));
1304 auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
1305 auto emptyNanTensor =
1306 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1308 auto nanFilledTensor =
1309 linalg::FillOp::create(rewriter, loc,
ValueRange{nanValue},
1315 auto finalEmptyTensor =
1316 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1322 ins.push_back(linalgOp->getOpResult(1));
1323 ins.push_back(nanFilledTensor);
1324 ins.push_back(linalgOp->getResult(0));
1325 outs.push_back(finalEmptyTensor);
1327 linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs);
1328 linalgOp = linalgSelect;
1332 Value reducedRes = linalgOp->getResult(0);
1335 tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
1338 const unsigned reducedRank =
1339 cast<ShapedType>(reducedRes.
getType()).getRank();
1342 linalg::GenericOp::create(
1348 Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc,
1349 elementTy, args[0]);
1350 linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf);
1356 uint64_t expandInputRank = cast<ShapedType>(reducedRes.
getType()).getRank();
1357 reassociationMap.resize(expandInputRank);
1359 for (uint64_t i = 0; i < expandInputRank; i++) {
1360 int32_t dimToPush = i > axis ? i + 1 : i;
1364 if (expandInputRank != 0) {
1365 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1366 reassociationMap[expandedDim].push_back(
1381 template <
typename SrcOp>
1388 matchAndRewrite(SrcOp op, OpAdaptor operands,
1391 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1399 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1401 auto loc = op.getLoc();
1402 auto input = op.getInput();
1403 auto inputTy = cast<ShapedType>(op.getInput().getType());
1404 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1405 unsigned rank = inputTy.getRank();
1408 if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND)
1409 return rewriter.notifyMatchFailure(
1410 op,
"tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1411 "currently supported");
1412 if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32())
1413 return rewriter.notifyMatchFailure(
1414 op,
"tosa.rescale requires scale32 for double_round to be true");
1416 if (!isa<IntegerType>(inputTy.getElementType()))
1417 return rewriter.notifyMatchFailure(op,
"only support integer type");
1420 for (
int i = 0; i < outputTy.getRank(); i++) {
1421 if (outputTy.isDynamicDim(i)) {
1422 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1429 return rewriter.notifyMatchFailure(
1430 op,
"tosa.rescale requires constant shift input values");
1434 return rewriter.notifyMatchFailure(
1435 op,
"tosa.rescale requires constant multiplier input values");
1438 llvm::to_vector(shiftElems.
getValues<int8_t>());
1441 llvm::map_range(multiplierElems.
getValues<IntegerAttr>(),
1442 [](IntegerAttr attr) -> int32_t {
1443 return static_cast<int32_t>(attr.getInt());
1447 for (
int i = 0, s = multiplierValues.size(); i < s; i++) {
1448 if (shiftValues[i] > 63) {
1450 multiplierValues[i] = 0;
1458 op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1459 llvm::any_of(shiftValues, [](int32_t v) {
return v > 31; });
1460 RoundingMode roundingMode =
1461 doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
1464 rewriter.getMultiDimIdentityMap(rank)};
1469 Value multiplierConstant;
1470 int64_t multiplierArg = 0;
1471 if (multiplierValues.size() == 1) {
1472 multiplierConstant = arith::ConstantOp::create(
1473 rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1476 rewriter.getAffineDimExpr(rank - 1)};
1477 auto multiplierType =
1479 rewriter.getI32Type());
1480 genericInputs.push_back(arith::ConstantOp::create(
1486 rewriter.getContext()));
1488 multiplierArg = indexingMaps.size() - 1;
1493 Value shiftConstant;
1494 int64_t shiftArg = 0;
1495 if (shiftValues.size() == 1) {
1496 shiftConstant = arith::ConstantOp::create(
1497 rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1500 rewriter.getAffineDimExpr(rank - 1)};
1503 rewriter.getIntegerType(8));
1504 genericInputs.push_back(arith::ConstantOp::create(
1508 rewriter.getContext()));
1509 shiftArg = indexingMaps.size() - 1;
1513 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1516 Value emptyTensor = tensor::EmptyOp::create(
1517 rewriter, loc, outputTy.getShape(), outputTy.getElementType(),
1520 auto linalgOp = linalg::GenericOp::create(
1521 rewriter, loc, outputTy, genericInputs,
ValueRange{emptyTensor},
1525 Value value = blockArgs[0];
1528 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1530 (void)rewriter.notifyMatchFailure(
1531 op,
"input zero point cannot be statically determined");
1537 const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
1538 auto inputZp = arith::ConstantOp::create(
1543 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1545 (void)rewriter.notifyMatchFailure(
1546 op,
"output zero point cannot be statically determined");
1550 IntegerType outIntType =
1551 cast<IntegerType>(blockArgs.back().getType());
1552 unsigned outBitWidth = outIntType.getWidth();
1553 const int32_t outAttrBitwidth = 32;
1554 assert(outBitWidth <= 32 &&
"Unexpected output zeropoint bitwidth");
1555 auto outputZp = arith::ConstantOp::create(
1560 Value multiplier = multiplierConstant ? multiplierConstant
1561 : blockArgs[multiplierArg];
1562 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1565 value = UnrealizedConversionCastOp::create(
1566 nestedBuilder, nestedLoc,
1567 nestedBuilder.getIntegerType(
1573 if (op.getInputUnsigned()) {
1574 value = arith::ExtUIOp::create(nestedBuilder, nestedLoc,
1575 nestedBuilder.getI32Type(), value);
1577 value = arith::ExtSIOp::create(nestedBuilder, nestedLoc,
1578 nestedBuilder.getI32Type(), value);
1583 arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp);
1585 value = tosa::ApplyScaleOp::create(nestedBuilder, loc,
1586 nestedBuilder.getI32Type(), value,
1587 multiplier, shift, roundingMode);
1591 arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp);
1594 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1595 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1598 if (op.getOutputUnsigned()) {
1600 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1603 auto intMinVal = arith::ConstantOp::create(
1604 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin));
1605 auto intMaxVal = arith::ConstantOp::create(
1606 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax));
1609 nestedBuilder,
false);
1611 if (outIntType.getWidth() < 32) {
1612 value = arith::TruncIOp::create(
1613 nestedBuilder, nestedLoc,
1614 rewriter.getIntegerType(outIntType.getWidth()), value);
1617 if (outIntType.isUnsignedInteger()) {
1618 value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
1622 linalg::YieldOp::create(nestedBuilder, loc, value);
1625 rewriter.replaceOp(op, linalgOp->getResults());
1637 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1641 auto input = op.getInput();
1642 auto inputTy = cast<RankedTensorType>(input.getType());
1643 auto resultTy = cast<RankedTensorType>(op.getType());
1644 const bool isBilinear = op.getMode() == ResizeMode::BILINEAR;
1646 auto inputH = inputTy.getDimSize(1);
1647 auto inputW = inputTy.getDimSize(2);
1648 auto outputH = resultTy.getDimSize(1);
1649 auto outputW = resultTy.getDimSize(2);
1651 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1652 return rewriter.notifyMatchFailure(
1653 op,
"tosa.resize is not a pure 1x1->1x1 image operation");
1655 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1656 op.getMode() != ResizeMode::BILINEAR)
1657 return rewriter.notifyMatchFailure(
1658 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1660 if (inputTy == resultTy) {
1661 rewriter.replaceOp(op, input);
1672 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1673 reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1674 reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1675 reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1679 inputTy.getElementType());
1680 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input,
1685 if (inputTy.isDynamicDim(0))
1686 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1687 if (inputTy.isDynamicDim(3))
1688 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1691 auto genericTy = collapseTy.clone(resultTy.getElementType());
1693 tensor::EmptyOp::create(builder, genericTy.getShape(),
1694 resultTy.getElementType(), outputDynSize);
1695 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1697 utils::IteratorType::parallel);
1699 auto generic = linalg::GenericOp::create(
1703 Value value = args[0];
1705 if (inputTy.getElementType() != resultTy.getElementType()) {
1706 value = arith::ExtSIOp::create(b, loc, resultTy.getElementType(),
1709 if (isBilinear && scale[0] != 0) {
1710 Value scaleY = arith::ConstantOp::create(
1711 b, loc, b.getI32IntegerAttr(scale[0]));
1712 value = arith::MulIOp::create(b, loc, value, scaleY);
1715 if (isBilinear && scale[2] != 0) {
1716 Value scaleX = arith::ConstantOp::create(
1717 b, loc, b.getI32IntegerAttr(scale[2]));
1718 value = arith::MulIOp::create(b, loc, value, scaleX);
1722 linalg::YieldOp::create(b, loc, value);
1725 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1726 op, resultTy,
generic.getResults()[0], reassociationMap);
1738 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1742 auto input = op.getInput();
1743 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1744 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1746 if (!inputTy || !resultTy)
1747 return rewriter.notifyMatchFailure(op,
1748 "requires ranked input/output types");
1750 auto batch = inputTy.getDimSize(0);
1751 auto channels = inputTy.getDimSize(3);
1752 auto inputH = inputTy.getDimSize(1);
1753 auto inputW = inputTy.getDimSize(2);
1754 auto outputH = resultTy.getDimSize(1);
1755 auto outputW = resultTy.getDimSize(2);
1757 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1758 return rewriter.notifyMatchFailure(
1759 op,
"tosa.resize has no broadcasting behavior");
1764 resizeShape.push_back(batch);
1765 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1766 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1767 resizeShape.push_back(channels);
1769 auto resizeTy = resultTy.clone(resizeShape);
1771 tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(),
1772 op.getOffset(), op.getBorder(), op.getMode());
1776 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1777 reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1779 reassociationMap.push_back({});
1780 reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1782 reassociationMap.push_back({});
1783 reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1787 collapseShape.push_back(outputH);
1789 collapseShape.push_back(outputW);
1790 collapseShape.push_back(channels);
1792 auto collapseTy = resultTy.clone(collapseShape);
1793 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy,
1794 resize, reassociationMap);
1798 if (inputTy.isDynamicDim(0))
1799 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1800 if (inputTy.isDynamicDim(3))
1801 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1804 utils::IteratorType::parallel);
1805 Value empty = tensor::EmptyOp::create(
1806 builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1810 inputExprs.push_back(rewriter.getAffineDimExpr(1));
1812 inputExprs.push_back(rewriter.getAffineDimExpr(2));
1813 inputExprs.push_back(rewriter.getAffineDimExpr(3));
1816 inputExprs, rewriter.getContext());
1818 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1819 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1823 Value value = args[0];
1824 linalg::YieldOp::create(b, loc, value);
1835 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1839 auto input = op.getInput();
1840 auto inputTy = cast<ShapedType>(input.getType());
1841 auto resultTy = cast<ShapedType>(op.getType());
1842 auto resultETy = resultTy.getElementType();
1844 bool floatingPointMode = isa<FloatType>(resultETy);
1845 auto floatTy = resultETy;
1847 auto imageH = inputTy.getShape()[1];
1848 auto imageW = inputTy.getShape()[2];
1850 auto dynamicDimsOr =
1852 if (!dynamicDimsOr.has_value())
1853 return rewriter.notifyMatchFailure(
1854 op,
"unable to get dynamic dimensions of tosa.resize");
1856 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1857 op.getMode() != ResizeMode::BILINEAR)
1858 return rewriter.notifyMatchFailure(
1859 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1862 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1863 auto emptyTensor = tensor::EmptyOp::create(b, resultTy.getShape(),
1864 resultETy, *dynamicDimsOr);
1865 auto genericOp = linalg::GenericOp::create(
1868 Value resize = genericOp.getResult(0);
1872 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1874 Value batch = linalg::IndexOp::create(b, 0);
1875 Value y = linalg::IndexOp::create(b, 1);
1876 Value x = linalg::IndexOp::create(b, 2);
1877 Value channel = linalg::IndexOp::create(b, 3);
1880 arith::ConstantOp::create(b, b.getZeroAttr(b.getI32Type()));
1881 Value zeroFp = arith::ConstantOp::create(b, b.getZeroAttr(floatTy));
1883 arith::ConstantOp::create(b, b.getI32IntegerAttr(imageH - 1));
1885 arith::ConstantOp::create(b, b.getI32IntegerAttr(imageW - 1));
1887 Value inY = arith::IndexCastOp::create(b, b.getI32Type(), y);
1888 Value inX = arith::IndexCastOp::create(b, b.getI32Type(), x);
1894 return rewriter.notifyMatchFailure(
1895 op,
"tosa.resize scale/offset/border should have compile time "
1896 "constant values.");
1899 Value yScaleN, yScaleD, xScaleN, xScaleD;
1900 yScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[0]));
1901 yScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[1]));
1902 xScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[2]));
1903 xScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[3]));
1905 Value yOffset, xOffset, yBorder, xBorder;
1906 yOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[0]));
1907 xOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[1]));
1908 yBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[0]));
1909 xBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[1]));
1922 Value val = arith::MulIOp::create(b, in, scaleD);
1923 val = arith::AddIOp::create(b, val, offset);
1924 index = arith::FloorDivSIOp::create(b, val, scaleN);
1928 Value r = arith::RemSIOp::create(b, val, scaleN);
1929 Value rFp = arith::SIToFPOp::create(b, floatTy, r);
1930 Value scaleNfp = arith::UIToFPOp::create(b, floatTy, scaleN);
1931 delta = arith::DivFOp::create(b, rFp, scaleNfp);
1946 Value val = arith::MulIOp::create(b, in, scaleD);
1947 val = arith::AddIOp::create(b, val, offset);
1948 index = arith::DivSIOp::create(b, val, scaleN);
1949 delta = arith::MulIOp::create(b, index, scaleN);
1950 delta = arith::SubIOp::create(b, val, delta);
1953 Value ix, iy, dx, dy;
1954 if (floatingPointMode) {
1955 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1956 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1958 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1959 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1962 if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) {
1963 auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
1965 auto getNearestIndexAndClamp = [&](
Value val,
Value dval,
Value scale,
1973 if (floatingPointMode) {
1975 arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 0.5f));
1976 pred = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, dval, h);
1978 Value dvalDouble = arith::ShLIOp::create(b, dval, one);
1979 pred = arith::CmpIOp::create(b, arith::CmpIPredicate::sge,
1983 auto offset = arith::SelectOp::create(b, pred, one, zeroI32);
1984 val = arith::AddIOp::create(b, val, offset);
1986 return arith::IndexCastOp::create(b, b.getIndexType(), val);
1989 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1990 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1992 Value result = tensor::ExtractOp::create(
1993 b, input,
ValueRange{batch, iy, ix, channel});
1995 linalg::YieldOp::create(b, result);
1998 assert(op.getMode() == ResizeMode::BILINEAR);
2000 auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
2002 auto getClampedIdxs = [&](
Value &val0,
Value &val1,
int size,
Value in,
2005 val1 = arith::AddIOp::create(b, val0, oneVal);
2010 val0 = arith::IndexCastOp::create(b, b.getIndexType(), val0);
2011 val1 = arith::IndexCastOp::create(b, b.getIndexType(), val1);
2019 Value x0, x1, y0, y1;
2020 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
2021 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
2023 Value y0x0 = tensor::ExtractOp::create(
2024 b, input,
ValueRange{batch, y0, x0, channel});
2025 Value y0x1 = tensor::ExtractOp::create(
2026 b, input,
ValueRange{batch, y0, x1, channel});
2027 Value y1x0 = tensor::ExtractOp::create(
2028 b, input,
ValueRange{batch, y1, x0, channel});
2029 Value y1x1 = tensor::ExtractOp::create(
2030 b, input,
ValueRange{batch, y1, x1, channel});
2032 if (floatingPointMode) {
2034 arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 1.0f));
2040 Value oneMinusDelta = arith::SubFOp::create(b, oneVal, delta);
2041 Value mul0 = arith::MulFOp::create(b, val0, oneMinusDelta);
2042 Value mul1 = arith::MulFOp::create(b, val1, delta);
2043 return arith::AddFOp::create(b, mul0, mul1);
2049 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
2054 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
2058 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
2059 linalg::YieldOp::create(b, result);
2062 y0x0 = arith::ExtSIOp::create(b, resultETy, y0x0);
2063 y0x1 = arith::ExtSIOp::create(b, resultETy, y0x1);
2064 y1x0 = arith::ExtSIOp::create(b, resultETy, y1x0);
2065 y1x1 = arith::ExtSIOp::create(b, resultETy, y1x1);
2068 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
2069 dx = arith::ExtSIOp::create(b, resultETy, dx);
2070 dy = arith::ExtSIOp::create(b, resultETy, dy);
2073 Value yScaleNExt = yScaleN;
2074 Value xScaleNExt = xScaleN;
2076 const int64_t scaleBitwidth =
2078 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2079 yScaleNExt = arith::ExtSIOp::create(b, resultETy, yScaleN);
2080 xScaleNExt = arith::ExtSIOp::create(b, resultETy, xScaleN);
2084 Value scale,
int inputSize,
2087 return arith::MulIOp::create(b, val0, scale);
2088 Value weight0 = arith::SubIOp::create(b, scale, weight1);
2089 Value mul0 = arith::MulIOp::create(b, val0, weight0);
2090 Value mul1 = arith::MulIOp::create(b, val1, weight1);
2091 return arith::AddIOp::create(b, mul0, mul1);
2094 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
2095 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
2097 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
2098 linalg::YieldOp::create(b, result);
2103 rewriter.replaceOp(op, resize);
2111 template <
typename SrcOp>
2116 LogicalResult matchAndRewrite(SrcOp op,
2118 rewriter.replaceOp(op, op.getOperation()->getOperands());
2123 template <
typename SrcOp>
2128 LogicalResult matchAndRewrite(SrcOp reduceOp,
2138 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2140 auto loc = op.getLoc();
2141 Value input = op.getInput1();
2142 auto inputTy = cast<ShapedType>(input.
getType());
2143 auto resultTy = cast<ShapedType>(op.getType());
2144 auto axis = op.getAxis();
2147 for (
int i = 0; i < inputTy.getRank(); i++) {
2148 if (inputTy.isDynamicDim(i)) {
2149 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2153 Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
2156 auto emptyTensor = tensor::EmptyOp::create(
2157 rewriter, loc, inputTy.getShape(),
2161 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2163 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2168 for (
unsigned int i = 0; i < inputTy.getRank(); i++) {
2170 linalg::IndexOp::create(rewriter, nestedLoc, i).getResult();
2174 arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one);
2175 index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne,
2179 indices.push_back(index);
2182 auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc,
2184 linalg::YieldOp::create(nestedBuilder, op.getLoc(),
2185 extract.getResult());
2199 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2201 auto loc = op.getLoc();
2202 auto input = op.getInput1();
2203 auto inputTy = cast<ShapedType>(input.
getType());
2204 auto inputShape = inputTy.getShape();
2205 auto resultTy = cast<ShapedType>(op.getType());
2206 auto elementTy = inputTy.getElementType();
2207 int64_t rank = inputTy.getRank();
2210 if (
failed(op.getConstantMultiples(multiples)))
2215 for (
int i = 0; i < rank; i++) {
2216 int64_t dim = multiples[i];
2217 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2218 genericShape.push_back(inputShape[i]);
2222 for (
int i = 0; i < inputTy.getRank(); i++) {
2223 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2224 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2228 auto emptyTensor = tensor::EmptyOp::create(
2229 rewriter, op.getLoc(), genericShape, elementTy, dynDims);
2233 dimExprs.reserve(rank);
2234 for (
unsigned i = 0; i < rank; ++i)
2237 auto readAffineMap =
2244 auto genericOp = linalg::GenericOp::create(
2249 linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin());
2255 op, resultTy, genericOp.getResult(0), shapeValue);
2277 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2279 auto loc = argmaxOp.getLoc();
2280 Value input = argmaxOp.getInput();
2281 auto inputTy = cast<ShapedType>(input.
getType());
2282 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2283 auto inElementTy = inputTy.getElementType();
2284 auto outElementTy = resultTy.getElementType();
2285 int axis = argmaxOp.getAxis();
2288 if (!isa<IntegerType>(outElementTy))
2291 "tosa.arg_max to linalg.* requires integer-like result type");
2294 for (
int i = 0; i < inputTy.getRank(); i++) {
2295 if (inputTy.isDynamicDim(i) && i != axis) {
2296 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2301 auto emptyTensorIdx =
2302 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2303 outElementTy, dynDims)
2305 auto fillValueIdx = arith::ConstantOp::create(
2307 auto filledTensorIdx =
2308 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValueIdx},
2313 auto emptyTensorMax =
2314 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
2317 auto fillValueMaxAttr =
2320 if (!fillValueMaxAttr)
2322 argmaxOp,
"unsupported tosa.argmax element type");
2325 arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
2326 auto filledTensorMax =
2327 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValueMax},
2334 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2335 iteratorTypes[axis] = utils::IteratorType::reduction;
2339 for (
int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2345 bool didEncounterError =
false;
2348 auto linalgOp = linalg::GenericOp::create(
2350 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2353 auto newValue = blockArgs[0];
2354 auto oldIndex = blockArgs[1];
2355 auto oldValue = blockArgs[2];
2357 Value newIndex = arith::IndexCastOp::create(
2358 rewriter, nestedLoc, oldIndex.getType(),
2359 linalg::IndexOp::create(rewriter, loc, axis));
2362 if (isa<FloatType>(inElementTy)) {
2363 if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) {
2366 predicate = arith::CmpFOp::create(rewriter, nestedLoc,
2367 arith::CmpFPredicate::OGT,
2368 newValue, oldValue);
2373 Value gt = arith::CmpFOp::create(rewriter, nestedLoc,
2374 arith::CmpFPredicate::UGT,
2375 newValue, oldValue);
2376 Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc,
2377 arith::CmpFPredicate::ORD,
2378 oldValue, oldValue);
2379 predicate = arith::AndIOp::create(
2380 rewriter, nestedLoc, rewriter.
getI1Type(), gt, oldNonNaN);
2382 }
else if (isa<IntegerType>(inElementTy)) {
2383 predicate = arith::CmpIOp::create(rewriter, nestedLoc,
2384 arith::CmpIPredicate::sgt,
2385 newValue, oldValue);
2387 didEncounterError =
true;
2391 auto resultMax = arith::SelectOp::create(
2392 rewriter, nestedLoc, predicate, newValue, oldValue);
2393 auto resultIndex = arith::SelectOp::create(
2394 rewriter, nestedLoc, predicate, newIndex, oldIndex);
2395 linalg::YieldOp::create(nestedBuilder, nestedLoc,
2399 if (didEncounterError)
2401 argmaxOp,
"unsupported tosa.argmax element type");
2403 rewriter.
replaceOp(argmaxOp, linalgOp.getResult(0));
2412 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2414 auto input = adaptor.getOperands()[0];
2415 auto indices = adaptor.getOperands()[1];
2417 auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2418 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2419 if (!valuesTy || !resultTy)
2422 auto dynamicDims = inferDynamicDimsForGather(
2423 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2425 auto resultElementTy = resultTy.getElementType();
2427 auto loc = op.getLoc();
2429 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2430 resultElementTy, dynamicDims)
2435 resultTy.getRank(), 0,
2436 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2440 auto genericOp = linalg::GenericOp::create(
2445 auto indexValue = args[0];
2446 auto index0 = linalg::IndexOp::create(rewriter, loc, 0);
2447 Value index1 = arith::IndexCastOp::create(
2449 auto index2 = linalg::IndexOp::create(rewriter, loc, 2);
2450 Value extract = tensor::ExtractOp::create(
2451 rewriter, loc, input,
ValueRange{index0, index1, index2});
2452 linalg::YieldOp::create(rewriter, loc, extract);
2454 rewriter.
replaceOp(op, genericOp.getResult(0));
2464 auto addDynamicDimension = [&](
Value source, int64_t dim) {
2466 if (
auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2467 results.push_back(dimValue);
2470 addDynamicDimension(values, 0);
2471 addDynamicDimension(indices, 1);
2472 addDynamicDimension(values, 2);
2484 LogicalResult matchAndRewrite(tosa::TableOp op,
2486 auto loc = op.getLoc();
2487 Value input = op.getInput1();
2489 auto inputTy = cast<ShapedType>(input.
getType());
2490 auto tableTy = cast<ShapedType>(
table.getType());
2491 auto resultTy = cast<ShapedType>(op.getType());
2493 auto inputElementTy = inputTy.getElementType();
2494 auto tableElementTy = tableTy.getElementType();
2495 auto resultElementTy = resultTy.getElementType();
2498 for (
int i = 0; i < resultTy.getRank(); ++i) {
2499 if (inputTy.isDynamicDim(i)) {
2501 tensor::DimOp::create(rewriter, loc, op.getOperand(0), i));
2506 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2507 resultElementTy, dynDims)
2514 auto genericOp = linalg::GenericOp::create(
2517 rewriter.
replaceOp(op, genericOp.getResult(0));
2522 &genericOp.getRegion(), genericOp.getRegion().end(),
2523 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2527 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2528 resultElementTy.isInteger(8)) {
2529 Value index = arith::IndexCastOp::create(
2532 index = arith::AddIOp::create(rewriter, loc, rewriter.
getIndexType(),
2536 linalg::YieldOp::create(rewriter, loc, extract);
2540 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2541 resultElementTy.isInteger(32)) {
2542 Value extend = arith::ExtSIOp::create(
2543 rewriter, loc, rewriter.
getI32Type(), inputValue);
2545 auto offset = arith::ConstantOp::create(
2547 auto seven = arith::ConstantOp::create(rewriter, loc,
2549 auto one = arith::ConstantOp::create(rewriter, loc,
2551 auto b1111111 = arith::ConstantOp::create(
2558 auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset);
2559 Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven);
2561 arith::AndIOp::create(rewriter, loc, extendAdd, b1111111);
2566 Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one);
2568 index = arith::IndexCastOp::create(rewriter, loc,
2570 indexPlusOne = arith::IndexCastOp::create(
2575 Value next = tensor::ExtractOp::create(rewriter, loc,
table,
2579 arith::ExtSIOp::create(rewriter, loc, rewriter.
getI32Type(), base);
2581 arith::ExtSIOp::create(rewriter, loc, rewriter.
getI32Type(), next);
2585 Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven);
2586 Value diff = arith::SubIOp::create(rewriter, loc, next, base);
2587 Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction);
2589 arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled);
2591 linalg::YieldOp::create(rewriter, loc, result);
2598 op,
"unable to create body for tosa.table op");
2605 static bool isRankedTensor(
Type type) {
return isa<RankedTensorType>(type); }
2613 auto divBy2 = builder.
createOrFold<arith::DivUIOp>(loc, value, two);
2614 auto plusOne = builder.
createOrFold<arith::AddIOp>(loc, divBy2, one);
2618 static RankedTensorType
2626 dims[2] = halfPlusOne(builder, loc, dims[2]);
2631 auto elementType = cast<RankedTensorType>(input.
getType()).getElementType();
2636 RankedTensorType type,
2639 tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
2640 auto fillValueAttr = rewriter.
getZeroAttr(type.getElementType());
2641 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
2643 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValue},
2646 return filledTensor;
2650 FloatType type,
Value value) {
2651 auto integerVal = arith::IndexCastUIOp::create(
2653 type.getIntOrFloatBitWidth() > 32 ? builder.
getI64Type()
2657 return arith::UIToFPOp::create(builder, loc, type, integerVal);
2661 FloatType type, int64_t index) {
2662 auto indexVal = linalg::IndexOp::create(builder, loc, index);
2663 return castIndexToFloat(builder, loc, type, indexVal);
2666 template <
typename... Args>
2672 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2674 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2675 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2677 "only supports ranked tensors");
2680 auto loc = rfft2d.getLoc();
2681 auto input = rfft2d.getInputReal();
2683 dyn_cast<FloatType>(cast<ShapedType>(input.
getType()).getElementType());
2686 "only supports float element types");
2690 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2694 utils::IteratorType::parallel, utils::IteratorType::parallel,
2695 utils::IteratorType::parallel, utils::IteratorType::reduction,
2696 utils::IteratorType::reduction};
2701 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2702 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2707 affineDimsExpr(rewriter, 0, 1, 2),
2708 affineDimsExpr(rewriter, 0, 1, 2)},
2712 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input, 1);
2713 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input, 2);
2716 auto twoPiAttr = rewriter.
getFloatAttr(elementType, 6.283185307179586);
2717 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2718 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2719 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2722 Value valReal = args[0];
2723 Value sumReal = args[1];
2724 Value sumImag = args[2];
2727 Value oy = linalg::IndexOp::create(builder, loc, 1);
2728 Value ox = linalg::IndexOp::create(builder, loc, 2);
2729 Value iy = linalg::IndexOp::create(builder, loc, 3);
2730 Value ix = linalg::IndexOp::create(builder, loc, 4);
2735 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2736 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2738 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2739 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2741 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2742 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2744 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2745 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2746 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2747 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2751 auto cosAngle = math::CosOp::create(builder, loc, angle);
2752 auto sinAngle = math::SinOp::create(builder, loc, angle);
2753 auto realComponent =
2754 arith::MulFOp::create(builder, loc, valReal, cosAngle);
2755 auto imagComponent =
2756 arith::MulFOp::create(builder, loc, valReal, sinAngle);
2761 arith::AddFOp::create(builder, loc, sumReal, realComponent);
2763 arith::SubFOp::create(builder, loc, sumImag, imagComponent);
2765 linalg::YieldOp::create(builder, loc,
ValueRange{outReal, outImag});
2769 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2770 indexingMaps, iteratorTypes, buildBody);
2779 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2781 if (!llvm::all_of(fft2d->getOperandTypes(),
2782 RFFT2dConverter::isRankedTensor) ||
2783 !llvm::all_of(fft2d->getResultTypes(),
2784 RFFT2dConverter::isRankedTensor)) {
2789 Value input_real = fft2d.getInputReal();
2790 Value input_imag = fft2d.getInputImag();
2791 BoolAttr inverse = fft2d.getInverseAttr();
2793 auto real_el_ty = cast<FloatType>(
2794 cast<ShapedType>(input_real.
getType()).getElementType());
2795 [[maybe_unused]]
auto imag_el_ty = cast<FloatType>(
2796 cast<ShapedType>(input_imag.
getType()).getElementType());
2798 assert(real_el_ty == imag_el_ty);
2813 utils::IteratorType::parallel, utils::IteratorType::parallel,
2814 utils::IteratorType::parallel, utils::IteratorType::reduction,
2815 utils::IteratorType::reduction};
2820 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2822 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2827 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2828 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2829 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2830 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2834 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 1);
2835 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 2);
2838 auto twoPiAttr = rewriter.
getFloatAttr(real_el_ty, 6.283185307179586);
2839 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2841 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2843 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2846 Value valReal = args[0];
2847 Value valImag = args[1];
2848 Value sumReal = args[2];
2849 Value sumImag = args[3];
2852 Value oy = linalg::IndexOp::create(builder, loc, 1);
2853 Value ox = linalg::IndexOp::create(builder, loc, 2);
2854 Value iy = linalg::IndexOp::create(builder, loc, 3);
2855 Value ix = linalg::IndexOp::create(builder, loc, 4);
2859 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2860 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2862 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2863 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2866 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2868 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2870 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2871 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2873 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2874 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2877 angle = arith::MulFOp::create(
2878 builder, loc, angle,
2879 arith::ConstantOp::create(rewriter, loc,
2885 auto cosAngle = math::CosOp::create(builder, loc, angle);
2886 auto sinAngle = math::SinOp::create(builder, loc, angle);
2888 auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle);
2889 auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle);
2890 auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin);
2892 auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle);
2893 auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle);
2895 auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin);
2900 arith::AddFOp::create(builder, loc, sumReal, realComponent);
2902 arith::AddFOp::create(builder, loc, sumImag, imagComponent);
2904 linalg::YieldOp::create(builder, loc,
ValueRange{outReal, outImag});
2908 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2909 indexingMaps, iteratorTypes, buildBody);
2930 PointwiseConverter<tosa::AddOp>,
2931 PointwiseConverter<tosa::SubOp>,
2932 PointwiseConverter<tosa::MulOp>,
2933 PointwiseConverter<tosa::IntDivOp>,
2934 PointwiseConverter<tosa::NegateOp>,
2935 PointwiseConverter<tosa::PowOp>,
2936 PointwiseConverter<tosa::ReciprocalOp>,
2937 PointwiseConverter<tosa::RsqrtOp>,
2938 PointwiseConverter<tosa::LogOp>,
2939 PointwiseConverter<tosa::ExpOp>,
2940 PointwiseConverter<tosa::AbsOp>,
2941 PointwiseConverter<tosa::SinOp>,
2942 PointwiseConverter<tosa::CosOp>,
2943 PointwiseConverter<tosa::TanhOp>,
2944 PointwiseConverter<tosa::ErfOp>,
2945 PointwiseConverter<tosa::BitwiseAndOp>,
2946 PointwiseConverter<tosa::BitwiseOrOp>,
2947 PointwiseConverter<tosa::BitwiseNotOp>,
2948 PointwiseConverter<tosa::BitwiseXorOp>,
2949 PointwiseConverter<tosa::LogicalAndOp>,
2950 PointwiseConverter<tosa::LogicalNotOp>,
2951 PointwiseConverter<tosa::LogicalOrOp>,
2952 PointwiseConverter<tosa::LogicalXorOp>,
2953 PointwiseConverter<tosa::CastOp>,
2954 PointwiseConverter<tosa::LogicalLeftShiftOp>,
2955 PointwiseConverter<tosa::LogicalRightShiftOp>,
2956 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2957 PointwiseConverter<tosa::ClzOp>,
2958 PointwiseConverter<tosa::SelectOp>,
2959 PointwiseConverter<tosa::GreaterOp>,
2960 PointwiseConverter<tosa::GreaterEqualOp>,
2961 PointwiseConverter<tosa::EqualOp>,
2962 PointwiseConverter<tosa::MaximumOp>,
2963 PointwiseConverter<tosa::MinimumOp>,
2964 PointwiseConverter<tosa::CeilOp>,
2965 PointwiseConverter<tosa::FloorOp>,
2966 PointwiseConverter<tosa::ClampOp>,
2967 PointwiseConverter<tosa::SigmoidOp>
2968 >(converter,
patterns->getContext());
2971 IdentityNConverter<tosa::IdentityOp>,
2972 ReduceConverter<tosa::ReduceAllOp>,
2973 ReduceConverter<tosa::ReduceAnyOp>,
2974 ReduceConverter<tosa::ReduceMinOp>,
2975 ReduceConverter<tosa::ReduceMaxOp>,
2976 ReduceConverter<tosa::ReduceSumOp>,
2977 ReduceConverter<tosa::ReduceProductOp>,
2985 TileConverter>(
patterns->getContext());
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static bool operandsAndResultsRanked(Operation *operation)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
BoolAttr getBoolAttr(bool value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...