29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/Sequence.h"
32 #include <type_traits>
59 template <
typename OpTy>
67 auto nanMode = op.getNanMode();
68 if (nanMode == NanPropagationMode::PROPAGATE)
72 Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
73 arith::CmpFPredicate::UNO, lhs, lhs);
74 Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
75 arith::CmpFPredicate::UNO, rhs, rhs);
77 arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN, rhs, result);
78 return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN, lhs,
90 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
91 return math::AbsFOp::create(rewriter, loc, resultTypes, args);
93 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
94 auto zero = arith::ConstantOp::create(rewriter, loc,
96 auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]);
97 return arith::MaxSIOp::create(rewriter, loc, args[0], neg);
101 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
102 return arith::AddFOp::create(rewriter, loc, resultTypes, args);
104 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
105 return arith::AddIOp::create(rewriter, loc, resultTypes, args);
108 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
109 return arith::SubFOp::create(rewriter, loc, resultTypes, args);
111 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
112 return arith::SubIOp::create(rewriter, loc, resultTypes, args);
115 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
116 return arith::DivSIOp::create(rewriter, loc, resultTypes, args);
119 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
121 arith::ConstantOp::create(rewriter, loc,
FloatAttr::get(elementTy, 1));
122 return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
126 if (isa<tosa::MulOp>(op)) {
127 auto shiftVal = cast<tosa::MulOp>(op).getShift();
129 bool shiftIsConstant =
true;
132 shift = shiftElem.
getValues<IntegerAttr>()[0].getInt();
134 shiftIsConstant =
false;
136 if (isa<FloatType>(elementTy)) {
139 "Cannot have shift value for float");
142 return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
146 if (isa<IntegerType>(elementTy)) {
150 if (shift > 0 || !shiftIsConstant) {
157 a = arith::ExtSIOp::create(rewriter, loc, rewriter.
getI32Type(), a);
160 b = arith::ExtSIOp::create(rewriter, loc, rewriter.
getI32Type(), b);
162 auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
164 RoundingMode::SINGLE_ROUND);
166 tosa::ApplyScaleOp::create(rewriter, loc, rewriter.
getI32Type(), a,
167 b, shiftAmount, roundingAttr);
174 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
177 a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a);
179 b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], b);
181 return arith::MulIOp::create(rewriter, loc, resultTypes, a, b);
186 if (isa<tosa::NegateOp>(op)) {
187 auto negate = cast<tosa::NegateOp>(op);
189 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
192 op,
"input1 zero point cannot be statically determined");
196 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
199 op,
"output zero point cannot be statically determined");
203 int64_t inZp = *maybeInZp;
204 int64_t outZp = *maybeOutZp;
206 if (isa<FloatType>(elementTy))
207 return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
209 if (isa<IntegerType>(elementTy)) {
210 if (!inZp && !outZp) {
211 auto constant = arith::ConstantOp::create(
213 return arith::SubIOp::create(rewriter, loc, resultTypes, constant,
218 const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
219 const int64_t zpAdd = inZp + outZp;
220 const int64_t maxValue =
221 APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
227 int intermediateBitWidth = 64;
228 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
229 intermediateBitWidth = 16;
230 }
else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
231 intermediateBitWidth = 32;
232 }
else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
233 intermediateBitWidth = 48;
237 Value zpAddValue = arith::ConstantOp::create(
243 arith::ExtSIOp::create(rewriter, loc, intermediateType, args[0]);
244 auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext);
248 rewriter, loc, intermediateType,
249 APInt::getSignedMinValue(inputBitWidth).getSExtValue());
251 rewriter, loc, intermediateType,
252 APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
256 return arith::TruncIOp::create(rewriter, loc, elementTy,
clamp);
261 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
262 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
265 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
266 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
269 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
271 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
272 auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr);
273 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes);
277 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
278 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
281 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
282 return arith::ShLIOp::create(rewriter, loc, resultTypes, args);
285 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
286 return arith::ShRUIOp::create(rewriter, loc, resultTypes, args);
289 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
290 auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args);
291 auto round = cast<BoolAttr>(op->
getAttr(
"round")).getValue();
297 auto one = arith::ConstantOp::create(rewriter, loc,
299 auto zero = arith::ConstantOp::create(rewriter, loc,
305 auto shiftValueGreaterThanZero = arith::CmpIOp::create(
306 rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero);
310 arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one);
312 arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract)
314 auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted,
317 arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
319 auto shouldRound = arith::AndIOp::create(
320 rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
322 arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
323 return arith::AddIOp::create(rewriter, loc, resultTypes, result, extended);
327 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
328 return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]);
332 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
333 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
336 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
337 auto one = arith::ConstantOp::create(rewriter, loc,
339 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one);
343 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
344 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
347 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
348 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
351 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
352 return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args);
355 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
356 return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args);
359 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
360 return mlir::math::LogOp::create(rewriter, loc, resultTypes, args);
363 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
364 return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args);
367 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
368 return mlir::math::SinOp::create(rewriter, loc, resultTypes, args);
371 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
372 return mlir::math::CosOp::create(rewriter, loc, resultTypes, args);
375 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
376 return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args);
379 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
380 return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args);
383 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
384 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT,
387 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
388 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt,
392 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
393 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE,
396 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
397 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
401 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
402 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ,
405 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
406 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
410 if (isa<tosa::SelectOp>(op)) {
412 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
413 return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]);
417 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
418 auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
420 rewriter, args[0], args[1],
max);
423 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
424 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
428 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
429 auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
431 rewriter, args[0], args[1],
min);
434 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
435 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
439 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
440 return math::CeilOp::create(rewriter, loc, resultTypes, args);
443 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
444 return math::FloorOp::create(rewriter, loc, resultTypes, args);
447 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
448 bool losesInfo =
false;
449 APFloat minApf = cast<FloatAttr>(op->
getAttr(
"min_val")).getValue();
450 APFloat maxApf = cast<FloatAttr>(op->
getAttr(
"max_val")).getValue();
451 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
452 APFloat::rmNearestTiesToEven, &losesInfo);
453 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
454 APFloat::rmNearestTiesToEven, &losesInfo);
455 auto min = arith::ConstantOp::create(
456 rewriter, loc, elementTy, rewriter.
getFloatAttr(elementTy, minApf));
457 auto max = arith::ConstantOp::create(
458 rewriter, loc, elementTy, rewriter.
getFloatAttr(elementTy, maxApf));
461 auto clampOp = llvm::cast<tosa::ClampOp>(op);
462 const auto nanMode = clampOp.getNanMode();
465 if (!isa<FloatType>(elementTy))
470 if (nanMode == NanPropagationMode::PROPAGATE)
484 Value isNaN = arith::CmpFOp::create(
485 rewriter, op->
getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
488 return arith::SelectOp::create(rewriter, op->
getLoc(), isNaN,
min, result);
491 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
492 auto intTy = cast<IntegerType>(elementTy);
494 cast<IntegerAttr>(op->
getAttr(
"min_val")).getValue().getSExtValue();
496 cast<IntegerAttr>(op->
getAttr(
"max_val")).getValue().getSExtValue();
500 if (intTy.isUnsignedInteger()) {
501 minRepresentable = 0;
502 if (intTy.getIntOrFloatBitWidth() <= 63) {
504 (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
507 }
else if (intTy.getIntOrFloatBitWidth() <= 64) {
509 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
511 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
522 intTy.getIntOrFloatBitWidth());
524 intTy.getIntOrFloatBitWidth());
526 intTy.isUnsignedInteger());
530 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
532 arith::ConstantOp::create(rewriter, loc,
FloatAttr::get(elementTy, 1));
533 auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
534 auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
535 auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
536 return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
540 if (isa<tosa::CastOp>(op)) {
541 Type srcTy = elementTy;
542 Type dstTy = resultTypes.front();
554 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
555 return arith::ExtFOp::create(rewriter, loc, resultTypes, args,
558 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
559 return arith::TruncFOp::create(rewriter, loc, resultTypes, args,
563 if (srcTy.
isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
564 return arith::UIToFPOp::create(rewriter, loc, resultTypes, args,
567 if (srcTy.
isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
568 return arith::ExtUIOp::create(rewriter, loc, resultTypes, args,
574 auto unrealizedCast =
575 UnrealizedConversionCastOp::create(
579 return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
584 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
585 return arith::SIToFPOp::create(rewriter, loc, resultTypes, args,
589 if (isa<FloatType>(srcTy) && dstTy.
isInteger(1)) {
590 Value zero = arith::ConstantOp::create(rewriter, loc,
592 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE,
596 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
597 auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]);
599 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
603 APFloat::semanticsMaxExponent(fltSemantics)) {
606 auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded);
607 auto posInf = arith::ConstantOp::create(
610 APFloat::getInf(fltSemantics)));
611 auto negInf = arith::ConstantOp::create(
615 APFloat::getInf(fltSemantics,
true)));
616 auto overflow = arith::CmpFOp::create(
617 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf);
618 auto underflow = arith::CmpFOp::create(
619 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf);
620 auto intMin = arith::ConstantOp::create(
625 auto intMax = arith::ConstantOp::create(
631 arith::SelectOp::create(rewriter, loc, overflow, intMax, conv);
632 return arith::SelectOp::create(rewriter, loc, underflow, intMin,
636 auto intMinFP = arith::ConstantOp::create(
644 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
650 auto intMaxFP = arith::ConstantOp::create(
659 return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped);
666 auto intMaxPlusOneFP = arith::ConstantOp::create(
675 auto intMax = arith::ConstantOp::create(
681 arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP);
683 arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP);
684 auto overflow = arith::CmpFOp::create(
685 rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
686 return arith::SelectOp::create(rewriter, loc, overflow, intMax,
692 if (isa<IntegerType>(srcTy) && dstTy.
isInteger(1)) {
695 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne,
699 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
700 return arith::ExtSIOp::create(rewriter, loc, resultTypes, args,
703 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
704 return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]);
709 op,
"unhandled op for linalg body calculation for elementwise op");
720 auto [it, inserted] = indexPool.try_emplace(index);
723 arith::ConstantOp::create(rewriter, loc, rewriter.
getIndexAttr(index));
729 auto indexValue =
createIndex(rewriter, loc, indexPool, index);
730 return tensor::DimOp::create(rewriter, loc, tensor, indexValue).getResult();
736 auto shapedType = dyn_cast<ShapedType>(tensor.
getType());
737 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
738 assert(index >= 0 && index < shapedType.getRank() &&
"index out of bounds");
739 if (shapedType.isDynamicDim(index))
740 return getTensorDim(rewriter, loc, indexPool, tensor, index);
741 return rewriter.
getIndexAttr(shapedType.getDimSize(index));
745 auto isRanked = [](
Value value) {
746 return isa<RankedTensorType>(value.getType());
748 return llvm::all_of(operation->
getOperands(), isRanked) &&
749 llvm::all_of(operation->
getResults(), isRanked);
762 static std::pair<OpFoldResult, Value>
768 for (
auto operand : operands) {
769 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
770 if (ShapedType::isStatic(size) && size > 1)
775 auto operandsWithDynamicDim =
776 llvm::filter_to_vector(operands, [&](
Value operand) {
777 return cast<RankedTensorType>(operand.
getType()).isDynamicDim(dim);
781 if (operandsWithDynamicDim.empty())
788 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
789 if (operandsWithDynamicDim.size() == 1)
790 return {targetSize, operandsWithDynamicDim[0]};
793 for (
size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
795 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
796 targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize);
798 return {targetSize,
nullptr};
806 assert(!operands.empty());
807 auto rank = cast<RankedTensorType>(operands.front().
getType()).getRank();
810 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
811 auto [targetSize, masterOperand] =
813 targetShape.push_back(targetSize);
814 masterOperands.push_back(masterOperand);
816 return {targetShape, masterOperands};
822 Value masterOperand) {
824 auto rankedTensorType = cast<RankedTensorType>(operand.
getType());
825 if (!rankedTensorType.isDynamicDim(dim))
832 if (operand == masterOperand)
836 auto rank = rankedTensorType.getRank();
838 for (
auto index : llvm::seq<int64_t>(0, rank)) {
841 affineExprs.push_back(affineExpr);
843 auto broadcastAffineMap =
849 auto one =
createIndex(rewriter, loc, indexPool, 1);
850 auto runtimeSize =
getTensorDim(rewriter, loc, indexPool, operand, dim);
851 auto broadcastNecessary = arith::CmpIOp::create(
852 rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one);
862 for (
auto index : llvm::seq<int64_t>(0, rank)) {
863 auto size = index == dim ? targetSize
866 outputTensorShape.push_back(size);
868 Value outputTensor = tensor::EmptyOp::create(
869 opBuilder, loc, outputTensorShape, rankedTensorType.getElementType());
873 linalg::GenericOp::create(
874 opBuilder, loc, outputTensor.
getType(), operand, outputTensor,
878 linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
883 auto castResultTensor = rewriter.
createOrFold<tensor::CastOp>(
884 loc, operand.
getType(), resultTensor);
887 scf::YieldOp::create(opBuilder, loc, castResultTensor);
892 scf::YieldOp::create(opBuilder, loc, operand);
896 auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary,
897 emitThenRegion, emitElseRegion);
898 return ifOp.getResult(0);
905 int64_t rank = cast<RankedTensorType>(operand.
getType()).getRank();
906 assert((int64_t)targetShape.size() == rank);
907 assert((int64_t)masterOperands.size() == rank);
908 for (
auto index : llvm::seq<int64_t>(0, rank))
911 targetShape[index], masterOperands[index]);
921 if (operands.size() == 1)
925 bool hasDynamic =
false;
926 for (
auto op : operands) {
927 const auto tType = dyn_cast<RankedTensorType>(op.getType());
928 if (tType && !tType.hasStaticShape()) {
937 return llvm::map_to_vector(operands, [&](
Value operand) {
939 targetShape, masterOperands);
949 auto resultType = cast_or_null<RankedTensorType>(
954 Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape,
955 resultType.getElementType());
960 auto rank = resultType.getRank();
961 auto affineMaps = llvm::map_to_vector(operands, [&](
Value operand) {
962 auto shape = cast<ShapedType>(operand.
getType()).getShape();
968 bool requiresBroadcast =
969 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
970 auto affineExpr = requiresBroadcast
973 affineExprs.push_back(affineExpr);
980 bool encounteredError =
false;
981 auto linalgOp = linalg::GenericOp::create(
982 rewriter, loc, outputTensor.
getType(), operands, outputTensor, affineMaps,
987 {resultType.getElementType()}, rewriter);
989 encounteredError =
true;
992 linalg::YieldOp::create(opBuilder, loc, opResult);
994 if (encounteredError)
996 operation,
"unable to create linalg.generic body for elementwise op");
999 auto castResult = rewriter.
createOrFold<tensor::CastOp>(
1000 loc, resultType, linalgOp->getResult(0));
1001 rewriter.
replaceOp(operation, castResult);
1008 if (isa<tosa::MulOp>(operation)) {
1012 return operands.take_front(2);
1014 return operands.take_front(3);
1017 if (isa<tosa::NegateOp>(operation))
1018 return operands.take_front(1);
1022 static LogicalResult
1028 assert(operation->
getNumResults() == 1 &&
"elementwise op expects 1 result");
1030 "elementwise op expects at least 1 operand");
1033 "Unranked tensors not supported");
1037 auto loc = operation->
getLoc();
1039 auto [targetShape, masterOperands] =
1041 auto broadcastOperands =
1043 targetShape, masterOperands);
1045 targetShape, converter);
1052 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1055 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1058 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1061 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1064 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1066 elementTy, APFloat::getLargest(
1067 cast<FloatType>(elementTy).getFloatSemantics(),
false));
1069 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1073 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1075 elementTy, APFloat::getLargest(
1076 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1078 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1082 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1085 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1088 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1090 elementTy, APFloat::getLargest(
1091 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1093 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1107 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1108 return arith::AddFOp::create(rewriter, loc, args);
1111 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1112 return arith::AddIOp::create(rewriter, loc, args);
1115 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1116 return arith::MulFOp::create(rewriter, loc, args);
1119 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1120 return arith::MulIOp::create(rewriter, loc, args);
1123 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1124 return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
1127 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1128 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
1131 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1132 return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
1135 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1136 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
1139 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1140 return arith::AndIOp::create(rewriter, loc, args);
1142 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1143 return arith::OrIOp::create(rewriter, loc, args);
1151 template <
typename OpTy>
1154 auto loc = op->getLoc();
1155 auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1156 auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1157 if (!inputTy || !resultTy)
1160 auto elementTy = resultTy.getElementType();
1161 Value input = op->getOperand(0);
1164 bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
1165 isa<FloatType>(elementTy) &&
1166 cast<FloatType>(elementTy).isBF16();
1171 for (
unsigned i = 0; i < inputTy.getRank(); i++) {
1173 reduceShape.push_back(inputTy.getDimSize(i));
1174 if (inputTy.isDynamicDim(i))
1175 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1180 inputs.push_back(input);
1184 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1190 op,
"No initial value found for reduction operation");
1192 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
1194 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValue},
1197 outputs.push_back(filledTensor);
1199 bool isNanIgnoreMode =
false;
1200 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1201 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1203 if (isa<FloatType>(elementTy) &&
1204 op.getNanMode() == NanPropagationMode::IGNORE) {
1205 isNanIgnoreMode =
true;
1211 auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
1212 auto emptyBoolTensor =
1213 tensor::EmptyOp::create(rewriter, loc, reduceShape,
1214 trueValue.getType(), dynDims)
1216 auto allResultsNaNTensor =
1217 linalg::FillOp::create(rewriter, loc,
ValueRange{trueValue},
1229 inputs.push_back(input);
1230 outputs.push_back(allResultsNaNTensor);
1234 bool didEncounterError =
false;
1235 linalg::LinalgOp linalgOp = linalg::ReduceOp::create(
1236 rewriter, loc, inputs, outputs, axis,
1238 std::array<Value, 2> binaryArgs{
1239 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1242 if (binaryArgs[0].
getType() != accTy)
1243 binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
1249 didEncounterError =
true;
1252 if (isNanIgnoreMode) {
1253 auto inputValue = blockArgs[0];
1254 auto initialValue = blockArgs[2];
1255 auto oldAllResultsNanFlagValue = blockArgs[3];
1258 Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(),
1259 arith::CmpFPredicate::UNO,
1260 inputValue, inputValue);
1262 auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(),
1263 isNaN, initialValue, result);
1266 auto newAllResultsNanFlagValue = arith::AndIOp::create(
1267 nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1268 resultsToYield.push_back(selectOp);
1269 resultsToYield.push_back(newAllResultsNanFlagValue);
1271 resultsToYield.push_back(result);
1273 linalg::YieldOp::create(nestedBuilder, loc, resultsToYield);
1276 if (!didEncounterError)
1278 op,
"unable to create linalg.generic body for reduce op");
1280 if (isNanIgnoreMode) {
1289 APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(),
false));
1290 auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
1291 auto emptyNanTensor =
1292 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1294 auto nanFilledTensor =
1295 linalg::FillOp::create(rewriter, loc,
ValueRange{nanValue},
1301 auto finalEmptyTensor =
1302 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1308 ins.push_back(linalgOp->getOpResult(1));
1309 ins.push_back(nanFilledTensor);
1310 ins.push_back(linalgOp->getResult(0));
1311 outs.push_back(finalEmptyTensor);
1313 linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs);
1314 linalgOp = linalgSelect;
1318 Value reducedRes = linalgOp->getResult(0);
1321 tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
1324 const unsigned reducedRank =
1325 cast<ShapedType>(reducedRes.
getType()).getRank();
1328 linalg::GenericOp::create(
1334 Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc,
1335 elementTy, args[0]);
1336 linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf);
1342 uint64_t expandInputRank = cast<ShapedType>(reducedRes.
getType()).getRank();
1343 reassociationMap.resize(expandInputRank);
1345 for (uint64_t i = 0; i < expandInputRank; i++) {
1346 int32_t dimToPush = i > axis ? i + 1 : i;
1350 if (expandInputRank != 0) {
1351 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1352 reassociationMap[expandedDim].push_back(
1367 template <
typename SrcOp>
1374 matchAndRewrite(SrcOp op, OpAdaptor operands,
1377 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1385 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1387 auto loc = op.getLoc();
1388 auto input = op.getInput();
1389 auto inputTy = cast<ShapedType>(op.getInput().getType());
1390 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1391 unsigned rank = inputTy.getRank();
1394 if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND)
1395 return rewriter.notifyMatchFailure(
1396 op,
"tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1397 "currently supported");
1398 if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32())
1399 return rewriter.notifyMatchFailure(
1400 op,
"tosa.rescale requires scale32 for double_round to be true");
1402 if (!isa<IntegerType>(inputTy.getElementType()))
1403 return rewriter.notifyMatchFailure(op,
"only support integer type");
1406 for (
int i = 0; i < outputTy.getRank(); i++) {
1407 if (outputTy.isDynamicDim(i)) {
1408 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1415 return rewriter.notifyMatchFailure(
1416 op,
"tosa.rescale requires constant shift input values");
1420 return rewriter.notifyMatchFailure(
1421 op,
"tosa.rescale requires constant multiplier input values");
1424 llvm::to_vector(shiftElems.
getValues<int8_t>());
1427 llvm::map_range(multiplierElems.
getValues<IntegerAttr>(),
1428 [](IntegerAttr attr) -> int32_t {
1429 return static_cast<int32_t>(attr.getInt());
1433 for (
int i = 0, s = multiplierValues.size(); i < s; i++) {
1434 if (shiftValues[i] > 63) {
1436 multiplierValues[i] = 0;
1444 op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1445 llvm::any_of(shiftValues, [](int32_t v) {
return v > 31; });
1446 RoundingMode roundingMode =
1447 doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
1450 rewriter.getMultiDimIdentityMap(rank)};
1455 Value multiplierConstant;
1456 int64_t multiplierArg = 0;
1457 if (multiplierValues.size() == 1) {
1458 multiplierConstant = arith::ConstantOp::create(
1459 rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1462 rewriter.getAffineDimExpr(rank - 1)};
1463 auto multiplierType =
1465 rewriter.getI32Type());
1466 genericInputs.push_back(arith::ConstantOp::create(
1472 rewriter.getContext()));
1474 multiplierArg = indexingMaps.size() - 1;
1479 Value shiftConstant;
1480 int64_t shiftArg = 0;
1481 if (shiftValues.size() == 1) {
1482 shiftConstant = arith::ConstantOp::create(
1483 rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1486 rewriter.getAffineDimExpr(rank - 1)};
1489 rewriter.getIntegerType(8));
1490 genericInputs.push_back(arith::ConstantOp::create(
1494 rewriter.getContext()));
1495 shiftArg = indexingMaps.size() - 1;
1499 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1502 Value emptyTensor = tensor::EmptyOp::create(
1503 rewriter, loc, outputTy.getShape(), outputTy.getElementType(),
1506 auto linalgOp = linalg::GenericOp::create(
1507 rewriter, loc, outputTy, genericInputs,
ValueRange{emptyTensor},
1511 Value value = blockArgs[0];
1514 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1516 (void)rewriter.notifyMatchFailure(
1517 op,
"input zero point cannot be statically determined");
1523 const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
1524 auto inputZp = arith::ConstantOp::create(
1529 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1531 (void)rewriter.notifyMatchFailure(
1532 op,
"output zero point cannot be statically determined");
1536 IntegerType outIntType =
1537 cast<IntegerType>(blockArgs.back().getType());
1538 unsigned outBitWidth = outIntType.getWidth();
1539 const int32_t outAttrBitwidth = 32;
1540 assert(outBitWidth <= 32 &&
"Unexpected output zeropoint bitwidth");
1541 auto outputZp = arith::ConstantOp::create(
1546 Value multiplier = multiplierConstant ? multiplierConstant
1547 : blockArgs[multiplierArg];
1548 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1551 value = UnrealizedConversionCastOp::create(
1552 nestedBuilder, nestedLoc,
1553 nestedBuilder.getIntegerType(
1559 if (op.getInputUnsigned()) {
1560 value = arith::ExtUIOp::create(nestedBuilder, nestedLoc,
1561 nestedBuilder.getI32Type(), value);
1563 value = arith::ExtSIOp::create(nestedBuilder, nestedLoc,
1564 nestedBuilder.getI32Type(), value);
1569 arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp);
1571 value = tosa::ApplyScaleOp::create(nestedBuilder, loc,
1572 nestedBuilder.getI32Type(), value,
1573 multiplier, shift, roundingMode);
1577 arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp);
1580 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1581 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1584 if (op.getOutputUnsigned()) {
1586 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1589 auto intMinVal = arith::ConstantOp::create(
1590 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin));
1591 auto intMaxVal = arith::ConstantOp::create(
1592 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax));
1595 nestedBuilder,
false);
1597 if (outIntType.getWidth() < 32) {
1598 value = arith::TruncIOp::create(
1599 nestedBuilder, nestedLoc,
1600 rewriter.getIntegerType(outIntType.getWidth()), value);
1603 if (outIntType.isUnsignedInteger()) {
1604 value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
1608 linalg::YieldOp::create(nestedBuilder, loc, value);
1611 rewriter.replaceOp(op, linalgOp->getResults());
1623 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1627 auto input = op.getInput();
1628 auto inputTy = cast<RankedTensorType>(input.getType());
1629 auto resultTy = cast<RankedTensorType>(op.getType());
1630 const bool isBilinear = op.getMode() == ResizeMode::BILINEAR;
1632 auto inputH = inputTy.getDimSize(1);
1633 auto inputW = inputTy.getDimSize(2);
1634 auto outputH = resultTy.getDimSize(1);
1635 auto outputW = resultTy.getDimSize(2);
1637 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1638 return rewriter.notifyMatchFailure(
1639 op,
"tosa.resize is not a pure 1x1->1x1 image operation");
1641 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1642 op.getMode() != ResizeMode::BILINEAR)
1643 return rewriter.notifyMatchFailure(
1644 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1646 if (inputTy == resultTy) {
1647 rewriter.replaceOp(op, input);
1658 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1659 reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1660 reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1661 reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1665 inputTy.getElementType());
1666 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input,
1671 if (inputTy.isDynamicDim(0))
1672 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1673 if (inputTy.isDynamicDim(3))
1674 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1677 auto genericTy = collapseTy.clone(resultTy.getElementType());
1679 tensor::EmptyOp::create(builder, genericTy.getShape(),
1680 resultTy.getElementType(), outputDynSize);
1681 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1683 utils::IteratorType::parallel);
1685 auto generic = linalg::GenericOp::create(
1689 Value value = args[0];
1691 if (inputTy.getElementType() != resultTy.getElementType()) {
1692 value = arith::ExtSIOp::create(b, loc, resultTy.getElementType(),
1695 if (isBilinear && scale[0] != 0) {
1696 Value scaleY = arith::ConstantOp::create(
1697 b, loc, b.getI32IntegerAttr(scale[0]));
1698 value = arith::MulIOp::create(b, loc, value, scaleY);
1701 if (isBilinear && scale[2] != 0) {
1702 Value scaleX = arith::ConstantOp::create(
1703 b, loc, b.getI32IntegerAttr(scale[2]));
1704 value = arith::MulIOp::create(b, loc, value, scaleX);
1708 linalg::YieldOp::create(b, loc, value);
1711 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1712 op, resultTy,
generic.getResults()[0], reassociationMap);
1724 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1728 auto input = op.getInput();
1729 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1730 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1732 if (!inputTy || !resultTy)
1733 return rewriter.notifyMatchFailure(op,
1734 "requires ranked input/output types");
1736 auto batch = inputTy.getDimSize(0);
1737 auto channels = inputTy.getDimSize(3);
1738 auto inputH = inputTy.getDimSize(1);
1739 auto inputW = inputTy.getDimSize(2);
1740 auto outputH = resultTy.getDimSize(1);
1741 auto outputW = resultTy.getDimSize(2);
1743 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1744 return rewriter.notifyMatchFailure(
1745 op,
"tosa.resize has no broadcasting behavior");
1750 resizeShape.push_back(batch);
1751 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1752 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1753 resizeShape.push_back(channels);
1755 auto resizeTy = resultTy.clone(resizeShape);
1757 tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(),
1758 op.getOffset(), op.getBorder(), op.getMode());
1762 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1763 reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1765 reassociationMap.push_back({});
1766 reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1768 reassociationMap.push_back({});
1769 reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1773 collapseShape.push_back(outputH);
1775 collapseShape.push_back(outputW);
1776 collapseShape.push_back(channels);
1778 auto collapseTy = resultTy.clone(collapseShape);
1779 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy,
1780 resize, reassociationMap);
1784 if (inputTy.isDynamicDim(0))
1785 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1786 if (inputTy.isDynamicDim(3))
1787 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1790 utils::IteratorType::parallel);
1791 Value empty = tensor::EmptyOp::create(
1792 builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1796 inputExprs.push_back(rewriter.getAffineDimExpr(1));
1798 inputExprs.push_back(rewriter.getAffineDimExpr(2));
1799 inputExprs.push_back(rewriter.getAffineDimExpr(3));
1802 inputExprs, rewriter.getContext());
1804 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1805 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1809 Value value = args[0];
1810 linalg::YieldOp::create(b, loc, value);
1821 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1825 auto input = op.getInput();
1826 auto inputTy = cast<ShapedType>(input.getType());
1827 auto resultTy = cast<ShapedType>(op.getType());
1828 auto resultETy = resultTy.getElementType();
1830 bool floatingPointMode = isa<FloatType>(resultETy);
1831 auto floatTy = resultETy;
1833 auto imageH = inputTy.getShape()[1];
1834 auto imageW = inputTy.getShape()[2];
1836 auto dynamicDimsOr =
1838 if (!dynamicDimsOr.has_value())
1839 return rewriter.notifyMatchFailure(
1840 op,
"unable to get dynamic dimensions of tosa.resize");
1842 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1843 op.getMode() != ResizeMode::BILINEAR)
1844 return rewriter.notifyMatchFailure(
1845 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1848 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1849 auto emptyTensor = tensor::EmptyOp::create(b, resultTy.getShape(),
1850 resultETy, *dynamicDimsOr);
1851 auto genericOp = linalg::GenericOp::create(
1854 Value resize = genericOp.getResult(0);
1858 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1860 Value batch = linalg::IndexOp::create(b, 0);
1861 Value y = linalg::IndexOp::create(b, 1);
1862 Value x = linalg::IndexOp::create(b, 2);
1863 Value channel = linalg::IndexOp::create(b, 3);
1866 arith::ConstantOp::create(b, b.getZeroAttr(b.getI32Type()));
1867 Value zeroFp = arith::ConstantOp::create(b, b.getZeroAttr(floatTy));
1869 arith::ConstantOp::create(b, b.getI32IntegerAttr(imageH - 1));
1871 arith::ConstantOp::create(b, b.getI32IntegerAttr(imageW - 1));
1873 Value inY = arith::IndexCastOp::create(b, b.getI32Type(), y);
1874 Value inX = arith::IndexCastOp::create(b, b.getI32Type(), x);
1880 return rewriter.notifyMatchFailure(
1881 op,
"tosa.resize scale/offset/border should have compile time "
1882 "constant values.");
1885 Value yScaleN, yScaleD, xScaleN, xScaleD;
1886 yScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[0]));
1887 yScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[1]));
1888 xScaleN = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[2]));
1889 xScaleD = arith::ConstantOp::create(b, b.getI32IntegerAttr(scale[3]));
1891 Value yOffset, xOffset, yBorder, xBorder;
1892 yOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[0]));
1893 xOffset = arith::ConstantOp::create(b, b.getI32IntegerAttr(offset[1]));
1894 yBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[0]));
1895 xBorder = arith::ConstantOp::create(b, b.getI32IntegerAttr(border[1]));
1908 Value val = arith::MulIOp::create(b, in, scaleD);
1909 val = arith::AddIOp::create(b, val, offset);
1910 index = arith::FloorDivSIOp::create(b, val, scaleN);
1914 Value r = arith::RemSIOp::create(b, val, scaleN);
1915 Value rFp = arith::SIToFPOp::create(b, floatTy, r);
1916 Value scaleNfp = arith::UIToFPOp::create(b, floatTy, scaleN);
1917 delta = arith::DivFOp::create(b, rFp, scaleNfp);
1932 Value val = arith::MulIOp::create(b, in, scaleD);
1933 val = arith::AddIOp::create(b, val, offset);
1934 index = arith::DivSIOp::create(b, val, scaleN);
1935 delta = arith::MulIOp::create(b, index, scaleN);
1936 delta = arith::SubIOp::create(b, val, delta);
1939 Value ix, iy, dx, dy;
1940 if (floatingPointMode) {
1941 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1942 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1944 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1945 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1948 if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) {
1949 auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
1951 auto getNearestIndexAndClamp = [&](
Value val,
Value dval,
Value scale,
1959 if (floatingPointMode) {
1961 arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 0.5f));
1962 pred = arith::CmpFOp::create(b, arith::CmpFPredicate::OGE, dval, h);
1964 Value dvalDouble = arith::ShLIOp::create(b, dval, one);
1965 pred = arith::CmpIOp::create(b, arith::CmpIPredicate::sge,
1969 auto offset = arith::SelectOp::create(b, pred, one, zeroI32);
1970 val = arith::AddIOp::create(b, val, offset);
1972 return arith::IndexCastOp::create(b, b.getIndexType(), val);
1975 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1976 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1978 Value result = tensor::ExtractOp::create(
1979 b, input,
ValueRange{batch, iy, ix, channel});
1981 linalg::YieldOp::create(b, result);
1984 assert(op.getMode() == ResizeMode::BILINEAR);
1986 auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1));
1988 auto getClampedIdxs = [&](
Value &val0,
Value &val1,
int size,
Value in,
1991 val1 = arith::AddIOp::create(b, val0, oneVal);
1996 val0 = arith::IndexCastOp::create(b, b.getIndexType(), val0);
1997 val1 = arith::IndexCastOp::create(b, b.getIndexType(), val1);
2005 Value x0, x1, y0, y1;
2006 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
2007 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
2009 Value y0x0 = tensor::ExtractOp::create(
2010 b, input,
ValueRange{batch, y0, x0, channel});
2011 Value y0x1 = tensor::ExtractOp::create(
2012 b, input,
ValueRange{batch, y0, x1, channel});
2013 Value y1x0 = tensor::ExtractOp::create(
2014 b, input,
ValueRange{batch, y1, x0, channel});
2015 Value y1x1 = tensor::ExtractOp::create(
2016 b, input,
ValueRange{batch, y1, x1, channel});
2018 if (floatingPointMode) {
2020 arith::ConstantOp::create(b, b.getFloatAttr(floatTy, 1.0f));
2026 Value oneMinusDelta = arith::SubFOp::create(b, oneVal, delta);
2027 Value mul0 = arith::MulFOp::create(b, val0, oneMinusDelta);
2028 Value mul1 = arith::MulFOp::create(b, val1, delta);
2029 return arith::AddFOp::create(b, mul0, mul1);
2035 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
2040 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
2044 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
2045 linalg::YieldOp::create(b, result);
2048 y0x0 = arith::ExtSIOp::create(b, resultETy, y0x0);
2049 y0x1 = arith::ExtSIOp::create(b, resultETy, y0x1);
2050 y1x0 = arith::ExtSIOp::create(b, resultETy, y1x0);
2051 y1x1 = arith::ExtSIOp::create(b, resultETy, y1x1);
2054 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
2055 dx = arith::ExtSIOp::create(b, resultETy, dx);
2056 dy = arith::ExtSIOp::create(b, resultETy, dy);
2059 Value yScaleNExt = yScaleN;
2060 Value xScaleNExt = xScaleN;
2062 const int64_t scaleBitwidth =
2064 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2065 yScaleNExt = arith::ExtSIOp::create(b, resultETy, yScaleN);
2066 xScaleNExt = arith::ExtSIOp::create(b, resultETy, xScaleN);
2070 Value scale,
int inputSize,
2073 return arith::MulIOp::create(b, val0, scale);
2074 Value weight0 = arith::SubIOp::create(b, scale, weight1);
2075 Value mul0 = arith::MulIOp::create(b, val0, weight0);
2076 Value mul1 = arith::MulIOp::create(b, val1, weight1);
2077 return arith::AddIOp::create(b, mul0, mul1);
2080 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
2081 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
2083 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
2084 linalg::YieldOp::create(b, result);
2089 rewriter.replaceOp(op, resize);
2097 template <
typename SrcOp>
2102 LogicalResult matchAndRewrite(SrcOp op,
2104 rewriter.replaceOp(op, op.getOperation()->getOperands());
2109 template <
typename SrcOp>
2114 LogicalResult matchAndRewrite(SrcOp reduceOp,
2124 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2126 auto loc = op.getLoc();
2127 Value input = op.getInput1();
2128 auto inputTy = cast<ShapedType>(input.
getType());
2129 auto resultTy = cast<ShapedType>(op.getType());
2130 auto axis = op.getAxis();
2133 for (
int i = 0; i < inputTy.getRank(); i++) {
2134 if (inputTy.isDynamicDim(i)) {
2135 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2139 Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
2142 auto emptyTensor = tensor::EmptyOp::create(
2143 rewriter, loc, inputTy.getShape(),
2147 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2149 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2154 for (
unsigned int i = 0; i < inputTy.getRank(); i++) {
2156 linalg::IndexOp::create(rewriter, nestedLoc, i).getResult();
2160 arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one);
2161 index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne,
2165 indices.push_back(index);
2168 auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc,
2170 linalg::YieldOp::create(nestedBuilder, op.getLoc(),
2171 extract.getResult());
2185 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2187 auto loc = op.getLoc();
2188 auto input = op.getInput1();
2189 auto inputTy = cast<ShapedType>(input.
getType());
2190 auto inputShape = inputTy.getShape();
2191 auto resultTy = cast<ShapedType>(op.getType());
2192 auto elementTy = inputTy.getElementType();
2193 int64_t rank = inputTy.getRank();
2196 if (
failed(op.getConstantMultiples(multiples)))
2201 for (
int i = 0; i < rank; i++) {
2202 int64_t dim = multiples[i];
2203 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2204 genericShape.push_back(inputShape[i]);
2208 for (
int i = 0; i < inputTy.getRank(); i++) {
2209 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2210 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2214 auto emptyTensor = tensor::EmptyOp::create(
2215 rewriter, op.getLoc(), genericShape, elementTy, dynDims);
2219 dimExprs.reserve(rank);
2220 for (
unsigned i = 0; i < rank; ++i)
2223 auto readAffineMap =
2230 auto genericOp = linalg::GenericOp::create(
2235 linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin());
2241 op, resultTy, genericOp.getResult(0), shapeValue);
2263 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2265 auto loc = argmaxOp.getLoc();
2266 Value input = argmaxOp.getInput();
2267 auto inputTy = cast<ShapedType>(input.
getType());
2268 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2269 auto inElementTy = inputTy.getElementType();
2270 auto outElementTy = resultTy.getElementType();
2271 int axis = argmaxOp.getAxis();
2274 if (!isa<IntegerType>(outElementTy))
2277 "tosa.arg_max to linalg.* requires integer-like result type");
2280 for (
int i = 0; i < inputTy.getRank(); i++) {
2281 if (inputTy.isDynamicDim(i) && i != axis) {
2282 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2287 auto emptyTensorIdx =
2288 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2289 outElementTy, dynDims)
2291 auto fillValueIdx = arith::ConstantOp::create(
2293 auto filledTensorIdx =
2294 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValueIdx},
2299 auto emptyTensorMax =
2300 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
2303 auto fillValueMaxAttr =
2306 if (!fillValueMaxAttr)
2308 argmaxOp,
"unsupported tosa.argmax element type");
2311 arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
2312 auto filledTensorMax =
2313 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValueMax},
2320 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2321 iteratorTypes[axis] = utils::IteratorType::reduction;
2325 for (
int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2331 bool didEncounterError =
false;
2334 auto linalgOp = linalg::GenericOp::create(
2336 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2339 auto newValue = blockArgs[0];
2340 auto oldIndex = blockArgs[1];
2341 auto oldValue = blockArgs[2];
2343 Value newIndex = arith::IndexCastOp::create(
2344 rewriter, nestedLoc, oldIndex.getType(),
2345 linalg::IndexOp::create(rewriter, loc, axis));
2348 if (isa<FloatType>(inElementTy)) {
2349 if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) {
2352 predicate = arith::CmpFOp::create(rewriter, nestedLoc,
2353 arith::CmpFPredicate::OGT,
2354 newValue, oldValue);
2359 Value gt = arith::CmpFOp::create(rewriter, nestedLoc,
2360 arith::CmpFPredicate::UGT,
2361 newValue, oldValue);
2362 Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc,
2363 arith::CmpFPredicate::ORD,
2364 oldValue, oldValue);
2365 predicate = arith::AndIOp::create(
2366 rewriter, nestedLoc, rewriter.
getI1Type(), gt, oldNonNaN);
2368 }
else if (isa<IntegerType>(inElementTy)) {
2369 predicate = arith::CmpIOp::create(rewriter, nestedLoc,
2370 arith::CmpIPredicate::sgt,
2371 newValue, oldValue);
2373 didEncounterError =
true;
2377 auto resultMax = arith::SelectOp::create(
2378 rewriter, nestedLoc, predicate, newValue, oldValue);
2379 auto resultIndex = arith::SelectOp::create(
2380 rewriter, nestedLoc, predicate, newIndex, oldIndex);
2381 linalg::YieldOp::create(nestedBuilder, nestedLoc,
2385 if (didEncounterError)
2387 argmaxOp,
"unsupported tosa.argmax element type");
2389 rewriter.
replaceOp(argmaxOp, linalgOp.getResult(0));
2398 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2400 auto input = adaptor.getOperands()[0];
2401 auto indices = adaptor.getOperands()[1];
2403 auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2404 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2405 if (!valuesTy || !resultTy)
2408 auto dynamicDims = inferDynamicDimsForGather(
2409 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2411 auto resultElementTy = resultTy.getElementType();
2413 auto loc = op.getLoc();
2415 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2416 resultElementTy, dynamicDims)
2421 resultTy.getRank(), 0,
2422 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2426 auto genericOp = linalg::GenericOp::create(
2431 auto indexValue = args[0];
2432 auto index0 = linalg::IndexOp::create(rewriter, loc, 0);
2433 Value index1 = arith::IndexCastOp::create(
2435 auto index2 = linalg::IndexOp::create(rewriter, loc, 2);
2436 Value extract = tensor::ExtractOp::create(
2437 rewriter, loc, input,
ValueRange{index0, index1, index2});
2438 linalg::YieldOp::create(rewriter, loc, extract);
2440 rewriter.
replaceOp(op, genericOp.getResult(0));
2450 auto addDynamicDimension = [&](
Value source, int64_t dim) {
2452 if (
auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2453 results.push_back(dimValue);
2456 addDynamicDimension(values, 0);
2457 addDynamicDimension(indices, 1);
2458 addDynamicDimension(values, 2);
2470 LogicalResult matchAndRewrite(tosa::TableOp op,
2472 auto loc = op.getLoc();
2473 Value input = op.getInput1();
2475 auto inputTy = cast<ShapedType>(input.
getType());
2476 auto tableTy = cast<ShapedType>(
table.getType());
2477 auto resultTy = cast<ShapedType>(op.getType());
2479 auto inputElementTy = inputTy.getElementType();
2480 auto tableElementTy = tableTy.getElementType();
2481 auto resultElementTy = resultTy.getElementType();
2484 for (
int i = 0; i < resultTy.getRank(); ++i) {
2485 if (inputTy.isDynamicDim(i)) {
2487 tensor::DimOp::create(rewriter, loc, op.getOperand(0), i));
2492 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2493 resultElementTy, dynDims)
2500 auto genericOp = linalg::GenericOp::create(
2503 rewriter.
replaceOp(op, genericOp.getResult(0));
2508 &genericOp.getRegion(), genericOp.getRegion().end(),
2509 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2513 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2514 resultElementTy.isInteger(8)) {
2515 Value index = arith::IndexCastOp::create(
2518 index = arith::AddIOp::create(rewriter, loc, rewriter.
getIndexType(),
2522 linalg::YieldOp::create(rewriter, loc, extract);
2526 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2527 resultElementTy.isInteger(32)) {
2528 Value extend = arith::ExtSIOp::create(
2529 rewriter, loc, rewriter.
getI32Type(), inputValue);
2531 auto offset = arith::ConstantOp::create(
2533 auto seven = arith::ConstantOp::create(rewriter, loc,
2535 auto one = arith::ConstantOp::create(rewriter, loc,
2537 auto b1111111 = arith::ConstantOp::create(
2544 auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset);
2545 Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven);
2547 arith::AndIOp::create(rewriter, loc, extendAdd, b1111111);
2552 Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one);
2554 index = arith::IndexCastOp::create(rewriter, loc,
2556 indexPlusOne = arith::IndexCastOp::create(
2561 Value next = tensor::ExtractOp::create(rewriter, loc,
table,
2565 arith::ExtSIOp::create(rewriter, loc, rewriter.
getI32Type(), base);
2567 arith::ExtSIOp::create(rewriter, loc, rewriter.
getI32Type(), next);
2571 Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven);
2572 Value diff = arith::SubIOp::create(rewriter, loc, next, base);
2573 Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction);
2575 arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled);
2577 linalg::YieldOp::create(rewriter, loc, result);
2584 op,
"unable to create body for tosa.table op");
2591 static bool isRankedTensor(
Type type) {
return isa<RankedTensorType>(type); }
2599 auto divBy2 = builder.
createOrFold<arith::DivUIOp>(loc, value, two);
2600 auto plusOne = builder.
createOrFold<arith::AddIOp>(loc, divBy2, one);
2604 static RankedTensorType
2612 dims[2] = halfPlusOne(builder, loc, dims[2]);
2617 auto elementType = cast<RankedTensorType>(input.
getType()).getElementType();
2622 RankedTensorType type,
2625 tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
2626 auto fillValueAttr = rewriter.
getZeroAttr(type.getElementType());
2627 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
2629 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValue},
2632 return filledTensor;
2636 FloatType type,
Value value) {
2637 auto integerVal = arith::IndexCastUIOp::create(
2639 type.getIntOrFloatBitWidth() > 32 ? builder.
getI64Type()
2643 return arith::UIToFPOp::create(builder, loc, type, integerVal);
2647 FloatType type, int64_t index) {
2648 auto indexVal = linalg::IndexOp::create(builder, loc, index);
2649 return castIndexToFloat(builder, loc, type, indexVal);
2652 template <
typename... Args>
2658 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2660 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2661 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2663 "only supports ranked tensors");
2666 auto loc = rfft2d.getLoc();
2667 auto input = rfft2d.getInputReal();
2669 dyn_cast<FloatType>(cast<ShapedType>(input.
getType()).getElementType());
2672 "only supports float element types");
2676 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2680 utils::IteratorType::parallel, utils::IteratorType::parallel,
2681 utils::IteratorType::parallel, utils::IteratorType::reduction,
2682 utils::IteratorType::reduction};
2687 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2688 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2693 affineDimsExpr(rewriter, 0, 1, 2),
2694 affineDimsExpr(rewriter, 0, 1, 2)},
2698 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input, 1);
2699 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input, 2);
2702 auto twoPiAttr = rewriter.
getFloatAttr(elementType, 6.283185307179586);
2703 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2704 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2705 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2708 Value valReal = args[0];
2709 Value sumReal = args[1];
2710 Value sumImag = args[2];
2713 Value oy = linalg::IndexOp::create(builder, loc, 1);
2714 Value ox = linalg::IndexOp::create(builder, loc, 2);
2715 Value iy = linalg::IndexOp::create(builder, loc, 3);
2716 Value ix = linalg::IndexOp::create(builder, loc, 4);
2721 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2722 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2724 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2725 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2727 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2728 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2730 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2731 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2732 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2733 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2737 auto cosAngle = math::CosOp::create(builder, loc, angle);
2738 auto sinAngle = math::SinOp::create(builder, loc, angle);
2739 auto realComponent =
2740 arith::MulFOp::create(builder, loc, valReal, cosAngle);
2741 auto imagComponent =
2742 arith::MulFOp::create(builder, loc, valReal, sinAngle);
2747 arith::AddFOp::create(builder, loc, sumReal, realComponent);
2749 arith::SubFOp::create(builder, loc, sumImag, imagComponent);
2751 linalg::YieldOp::create(builder, loc,
ValueRange{outReal, outImag});
2755 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2756 indexingMaps, iteratorTypes, buildBody);
2765 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2767 if (!llvm::all_of(fft2d->getOperandTypes(),
2768 RFFT2dConverter::isRankedTensor) ||
2769 !llvm::all_of(fft2d->getResultTypes(),
2770 RFFT2dConverter::isRankedTensor)) {
2775 Value input_real = fft2d.getInputReal();
2776 Value input_imag = fft2d.getInputImag();
2777 BoolAttr inverse = fft2d.getInverseAttr();
2779 auto real_el_ty = cast<FloatType>(
2780 cast<ShapedType>(input_real.
getType()).getElementType());
2781 [[maybe_unused]]
auto imag_el_ty = cast<FloatType>(
2782 cast<ShapedType>(input_imag.
getType()).getElementType());
2784 assert(real_el_ty == imag_el_ty);
2799 utils::IteratorType::parallel, utils::IteratorType::parallel,
2800 utils::IteratorType::parallel, utils::IteratorType::reduction,
2801 utils::IteratorType::reduction};
2806 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2808 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2813 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2814 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2815 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2816 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2820 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 1);
2821 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 2);
2824 auto twoPiAttr = rewriter.
getFloatAttr(real_el_ty, 6.283185307179586);
2825 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2827 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2829 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2832 Value valReal = args[0];
2833 Value valImag = args[1];
2834 Value sumReal = args[2];
2835 Value sumImag = args[3];
2838 Value oy = linalg::IndexOp::create(builder, loc, 1);
2839 Value ox = linalg::IndexOp::create(builder, loc, 2);
2840 Value iy = linalg::IndexOp::create(builder, loc, 3);
2841 Value ix = linalg::IndexOp::create(builder, loc, 4);
2845 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2846 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2848 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2849 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2852 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2854 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2856 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2857 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2859 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2860 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2863 angle = arith::MulFOp::create(
2864 builder, loc, angle,
2865 arith::ConstantOp::create(rewriter, loc,
2871 auto cosAngle = math::CosOp::create(builder, loc, angle);
2872 auto sinAngle = math::SinOp::create(builder, loc, angle);
2874 auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle);
2875 auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle);
2876 auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin);
2878 auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle);
2879 auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle);
2881 auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin);
2886 arith::AddFOp::create(builder, loc, sumReal, realComponent);
2888 arith::AddFOp::create(builder, loc, sumImag, imagComponent);
2890 linalg::YieldOp::create(builder, loc,
ValueRange{outReal, outImag});
2894 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2895 indexingMaps, iteratorTypes, buildBody);
2916 PointwiseConverter<tosa::AddOp>,
2917 PointwiseConverter<tosa::SubOp>,
2918 PointwiseConverter<tosa::MulOp>,
2919 PointwiseConverter<tosa::IntDivOp>,
2920 PointwiseConverter<tosa::NegateOp>,
2921 PointwiseConverter<tosa::PowOp>,
2922 PointwiseConverter<tosa::ReciprocalOp>,
2923 PointwiseConverter<tosa::RsqrtOp>,
2924 PointwiseConverter<tosa::LogOp>,
2925 PointwiseConverter<tosa::ExpOp>,
2926 PointwiseConverter<tosa::AbsOp>,
2927 PointwiseConverter<tosa::SinOp>,
2928 PointwiseConverter<tosa::CosOp>,
2929 PointwiseConverter<tosa::TanhOp>,
2930 PointwiseConverter<tosa::ErfOp>,
2931 PointwiseConverter<tosa::BitwiseAndOp>,
2932 PointwiseConverter<tosa::BitwiseOrOp>,
2933 PointwiseConverter<tosa::BitwiseNotOp>,
2934 PointwiseConverter<tosa::BitwiseXorOp>,
2935 PointwiseConverter<tosa::LogicalAndOp>,
2936 PointwiseConverter<tosa::LogicalNotOp>,
2937 PointwiseConverter<tosa::LogicalOrOp>,
2938 PointwiseConverter<tosa::LogicalXorOp>,
2939 PointwiseConverter<tosa::CastOp>,
2940 PointwiseConverter<tosa::LogicalLeftShiftOp>,
2941 PointwiseConverter<tosa::LogicalRightShiftOp>,
2942 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2943 PointwiseConverter<tosa::ClzOp>,
2944 PointwiseConverter<tosa::SelectOp>,
2945 PointwiseConverter<tosa::GreaterOp>,
2946 PointwiseConverter<tosa::GreaterEqualOp>,
2947 PointwiseConverter<tosa::EqualOp>,
2948 PointwiseConverter<tosa::MaximumOp>,
2949 PointwiseConverter<tosa::MinimumOp>,
2950 PointwiseConverter<tosa::CeilOp>,
2951 PointwiseConverter<tosa::FloorOp>,
2952 PointwiseConverter<tosa::ClampOp>,
2953 PointwiseConverter<tosa::SigmoidOp>
2954 >(converter,
patterns->getContext());
2957 IdentityNConverter<tosa::IdentityOp>,
2958 ReduceConverter<tosa::ReduceAllOp>,
2959 ReduceConverter<tosa::ReduceAnyOp>,
2960 ReduceConverter<tosa::ReduceMinOp>,
2961 ReduceConverter<tosa::ReduceMaxOp>,
2962 ReduceConverter<tosa::ReduceSumOp>,
2963 ReduceConverter<tosa::ReduceProductOp>,
2971 TileConverter>(
patterns->getContext());
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static bool operandsAndResultsRanked(Operation *operation)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
BoolAttr getBoolAttr(bool value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...