31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
40 static arith::ConstantOp
43 auto castedN =
static_cast<T
>(
44 cast<IntegerAttr>(op->
getAttr(attrName)).getValue().getSExtValue());
45 return rewriter.
create<arith::ConstantOp>(
57 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
58 return rewriter.
create<math::AbsFOp>(loc, resultTypes, args);
60 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
61 auto zero = rewriter.
create<arith::ConstantOp>(
63 auto neg = rewriter.
create<arith::SubIOp>(loc, zero, args[0]);
64 return rewriter.
create<arith::MaxSIOp>(loc, args[0], neg);
68 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
69 return rewriter.
create<arith::AddFOp>(loc, resultTypes, args);
71 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
72 return rewriter.
create<arith::AddIOp>(loc, resultTypes, args);
75 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
76 return rewriter.
create<arith::SubFOp>(loc, resultTypes, args);
78 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
79 return rewriter.
create<arith::SubIOp>(loc, resultTypes, args);
82 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
83 return rewriter.
create<arith::DivSIOp>(loc, resultTypes, args);
86 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
89 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, args[0]);
93 if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
94 return rewriter.
create<arith::MulFOp>(loc, resultTypes, args);
96 if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
100 cast<IntegerAttr>(op->
getAttr(
"shift")).getValue().getSExtValue();
103 rewriter.
create<arith::ConstantIntOp>(loc, shift, 8);
110 auto result = rewriter.
create<tosa::ApplyScaleOp>(
114 if (elementTy.isInteger(32))
117 return rewriter.
create<arith::TruncIOp>(loc, elementTy, result);
122 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
125 a = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], a);
127 b = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], b);
129 return rewriter.
create<arith::MulIOp>(loc, resultTypes, a, b);
133 if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
134 return rewriter.
create<arith::NegFOp>(loc, resultTypes, args);
136 if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
137 int64_t inZp = 0, outZp = 0;
139 if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
140 auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
141 inZp = quantizationInfo.value().getInputZp();
142 outZp = quantizationInfo.value().getOutputZp();
145 int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
146 if (!inZp && !outZp) {
147 auto constant = rewriter.
create<arith::ConstantOp>(
149 return rewriter.
create<arith::SubIOp>(loc, resultTypes, constant,
154 int64_t zpAdd = inZp + outZp;
155 int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
160 int intermediateBitWidth = 64;
161 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
162 intermediateBitWidth = 16;
163 }
else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
164 intermediateBitWidth = 32;
165 }
else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
166 intermediateBitWidth = 48;
170 Value zpAddValue = rewriter.
create<arith::ConstantOp>(
175 auto ext = rewriter.
create<arith::ExtSIOp>(loc, intermediateType, args[0]);
176 auto sub = rewriter.
create<arith::SubIOp>(loc, zpAddValue, ext);
180 loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
183 loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
189 return rewriter.
create<arith::TruncIOp>(loc, elementTy,
clamp);
193 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
194 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
197 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
198 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
201 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
203 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
204 auto allOnes = rewriter.
create<arith::ConstantOp>(loc, allOnesAttr);
205 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
209 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
210 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
213 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
214 return rewriter.
create<arith::ShLIOp>(loc, resultTypes, args);
217 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
218 return rewriter.
create<arith::ShRUIOp>(loc, resultTypes, args);
221 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
222 auto result = rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args);
223 auto round = cast<BoolAttr>(op->
getAttr(
"round")).getValue();
237 auto shiftValueGreaterThanZero = rewriter.
create<arith::CmpIOp>(
238 loc, arith::CmpIPredicate::sgt, args[1], zero);
242 rewriter.
create<arith::SubIOp>(loc, resultTypes, args[1], one);
244 rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
247 rewriter.
create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
249 rewriter.
create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
251 auto shouldRound = rewriter.
create<arith::AndIOp>(
252 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
254 rewriter.
create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
255 return rewriter.
create<arith::AddIOp>(loc, resultTypes, result, extended);
259 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
260 return rewriter.
create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
264 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
265 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
268 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
269 auto one = rewriter.
create<arith::ConstantOp>(
271 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], one);
275 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
276 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
279 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
280 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
283 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
284 return rewriter.
create<mlir::math::PowFOp>(loc, resultTypes, args);
287 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
288 return rewriter.
create<mlir::math::RsqrtOp>(loc, resultTypes, args);
291 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
292 return rewriter.
create<mlir::math::LogOp>(loc, resultTypes, args);
295 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
296 return rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, args);
299 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
300 return rewriter.
create<mlir::math::SinOp>(loc, resultTypes, args);
303 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
304 return rewriter.
create<mlir::math::CosOp>(loc, resultTypes, args);
307 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
308 return rewriter.
create<mlir::math::TanhOp>(loc, resultTypes, args);
311 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
312 return rewriter.
create<mlir::math::ErfOp>(loc, resultTypes, args);
315 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
316 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
319 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
320 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
324 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
325 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
328 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
329 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
333 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
334 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
337 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
338 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
342 if (isa<tosa::SelectOp>(op)) {
344 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
345 return rewriter.
create<arith::SelectOp>(loc, args[0], args[1], args[2]);
349 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
350 return rewriter.
create<arith::MaximumFOp>(loc, args[0], args[1]);
353 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
354 return rewriter.
create<arith::MaxSIOp>(loc, args[0], args[1]);
358 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
359 return rewriter.
create<arith::MinimumFOp>(loc, args[0], args[1]);
362 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
363 return rewriter.
create<arith::MinSIOp>(loc, args[0], args[1]);
367 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
368 return rewriter.
create<math::CeilOp>(loc, resultTypes, args);
371 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
372 return rewriter.
create<math::FloorOp>(loc, resultTypes, args);
375 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
376 bool losesInfo =
false;
377 APFloat minApf = cast<FloatAttr>(op->
getAttr(
"min_fp")).getValue();
378 APFloat maxApf = cast<FloatAttr>(op->
getAttr(
"max_fp")).getValue();
379 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
380 APFloat::rmNearestTiesToEven, &losesInfo);
381 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
382 APFloat::rmNearestTiesToEven, &losesInfo);
383 auto min = rewriter.
create<arith::ConstantOp>(
384 loc, elementTy, rewriter.
getFloatAttr(elementTy, minApf));
385 auto max = rewriter.
create<arith::ConstantOp>(
386 loc, elementTy, rewriter.
getFloatAttr(elementTy, maxApf));
390 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
391 auto intTy = cast<IntegerType>(elementTy);
393 cast<IntegerAttr>(op->
getAttr(
"min_int")).getValue().getSExtValue();
395 cast<IntegerAttr>(op->
getAttr(
"max_int")).getValue().getSExtValue();
399 if (intTy.isUnsignedInteger()) {
400 minRepresentable = 0;
401 if (intTy.getIntOrFloatBitWidth() <= 63) {
403 (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
406 }
else if (intTy.getIntOrFloatBitWidth() <= 64) {
408 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
410 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
420 auto minVal = rewriter.
create<arith::ConstantIntOp>(
421 loc,
min, intTy.getIntOrFloatBitWidth());
422 auto maxVal = rewriter.
create<arith::ConstantIntOp>(
423 loc,
max, intTy.getIntOrFloatBitWidth());
425 intTy.isUnsignedInteger());
429 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
432 auto negate = rewriter.
create<arith::NegFOp>(loc, resultTypes, args[0]);
433 auto exp = rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, negate);
434 auto added = rewriter.
create<arith::AddFOp>(loc, resultTypes, exp, one);
435 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, added);
439 if (isa<tosa::CastOp>(op)) {
440 Type srcTy = elementTy;
441 Type dstTy = resultTypes.front();
448 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
449 return rewriter.
create<arith::ExtFOp>(loc, resultTypes, args,
452 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
453 return rewriter.
create<arith::TruncFOp>(loc, resultTypes, args,
457 if (srcTy.
isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
458 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes, args,
461 if (srcTy.
isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
462 return rewriter.
create<arith::ExtUIOp>(loc, resultTypes, args,
468 auto unrealizedCast =
470 .
create<UnrealizedConversionCastOp>(
474 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes[0],
479 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
480 return rewriter.
create<arith::SIToFPOp>(loc, resultTypes, args,
484 if (isa<FloatType>(srcTy) && dstTy.
isInteger(1)) {
487 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
491 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
492 auto rounded = rewriter.
create<math::RoundEvenOp>(loc, args[0]);
494 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
498 APFloat::semanticsMaxExponent(fltSemantics)) {
501 auto conv = rewriter.
create<arith::FPToSIOp>(loc, dstTy, rounded);
502 auto posInf = rewriter.
create<arith::ConstantOp>(
504 APFloat::getInf(fltSemantics)));
505 auto negInf = rewriter.
create<arith::ConstantOp>(
508 APFloat::getInf(fltSemantics,
true)));
509 auto overflow = rewriter.
create<arith::CmpFOp>(
510 loc, arith::CmpFPredicate::UEQ, rounded, posInf);
511 auto underflow = rewriter.
create<arith::CmpFOp>(
512 loc, arith::CmpFPredicate::UEQ, rounded, negInf);
513 auto intMin = rewriter.
create<arith::ConstantOp>(
516 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
517 auto intMax = rewriter.
create<arith::ConstantOp>(
520 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
522 rewriter.
create<arith::SelectOp>(loc, overflow, intMax, conv);
523 return rewriter.
create<arith::SelectOp>(loc, underflow, intMin,
527 auto intMinFP = rewriter.
create<arith::ConstantOp>(
534 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
540 auto intMaxFP = rewriter.
create<arith::ConstantOp>(
548 return rewriter.
create<arith::FPToSIOp>(loc, dstTy, clamped);
555 auto intMaxPlusOneFP = rewriter.
create<arith::ConstantOp>(
563 auto intMax = rewriter.
create<arith::ConstantOp>(
568 rewriter.
create<arith::MaximumFOp>(loc, rounded, intMinFP);
570 rewriter.
create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
571 auto overflow = rewriter.
create<arith::CmpFOp>(
572 loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
573 return rewriter.
create<arith::SelectOp>(loc, overflow, intMax,
579 if (isa<IntegerType>(srcTy) && dstTy.
isInteger(1)) {
580 Value zero = rewriter.
create<arith::ConstantIntOp>(
582 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
586 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
587 return rewriter.
create<arith::ExtSIOp>(loc, resultTypes, args,
590 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
591 return rewriter.
create<arith::TruncIOp>(loc, dstTy, args[0]);
596 op,
"unhandled op for linalg body calculation for elementwise op");
603 auto shapedType = dyn_cast<ShapedType>(tensor.
getType());
604 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
605 int64_t numExtraDims = rank - shapedType.getRank();
606 assert(numExtraDims >= 0 &&
"cannot expand tensor to a lower rank");
612 shapedType.getRank());
614 for (index = 0; index <= numExtraDims; index++)
615 reassociationIndices[0].push_back(index);
616 for (
size_t position = 1; position < reassociationIndices.size(); position++)
617 reassociationIndices[position].push_back(index++);
621 for (index = 0; index < numExtraDims; index++)
622 resultShape.push_back(1);
623 for (
auto size : shapedType.getShape())
624 resultShape.push_back(size);
629 return rewriter.
create<tensor::ExpandShapeOp>(loc, resultType, tensor,
630 reassociationIndices);
636 return llvm::map_to_vector(operands, [&](
Value operand) {
637 return expandRank(rewriter, loc, operand, rank);
648 auto [it, inserted] = indexPool.try_emplace(index);
657 auto indexValue =
createIndex(rewriter, loc, indexPool, index);
658 return rewriter.
create<tensor::DimOp>(loc, tensor, indexValue).getResult();
664 auto shapedType = dyn_cast<ShapedType>(tensor.
getType());
665 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
666 assert(index >= 0 && index < shapedType.getRank() &&
"index out of bounds");
667 if (shapedType.isDynamicDim(index))
668 return getTensorDim(rewriter, loc, indexPool, tensor, index);
669 return rewriter.
getIndexAttr(shapedType.getDimSize(index));
673 auto isRanked = [](
Value value) {
674 return isa<RankedTensorType>(value.getType());
676 return llvm::all_of(operation->
getOperands(), isRanked) &&
677 llvm::all_of(operation->
getResults(), isRanked);
690 static std::pair<OpFoldResult, Value>
696 for (
auto operand : operands) {
697 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
698 if (!ShapedType::isDynamic(size) && size > 1)
703 auto operandsWithDynamicDim =
704 llvm::filter_to_vector(operands, [&](
Value operand) {
705 return cast<RankedTensorType>(operand.
getType()).isDynamicDim(dim);
709 if (operandsWithDynamicDim.empty())
716 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
717 if (operandsWithDynamicDim.size() == 1)
718 return {targetSize, operandsWithDynamicDim[0]};
721 for (
size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
723 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
724 targetSize = rewriter.
create<arith::MaxUIOp>(loc, targetSize, nextSize);
726 return {targetSize,
nullptr};
734 assert(!operands.empty());
735 auto rank = cast<RankedTensorType>(operands.front().
getType()).getRank();
738 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
739 auto [targetSize, masterOperand] =
741 targetShape.push_back(targetSize);
742 masterOperands.push_back(masterOperand);
744 return {targetShape, masterOperands};
750 Value masterOperand) {
752 auto rankedTensorType = cast<RankedTensorType>(operand.
getType());
753 if (!rankedTensorType.isDynamicDim(dim))
760 if (operand == masterOperand)
764 auto rank = rankedTensorType.getRank();
766 for (
auto index : llvm::seq<int64_t>(0, rank)) {
769 affineExprs.push_back(affineExpr);
771 auto broadcastAffineMap =
777 auto one =
createIndex(rewriter, loc, indexPool, 1);
778 auto runtimeSize =
getTensorDim(rewriter, loc, indexPool, operand, dim);
779 auto broadcastNecessary = rewriter.
create<arith::CmpIOp>(
780 loc, arith::CmpIPredicate::eq, runtimeSize, one);
790 for (
auto index : llvm::seq<int64_t>(0, rank)) {
791 auto size = index == dim ? targetSize
794 outputTensorShape.push_back(size);
796 Value outputTensor = opBuilder.
create<tensor::EmptyOp>(
797 loc, outputTensorShape, rankedTensorType.getElementType());
802 .
create<linalg::GenericOp>(
803 loc, outputTensor.
getType(), operand, outputTensor, affineMaps,
807 opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
812 auto castResultTensor = rewriter.
createOrFold<tensor::CastOp>(
813 loc, operand.
getType(), resultTensor);
816 opBuilder.create<scf::YieldOp>(loc, castResultTensor);
821 opBuilder.
create<scf::YieldOp>(loc, operand);
825 auto ifOp = rewriter.
create<scf::IfOp>(loc, broadcastNecessary,
826 emitThenRegion, emitElseRegion);
834 int64_t rank = cast<RankedTensorType>(operand.
getType()).getRank();
835 assert((int64_t)targetShape.size() == rank);
836 assert((int64_t)masterOperands.size() == rank);
837 for (
auto index : llvm::seq<int64_t>(0, rank))
840 targetShape[index], masterOperands[index]);
850 if (operands.size() == 1)
854 return llvm::map_to_vector(operands, [&](
Value operand) {
856 targetShape, masterOperands);
866 auto resultType = cast_or_null<RankedTensorType>(
871 Value outputTensor = rewriter.
create<tensor::EmptyOp>(
872 loc, targetShape, resultType.getElementType());
877 auto rank = resultType.getRank();
878 auto affineMaps = llvm::map_to_vector(operands, [&](
Value operand) {
879 auto shape = cast<ShapedType>(operand.
getType()).getShape();
884 affineExprs.push_back(affineExpr);
891 bool encounteredError =
false;
892 auto linalgOp = rewriter.
create<linalg::GenericOp>(
893 loc, outputTensor.
getType(), operands, outputTensor, affineMaps,
898 {resultType.getElementType()}, rewriter);
900 encounteredError =
true;
903 opBuilder.create<linalg::YieldOp>(loc, opResult);
905 if (encounteredError)
907 operation,
"unable to create linalg.generic body for elementwise op");
910 auto castResult = rewriter.
createOrFold<tensor::CastOp>(
911 loc, resultType, linalgOp->getResult(0));
912 rewriter.
replaceOp(operation, castResult);
922 assert(operation->
getNumResults() == 1 &&
"elementwise op expects 1 result");
924 "elementwise op expects at least 1 operand");
927 "Unranked tensors not supported");
931 auto loc = operation->
getLoc();
935 auto [targetShape, masterOperands] =
938 rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
940 targetShape, converter);
947 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
950 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
953 if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy))
956 if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy))
959 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
961 elementTy, APFloat::getLargest(
962 cast<FloatType>(elementTy).getFloatSemantics(),
false));
964 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
968 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
970 elementTy, APFloat::getLargest(
971 cast<FloatType>(elementTy).getFloatSemantics(),
true));
973 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
977 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
980 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
983 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
985 elementTy, APFloat::getLargest(
986 cast<FloatType>(elementTy).getFloatSemantics(),
true));
988 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1002 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1003 return rewriter.
create<arith::AddFOp>(loc, args);
1006 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1007 return rewriter.
create<arith::AddIOp>(loc, args);
1010 if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) {
1011 return rewriter.
create<arith::MulFOp>(loc, args);
1014 if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) {
1015 return rewriter.
create<arith::MulIOp>(loc, args);
1018 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1019 return rewriter.
create<arith::MinimumFOp>(loc, args[0], args[1]);
1022 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1023 return rewriter.
create<arith::MinSIOp>(loc, args[0], args[1]);
1026 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1027 return rewriter.
create<arith::MaximumFOp>(loc, args[0], args[1]);
1030 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1031 return rewriter.
create<arith::MaxSIOp>(loc, args[0], args[1]);
1034 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1035 return rewriter.
create<arith::AndIOp>(loc, args);
1037 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1038 return rewriter.
create<arith::OrIOp>(loc, args);
1051 auto elementTy = resultTy.getElementType();
1056 for (
unsigned i = 0; i < inputTy.getRank(); i++) {
1058 reduceShape.push_back(inputTy.getDimSize(i));
1059 if (inputTy.isDynamicDim(i))
1060 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1067 .
create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1074 op,
"No initial value found for reduction operation");
1076 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
1077 auto filledTensor = rewriter
1082 bool didEncounterError =
false;
1083 auto linalgOp = rewriter.
create<linalg::ReduceOp>(
1084 loc, input, filledTensor, axis,
1087 op, blockArgs, elementTy, rewriter);
1089 didEncounterError =
true;
1091 nestedBuilder.create<linalg::YieldOp>(loc, result);
1094 if (!didEncounterError)
1096 op,
"unable to create linalg.generic body for reduce op");
1099 uint64_t expandInputRank =
1100 cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
1101 reassociationMap.resize(expandInputRank);
1103 for (uint64_t i = 0; i < expandInputRank; i++) {
1104 int32_t dimToPush = i > axis ? i + 1 : i;
1108 if (expandInputRank != 0) {
1109 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1110 reassociationMap[expandedDim].push_back(
1119 op, resultTy, linalgOp.
getResults()[0], reassociationMap);
1125 template <
typename SrcOp>
1132 matchAndRewrite(SrcOp op, OpAdaptor operands,
1135 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1143 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1145 auto loc = op.getLoc();
1146 auto input = op.getInput();
1147 auto inputTy = cast<ShapedType>(op.getInput().getType());
1148 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1149 unsigned rank = inputTy.getRank();
1152 if (op.getDoubleRound() && !op.getScale32())
1153 return rewriter.notifyMatchFailure(
1154 op,
"tosa.rescale requires scale32 for double_round to be true");
1156 if (!isa<IntegerType>(inputTy.getElementType()))
1157 return rewriter.notifyMatchFailure(op,
"only support integer type");
1160 for (
int i = 0; i < outputTy.getRank(); i++) {
1161 if (outputTy.isDynamicDim(i)) {
1162 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1171 for (
int i = 0, s = multiplierValues.size(); i < s; i++) {
1172 if (shiftValues[i] > 63) {
1174 multiplierValues[i] = 0;
1181 op.getDoubleRound() &&
1182 llvm::any_of(shiftValues, [](int32_t v) {
return v > 31; });
1185 rewriter.getMultiDimIdentityMap(rank)};
1190 Value multiplierConstant;
1191 int64_t multiplierArg = 0;
1192 if (multiplierValues.size() == 1) {
1193 multiplierConstant = rewriter.create<arith::ConstantOp>(
1194 loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1197 rewriter.getAffineDimExpr(rank - 1)};
1198 auto multiplierType =
1200 rewriter.getI32Type());
1201 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1206 rewriter.getContext()));
1208 multiplierArg = indexingMaps.size() - 1;
1213 Value shiftConstant;
1214 int64_t shiftArg = 0;
1215 if (shiftValues.size() == 1) {
1216 shiftConstant = rewriter.create<arith::ConstantOp>(
1217 loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1220 rewriter.getAffineDimExpr(rank - 1)};
1223 rewriter.getIntegerType(8));
1224 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1228 rewriter.getContext()));
1229 shiftArg = indexingMaps.size() - 1;
1233 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1236 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1237 loc, outputTy.getShape(), outputTy.getElementType(),
1240 auto linalgOp = rewriter.create<linalg::GenericOp>(
1241 loc, outputTy, genericInputs,
ValueRange{emptyTensor}, indexingMaps,
1245 Value value = blockArgs[0];
1253 auto inputZp = createConstFromIntAttribute<int32_t>(
1254 op,
"input_zp", nestedBuilder.getIntegerType(inBitwidth),
1256 auto outputZp = createConstFromIntAttribute<int32_t>(
1257 op,
"output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1259 Value multiplier = multiplierConstant ? multiplierConstant
1260 : blockArgs[multiplierArg];
1261 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1264 if (op.getInputUnsigned()) {
1265 value = nestedBuilder.create<arith::ExtUIOp>(
1266 nestedLoc, nestedBuilder.getI32Type(), value);
1268 value = nestedBuilder.create<arith::ExtSIOp>(
1269 nestedLoc, nestedBuilder.getI32Type(), value);
1274 nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1276 value = nestedBuilder.create<tosa::ApplyScaleOp>(
1277 loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1278 nestedBuilder.getBoolAttr(doubleRound));
1282 nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1285 IntegerType outIntType =
1286 cast<IntegerType>(blockArgs.back().getType());
1287 unsigned outBitWidth = outIntType.getWidth();
1289 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1290 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1293 if (op.getOutputUnsigned()) {
1295 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1298 auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1299 loc, nestedBuilder.getI32IntegerAttr(intMin));
1300 auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1301 loc, nestedBuilder.getI32IntegerAttr(intMax));
1304 nestedBuilder,
false);
1306 if (outIntType.getWidth() < 32) {
1307 value = nestedBuilder.create<arith::TruncIOp>(
1308 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1312 nestedBuilder.create<linalg::YieldOp>(loc, value);
1315 rewriter.replaceOp(op, linalgOp->getResults());
1327 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1331 auto input = op.getInput();
1332 auto inputTy = cast<RankedTensorType>(input.getType());
1333 auto resultTy = cast<RankedTensorType>(op.getType());
1334 const bool isBilinear = op.getMode() ==
"BILINEAR";
1336 auto inputH = inputTy.getDimSize(1);
1337 auto inputW = inputTy.getDimSize(2);
1338 auto outputH = resultTy.getDimSize(1);
1339 auto outputW = resultTy.getDimSize(2);
1341 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1342 return rewriter.notifyMatchFailure(
1343 op,
"tosa.resize is not a pure 1x1->1x1 image operation");
1346 if (op.getMode() !=
"NEAREST_NEIGHBOR" && op.getMode() !=
"BILINEAR")
1347 return rewriter.notifyMatchFailure(
1348 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1350 if (inputTy == resultTy) {
1351 rewriter.replaceOp(op, input);
1359 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1360 reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1361 reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1362 reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1366 inputTy.getElementType());
1367 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1372 if (inputTy.isDynamicDim(0))
1373 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1374 if (inputTy.isDynamicDim(3))
1375 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1378 auto genericTy = collapseTy.clone(resultTy.getElementType());
1379 Value empty = builder.create<tensor::EmptyOp>(
1380 genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1381 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1383 utils::IteratorType::parallel);
1385 auto generic = builder.create<linalg::GenericOp>(
1389 Value value = args[0];
1391 if (inputTy.getElementType() != resultTy.getElementType()) {
1393 b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1395 if (isBilinear && scale[0] != 0) {
1396 Value scaleY = b.create<arith::ConstantOp>(
1397 loc, b.getI32IntegerAttr(scale[0]));
1398 value = b.create<arith::MulIOp>(loc, value, scaleY);
1401 if (isBilinear && scale[2] != 0) {
1402 Value scaleX = b.create<arith::ConstantOp>(
1403 loc, b.getI32IntegerAttr(scale[2]));
1404 value = b.create<arith::MulIOp>(loc, value, scaleX);
1408 b.create<linalg::YieldOp>(loc, value);
1411 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1412 op, resultTy,
generic.getResults()[0], reassociationMap);
1424 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1428 auto input = op.getInput();
1429 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1430 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1432 if (!inputTy || !resultTy)
1433 return rewriter.notifyMatchFailure(op,
1434 "requires ranked input/output types");
1436 auto batch = inputTy.getDimSize(0);
1437 auto channels = inputTy.getDimSize(3);
1438 auto inputH = inputTy.getDimSize(1);
1439 auto inputW = inputTy.getDimSize(2);
1440 auto outputH = resultTy.getDimSize(1);
1441 auto outputW = resultTy.getDimSize(2);
1443 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1444 return rewriter.notifyMatchFailure(
1445 op,
"tosa.resize has no broadcasting behavior");
1450 resizeShape.push_back(batch);
1451 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1452 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1453 resizeShape.push_back(channels);
1455 auto resizeTy = resultTy.clone(resizeShape);
1457 builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs());
1461 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1462 reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1464 reassociationMap.push_back({});
1465 reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1467 reassociationMap.push_back({});
1468 reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1472 collapseShape.push_back(outputH);
1474 collapseShape.push_back(outputW);
1475 collapseShape.push_back(channels);
1477 auto collapseTy = resultTy.clone(collapseShape);
1478 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1483 if (inputTy.isDynamicDim(0))
1484 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1485 if (inputTy.isDynamicDim(3))
1486 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1489 utils::IteratorType::parallel);
1490 Value empty = builder.create<tensor::EmptyOp>(
1491 resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1495 inputExprs.push_back(rewriter.getAffineDimExpr(1));
1497 inputExprs.push_back(rewriter.getAffineDimExpr(2));
1498 inputExprs.push_back(rewriter.getAffineDimExpr(3));
1501 inputExprs, rewriter.getContext());
1503 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1504 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1508 Value value = args[0];
1509 b.create<linalg::YieldOp>(loc, value);
1520 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1524 auto input = op.getInput();
1525 auto inputTy = cast<ShapedType>(input.getType());
1526 auto resultTy = cast<ShapedType>(op.getType());
1527 auto resultETy = resultTy.getElementType();
1529 bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1530 auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1532 auto imageH = inputTy.getShape()[1];
1533 auto imageW = inputTy.getShape()[2];
1535 auto dynamicDimsOr =
1537 if (!dynamicDimsOr.has_value())
1538 return rewriter.notifyMatchFailure(
1539 op,
"unable to get dynamic dimensions of tosa.resize");
1541 if (op.getMode() !=
"NEAREST_NEIGHBOR" && op.getMode() !=
"BILINEAR")
1542 return rewriter.notifyMatchFailure(
1543 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1546 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1547 auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1549 auto genericOp = b.create<linalg::GenericOp>(
1552 Value resize = genericOp.getResult(0);
1556 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1558 Value batch = b.create<linalg::IndexOp>(0);
1559 Value y = b.create<linalg::IndexOp>(1);
1560 Value x = b.create<linalg::IndexOp>(2);
1561 Value channel = b.create<linalg::IndexOp>(3);
1564 b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1565 Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1566 Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1567 Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1569 Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1570 Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1576 Value yScaleN, yScaleD, xScaleN, xScaleD;
1577 yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1578 yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1579 xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1580 xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1582 Value yOffset, xOffset, yBorder, xBorder;
1583 yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1584 xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1585 yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1586 xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1599 Value val = b.create<arith::MulIOp>(in, scaleD);
1600 val = b.create<arith::AddIOp>(val, offset);
1601 index = b.create<arith::FloorDivSIOp>(val, scaleN);
1605 Value r = b.create<arith::RemSIOp>(val, scaleN);
1606 Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1607 Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1608 delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1623 Value val = b.create<arith::MulIOp>(in, scaleD);
1624 val = b.create<arith::AddIOp>(val, offset);
1625 index = b.create<arith::DivSIOp>(val, scaleN);
1626 delta = b.create<arith::MulIOp>(index, scaleN);
1627 delta = b.create<arith::SubIOp>(val, delta);
1630 Value ix, iy, dx, dy;
1631 if (floatingPointMode) {
1632 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1633 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1635 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1636 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1639 if (op.getMode() ==
"NEAREST_NEIGHBOR") {
1640 auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1642 auto getNearestIndexAndClamp = [&](
Value val,
Value dval,
Value scale,
1646 return b.create<arith::ConstantIndexOp>(0);
1650 if (floatingPointMode) {
1651 auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1652 pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1654 Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1655 pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1659 auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1660 val = b.create<arith::AddIOp>(val, offset);
1662 return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1665 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1666 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1668 Value result = b.create<tensor::ExtractOp>(
1671 b.create<linalg::YieldOp>(result);
1674 assert(op.getMode() ==
"BILINEAR");
1676 auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1678 auto getClampedIdxs = [&](
Value &val0,
Value &val1,
int size,
Value in,
1681 val1 = b.create<arith::AddIOp>(val0, oneVal);
1686 val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1687 val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1695 Value x0, x1, y0, y1;
1696 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1697 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1699 Value y0x0 = b.create<tensor::ExtractOp>(
1701 Value y0x1 = b.create<tensor::ExtractOp>(
1703 Value y1x0 = b.create<tensor::ExtractOp>(
1705 Value y1x1 = b.create<tensor::ExtractOp>(
1708 if (floatingPointMode) {
1710 b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1716 Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1717 Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1718 Value mul1 = b.create<arith::MulFOp>(val1, delta);
1719 return b.create<arith::AddFOp>(mul0, mul1);
1725 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1730 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1734 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1735 b.create<linalg::YieldOp>(result);
1738 y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1739 y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1740 y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1741 y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1744 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1745 dx = b.create<arith::ExtSIOp>(resultETy, dx);
1746 dy = b.create<arith::ExtSIOp>(resultETy, dy);
1749 Value yScaleNExt = yScaleN;
1750 Value xScaleNExt = xScaleN;
1752 const int64_t scaleBitwidth =
1754 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
1755 yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
1756 xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
1760 Value scale,
int inputSize,
1763 return b.create<arith::MulIOp>(val0, scale);
1764 Value weight0 = b.create<arith::SubIOp>(scale, weight1);
1765 Value mul0 = b.create<arith::MulIOp>(val0, weight0);
1766 Value mul1 = b.create<arith::MulIOp>(val1, weight1);
1767 return b.create<arith::AddIOp>(mul0, mul1);
1770 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
1771 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
1773 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
1774 b.create<linalg::YieldOp>(result);
1779 rewriter.replaceOp(op, resize);
1787 template <
typename SrcOp>
1792 LogicalResult matchAndRewrite(SrcOp op,
1794 rewriter.replaceOp(op, op.getOperation()->getOperands());
1799 template <
typename SrcOp>
1804 LogicalResult matchAndRewrite(SrcOp reduceOp,
1814 LogicalResult matchAndRewrite(tosa::ReverseOp op,
1816 auto loc = op.getLoc();
1817 Value input = op.getInput1();
1818 auto inputTy = cast<ShapedType>(input.
getType());
1819 auto resultTy = cast<ShapedType>(op.getType());
1820 auto axis = op.getAxis();
1823 for (
int i = 0; i < inputTy.getRank(); i++) {
1824 if (inputTy.isDynamicDim(i)) {
1825 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1829 Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1832 auto emptyTensor = rewriter
1833 .create<tensor::EmptyOp>(loc, inputTy.getShape(),
1834 inputTy.getElementType(),
1838 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1840 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1845 for (
unsigned int i = 0; i < inputTy.getRank(); i++) {
1847 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1849 auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1851 rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1852 index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1856 indices.push_back(index);
1859 auto extract = nestedBuilder.create<tensor::ExtractOp>(
1860 nestedLoc, input, indices);
1861 nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
1862 extract.getResult());
1876 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1878 auto loc = op.getLoc();
1879 auto input = op.getInput1();
1880 auto inputTy = cast<ShapedType>(input.getType());
1881 auto inputShape = inputTy.getShape();
1882 auto resultTy = cast<ShapedType>(op.getType());
1883 auto elementTy = inputTy.getElementType();
1884 int64_t rank = inputTy.getRank();
1890 for (
int i = 0; i < rank; i++) {
1891 int64_t dim = multiples[i];
1892 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
1893 genericShape.push_back(inputShape[i]);
1897 for (
int i = 0; i < inputTy.getRank(); i++) {
1898 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1899 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1903 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
1904 op.getLoc(), genericShape, elementTy, dynDims);
1908 dimExprs.reserve(rank);
1909 for (
unsigned i = 0; i < rank; ++i)
1912 auto readAffineMap =
1919 auto genericOp = rewriter.
create<linalg::GenericOp>(
1924 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1928 op, resultTy, genericOp.getResult(0),
1951 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
1953 auto loc = argmaxOp.getLoc();
1954 Value input = argmaxOp.getInput();
1955 auto inputTy = cast<ShapedType>(input.
getType());
1956 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
1957 auto inElementTy = inputTy.getElementType();
1958 auto outElementTy = resultTy.getElementType();
1959 int axis = argmaxOp.getAxis();
1962 if (!isa<IntegerType>(outElementTy))
1965 "tosa.arg_max to linalg.* requires integer-like result type");
1968 for (
int i = 0; i < inputTy.getRank(); i++) {
1969 if (inputTy.isDynamicDim(i) && i != axis) {
1970 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1975 auto emptyTensorIdx = rewriter
1976 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
1977 outElementTy, dynDims)
1979 auto fillValueIdx = rewriter.
create<arith::ConstantOp>(
1981 auto filledTensorIdx =
1988 auto emptyTensorMax = rewriter
1989 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
1990 inElementTy, dynDims)
1992 auto fillValueMaxAttr =
1995 if (!fillValueMaxAttr)
1997 argmaxOp,
"unsupported tosa.argmax element type");
2000 rewriter.
create<arith::ConstantOp>(loc, fillValueMaxAttr);
2001 auto filledTensorMax =
2010 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2011 iteratorTypes[axis] = utils::IteratorType::reduction;
2015 for (
int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2021 bool didEncounterError =
false;
2024 auto linalgOp = rewriter.
create<linalg::GenericOp>(
2026 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2029 auto newValue = blockArgs[0];
2030 auto oldIndex = blockArgs[1];
2031 auto oldValue = blockArgs[2];
2033 Value newIndex = rewriter.
create<arith::IndexCastOp>(
2034 nestedLoc, oldIndex.getType(),
2035 rewriter.
create<linalg::IndexOp>(loc, axis));
2038 if (isa<FloatType>(inElementTy)) {
2039 predicate = rewriter.
create<arith::CmpFOp>(
2040 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2041 }
else if (isa<IntegerType>(inElementTy)) {
2042 predicate = rewriter.
create<arith::CmpIOp>(
2043 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2045 didEncounterError =
true;
2049 auto resultMax = rewriter.
create<arith::SelectOp>(
2050 nestedLoc, predicate, newValue, oldValue);
2051 auto resultIndex = rewriter.
create<arith::SelectOp>(
2052 nestedLoc, predicate, newIndex, oldIndex);
2053 nestedBuilder.
create<linalg::YieldOp>(
2054 nestedLoc,
ValueRange({resultIndex, resultMax}));
2057 if (didEncounterError)
2059 argmaxOp,
"unsupported tosa.argmax element type");
2061 rewriter.
replaceOp(argmaxOp, linalgOp.getResult(0));
2070 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2072 auto input = adaptor.getOperands()[0];
2073 auto indices = adaptor.getOperands()[1];
2076 dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
2077 auto resultTy = cast<ShapedType>(op.getType());
2082 auto dynamicDims = inferDynamicDimsForGather(
2083 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2085 auto resultElementTy = resultTy.getElementType();
2087 auto loc = op.getLoc();
2090 .
create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2096 resultTy.getRank(), 0,
2097 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2101 auto genericOp = rewriter.
create<linalg::GenericOp>(
2106 auto indexValue = args[0];
2107 auto index0 = rewriter.
create<linalg::IndexOp>(loc, 0);
2108 Value index1 = rewriter.
create<arith::IndexCastOp>(
2110 auto index2 = rewriter.
create<linalg::IndexOp>(loc, 2);
2111 Value extract = rewriter.
create<tensor::ExtractOp>(
2112 loc, input,
ValueRange{index0, index1, index2});
2113 rewriter.
create<linalg::YieldOp>(loc, extract);
2115 rewriter.
replaceOp(op, genericOp.getResult(0));
2125 auto addDynamicDimension = [&](
Value source, int64_t dim) {
2127 if (
auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2128 results.push_back(dimValue);
2131 addDynamicDimension(values, 0);
2132 addDynamicDimension(indices, 1);
2133 addDynamicDimension(values, 2);
2145 LogicalResult matchAndRewrite(tosa::TableOp op,
2147 auto loc = op.getLoc();
2148 Value input = op.getInput1();
2150 auto inputTy = cast<ShapedType>(input.
getType());
2151 auto tableTy = cast<ShapedType>(
table.getType());
2152 auto resultTy = cast<ShapedType>(op.getType());
2154 auto inputElementTy = inputTy.getElementType();
2155 auto tableElementTy = tableTy.getElementType();
2156 auto resultElementTy = resultTy.getElementType();
2159 for (
int i = 0; i < resultTy.getRank(); ++i) {
2160 if (inputTy.isDynamicDim(i)) {
2162 rewriter.
create<tensor::DimOp>(loc, op.getOperand(0), i));
2166 auto emptyTensor = rewriter
2167 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2168 resultElementTy, dynDims)
2175 auto genericOp = rewriter.
create<linalg::GenericOp>(
2178 rewriter.
replaceOp(op, genericOp.getResult(0));
2183 &genericOp.getRegion(), genericOp.getRegion().end(),
2184 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2188 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2189 resultElementTy.isInteger(8)) {
2190 Value index = rewriter.
create<arith::IndexCastOp>(
2192 Value offset = rewriter.
create<arith::ConstantIndexOp>(loc, 128);
2197 rewriter.
create<linalg::YieldOp>(loc, extract);
2201 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2202 resultElementTy.isInteger(32)) {
2206 auto offset = rewriter.
create<arith::ConstantOp>(
2208 auto seven = rewriter.
create<arith::ConstantOp>(
2210 auto one = rewriter.
create<arith::ConstantOp>(
2212 auto b1111111 = rewriter.
create<arith::ConstantOp>(
2219 auto extendAdd = rewriter.
create<arith::AddIOp>(loc, extend, offset);
2220 Value index = rewriter.
create<arith::ShRUIOp>(loc, extendAdd, seven);
2222 rewriter.
create<arith::AndIOp>(loc, extendAdd, b1111111);
2227 Value indexPlusOne = rewriter.
create<arith::AddIOp>(loc, index, one);
2229 index = rewriter.
create<arith::IndexCastOp>(
2231 indexPlusOne = rewriter.
create<arith::IndexCastOp>(
2246 Value baseScaled = rewriter.
create<arith::ShLIOp>(loc, base, seven);
2247 Value diff = rewriter.
create<arith::SubIOp>(loc, next, base);
2248 Value diffScaled = rewriter.
create<arith::MulIOp>(loc, diff, fraction);
2250 rewriter.
create<arith::AddIOp>(loc, baseScaled, diffScaled);
2252 rewriter.
create<linalg::YieldOp>(loc, result);
2259 op,
"unable to create body for tosa.table op");
2266 static bool isRankedTensor(
Type type) {
return isa<RankedTensorType>(type); }
2270 auto one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2271 auto two = builder.
create<arith::ConstantIndexOp>(loc, 2);
2274 auto divBy2 = builder.
createOrFold<arith::DivUIOp>(loc, value, two);
2275 auto plusOne = builder.
createOrFold<arith::AddIOp>(loc, divBy2, one);
2279 static RankedTensorType
2287 dims[2] = halfPlusOne(builder, loc, dims[2]);
2292 auto elementType = cast<RankedTensorType>(input.
getType()).getElementType();
2297 RankedTensorType type,
2300 rewriter.
create<tensor::EmptyOp>(loc, type, dynamicSizes);
2301 auto fillValueAttr = rewriter.
getZeroAttr(type.getElementType());
2302 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
2303 auto filledTensor = rewriter
2307 return filledTensor;
2312 auto integerVal = builder.
create<arith::IndexCastUIOp>(
2318 return builder.
create<arith::UIToFPOp>(loc, type, integerVal);
2323 auto indexVal = builder.
create<linalg::IndexOp>(loc, index);
2324 return castIndexToFloat(builder, loc, type, indexVal);
2327 template <
typename... Args>
2333 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2335 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2336 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2338 "only supports ranked tensors");
2341 auto loc = rfft2d.getLoc();
2342 auto input = rfft2d.getInput();
2344 dyn_cast<FloatType>(cast<ShapedType>(input.
getType()).getElementType());
2347 "only supports float element types");
2351 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2355 utils::IteratorType::parallel, utils::IteratorType::parallel,
2356 utils::IteratorType::parallel, utils::IteratorType::reduction,
2357 utils::IteratorType::reduction};
2362 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2363 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2368 affineDimsExpr(rewriter, 0, 1, 2),
2369 affineDimsExpr(rewriter, 0, 1, 2)},
2373 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input, 1);
2374 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input, 2);
2377 auto twoPiAttr = rewriter.
getFloatAttr(elementType, 6.283185307179586);
2378 auto twoPi = rewriter.
create<arith::ConstantOp>(loc, twoPiAttr);
2379 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2380 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2383 Value valReal = args[0];
2384 Value sumReal = args[1];
2385 Value sumImag = args[2];
2388 Value oy = builder.
create<linalg::IndexOp>(loc, 1);
2389 Value ox = builder.
create<linalg::IndexOp>(loc, 2);
2390 Value iy = builder.
create<linalg::IndexOp>(loc, 3);
2391 Value ix = builder.
create<linalg::IndexOp>(loc, 4);
2396 auto iyXoy = builder.
create<index::MulOp>(loc, iy, oy);
2397 auto ixXox = builder.
create<index::MulOp>(loc, ix, ox);
2399 auto iyRem = builder.
create<index::RemUOp>(loc, iyXoy, dimH);
2400 auto ixRem = builder.
create<index::RemUOp>(loc, ixXox, dimW);
2402 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2403 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2405 auto yComponent = builder.
create<arith::DivFOp>(loc, iyRemFloat, constH);
2406 auto xComponent = builder.
create<arith::DivFOp>(loc, ixRemFloat, constW);
2407 auto sumXY = builder.
create<arith::AddFOp>(loc, yComponent, xComponent);
2408 auto angle = builder.
create<arith::MulFOp>(loc, twoPi, sumXY);
2412 auto cosAngle = builder.
create<math::CosOp>(loc, angle);
2413 auto sinAngle = builder.
create<math::SinOp>(loc, angle);
2414 auto realComponent =
2415 builder.
create<arith::MulFOp>(loc, valReal, cosAngle);
2416 auto imagComponent =
2417 builder.
create<arith::MulFOp>(loc, valReal, sinAngle);
2421 auto outReal = builder.
create<arith::AddFOp>(loc, sumReal, realComponent);
2422 auto outImag = builder.
create<arith::SubFOp>(loc, sumImag, imagComponent);
2428 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2429 indexingMaps, iteratorTypes, buildBody);
2438 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2440 if (!llvm::all_of(fft2d->getOperandTypes(),
2441 RFFT2dConverter::isRankedTensor) ||
2442 !llvm::all_of(fft2d->getResultTypes(),
2443 RFFT2dConverter::isRankedTensor)) {
2448 Value input_real = fft2d.getInputReal();
2449 Value input_imag = fft2d.getInputImag();
2450 BoolAttr inverse = fft2d.getInverseAttr();
2452 auto real_el_ty = cast<FloatType>(
2453 cast<ShapedType>(input_real.
getType()).getElementType());
2454 [[maybe_unused]]
auto imag_el_ty = cast<FloatType>(
2455 cast<ShapedType>(input_imag.
getType()).getElementType());
2457 assert(real_el_ty == imag_el_ty);
2472 utils::IteratorType::parallel, utils::IteratorType::parallel,
2473 utils::IteratorType::parallel, utils::IteratorType::reduction,
2474 utils::IteratorType::reduction};
2479 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2481 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2486 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2487 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2488 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2489 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2493 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 1);
2494 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 2);
2497 auto twoPiAttr = rewriter.
getFloatAttr(real_el_ty, 6.283185307179586);
2498 auto twoPi = rewriter.
create<arith::ConstantOp>(loc, twoPiAttr);
2500 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2502 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2505 Value valReal = args[0];
2506 Value valImag = args[1];
2507 Value sumReal = args[2];
2508 Value sumImag = args[3];
2511 Value oy = builder.
create<linalg::IndexOp>(loc, 1);
2512 Value ox = builder.
create<linalg::IndexOp>(loc, 2);
2513 Value iy = builder.
create<linalg::IndexOp>(loc, 3);
2514 Value ix = builder.
create<linalg::IndexOp>(loc, 4);
2518 auto iyXoy = builder.
create<index::MulOp>(loc, iy, oy);
2519 auto ixXox = builder.
create<index::MulOp>(loc, ix, ox);
2521 auto iyRem = builder.
create<index::RemUOp>(loc, iyXoy, dimH);
2522 auto ixRem = builder.
create<index::RemUOp>(loc, ixXox, dimW);
2525 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2527 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2529 auto yComponent = builder.
create<arith::DivFOp>(loc, iyRemFloat, constH);
2530 auto xComponent = builder.
create<arith::DivFOp>(loc, ixRemFloat, constW);
2532 auto sumXY = builder.
create<arith::AddFOp>(loc, yComponent, xComponent);
2533 auto angle = builder.
create<arith::MulFOp>(loc, twoPi, sumXY);
2536 angle = builder.
create<arith::MulFOp>(
2538 rewriter.
create<arith::ConstantOp>(
2544 auto cosAngle = builder.
create<math::CosOp>(loc, angle);
2545 auto sinAngle = builder.
create<math::SinOp>(loc, angle);
2547 auto rcos = builder.
create<arith::MulFOp>(loc, valReal, cosAngle);
2548 auto rsin = builder.
create<arith::MulFOp>(loc, valImag, sinAngle);
2549 auto realComponent = builder.
create<arith::AddFOp>(loc, rcos, rsin);
2551 auto icos = builder.
create<arith::MulFOp>(loc, valImag, cosAngle);
2552 auto isin = builder.
create<arith::MulFOp>(loc, valReal, sinAngle);
2554 auto imagComponent = builder.
create<arith::SubFOp>(loc, icos, isin);
2558 auto outReal = builder.
create<arith::AddFOp>(loc, sumReal, realComponent);
2559 auto outImag = builder.
create<arith::AddFOp>(loc, sumImag, imagComponent);
2565 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2566 indexingMaps, iteratorTypes, buildBody);
2587 PointwiseConverter<tosa::AddOp>,
2588 PointwiseConverter<tosa::SubOp>,
2589 PointwiseConverter<tosa::MulOp>,
2590 PointwiseConverter<tosa::IntDivOp>,
2591 PointwiseConverter<tosa::NegateOp>,
2592 PointwiseConverter<tosa::PowOp>,
2593 PointwiseConverter<tosa::ReciprocalOp>,
2594 PointwiseConverter<tosa::RsqrtOp>,
2595 PointwiseConverter<tosa::LogOp>,
2596 PointwiseConverter<tosa::ExpOp>,
2597 PointwiseConverter<tosa::AbsOp>,
2598 PointwiseConverter<tosa::SinOp>,
2599 PointwiseConverter<tosa::CosOp>,
2600 PointwiseConverter<tosa::TanhOp>,
2601 PointwiseConverter<tosa::ErfOp>,
2602 PointwiseConverter<tosa::BitwiseAndOp>,
2603 PointwiseConverter<tosa::BitwiseOrOp>,
2604 PointwiseConverter<tosa::BitwiseNotOp>,
2605 PointwiseConverter<tosa::BitwiseXorOp>,
2606 PointwiseConverter<tosa::LogicalAndOp>,
2607 PointwiseConverter<tosa::LogicalNotOp>,
2608 PointwiseConverter<tosa::LogicalOrOp>,
2609 PointwiseConverter<tosa::LogicalXorOp>,
2610 PointwiseConverter<tosa::CastOp>,
2611 PointwiseConverter<tosa::LogicalLeftShiftOp>,
2612 PointwiseConverter<tosa::LogicalRightShiftOp>,
2613 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2614 PointwiseConverter<tosa::ClzOp>,
2615 PointwiseConverter<tosa::SelectOp>,
2616 PointwiseConverter<tosa::GreaterOp>,
2617 PointwiseConverter<tosa::GreaterEqualOp>,
2618 PointwiseConverter<tosa::EqualOp>,
2619 PointwiseConverter<tosa::MaximumOp>,
2620 PointwiseConverter<tosa::MinimumOp>,
2621 PointwiseConverter<tosa::CeilOp>,
2622 PointwiseConverter<tosa::FloorOp>,
2623 PointwiseConverter<tosa::ClampOp>,
2624 PointwiseConverter<tosa::SigmoidOp>
2625 >(converter,
patterns->getContext());
2628 IdentityNConverter<tosa::IdentityOp>,
2629 ReduceConverter<tosa::ReduceAllOp>,
2630 ReduceConverter<tosa::ReduceAnyOp>,
2631 ReduceConverter<tosa::ReduceMinOp>,
2632 ReduceConverter<tosa::ReduceMaxOp>,
2633 ReduceConverter<tosa::ReduceSumOp>,
2634 ReduceConverter<tosa::ReduceProdOp>,
2642 TileConverter>(
patterns->getContext());
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, int64_t rank)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter)
static SmallVector< Value > expandInputRanks(PatternRewriter &rewriter, Location loc, ValueRange operands, int64_t rank)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static bool operandsAndResultsRanked(Operation *operation)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...