31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
35 #include <type_traits>
62 template <
typename OpTy>
70 auto nanMode = op.getNanMode();
71 if (nanMode ==
"PROPAGATE")
76 op.getLoc(), arith::CmpFPredicate::UNO, lhs, lhs);
78 op.getLoc(), arith::CmpFPredicate::UNO, rhs, rhs);
80 rewriter.
create<arith::SelectOp>(op.getLoc(), lhsIsNaN, rhs, result);
81 return rewriter.
create<arith::SelectOp>(op.getLoc(), rhsIsNaN, lhs,
86 static arith::ConstantOp
89 auto castedN =
static_cast<T
>(zp);
90 return rewriter.
create<arith::ConstantOp>(
102 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
103 return rewriter.
create<math::AbsFOp>(loc, resultTypes, args);
105 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
106 auto zero = rewriter.
create<arith::ConstantOp>(
108 auto neg = rewriter.
create<arith::SubIOp>(loc, zero, args[0]);
109 return rewriter.
create<arith::MaxSIOp>(loc, args[0], neg);
113 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
114 return rewriter.
create<arith::AddFOp>(loc, resultTypes, args);
116 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
117 return rewriter.
create<arith::AddIOp>(loc, resultTypes, args);
120 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
121 return rewriter.
create<arith::SubFOp>(loc, resultTypes, args);
123 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
124 return rewriter.
create<arith::SubIOp>(loc, resultTypes, args);
127 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
128 return rewriter.
create<arith::DivSIOp>(loc, resultTypes, args);
131 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
134 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, args[0]);
138 if (isa<tosa::MulOp>(op)) {
139 auto shiftVal = cast<tosa::MulOp>(op).getShift();
146 int32_t shift = shiftElem.
getValues<IntegerAttr>()[0].getInt();
148 if (isa<FloatType>(elementTy)) {
151 "Cannot have shift value for float");
154 return rewriter.
create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
157 if (isa<IntegerType>(elementTy)) {
163 rewriter.
create<arith::ConstantIntOp>(loc, shift, 8);
170 auto result = rewriter.
create<tosa::ApplyScaleOp>(
174 if (elementTy.isInteger(32))
177 return rewriter.
create<arith::TruncIOp>(loc, elementTy, result);
182 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
185 a = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], a);
187 b = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], b);
189 return rewriter.
create<arith::MulIOp>(loc, resultTypes, a, b);
194 if (isa<tosa::NegateOp>(op)) {
195 auto negate = cast<tosa::NegateOp>(op);
197 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
198 if (failed(maybeInZp)) {
200 op,
"input1 zero point cannot be statically determined");
204 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
205 if (failed(maybeOutZp)) {
207 op,
"output zero point cannot be statically determined");
211 int64_t inZp = *maybeInZp;
212 int64_t outZp = *maybeOutZp;
214 if (isa<FloatType>(elementTy))
215 return rewriter.
create<arith::NegFOp>(loc, resultTypes, args[0]);
217 if (isa<IntegerType>(elementTy)) {
218 if (!inZp && !outZp) {
219 auto constant = rewriter.
create<arith::ConstantOp>(
221 return rewriter.
create<arith::SubIOp>(loc, resultTypes, constant,
226 const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
227 const int64_t zpAdd = inZp + outZp;
228 const int64_t maxValue =
229 APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
235 int intermediateBitWidth = 64;
236 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
237 intermediateBitWidth = 16;
238 }
else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
239 intermediateBitWidth = 32;
240 }
else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
241 intermediateBitWidth = 48;
245 Value zpAddValue = rewriter.
create<arith::ConstantOp>(
251 rewriter.
create<arith::ExtSIOp>(loc, intermediateType, args[0]);
252 auto sub = rewriter.
create<arith::SubIOp>(loc, zpAddValue, ext);
256 loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
259 loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
264 return rewriter.
create<arith::TruncIOp>(loc, elementTy,
clamp);
269 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
270 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
273 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
274 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
277 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
279 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
280 auto allOnes = rewriter.
create<arith::ConstantOp>(loc, allOnesAttr);
281 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
285 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
286 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
289 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
290 return rewriter.
create<arith::ShLIOp>(loc, resultTypes, args);
293 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
294 return rewriter.
create<arith::ShRUIOp>(loc, resultTypes, args);
297 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
298 auto result = rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args);
299 auto round = cast<BoolAttr>(op->
getAttr(
"round")).getValue();
313 auto shiftValueGreaterThanZero = rewriter.
create<arith::CmpIOp>(
314 loc, arith::CmpIPredicate::sgt, args[1], zero);
318 rewriter.
create<arith::SubIOp>(loc, resultTypes, args[1], one);
320 rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
323 rewriter.
create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
325 rewriter.
create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
327 auto shouldRound = rewriter.
create<arith::AndIOp>(
328 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
330 rewriter.
create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
331 return rewriter.
create<arith::AddIOp>(loc, resultTypes, result, extended);
335 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
336 return rewriter.
create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
340 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
341 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
344 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
345 auto one = rewriter.
create<arith::ConstantOp>(
347 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], one);
351 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
352 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
355 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
356 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
359 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
360 return rewriter.
create<mlir::math::PowFOp>(loc, resultTypes, args);
363 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
364 return rewriter.
create<mlir::math::RsqrtOp>(loc, resultTypes, args);
367 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
368 return rewriter.
create<mlir::math::LogOp>(loc, resultTypes, args);
371 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
372 return rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, args);
375 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
376 return rewriter.
create<mlir::math::SinOp>(loc, resultTypes, args);
379 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
380 return rewriter.
create<mlir::math::CosOp>(loc, resultTypes, args);
383 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
384 return rewriter.
create<mlir::math::TanhOp>(loc, resultTypes, args);
387 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
388 return rewriter.
create<mlir::math::ErfOp>(loc, resultTypes, args);
391 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
392 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
395 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
396 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
400 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
401 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
404 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
405 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
409 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
410 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
413 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
414 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
418 if (isa<tosa::SelectOp>(op)) {
420 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
421 return rewriter.
create<arith::SelectOp>(loc, args[0], args[1], args[2]);
425 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
426 auto max = rewriter.
create<arith::MaximumFOp>(loc, args[0], args[1]);
428 rewriter, args[0], args[1],
max);
431 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
432 return rewriter.
create<arith::MaxSIOp>(loc, args[0], args[1]);
436 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
437 auto min = rewriter.
create<arith::MinimumFOp>(loc, args[0], args[1]);
439 rewriter, args[0], args[1],
min);
442 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
443 return rewriter.
create<arith::MinSIOp>(loc, args[0], args[1]);
447 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
448 return rewriter.
create<math::CeilOp>(loc, resultTypes, args);
451 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
452 return rewriter.
create<math::FloorOp>(loc, resultTypes, args);
455 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
456 bool losesInfo =
false;
457 APFloat minApf = cast<FloatAttr>(op->
getAttr(
"min_val")).getValue();
458 APFloat maxApf = cast<FloatAttr>(op->
getAttr(
"max_val")).getValue();
459 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
460 APFloat::rmNearestTiesToEven, &losesInfo);
461 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
462 APFloat::rmNearestTiesToEven, &losesInfo);
463 auto min = rewriter.
create<arith::ConstantOp>(
464 loc, elementTy, rewriter.
getFloatAttr(elementTy, minApf));
465 auto max = rewriter.
create<arith::ConstantOp>(
466 loc, elementTy, rewriter.
getFloatAttr(elementTy, maxApf));
469 auto clampOp = llvm::cast<tosa::ClampOp>(op);
470 const auto nanMode = clampOp.getNanMode();
473 if (!isa<FloatType>(elementTy))
478 if (nanMode ==
"PROPAGATE")
493 op->
getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
496 return rewriter.
create<arith::SelectOp>(op->
getLoc(), isNaN,
min, result);
499 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
500 auto intTy = cast<IntegerType>(elementTy);
502 cast<IntegerAttr>(op->
getAttr(
"min_val")).getValue().getSExtValue();
504 cast<IntegerAttr>(op->
getAttr(
"max_val")).getValue().getSExtValue();
508 if (intTy.isUnsignedInteger()) {
509 minRepresentable = 0;
510 if (intTy.getIntOrFloatBitWidth() <= 63) {
512 (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
515 }
else if (intTy.getIntOrFloatBitWidth() <= 64) {
517 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
519 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
529 auto minVal = rewriter.
create<arith::ConstantIntOp>(
530 loc,
min, intTy.getIntOrFloatBitWidth());
531 auto maxVal = rewriter.
create<arith::ConstantIntOp>(
532 loc,
max, intTy.getIntOrFloatBitWidth());
534 intTy.isUnsignedInteger());
538 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
541 auto negate = rewriter.
create<arith::NegFOp>(loc, resultTypes, args[0]);
542 auto exp = rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, negate);
543 auto added = rewriter.
create<arith::AddFOp>(loc, resultTypes, exp, one);
544 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, added);
548 if (isa<tosa::CastOp>(op)) {
549 Type srcTy = elementTy;
550 Type dstTy = resultTypes.front();
562 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
563 return rewriter.
create<arith::ExtFOp>(loc, resultTypes, args,
566 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
567 return rewriter.
create<arith::TruncFOp>(loc, resultTypes, args,
571 if (srcTy.
isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
572 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes, args,
575 if (srcTy.
isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
576 return rewriter.
create<arith::ExtUIOp>(loc, resultTypes, args,
582 auto unrealizedCast =
584 .
create<UnrealizedConversionCastOp>(
588 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes[0],
593 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
594 return rewriter.
create<arith::SIToFPOp>(loc, resultTypes, args,
598 if (isa<FloatType>(srcTy) && dstTy.
isInteger(1)) {
601 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
605 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
606 auto rounded = rewriter.
create<math::RoundEvenOp>(loc, args[0]);
608 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
612 APFloat::semanticsMaxExponent(fltSemantics)) {
615 auto conv = rewriter.
create<arith::FPToSIOp>(loc, dstTy, rounded);
616 auto posInf = rewriter.
create<arith::ConstantOp>(
618 APFloat::getInf(fltSemantics)));
619 auto negInf = rewriter.
create<arith::ConstantOp>(
622 APFloat::getInf(fltSemantics,
true)));
623 auto overflow = rewriter.
create<arith::CmpFOp>(
624 loc, arith::CmpFPredicate::UEQ, rounded, posInf);
625 auto underflow = rewriter.
create<arith::CmpFOp>(
626 loc, arith::CmpFPredicate::UEQ, rounded, negInf);
627 auto intMin = rewriter.
create<arith::ConstantOp>(
630 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
631 auto intMax = rewriter.
create<arith::ConstantOp>(
634 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
636 rewriter.
create<arith::SelectOp>(loc, overflow, intMax, conv);
637 return rewriter.
create<arith::SelectOp>(loc, underflow, intMin,
641 auto intMinFP = rewriter.
create<arith::ConstantOp>(
648 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
654 auto intMaxFP = rewriter.
create<arith::ConstantOp>(
662 return rewriter.
create<arith::FPToSIOp>(loc, dstTy, clamped);
669 auto intMaxPlusOneFP = rewriter.
create<arith::ConstantOp>(
677 auto intMax = rewriter.
create<arith::ConstantOp>(
682 rewriter.
create<arith::MaximumFOp>(loc, rounded, intMinFP);
684 rewriter.
create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
685 auto overflow = rewriter.
create<arith::CmpFOp>(
686 loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
687 return rewriter.
create<arith::SelectOp>(loc, overflow, intMax,
693 if (isa<IntegerType>(srcTy) && dstTy.
isInteger(1)) {
694 Value zero = rewriter.
create<arith::ConstantIntOp>(
696 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
700 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
701 return rewriter.
create<arith::ExtSIOp>(loc, resultTypes, args,
704 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
705 return rewriter.
create<arith::TruncIOp>(loc, dstTy, args[0]);
710 op,
"unhandled op for linalg body calculation for elementwise op");
721 auto [it, inserted] = indexPool.try_emplace(index);
730 auto indexValue =
createIndex(rewriter, loc, indexPool, index);
731 return rewriter.
create<tensor::DimOp>(loc, tensor, indexValue).getResult();
737 auto shapedType = dyn_cast<ShapedType>(tensor.
getType());
738 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
739 assert(index >= 0 && index < shapedType.getRank() &&
"index out of bounds");
740 if (shapedType.isDynamicDim(index))
741 return getTensorDim(rewriter, loc, indexPool, tensor, index);
742 return rewriter.
getIndexAttr(shapedType.getDimSize(index));
746 auto isRanked = [](
Value value) {
747 return isa<RankedTensorType>(value.getType());
749 return llvm::all_of(operation->
getOperands(), isRanked) &&
750 llvm::all_of(operation->
getResults(), isRanked);
763 static std::pair<OpFoldResult, Value>
769 for (
auto operand : operands) {
770 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
771 if (!ShapedType::isDynamic(size) && size > 1)
776 auto operandsWithDynamicDim =
777 llvm::filter_to_vector(operands, [&](
Value operand) {
778 return cast<RankedTensorType>(operand.
getType()).isDynamicDim(dim);
782 if (operandsWithDynamicDim.empty())
789 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
790 if (operandsWithDynamicDim.size() == 1)
791 return {targetSize, operandsWithDynamicDim[0]};
794 for (
size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
796 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
797 targetSize = rewriter.
create<arith::MaxUIOp>(loc, targetSize, nextSize);
799 return {targetSize,
nullptr};
807 assert(!operands.empty());
808 auto rank = cast<RankedTensorType>(operands.front().
getType()).getRank();
811 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
812 auto [targetSize, masterOperand] =
814 targetShape.push_back(targetSize);
815 masterOperands.push_back(masterOperand);
817 return {targetShape, masterOperands};
823 Value masterOperand) {
825 auto rankedTensorType = cast<RankedTensorType>(operand.
getType());
826 if (!rankedTensorType.isDynamicDim(dim))
833 if (operand == masterOperand)
837 auto rank = rankedTensorType.getRank();
839 for (
auto index : llvm::seq<int64_t>(0, rank)) {
842 affineExprs.push_back(affineExpr);
844 auto broadcastAffineMap =
850 auto one =
createIndex(rewriter, loc, indexPool, 1);
851 auto runtimeSize =
getTensorDim(rewriter, loc, indexPool, operand, dim);
852 auto broadcastNecessary = rewriter.
create<arith::CmpIOp>(
853 loc, arith::CmpIPredicate::eq, runtimeSize, one);
863 for (
auto index : llvm::seq<int64_t>(0, rank)) {
864 auto size = index == dim ? targetSize
867 outputTensorShape.push_back(size);
869 Value outputTensor = opBuilder.
create<tensor::EmptyOp>(
870 loc, outputTensorShape, rankedTensorType.getElementType());
875 .
create<linalg::GenericOp>(
876 loc, outputTensor.
getType(), operand, outputTensor, affineMaps,
880 opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
885 auto castResultTensor = rewriter.
createOrFold<tensor::CastOp>(
886 loc, operand.
getType(), resultTensor);
889 opBuilder.create<scf::YieldOp>(loc, castResultTensor);
894 opBuilder.
create<scf::YieldOp>(loc, operand);
898 auto ifOp = rewriter.
create<scf::IfOp>(loc, broadcastNecessary,
899 emitThenRegion, emitElseRegion);
907 int64_t rank = cast<RankedTensorType>(operand.
getType()).getRank();
908 assert((int64_t)targetShape.size() == rank);
909 assert((int64_t)masterOperands.size() == rank);
910 for (
auto index : llvm::seq<int64_t>(0, rank))
913 targetShape[index], masterOperands[index]);
923 if (operands.size() == 1)
927 return llvm::map_to_vector(operands, [&](
Value operand) {
929 targetShape, masterOperands);
939 auto resultType = cast_or_null<RankedTensorType>(
944 Value outputTensor = rewriter.
create<tensor::EmptyOp>(
945 loc, targetShape, resultType.getElementType());
950 auto rank = resultType.getRank();
951 auto affineMaps = llvm::map_to_vector(operands, [&](
Value operand) {
952 auto shape = cast<ShapedType>(operand.
getType()).getShape();
958 bool requiresBroadcast =
959 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
960 auto affineExpr = requiresBroadcast
963 affineExprs.push_back(affineExpr);
970 bool encounteredError =
false;
971 auto linalgOp = rewriter.
create<linalg::GenericOp>(
972 loc, outputTensor.
getType(), operands, outputTensor, affineMaps,
977 {resultType.getElementType()}, rewriter);
979 encounteredError =
true;
982 opBuilder.create<linalg::YieldOp>(loc, opResult);
984 if (encounteredError)
986 operation,
"unable to create linalg.generic body for elementwise op");
989 auto castResult = rewriter.
createOrFold<tensor::CastOp>(
990 loc, resultType, linalgOp->getResult(0));
991 rewriter.
replaceOp(operation, castResult);
998 if (isa<tosa::MulOp>(operation))
999 return operands.take_front(2);
1001 if (isa<tosa::NegateOp>(operation))
1002 return operands.take_front(1);
1006 static LogicalResult
1012 assert(operation->
getNumResults() == 1 &&
"elementwise op expects 1 result");
1014 "elementwise op expects at least 1 operand");
1017 "Unranked tensors not supported");
1021 auto loc = operation->
getLoc();
1023 auto [targetShape, masterOperands] =
1025 auto broadcastOperands =
1027 targetShape, masterOperands);
1029 targetShape, converter);
1036 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1039 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1042 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1045 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1048 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1050 elementTy, APFloat::getLargest(
1051 cast<FloatType>(elementTy).getFloatSemantics(),
false));
1053 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1057 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1059 elementTy, APFloat::getLargest(
1060 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1062 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1066 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1069 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1072 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1074 elementTy, APFloat::getLargest(
1075 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1077 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1091 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1092 return rewriter.
create<arith::AddFOp>(loc, args);
1095 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1096 return rewriter.
create<arith::AddIOp>(loc, args);
1099 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1100 return rewriter.
create<arith::MulFOp>(loc, args);
1103 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1104 return rewriter.
create<arith::MulIOp>(loc, args);
1107 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1108 return rewriter.
create<arith::MinimumFOp>(loc, args[0], args[1]);
1111 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1112 return rewriter.
create<arith::MinSIOp>(loc, args[0], args[1]);
1115 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1116 return rewriter.
create<arith::MaximumFOp>(loc, args[0], args[1]);
1119 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1120 return rewriter.
create<arith::MaxSIOp>(loc, args[0], args[1]);
1123 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1124 return rewriter.
create<arith::AndIOp>(loc, args);
1126 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1127 return rewriter.
create<arith::OrIOp>(loc, args);
1135 template <
typename OpTy>
1138 auto loc = op->getLoc();
1139 auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1140 auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1141 if (!inputTy || !resultTy)
1144 auto elementTy = resultTy.getElementType();
1145 Value input = op->getOperand(0);
1149 for (
unsigned i = 0; i < inputTy.getRank(); i++) {
1151 reduceShape.push_back(inputTy.getDimSize(i));
1152 if (inputTy.isDynamicDim(i))
1153 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1158 inputs.push_back(input);
1163 .
create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1170 op,
"No initial value found for reduction operation");
1172 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
1173 auto filledTensor = rewriter
1177 outputs.push_back(filledTensor);
1179 bool isNanIgnoreMode =
false;
1180 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1181 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1183 if (isa<FloatType>(elementTy) && op.getNanMode() ==
"IGNORE") {
1184 isNanIgnoreMode =
true;
1190 auto trueValue = rewriter.
create<arith::ConstantOp>(loc, trueAttr);
1191 auto emptyBoolTensor =
1193 .
create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(),
1196 auto allResultsNaNTensor =
1210 inputs.push_back(input);
1211 outputs.push_back(allResultsNaNTensor);
1215 bool didEncounterError =
false;
1216 linalg::LinalgOp linalgOp = rewriter.
create<linalg::ReduceOp>(
1217 loc, inputs, outputs, axis,
1219 std::array<Value, 2> binaryArgs{
1220 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1222 op, binaryArgs, elementTy, rewriter);
1224 didEncounterError =
true;
1227 if (isNanIgnoreMode) {
1228 auto inputValue = blockArgs[0];
1229 auto initialValue = blockArgs[2];
1230 auto oldAllResultsNanFlagValue = blockArgs[3];
1233 Value isNaN = nestedBuilder.create<arith::CmpFOp>(
1234 op->getLoc(), arith::CmpFPredicate::UNO, inputValue, inputValue);
1236 auto selectOp = nestedBuilder.create<arith::SelectOp>(
1237 op->getLoc(), isNaN, initialValue, result);
1240 auto newAllResultsNanFlagValue = nestedBuilder.create<arith::AndIOp>(
1241 op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1242 resultsToYield.push_back(selectOp);
1243 resultsToYield.push_back(newAllResultsNanFlagValue);
1245 resultsToYield.push_back(result);
1247 nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield);
1250 if (!didEncounterError)
1252 op,
"unable to create linalg.generic body for reduce op");
1254 if (isNanIgnoreMode) {
1263 APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(),
false));
1264 auto nanValue = rewriter.
create<arith::ConstantOp>(loc, nanValueAttr);
1265 auto emptyNanTensor =
1267 .
create<tensor::EmptyOp>(loc, reduceShape,
1268 resultTy.getElementType(), dynDims)
1270 auto nanFilledTensor =
1278 auto finalEmptyTensor =
1280 .
create<tensor::EmptyOp>(loc, reduceShape,
1281 resultTy.getElementType(), dynDims)
1287 ins.push_back(linalgOp->getOpResult(1));
1288 ins.push_back(nanFilledTensor);
1289 ins.push_back(linalgOp->getResult(0));
1290 outs.push_back(finalEmptyTensor);
1292 rewriter.
create<linalg::SelectOp>(op->getLoc(), ins, outs);
1293 linalgOp = linalgSelect;
1297 uint64_t expandInputRank =
1298 cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
1299 reassociationMap.resize(expandInputRank);
1301 for (uint64_t i = 0; i < expandInputRank; i++) {
1302 int32_t dimToPush = i > axis ? i + 1 : i;
1306 if (expandInputRank != 0) {
1307 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1308 reassociationMap[expandedDim].push_back(
1317 op, resultTy, linalgOp->getResults()[0], reassociationMap);
1323 template <
typename SrcOp>
1330 matchAndRewrite(SrcOp op, OpAdaptor operands,
1333 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1341 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1343 auto loc = op.getLoc();
1344 auto input = op.getInput();
1345 auto inputTy = cast<ShapedType>(op.getInput().getType());
1346 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1347 unsigned rank = inputTy.getRank();
1350 if (op.getRoundingMode() ==
"INEXACT_ROUND")
1351 return rewriter.notifyMatchFailure(
1352 op,
"tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1353 "currently supported");
1354 if (op.getRoundingMode() ==
"DOUBLE_ROUND" && !op.getScale32())
1355 return rewriter.notifyMatchFailure(
1356 op,
"tosa.rescale requires scale32 for double_round to be true");
1358 if (!isa<IntegerType>(inputTy.getElementType()))
1359 return rewriter.notifyMatchFailure(op,
"only support integer type");
1362 for (
int i = 0; i < outputTy.getRank(); i++) {
1363 if (outputTy.isDynamicDim(i)) {
1364 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1371 return rewriter.notifyMatchFailure(
1372 op,
"tosa.rescale requires constant shift input values");
1376 return rewriter.notifyMatchFailure(
1377 op,
"tosa.rescale requires constant multiplier input values");
1380 llvm::to_vector(shiftElems.
getValues<int8_t>());
1383 llvm::map_range(multiplierElems.
getValues<IntegerAttr>(),
1384 [](IntegerAttr attr) -> int32_t {
1385 return static_cast<int32_t>(attr.getInt());
1389 for (
int i = 0, s = multiplierValues.size(); i < s; i++) {
1390 if (shiftValues[i] > 63) {
1392 multiplierValues[i] = 0;
1400 op.getRoundingMode() ==
"DOUBLE_ROUND" &&
1401 llvm::any_of(shiftValues, [](int32_t v) {
return v > 31; });
1402 StringAttr roundingMode = doubleRound
1403 ? rewriter.getStringAttr(
"DOUBLE_ROUND")
1404 : rewriter.getStringAttr(
"SINGLE_ROUND");
1407 rewriter.getMultiDimIdentityMap(rank)};
1412 Value multiplierConstant;
1413 int64_t multiplierArg = 0;
1414 if (multiplierValues.size() == 1) {
1415 multiplierConstant = rewriter.create<arith::ConstantOp>(
1416 loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1419 rewriter.getAffineDimExpr(rank - 1)};
1420 auto multiplierType =
1422 rewriter.getI32Type());
1423 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1428 rewriter.getContext()));
1430 multiplierArg = indexingMaps.size() - 1;
1435 Value shiftConstant;
1436 int64_t shiftArg = 0;
1437 if (shiftValues.size() == 1) {
1438 shiftConstant = rewriter.create<arith::ConstantOp>(
1439 loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1442 rewriter.getAffineDimExpr(rank - 1)};
1445 rewriter.getIntegerType(8));
1446 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1450 rewriter.getContext()));
1451 shiftArg = indexingMaps.size() - 1;
1455 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1458 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1459 loc, outputTy.getShape(), outputTy.getElementType(),
1462 auto linalgOp = rewriter.create<linalg::GenericOp>(
1463 loc, outputTy, genericInputs,
ValueRange{emptyTensor}, indexingMaps,
1467 Value value = blockArgs[0];
1475 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1476 if (failed(maybeIZp)) {
1477 (void)rewriter.notifyMatchFailure(
1478 op,
"input zero point cannot be statically determined");
1482 auto inputZp = createConstOpFromZpVal<int32_t>(
1483 op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
1486 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1487 if (failed(maybeOZp)) {
1488 (void)rewriter.notifyMatchFailure(
1489 op,
"output zero point cannot be statically determined");
1493 auto outputZp = createConstOpFromZpVal<int32_t>(
1494 op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);
1496 Value multiplier = multiplierConstant ? multiplierConstant
1497 : blockArgs[multiplierArg];
1498 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1501 if (op.getInputUnsigned()) {
1502 value = nestedBuilder.create<arith::ExtUIOp>(
1503 nestedLoc, nestedBuilder.getI32Type(), value);
1505 value = nestedBuilder.create<arith::ExtSIOp>(
1506 nestedLoc, nestedBuilder.getI32Type(), value);
1511 nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1513 value = nestedBuilder.create<tosa::ApplyScaleOp>(
1514 loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1519 nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1522 IntegerType outIntType =
1523 cast<IntegerType>(blockArgs.back().getType());
1524 unsigned outBitWidth = outIntType.getWidth();
1526 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1527 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1530 if (op.getOutputUnsigned()) {
1532 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1535 auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1536 loc, nestedBuilder.getI32IntegerAttr(intMin));
1537 auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1538 loc, nestedBuilder.getI32IntegerAttr(intMax));
1541 nestedBuilder,
false);
1543 if (outIntType.getWidth() < 32) {
1544 value = nestedBuilder.create<arith::TruncIOp>(
1545 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1549 nestedBuilder.create<linalg::YieldOp>(loc, value);
1552 rewriter.replaceOp(op, linalgOp->getResults());
1564 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1568 auto input = op.getInput();
1569 auto inputTy = cast<RankedTensorType>(input.getType());
1570 auto resultTy = cast<RankedTensorType>(op.getType());
1571 const bool isBilinear = op.getMode() ==
"BILINEAR";
1573 auto inputH = inputTy.getDimSize(1);
1574 auto inputW = inputTy.getDimSize(2);
1575 auto outputH = resultTy.getDimSize(1);
1576 auto outputW = resultTy.getDimSize(2);
1578 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1579 return rewriter.notifyMatchFailure(
1580 op,
"tosa.resize is not a pure 1x1->1x1 image operation");
1583 if (op.getMode() !=
"NEAREST_NEIGHBOR" && op.getMode() !=
"BILINEAR")
1584 return rewriter.notifyMatchFailure(
1585 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1587 if (inputTy == resultTy) {
1588 rewriter.replaceOp(op, input);
1599 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1600 reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1601 reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1602 reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1606 inputTy.getElementType());
1607 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1612 if (inputTy.isDynamicDim(0))
1613 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1614 if (inputTy.isDynamicDim(3))
1615 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1618 auto genericTy = collapseTy.clone(resultTy.getElementType());
1619 Value empty = builder.create<tensor::EmptyOp>(
1620 genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1621 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1623 utils::IteratorType::parallel);
1625 auto generic = builder.create<linalg::GenericOp>(
1629 Value value = args[0];
1631 if (inputTy.getElementType() != resultTy.getElementType()) {
1633 b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1635 if (isBilinear && scale[0] != 0) {
1636 Value scaleY = b.create<arith::ConstantOp>(
1637 loc, b.getI32IntegerAttr(scale[0]));
1638 value = b.create<arith::MulIOp>(loc, value, scaleY);
1641 if (isBilinear && scale[2] != 0) {
1642 Value scaleX = b.create<arith::ConstantOp>(
1643 loc, b.getI32IntegerAttr(scale[2]));
1644 value = b.create<arith::MulIOp>(loc, value, scaleX);
1648 b.create<linalg::YieldOp>(loc, value);
1651 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1652 op, resultTy,
generic.getResults()[0], reassociationMap);
1664 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1668 auto input = op.getInput();
1669 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1670 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1672 if (!inputTy || !resultTy)
1673 return rewriter.notifyMatchFailure(op,
1674 "requires ranked input/output types");
1676 auto batch = inputTy.getDimSize(0);
1677 auto channels = inputTy.getDimSize(3);
1678 auto inputH = inputTy.getDimSize(1);
1679 auto inputW = inputTy.getDimSize(2);
1680 auto outputH = resultTy.getDimSize(1);
1681 auto outputW = resultTy.getDimSize(2);
1683 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1684 return rewriter.notifyMatchFailure(
1685 op,
"tosa.resize has no broadcasting behavior");
1690 resizeShape.push_back(batch);
1691 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1692 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1693 resizeShape.push_back(channels);
1695 auto resizeTy = resultTy.clone(resizeShape);
1696 auto resize = builder.create<tosa::ResizeOp>(resizeTy, input, op.getScale(),
1697 op.getOffset(), op.getBorder(),
1702 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1703 reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1705 reassociationMap.push_back({});
1706 reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1708 reassociationMap.push_back({});
1709 reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1713 collapseShape.push_back(outputH);
1715 collapseShape.push_back(outputW);
1716 collapseShape.push_back(channels);
1718 auto collapseTy = resultTy.clone(collapseShape);
1719 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1724 if (inputTy.isDynamicDim(0))
1725 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1726 if (inputTy.isDynamicDim(3))
1727 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1730 utils::IteratorType::parallel);
1731 Value empty = builder.create<tensor::EmptyOp>(
1732 resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1736 inputExprs.push_back(rewriter.getAffineDimExpr(1));
1738 inputExprs.push_back(rewriter.getAffineDimExpr(2));
1739 inputExprs.push_back(rewriter.getAffineDimExpr(3));
1742 inputExprs, rewriter.getContext());
1744 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1745 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1749 Value value = args[0];
1750 b.create<linalg::YieldOp>(loc, value);
1761 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1765 auto input = op.getInput();
1766 auto inputTy = cast<ShapedType>(input.getType());
1767 auto resultTy = cast<ShapedType>(op.getType());
1768 auto resultETy = resultTy.getElementType();
1770 bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1771 auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1773 auto imageH = inputTy.getShape()[1];
1774 auto imageW = inputTy.getShape()[2];
1776 auto dynamicDimsOr =
1778 if (!dynamicDimsOr.has_value())
1779 return rewriter.notifyMatchFailure(
1780 op,
"unable to get dynamic dimensions of tosa.resize");
1782 if (op.getMode() !=
"NEAREST_NEIGHBOR" && op.getMode() !=
"BILINEAR")
1783 return rewriter.notifyMatchFailure(
1784 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1787 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1788 auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1790 auto genericOp = b.create<linalg::GenericOp>(
1793 Value resize = genericOp.getResult(0);
1797 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1799 Value batch = b.create<linalg::IndexOp>(0);
1800 Value y = b.create<linalg::IndexOp>(1);
1801 Value x = b.create<linalg::IndexOp>(2);
1802 Value channel = b.create<linalg::IndexOp>(3);
1805 b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1806 Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1807 Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1808 Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1810 Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1811 Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1817 return rewriter.notifyMatchFailure(
1818 op,
"tosa.resize scale/offset/border should have compile time "
1819 "constant values.");
1822 Value yScaleN, yScaleD, xScaleN, xScaleD;
1823 yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1824 yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1825 xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1826 xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1828 Value yOffset, xOffset, yBorder, xBorder;
1829 yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1830 xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1831 yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1832 xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1845 Value val = b.create<arith::MulIOp>(in, scaleD);
1846 val = b.create<arith::AddIOp>(val, offset);
1847 index = b.create<arith::FloorDivSIOp>(val, scaleN);
1851 Value r = b.create<arith::RemSIOp>(val, scaleN);
1852 Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1853 Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1854 delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1869 Value val = b.create<arith::MulIOp>(in, scaleD);
1870 val = b.create<arith::AddIOp>(val, offset);
1871 index = b.create<arith::DivSIOp>(val, scaleN);
1872 delta = b.create<arith::MulIOp>(index, scaleN);
1873 delta = b.create<arith::SubIOp>(val, delta);
1876 Value ix, iy, dx, dy;
1877 if (floatingPointMode) {
1878 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1879 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1881 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1882 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1885 if (op.getMode() ==
"NEAREST_NEIGHBOR") {
1886 auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1888 auto getNearestIndexAndClamp = [&](
Value val,
Value dval,
Value scale,
1892 return b.create<arith::ConstantIndexOp>(0);
1896 if (floatingPointMode) {
1897 auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1898 pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1900 Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1901 pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1905 auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1906 val = b.create<arith::AddIOp>(val, offset);
1908 return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1911 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1912 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1914 Value result = b.create<tensor::ExtractOp>(
1917 b.create<linalg::YieldOp>(result);
1920 assert(op.getMode() ==
"BILINEAR");
1922 auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1924 auto getClampedIdxs = [&](
Value &val0,
Value &val1,
int size,
Value in,
1927 val1 = b.create<arith::AddIOp>(val0, oneVal);
1932 val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1933 val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1941 Value x0, x1, y0, y1;
1942 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1943 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1945 Value y0x0 = b.create<tensor::ExtractOp>(
1947 Value y0x1 = b.create<tensor::ExtractOp>(
1949 Value y1x0 = b.create<tensor::ExtractOp>(
1951 Value y1x1 = b.create<tensor::ExtractOp>(
1954 if (floatingPointMode) {
1956 b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1962 Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1963 Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1964 Value mul1 = b.create<arith::MulFOp>(val1, delta);
1965 return b.create<arith::AddFOp>(mul0, mul1);
1971 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1976 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1980 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1981 b.create<linalg::YieldOp>(result);
1984 y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1985 y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1986 y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1987 y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1990 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1991 dx = b.create<arith::ExtSIOp>(resultETy, dx);
1992 dy = b.create<arith::ExtSIOp>(resultETy, dy);
1995 Value yScaleNExt = yScaleN;
1996 Value xScaleNExt = xScaleN;
1998 const int64_t scaleBitwidth =
2000 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2001 yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
2002 xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
2006 Value scale,
int inputSize,
2009 return b.create<arith::MulIOp>(val0, scale);
2010 Value weight0 = b.create<arith::SubIOp>(scale, weight1);
2011 Value mul0 = b.create<arith::MulIOp>(val0, weight0);
2012 Value mul1 = b.create<arith::MulIOp>(val1, weight1);
2013 return b.create<arith::AddIOp>(mul0, mul1);
2016 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
2017 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
2019 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
2020 b.create<linalg::YieldOp>(result);
2025 rewriter.replaceOp(op, resize);
2033 template <
typename SrcOp>
2038 LogicalResult matchAndRewrite(SrcOp op,
2040 rewriter.replaceOp(op, op.getOperation()->getOperands());
2045 template <
typename SrcOp>
2050 LogicalResult matchAndRewrite(SrcOp reduceOp,
2060 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2062 auto loc = op.getLoc();
2063 Value input = op.getInput1();
2064 auto inputTy = cast<ShapedType>(input.
getType());
2065 auto resultTy = cast<ShapedType>(op.getType());
2066 auto axis = op.getAxis();
2069 for (
int i = 0; i < inputTy.getRank(); i++) {
2070 if (inputTy.isDynamicDim(i)) {
2071 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
2075 Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
2078 auto emptyTensor = rewriter
2079 .create<tensor::EmptyOp>(loc, inputTy.getShape(),
2080 inputTy.getElementType(),
2084 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2086 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
2091 for (
unsigned int i = 0; i < inputTy.getRank(); i++) {
2093 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
2095 auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
2097 rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
2098 index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
2102 indices.push_back(index);
2105 auto extract = nestedBuilder.create<tensor::ExtractOp>(
2106 nestedLoc, input, indices);
2107 nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
2108 extract.getResult());
2122 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2124 auto loc = op.getLoc();
2125 auto input = op.getInput1();
2126 auto inputTy = cast<ShapedType>(input.getType());
2127 auto inputShape = inputTy.getShape();
2128 auto resultTy = cast<ShapedType>(op.getType());
2129 auto elementTy = inputTy.getElementType();
2130 int64_t rank = inputTy.getRank();
2133 if (failed(op.getConstantMultiples(multiples)))
2138 for (
int i = 0; i < rank; i++) {
2139 int64_t dim = multiples[i];
2140 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2141 genericShape.push_back(inputShape[i]);
2145 for (
int i = 0; i < inputTy.getRank(); i++) {
2146 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2147 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
2151 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
2152 op.getLoc(), genericShape, elementTy, dynDims);
2156 dimExprs.reserve(rank);
2157 for (
unsigned i = 0; i < rank; ++i)
2160 auto readAffineMap =
2167 auto genericOp = rewriter.
create<linalg::GenericOp>(
2172 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
2178 op, resultTy, genericOp.getResult(0), shapeValue);
2200 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2202 auto loc = argmaxOp.getLoc();
2203 Value input = argmaxOp.getInput();
2204 auto inputTy = cast<ShapedType>(input.
getType());
2205 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2206 auto inElementTy = inputTy.getElementType();
2207 auto outElementTy = resultTy.getElementType();
2208 int axis = argmaxOp.getAxis();
2211 if (!isa<IntegerType>(outElementTy))
2214 "tosa.arg_max to linalg.* requires integer-like result type");
2217 for (
int i = 0; i < inputTy.getRank(); i++) {
2218 if (inputTy.isDynamicDim(i) && i != axis) {
2219 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
2224 auto emptyTensorIdx = rewriter
2225 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2226 outElementTy, dynDims)
2228 auto fillValueIdx = rewriter.
create<arith::ConstantOp>(
2230 auto filledTensorIdx =
2237 auto emptyTensorMax = rewriter
2238 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2239 inElementTy, dynDims)
2241 auto fillValueMaxAttr =
2244 if (!fillValueMaxAttr)
2246 argmaxOp,
"unsupported tosa.argmax element type");
2249 rewriter.
create<arith::ConstantOp>(loc, fillValueMaxAttr);
2250 auto filledTensorMax =
2259 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2260 iteratorTypes[axis] = utils::IteratorType::reduction;
2264 for (
int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2270 bool didEncounterError =
false;
2273 auto linalgOp = rewriter.
create<linalg::GenericOp>(
2275 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2278 auto newValue = blockArgs[0];
2279 auto oldIndex = blockArgs[1];
2280 auto oldValue = blockArgs[2];
2282 Value newIndex = rewriter.
create<arith::IndexCastOp>(
2283 nestedLoc, oldIndex.getType(),
2284 rewriter.
create<linalg::IndexOp>(loc, axis));
2287 if (isa<FloatType>(inElementTy)) {
2288 if (argmaxOp.getNanMode() ==
"IGNORE") {
2291 predicate = rewriter.
create<arith::CmpFOp>(
2292 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2298 nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue);
2300 nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue);
2301 predicate = rewriter.
create<arith::AndIOp>(
2302 nestedLoc, rewriter.
getI1Type(), gt, oldNonNaN);
2304 }
else if (isa<IntegerType>(inElementTy)) {
2305 predicate = rewriter.
create<arith::CmpIOp>(
2306 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2308 didEncounterError =
true;
2312 auto resultMax = rewriter.
create<arith::SelectOp>(
2313 nestedLoc, predicate, newValue, oldValue);
2314 auto resultIndex = rewriter.
create<arith::SelectOp>(
2315 nestedLoc, predicate, newIndex, oldIndex);
2316 nestedBuilder.
create<linalg::YieldOp>(
2317 nestedLoc,
ValueRange({resultIndex, resultMax}));
2320 if (didEncounterError)
2322 argmaxOp,
"unsupported tosa.argmax element type");
2324 rewriter.
replaceOp(argmaxOp, linalgOp.getResult(0));
2333 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2335 auto input = adaptor.getOperands()[0];
2336 auto indices = adaptor.getOperands()[1];
2338 auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2339 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2340 if (!valuesTy || !resultTy)
2343 auto dynamicDims = inferDynamicDimsForGather(
2344 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2346 auto resultElementTy = resultTy.getElementType();
2348 auto loc = op.getLoc();
2351 .
create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2357 resultTy.getRank(), 0,
2358 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2362 auto genericOp = rewriter.
create<linalg::GenericOp>(
2367 auto indexValue = args[0];
2368 auto index0 = rewriter.
create<linalg::IndexOp>(loc, 0);
2369 Value index1 = rewriter.
create<arith::IndexCastOp>(
2371 auto index2 = rewriter.
create<linalg::IndexOp>(loc, 2);
2372 Value extract = rewriter.
create<tensor::ExtractOp>(
2373 loc, input,
ValueRange{index0, index1, index2});
2374 rewriter.
create<linalg::YieldOp>(loc, extract);
2376 rewriter.
replaceOp(op, genericOp.getResult(0));
2386 auto addDynamicDimension = [&](
Value source, int64_t dim) {
2388 if (
auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2389 results.push_back(dimValue);
2392 addDynamicDimension(values, 0);
2393 addDynamicDimension(indices, 1);
2394 addDynamicDimension(values, 2);
2406 LogicalResult matchAndRewrite(tosa::TableOp op,
2408 auto loc = op.getLoc();
2409 Value input = op.getInput1();
2411 auto inputTy = cast<ShapedType>(input.
getType());
2412 auto tableTy = cast<ShapedType>(
table.getType());
2413 auto resultTy = cast<ShapedType>(op.getType());
2415 auto inputElementTy = inputTy.getElementType();
2416 auto tableElementTy = tableTy.getElementType();
2417 auto resultElementTy = resultTy.getElementType();
2420 for (
int i = 0; i < resultTy.getRank(); ++i) {
2421 if (inputTy.isDynamicDim(i)) {
2423 rewriter.
create<tensor::DimOp>(loc, op.getOperand(0), i));
2427 auto emptyTensor = rewriter
2428 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2429 resultElementTy, dynDims)
2436 auto genericOp = rewriter.
create<linalg::GenericOp>(
2439 rewriter.
replaceOp(op, genericOp.getResult(0));
2444 &genericOp.getRegion(), genericOp.getRegion().end(),
2445 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2449 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2450 resultElementTy.isInteger(8)) {
2451 Value index = rewriter.
create<arith::IndexCastOp>(
2453 Value offset = rewriter.
create<arith::ConstantIndexOp>(loc, 128);
2458 rewriter.
create<linalg::YieldOp>(loc, extract);
2462 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2463 resultElementTy.isInteger(32)) {
2467 auto offset = rewriter.
create<arith::ConstantOp>(
2469 auto seven = rewriter.
create<arith::ConstantOp>(
2471 auto one = rewriter.
create<arith::ConstantOp>(
2473 auto b1111111 = rewriter.
create<arith::ConstantOp>(
2480 auto extendAdd = rewriter.
create<arith::AddIOp>(loc, extend, offset);
2481 Value index = rewriter.
create<arith::ShRUIOp>(loc, extendAdd, seven);
2483 rewriter.
create<arith::AndIOp>(loc, extendAdd, b1111111);
2488 Value indexPlusOne = rewriter.
create<arith::AddIOp>(loc, index, one);
2490 index = rewriter.
create<arith::IndexCastOp>(
2492 indexPlusOne = rewriter.
create<arith::IndexCastOp>(
2507 Value baseScaled = rewriter.
create<arith::ShLIOp>(loc, base, seven);
2508 Value diff = rewriter.
create<arith::SubIOp>(loc, next, base);
2509 Value diffScaled = rewriter.
create<arith::MulIOp>(loc, diff, fraction);
2511 rewriter.
create<arith::AddIOp>(loc, baseScaled, diffScaled);
2513 rewriter.
create<linalg::YieldOp>(loc, result);
2520 op,
"unable to create body for tosa.table op");
2527 static bool isRankedTensor(
Type type) {
return isa<RankedTensorType>(type); }
2531 auto one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2532 auto two = builder.
create<arith::ConstantIndexOp>(loc, 2);
2535 auto divBy2 = builder.
createOrFold<arith::DivUIOp>(loc, value, two);
2536 auto plusOne = builder.
createOrFold<arith::AddIOp>(loc, divBy2, one);
2540 static RankedTensorType
2548 dims[2] = halfPlusOne(builder, loc, dims[2]);
2553 auto elementType = cast<RankedTensorType>(input.
getType()).getElementType();
2558 RankedTensorType type,
2561 rewriter.
create<tensor::EmptyOp>(loc, type, dynamicSizes);
2562 auto fillValueAttr = rewriter.
getZeroAttr(type.getElementType());
2563 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
2564 auto filledTensor = rewriter
2568 return filledTensor;
2572 FloatType type,
Value value) {
2573 auto integerVal = builder.
create<arith::IndexCastUIOp>(
2575 type.getIntOrFloatBitWidth() > 32 ? builder.
getI64Type()
2579 return builder.
create<arith::UIToFPOp>(loc, type, integerVal);
2583 FloatType type, int64_t index) {
2584 auto indexVal = builder.
create<linalg::IndexOp>(loc, index);
2585 return castIndexToFloat(builder, loc, type, indexVal);
2588 template <
typename... Args>
2594 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2596 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2597 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2599 "only supports ranked tensors");
2602 auto loc = rfft2d.getLoc();
2603 auto input = rfft2d.getInputReal();
2605 dyn_cast<FloatType>(cast<ShapedType>(input.
getType()).getElementType());
2608 "only supports float element types");
2612 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2616 utils::IteratorType::parallel, utils::IteratorType::parallel,
2617 utils::IteratorType::parallel, utils::IteratorType::reduction,
2618 utils::IteratorType::reduction};
2623 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2624 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2629 affineDimsExpr(rewriter, 0, 1, 2),
2630 affineDimsExpr(rewriter, 0, 1, 2)},
2634 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input, 1);
2635 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input, 2);
2638 auto twoPiAttr = rewriter.
getFloatAttr(elementType, 6.283185307179586);
2639 auto twoPi = rewriter.
create<arith::ConstantOp>(loc, twoPiAttr);
2640 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2641 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2644 Value valReal = args[0];
2645 Value sumReal = args[1];
2646 Value sumImag = args[2];
2649 Value oy = builder.
create<linalg::IndexOp>(loc, 1);
2650 Value ox = builder.
create<linalg::IndexOp>(loc, 2);
2651 Value iy = builder.
create<linalg::IndexOp>(loc, 3);
2652 Value ix = builder.
create<linalg::IndexOp>(loc, 4);
2657 auto iyXoy = builder.
create<index::MulOp>(loc, iy, oy);
2658 auto ixXox = builder.
create<index::MulOp>(loc, ix, ox);
2660 auto iyRem = builder.
create<index::RemUOp>(loc, iyXoy, dimH);
2661 auto ixRem = builder.
create<index::RemUOp>(loc, ixXox, dimW);
2663 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2664 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2666 auto yComponent = builder.
create<arith::DivFOp>(loc, iyRemFloat, constH);
2667 auto xComponent = builder.
create<arith::DivFOp>(loc, ixRemFloat, constW);
2668 auto sumXY = builder.
create<arith::AddFOp>(loc, yComponent, xComponent);
2669 auto angle = builder.
create<arith::MulFOp>(loc, twoPi, sumXY);
2673 auto cosAngle = builder.
create<math::CosOp>(loc, angle);
2674 auto sinAngle = builder.
create<math::SinOp>(loc, angle);
2675 auto realComponent =
2676 builder.
create<arith::MulFOp>(loc, valReal, cosAngle);
2677 auto imagComponent =
2678 builder.
create<arith::MulFOp>(loc, valReal, sinAngle);
2682 auto outReal = builder.
create<arith::AddFOp>(loc, sumReal, realComponent);
2683 auto outImag = builder.
create<arith::SubFOp>(loc, sumImag, imagComponent);
2689 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2690 indexingMaps, iteratorTypes, buildBody);
2699 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2701 if (!llvm::all_of(fft2d->getOperandTypes(),
2702 RFFT2dConverter::isRankedTensor) ||
2703 !llvm::all_of(fft2d->getResultTypes(),
2704 RFFT2dConverter::isRankedTensor)) {
2709 Value input_real = fft2d.getInputReal();
2710 Value input_imag = fft2d.getInputImag();
2711 BoolAttr inverse = fft2d.getInverseAttr();
2713 auto real_el_ty = cast<FloatType>(
2714 cast<ShapedType>(input_real.
getType()).getElementType());
2715 [[maybe_unused]]
auto imag_el_ty = cast<FloatType>(
2716 cast<ShapedType>(input_imag.
getType()).getElementType());
2718 assert(real_el_ty == imag_el_ty);
2733 utils::IteratorType::parallel, utils::IteratorType::parallel,
2734 utils::IteratorType::parallel, utils::IteratorType::reduction,
2735 utils::IteratorType::reduction};
2740 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2742 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2747 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2748 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2749 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2750 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2754 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 1);
2755 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 2);
2758 auto twoPiAttr = rewriter.
getFloatAttr(real_el_ty, 6.283185307179586);
2759 auto twoPi = rewriter.
create<arith::ConstantOp>(loc, twoPiAttr);
2761 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2763 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2766 Value valReal = args[0];
2767 Value valImag = args[1];
2768 Value sumReal = args[2];
2769 Value sumImag = args[3];
2772 Value oy = builder.
create<linalg::IndexOp>(loc, 1);
2773 Value ox = builder.
create<linalg::IndexOp>(loc, 2);
2774 Value iy = builder.
create<linalg::IndexOp>(loc, 3);
2775 Value ix = builder.
create<linalg::IndexOp>(loc, 4);
2779 auto iyXoy = builder.
create<index::MulOp>(loc, iy, oy);
2780 auto ixXox = builder.
create<index::MulOp>(loc, ix, ox);
2782 auto iyRem = builder.
create<index::RemUOp>(loc, iyXoy, dimH);
2783 auto ixRem = builder.
create<index::RemUOp>(loc, ixXox, dimW);
2786 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2788 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2790 auto yComponent = builder.
create<arith::DivFOp>(loc, iyRemFloat, constH);
2791 auto xComponent = builder.
create<arith::DivFOp>(loc, ixRemFloat, constW);
2793 auto sumXY = builder.
create<arith::AddFOp>(loc, yComponent, xComponent);
2794 auto angle = builder.
create<arith::MulFOp>(loc, twoPi, sumXY);
2797 angle = builder.
create<arith::MulFOp>(
2799 rewriter.
create<arith::ConstantOp>(
2805 auto cosAngle = builder.
create<math::CosOp>(loc, angle);
2806 auto sinAngle = builder.
create<math::SinOp>(loc, angle);
2808 auto rcos = builder.
create<arith::MulFOp>(loc, valReal, cosAngle);
2809 auto rsin = builder.
create<arith::MulFOp>(loc, valImag, sinAngle);
2810 auto realComponent = builder.
create<arith::AddFOp>(loc, rcos, rsin);
2812 auto icos = builder.
create<arith::MulFOp>(loc, valImag, cosAngle);
2813 auto isin = builder.
create<arith::MulFOp>(loc, valReal, sinAngle);
2815 auto imagComponent = builder.
create<arith::SubFOp>(loc, icos, isin);
2819 auto outReal = builder.
create<arith::AddFOp>(loc, sumReal, realComponent);
2820 auto outImag = builder.
create<arith::AddFOp>(loc, sumImag, imagComponent);
2826 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2827 indexingMaps, iteratorTypes, buildBody);
2848 PointwiseConverter<tosa::AddOp>,
2849 PointwiseConverter<tosa::SubOp>,
2850 PointwiseConverter<tosa::MulOp>,
2851 PointwiseConverter<tosa::IntDivOp>,
2852 PointwiseConverter<tosa::NegateOp>,
2853 PointwiseConverter<tosa::PowOp>,
2854 PointwiseConverter<tosa::ReciprocalOp>,
2855 PointwiseConverter<tosa::RsqrtOp>,
2856 PointwiseConverter<tosa::LogOp>,
2857 PointwiseConverter<tosa::ExpOp>,
2858 PointwiseConverter<tosa::AbsOp>,
2859 PointwiseConverter<tosa::SinOp>,
2860 PointwiseConverter<tosa::CosOp>,
2861 PointwiseConverter<tosa::TanhOp>,
2862 PointwiseConverter<tosa::ErfOp>,
2863 PointwiseConverter<tosa::BitwiseAndOp>,
2864 PointwiseConverter<tosa::BitwiseOrOp>,
2865 PointwiseConverter<tosa::BitwiseNotOp>,
2866 PointwiseConverter<tosa::BitwiseXorOp>,
2867 PointwiseConverter<tosa::LogicalAndOp>,
2868 PointwiseConverter<tosa::LogicalNotOp>,
2869 PointwiseConverter<tosa::LogicalOrOp>,
2870 PointwiseConverter<tosa::LogicalXorOp>,
2871 PointwiseConverter<tosa::CastOp>,
2872 PointwiseConverter<tosa::LogicalLeftShiftOp>,
2873 PointwiseConverter<tosa::LogicalRightShiftOp>,
2874 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2875 PointwiseConverter<tosa::ClzOp>,
2876 PointwiseConverter<tosa::SelectOp>,
2877 PointwiseConverter<tosa::GreaterOp>,
2878 PointwiseConverter<tosa::GreaterEqualOp>,
2879 PointwiseConverter<tosa::EqualOp>,
2880 PointwiseConverter<tosa::MaximumOp>,
2881 PointwiseConverter<tosa::MinimumOp>,
2882 PointwiseConverter<tosa::CeilOp>,
2883 PointwiseConverter<tosa::FloorOp>,
2884 PointwiseConverter<tosa::ClampOp>,
2885 PointwiseConverter<tosa::SigmoidOp>
2886 >(converter,
patterns->getContext());
2889 IdentityNConverter<tosa::IdentityOp>,
2890 ReduceConverter<tosa::ReduceAllOp>,
2891 ReduceConverter<tosa::ReduceAnyOp>,
2892 ReduceConverter<tosa::ReduceMinOp>,
2893 ReduceConverter<tosa::ReduceMaxOp>,
2894 ReduceConverter<tosa::ReduceSumOp>,
2895 ReduceConverter<tosa::ReduceProductOp>,
2903 TileConverter>(
patterns->getContext());
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static arith::ConstantOp createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType, OpBuilder &rewriter)
static bool operandsAndResultsRanked(Operation *operation)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...