31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
40 static arith::ConstantOp
43 auto castedN =
static_cast<T
>(
44 cast<IntegerAttr>(op->
getAttr(attrName)).getValue().getSExtValue());
45 return rewriter.
create<arith::ConstantOp>(
57 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
58 return rewriter.
create<math::AbsFOp>(loc, resultTypes, args);
60 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
61 auto zero = rewriter.
create<arith::ConstantOp>(
63 auto neg = rewriter.
create<arith::SubIOp>(loc, zero, args[0]);
64 return rewriter.
create<arith::MaxSIOp>(loc, args[0], neg);
68 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
69 return rewriter.
create<arith::AddFOp>(loc, resultTypes, args);
71 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
72 return rewriter.
create<arith::AddIOp>(loc, resultTypes, args);
75 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
76 return rewriter.
create<arith::SubFOp>(loc, resultTypes, args);
78 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
79 return rewriter.
create<arith::SubIOp>(loc, resultTypes, args);
82 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
83 return rewriter.
create<arith::DivSIOp>(loc, resultTypes, args);
86 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
89 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, args[0]);
93 if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
94 return rewriter.
create<arith::MulFOp>(loc, resultTypes, args);
96 if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
100 cast<IntegerAttr>(op->
getAttr(
"shift")).getValue().getSExtValue();
103 rewriter.
create<arith::ConstantIntOp>(loc, shift, 8);
110 auto result = rewriter.
create<tosa::ApplyScaleOp>(
114 if (elementTy.isInteger(32))
117 return rewriter.
create<arith::TruncIOp>(loc, elementTy, result);
122 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
125 a = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], a);
127 b = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], b);
129 return rewriter.
create<arith::MulIOp>(loc, resultTypes, a, b);
133 if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
134 return rewriter.
create<arith::NegFOp>(loc, resultTypes, args);
136 if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) {
137 int64_t inZp = 0, outZp = 0;
139 if (cast<tosa::NegateOp>(op).getQuantizationInfo()) {
140 auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
141 inZp = quantizationInfo.value().getInputZp();
142 outZp = quantizationInfo.value().getOutputZp();
145 int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
146 if (!inZp && !outZp) {
147 auto constant = rewriter.
create<arith::ConstantOp>(
149 return rewriter.
create<arith::SubIOp>(loc, resultTypes, constant,
154 int64_t zpAdd = inZp + outZp;
155 int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
160 int intermediateBitWidth = 64;
161 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
162 intermediateBitWidth = 16;
163 }
else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
164 intermediateBitWidth = 32;
165 }
else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
166 intermediateBitWidth = 48;
170 Value zpAddValue = rewriter.
create<arith::ConstantOp>(
175 auto ext = rewriter.
create<arith::ExtSIOp>(loc, intermediateType, args[0]);
176 auto sub = rewriter.
create<arith::SubIOp>(loc, zpAddValue, ext);
180 loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
183 loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
189 return rewriter.
create<arith::TruncIOp>(loc, elementTy,
clamp);
193 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
194 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
197 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
198 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
201 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
203 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
204 auto allOnes = rewriter.
create<arith::ConstantOp>(loc, allOnesAttr);
205 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
209 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
210 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
213 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
214 return rewriter.
create<arith::ShLIOp>(loc, resultTypes, args);
217 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
218 return rewriter.
create<arith::ShRUIOp>(loc, resultTypes, args);
221 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
222 auto result = rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args);
223 auto round = cast<BoolAttr>(op->
getAttr(
"round")).getValue();
237 auto shiftValueGreaterThanZero = rewriter.
create<arith::CmpIOp>(
238 loc, arith::CmpIPredicate::sgt, args[1], zero);
242 rewriter.
create<arith::SubIOp>(loc, resultTypes, args[1], one);
244 rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
247 rewriter.
create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
249 rewriter.
create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
251 auto shouldRound = rewriter.
create<arith::AndIOp>(
252 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
254 rewriter.
create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
255 return rewriter.
create<arith::AddIOp>(loc, resultTypes, result, extended);
259 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
260 return rewriter.
create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
264 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
265 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
268 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
269 auto one = rewriter.
create<arith::ConstantOp>(
271 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], one);
275 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
276 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
279 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
280 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
283 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
284 return rewriter.
create<mlir::math::PowFOp>(loc, resultTypes, args);
287 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
288 return rewriter.
create<mlir::math::RsqrtOp>(loc, resultTypes, args);
291 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
292 return rewriter.
create<mlir::math::LogOp>(loc, resultTypes, args);
295 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
296 return rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, args);
299 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
300 return rewriter.
create<mlir::math::SinOp>(loc, resultTypes, args);
303 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
304 return rewriter.
create<mlir::math::CosOp>(loc, resultTypes, args);
307 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
308 return rewriter.
create<mlir::math::TanhOp>(loc, resultTypes, args);
311 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
312 return rewriter.
create<mlir::math::ErfOp>(loc, resultTypes, args);
315 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
316 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
319 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
320 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
324 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
325 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
328 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
329 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
333 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
334 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
337 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
338 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
342 if (isa<tosa::SelectOp>(op)) {
344 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
345 return rewriter.
create<arith::SelectOp>(loc, args[0], args[1], args[2]);
349 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
350 return rewriter.
create<arith::MaximumFOp>(loc, args[0], args[1]);
353 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
354 return rewriter.
create<arith::MaxSIOp>(loc, args[0], args[1]);
358 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
359 return rewriter.
create<arith::MinimumFOp>(loc, args[0], args[1]);
362 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
363 return rewriter.
create<arith::MinSIOp>(loc, args[0], args[1]);
367 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
368 return rewriter.
create<math::CeilOp>(loc, resultTypes, args);
371 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
372 return rewriter.
create<math::FloorOp>(loc, resultTypes, args);
375 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
376 bool losesInfo =
false;
377 APFloat minApf = cast<FloatAttr>(op->
getAttr(
"min_fp")).getValue();
378 APFloat maxApf = cast<FloatAttr>(op->
getAttr(
"max_fp")).getValue();
379 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
380 APFloat::rmNearestTiesToEven, &losesInfo);
381 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
382 APFloat::rmNearestTiesToEven, &losesInfo);
383 auto min = rewriter.
create<arith::ConstantOp>(
384 loc, elementTy, rewriter.
getFloatAttr(elementTy, minApf));
385 auto max = rewriter.
create<arith::ConstantOp>(
386 loc, elementTy, rewriter.
getFloatAttr(elementTy, maxApf));
390 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
391 auto intTy = cast<IntegerType>(elementTy);
393 cast<IntegerAttr>(op->
getAttr(
"min_int")).getValue().getSExtValue();
395 cast<IntegerAttr>(op->
getAttr(
"max_int")).getValue().getSExtValue();
399 if (intTy.isUnsignedInteger()) {
400 minRepresentable = 0;
401 if (intTy.getIntOrFloatBitWidth() <= 63) {
403 (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
406 }
else if (intTy.getIntOrFloatBitWidth() <= 64) {
408 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
410 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
420 auto minVal = rewriter.
create<arith::ConstantIntOp>(
421 loc,
min, intTy.getIntOrFloatBitWidth());
422 auto maxVal = rewriter.
create<arith::ConstantIntOp>(
423 loc,
max, intTy.getIntOrFloatBitWidth());
425 intTy.isUnsignedInteger());
429 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
432 auto negate = rewriter.
create<arith::NegFOp>(loc, resultTypes, args[0]);
433 auto exp = rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, negate);
434 auto added = rewriter.
create<arith::AddFOp>(loc, resultTypes, exp, one);
435 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, added);
439 if (isa<tosa::CastOp>(op)) {
440 Type srcTy = elementTy;
441 Type dstTy = resultTypes.front();
448 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
449 return rewriter.
create<arith::ExtFOp>(loc, resultTypes, args,
452 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
453 return rewriter.
create<arith::TruncFOp>(loc, resultTypes, args,
457 if (srcTy.
isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
458 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes, args,
461 if (srcTy.
isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
462 return rewriter.
create<arith::ExtUIOp>(loc, resultTypes, args,
468 auto unrealizedCast =
470 .
create<UnrealizedConversionCastOp>(
474 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes[0],
479 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
480 return rewriter.
create<arith::SIToFPOp>(loc, resultTypes, args,
484 if (isa<FloatType>(srcTy) && dstTy.
isInteger(1)) {
487 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
491 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
492 auto rounded = rewriter.
create<math::RoundEvenOp>(loc, args[0]);
494 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
498 APFloat::semanticsMaxExponent(fltSemantics)) {
501 auto conv = rewriter.
create<arith::FPToSIOp>(loc, dstTy, rounded);
502 auto posInf = rewriter.
create<arith::ConstantOp>(
504 APFloat::getInf(fltSemantics)));
505 auto negInf = rewriter.
create<arith::ConstantOp>(
508 APFloat::getInf(fltSemantics,
true)));
509 auto overflow = rewriter.
create<arith::CmpFOp>(
510 loc, arith::CmpFPredicate::UEQ, rounded, posInf);
511 auto underflow = rewriter.
create<arith::CmpFOp>(
512 loc, arith::CmpFPredicate::UEQ, rounded, negInf);
513 auto intMin = rewriter.
create<arith::ConstantOp>(
516 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
517 auto intMax = rewriter.
create<arith::ConstantOp>(
520 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
522 rewriter.
create<arith::SelectOp>(loc, overflow, intMax, conv);
523 return rewriter.
create<arith::SelectOp>(loc, underflow, intMin,
527 auto intMinFP = rewriter.
create<arith::ConstantOp>(
534 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
540 auto intMaxFP = rewriter.
create<arith::ConstantOp>(
548 return rewriter.
create<arith::FPToSIOp>(loc, dstTy, clamped);
555 auto intMaxPlusOneFP = rewriter.
create<arith::ConstantOp>(
563 auto intMax = rewriter.
create<arith::ConstantOp>(
568 rewriter.
create<arith::MaximumFOp>(loc, rounded, intMinFP);
570 rewriter.
create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
571 auto overflow = rewriter.
create<arith::CmpFOp>(
572 loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
573 return rewriter.
create<arith::SelectOp>(loc, overflow, intMax,
579 if (isa<IntegerType>(srcTy) && dstTy.
isInteger(1)) {
580 Value zero = rewriter.
create<arith::ConstantIntOp>(
582 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
586 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
587 return rewriter.
create<arith::ExtSIOp>(loc, resultTypes, args,
590 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
591 return rewriter.
create<arith::TruncIOp>(loc, dstTy, args[0]);
596 op,
"unhandled op for linalg body calculation for elementwise op");
603 auto shapedType = dyn_cast<ShapedType>(tensor.
getType());
604 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
605 int64_t numExtraDims = rank - shapedType.getRank();
606 assert(numExtraDims >= 0 &&
"cannot expand tensor to a lower rank");
612 shapedType.getRank());
614 for (index = 0; index <= numExtraDims; index++)
615 reassociationIndices[0].push_back(index);
616 for (
size_t position = 1; position < reassociationIndices.size(); position++)
617 reassociationIndices[position].push_back(index++);
621 for (index = 0; index < numExtraDims; index++)
622 resultShape.push_back(1);
623 for (
auto size : shapedType.getShape())
624 resultShape.push_back(size);
629 return rewriter.
create<tensor::ExpandShapeOp>(loc, resultType, tensor,
630 reassociationIndices);
636 return llvm::map_to_vector(operands, [&](
Value operand) {
637 return expandRank(rewriter, loc, operand, rank);
648 auto [it, inserted] = indexPool.try_emplace(index);
657 auto indexValue =
createIndex(rewriter, loc, indexPool, index);
658 return rewriter.
create<tensor::DimOp>(loc, tensor, indexValue).getResult();
664 auto shapedType = dyn_cast<ShapedType>(tensor.
getType());
665 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
666 assert(index >= 0 && index < shapedType.getRank() &&
"index out of bounds");
667 if (shapedType.isDynamicDim(index))
668 return getTensorDim(rewriter, loc, indexPool, tensor, index);
669 return rewriter.
getIndexAttr(shapedType.getDimSize(index));
673 auto isRanked = [](
Value value) {
674 return isa<RankedTensorType>(value.getType());
676 return llvm::all_of(operation->
getOperands(), isRanked) &&
677 llvm::all_of(operation->
getResults(), isRanked);
690 static std::pair<OpFoldResult, Value>
696 for (
auto operand : operands) {
697 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
698 if (!ShapedType::isDynamic(size) && size > 1)
703 auto operandsWithDynamicDim =
704 llvm::to_vector(llvm::make_filter_range(operands, [&](
Value operand) {
705 return cast<RankedTensorType>(operand.
getType()).isDynamicDim(dim);
709 if (operandsWithDynamicDim.empty())
716 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
717 if (operandsWithDynamicDim.size() == 1)
718 return {targetSize, operandsWithDynamicDim[0]};
721 for (
size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
723 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
724 targetSize = rewriter.
create<arith::MaxUIOp>(loc, targetSize, nextSize);
726 return {targetSize,
nullptr};
734 assert(!operands.empty());
735 auto rank = cast<RankedTensorType>(operands.front().
getType()).getRank();
738 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
739 auto [targetSize, masterOperand] =
741 targetShape.push_back(targetSize);
742 masterOperands.push_back(masterOperand);
744 return {targetShape, masterOperands};
750 Value masterOperand) {
752 auto rankedTensorType = cast<RankedTensorType>(operand.
getType());
753 if (!rankedTensorType.isDynamicDim(dim))
760 if (operand == masterOperand)
764 auto rank = rankedTensorType.getRank();
766 for (
auto index : llvm::seq<int64_t>(0, rank)) {
769 affineExprs.push_back(affineExpr);
771 auto broadcastAffineMap =
777 auto one =
createIndex(rewriter, loc, indexPool, 1);
778 auto runtimeSize =
getTensorDim(rewriter, loc, indexPool, operand, dim);
779 auto broadcastNecessary = rewriter.
create<arith::CmpIOp>(
780 loc, arith::CmpIPredicate::eq, runtimeSize, one);
790 for (
auto index : llvm::seq<int64_t>(0, rank)) {
791 auto size = index == dim ? targetSize
794 outputTensorShape.push_back(size);
796 Value outputTensor = opBuilder.
create<tensor::EmptyOp>(
797 loc, outputTensorShape, rankedTensorType.getElementType());
802 .
create<linalg::GenericOp>(
803 loc, outputTensor.
getType(), operand, outputTensor, affineMaps,
807 opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
812 auto castResultTensor = rewriter.
createOrFold<tensor::CastOp>(
813 loc, operand.
getType(), resultTensor);
816 opBuilder.create<scf::YieldOp>(loc, castResultTensor);
821 opBuilder.
create<scf::YieldOp>(loc, operand);
825 auto ifOp = rewriter.
create<scf::IfOp>(loc, broadcastNecessary,
826 emitThenRegion, emitElseRegion);
834 int64_t rank = cast<RankedTensorType>(operand.
getType()).getRank();
835 assert((int64_t)targetShape.size() == rank);
836 assert((int64_t)masterOperands.size() == rank);
837 for (
auto index : llvm::seq<int64_t>(0, rank))
840 targetShape[index], masterOperands[index]);
850 if (operands.size() == 1)
854 return llvm::map_to_vector(operands, [&](
Value operand) {
856 targetShape, masterOperands);
866 auto resultType = cast_or_null<RankedTensorType>(
871 Value outputTensor = rewriter.
create<tensor::EmptyOp>(
872 loc, targetShape, resultType.getElementType());
877 auto rank = resultType.getRank();
878 auto affineMaps = llvm::map_to_vector(operands, [&](
Value operand) {
879 auto shape = cast<ShapedType>(operand.
getType()).getShape();
884 affineExprs.push_back(affineExpr);
891 bool encounteredError =
false;
892 auto linalgOp = rewriter.
create<linalg::GenericOp>(
893 loc, outputTensor.
getType(), operands, outputTensor, affineMaps,
898 {resultType.getElementType()}, rewriter);
900 encounteredError =
true;
903 opBuilder.create<linalg::YieldOp>(loc, opResult);
905 if (encounteredError)
907 operation,
"unable to create linalg.generic body for elementwise op");
910 auto castResult = rewriter.
createOrFold<tensor::CastOp>(
911 loc, resultType, linalgOp->getResult(0));
912 rewriter.
replaceOp(operation, castResult);
922 assert(operation->
getNumResults() == 1 &&
"elementwise op expects 1 result");
924 "elementwise op expects at least 1 operand");
927 "Unranked tensors not supported");
931 auto loc = operation->
getLoc();
935 auto [targetShape, masterOperands] =
938 rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
940 targetShape, converter);
947 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
950 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
953 if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy))
956 if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy))
959 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
961 elementTy, APFloat::getLargest(
962 cast<FloatType>(elementTy).getFloatSemantics(),
false));
964 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
968 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
970 elementTy, APFloat::getLargest(
971 cast<FloatType>(elementTy).getFloatSemantics(),
true));
973 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
977 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
980 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
983 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
985 elementTy, APFloat::getLargest(
986 cast<FloatType>(elementTy).getFloatSemantics(),
true));
988 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1002 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1003 return rewriter.
create<arith::AddFOp>(loc, args);
1006 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1007 return rewriter.
create<arith::AddIOp>(loc, args);
1010 if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) {
1011 return rewriter.
create<arith::MulFOp>(loc, args);
1014 if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) {
1015 return rewriter.
create<arith::MulIOp>(loc, args);
1018 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1019 return rewriter.
create<arith::MinimumFOp>(loc, args[0], args[1]);
1022 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1023 return rewriter.
create<arith::MinSIOp>(loc, args[0], args[1]);
1026 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1027 return rewriter.
create<arith::MaximumFOp>(loc, args[0], args[1]);
1030 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1031 return rewriter.
create<arith::MaxSIOp>(loc, args[0], args[1]);
1034 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1035 return rewriter.
create<arith::AndIOp>(loc, args);
1037 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1038 return rewriter.
create<arith::OrIOp>(loc, args);
1051 auto elementTy = resultTy.getElementType();
1056 for (
unsigned i = 0; i < inputTy.getRank(); i++) {
1058 reduceShape.push_back(inputTy.getDimSize(i));
1059 if (inputTy.isDynamicDim(i))
1060 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1067 .
create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1074 op,
"No initial value found for reduction operation");
1076 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
1077 auto filledTensor = rewriter
1082 bool didEncounterError =
false;
1083 auto linalgOp = rewriter.
create<linalg::ReduceOp>(
1084 loc, input, filledTensor, axis,
1087 op, blockArgs, elementTy, rewriter);
1089 didEncounterError =
true;
1091 nestedBuilder.create<linalg::YieldOp>(loc, result);
1094 if (!didEncounterError)
1096 op,
"unable to create linalg.generic body for reduce op");
1099 uint64_t expandInputRank =
1100 cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
1101 reassociationMap.resize(expandInputRank);
1103 for (uint64_t i = 0; i < expandInputRank; i++) {
1104 int32_t dimToPush = i > axis ? i + 1 : i;
1108 if (expandInputRank != 0) {
1109 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1110 reassociationMap[expandedDim].push_back(
1119 op, resultTy, linalgOp.
getResults()[0], reassociationMap);
1125 template <
typename SrcOp>
1132 matchAndRewrite(SrcOp op, OpAdaptor operands,
1135 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1143 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1145 auto loc = op.getLoc();
1146 auto input = op.getInput();
1147 auto inputTy = cast<ShapedType>(op.getInput().getType());
1148 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1149 unsigned rank = inputTy.getRank();
1152 if (op.getDoubleRound() && !op.getScale32())
1153 return rewriter.notifyMatchFailure(
1154 op,
"tosa.rescale requires scale32 for double_round to be true");
1156 if (!isa<IntegerType>(inputTy.getElementType()))
1157 return rewriter.notifyMatchFailure(op,
"only support integer type");
1160 for (
int i = 0; i < outputTy.getRank(); i++) {
1161 if (outputTy.isDynamicDim(i)) {
1162 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1171 for (
int i = 0, s = multiplierValues.size(); i < s; i++) {
1172 if (shiftValues[i] > 63) {
1174 multiplierValues[i] = 0;
1181 op.getDoubleRound() &&
1182 llvm::any_of(shiftValues, [](int32_t v) {
return v > 31; });
1185 rewriter.getMultiDimIdentityMap(rank)};
1190 Value multiplierConstant;
1191 int64_t multiplierArg = 0;
1192 if (multiplierValues.size() == 1) {
1193 multiplierConstant = rewriter.create<arith::ConstantOp>(
1194 loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1197 rewriter.getAffineDimExpr(rank - 1)};
1198 auto multiplierType =
1200 rewriter.getI32Type());
1201 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1206 rewriter.getContext()));
1208 multiplierArg = indexingMaps.size() - 1;
1213 Value shiftConstant;
1214 int64_t shiftArg = 0;
1215 if (shiftValues.size() == 1) {
1216 shiftConstant = rewriter.create<arith::ConstantOp>(
1217 loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1220 rewriter.getAffineDimExpr(rank - 1)};
1223 rewriter.getIntegerType(8));
1224 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1228 rewriter.getContext()));
1229 shiftArg = indexingMaps.size() - 1;
1233 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1236 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1237 loc, outputTy.getShape(), outputTy.getElementType(),
1240 auto linalgOp = rewriter.create<linalg::GenericOp>(
1241 loc, outputTy, genericInputs,
ValueRange{emptyTensor}, indexingMaps,
1245 Value value = blockArgs[0];
1253 auto inputZp = createConstFromIntAttribute<int32_t>(
1254 op,
"input_zp", nestedBuilder.getIntegerType(inBitwidth),
1256 auto outputZp = createConstFromIntAttribute<int32_t>(
1257 op,
"output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1259 Value multiplier = multiplierConstant ? multiplierConstant
1260 : blockArgs[multiplierArg];
1261 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1265 value = nestedBuilder
1266 .create<UnrealizedConversionCastOp>(
1268 nestedBuilder.getIntegerType(
1272 value = nestedBuilder.create<arith::ExtUIOp>(
1273 nestedLoc, nestedBuilder.getI32Type(), value);
1275 value = nestedBuilder.create<arith::ExtSIOp>(
1276 nestedLoc, nestedBuilder.getI32Type(), value);
1281 nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1283 value = nestedBuilder.create<tosa::ApplyScaleOp>(
1284 loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1285 nestedBuilder.getBoolAttr(doubleRound));
1289 nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1292 IntegerType outIntType =
1293 cast<IntegerType>(blockArgs.back().getType());
1294 unsigned outBitWidth = outIntType.getWidth();
1296 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1297 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1300 if (outIntType.isUnsignedInteger()) {
1302 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1305 auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1306 loc, nestedBuilder.getI32IntegerAttr(intMin));
1307 auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1308 loc, nestedBuilder.getI32IntegerAttr(intMax));
1311 nestedBuilder,
false);
1313 if (outIntType.getWidth() < 32) {
1314 value = nestedBuilder.create<arith::TruncIOp>(
1315 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1318 if (outIntType.isUnsignedInteger()) {
1319 value = nestedBuilder
1320 .create<UnrealizedConversionCastOp>(nestedLoc,
1326 nestedBuilder.create<linalg::YieldOp>(loc, value);
1329 rewriter.replaceOp(op, linalgOp->getResults());
1341 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1345 auto input = op.getInput();
1346 auto inputTy = cast<RankedTensorType>(input.getType());
1347 auto resultTy = cast<RankedTensorType>(op.getType());
1348 const bool isBilinear = op.getMode() ==
"BILINEAR";
1350 auto inputH = inputTy.getDimSize(1);
1351 auto inputW = inputTy.getDimSize(2);
1352 auto outputH = resultTy.getDimSize(1);
1353 auto outputW = resultTy.getDimSize(2);
1355 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1356 return rewriter.notifyMatchFailure(
1357 op,
"tosa.resize is not a pure 1x1->1x1 image operation");
1360 if (op.getMode() !=
"NEAREST_NEIGHBOR" && op.getMode() !=
"BILINEAR")
1361 return rewriter.notifyMatchFailure(
1362 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1364 if (inputTy == resultTy) {
1365 rewriter.replaceOp(op, input);
1373 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1374 reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1375 reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1376 reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1380 inputTy.getElementType());
1381 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1386 if (inputTy.isDynamicDim(0))
1387 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1388 if (inputTy.isDynamicDim(3))
1389 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1392 auto genericTy = collapseTy.clone(resultTy.getElementType());
1393 Value empty = builder.create<tensor::EmptyOp>(
1394 genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1395 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1397 utils::IteratorType::parallel);
1399 auto generic = builder.create<linalg::GenericOp>(
1403 Value value = args[0];
1405 if (inputTy.getElementType() != resultTy.getElementType()) {
1407 b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1409 if (isBilinear && scale[0] != 0) {
1410 Value scaleY = b.create<arith::ConstantOp>(
1411 loc, b.getI32IntegerAttr(scale[0]));
1412 value = b.create<arith::MulIOp>(loc, value, scaleY);
1415 if (isBilinear && scale[2] != 0) {
1416 Value scaleX = b.create<arith::ConstantOp>(
1417 loc, b.getI32IntegerAttr(scale[2]));
1418 value = b.create<arith::MulIOp>(loc, value, scaleX);
1422 b.create<linalg::YieldOp>(loc, value);
1425 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1426 op, resultTy,
generic.getResults()[0], reassociationMap);
1438 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1442 auto input = op.getInput();
1443 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1444 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1446 if (!inputTy || !resultTy)
1447 return rewriter.notifyMatchFailure(op,
1448 "requires ranked input/output types");
1450 auto batch = inputTy.getDimSize(0);
1451 auto channels = inputTy.getDimSize(3);
1452 auto inputH = inputTy.getDimSize(1);
1453 auto inputW = inputTy.getDimSize(2);
1454 auto outputH = resultTy.getDimSize(1);
1455 auto outputW = resultTy.getDimSize(2);
1457 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1458 return rewriter.notifyMatchFailure(
1459 op,
"tosa.resize has no broadcasting behavior");
1464 resizeShape.push_back(batch);
1465 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1466 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1467 resizeShape.push_back(channels);
1469 auto resizeTy = resultTy.clone(resizeShape);
1471 builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs());
1475 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1476 reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1478 reassociationMap.push_back({});
1479 reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1481 reassociationMap.push_back({});
1482 reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1486 collapseShape.push_back(outputH);
1488 collapseShape.push_back(outputW);
1489 collapseShape.push_back(channels);
1491 auto collapseTy = resultTy.clone(collapseShape);
1492 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1497 if (inputTy.isDynamicDim(0))
1498 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1499 if (inputTy.isDynamicDim(3))
1500 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1503 utils::IteratorType::parallel);
1504 Value empty = builder.create<tensor::EmptyOp>(
1505 resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1509 inputExprs.push_back(rewriter.getAffineDimExpr(1));
1511 inputExprs.push_back(rewriter.getAffineDimExpr(2));
1512 inputExprs.push_back(rewriter.getAffineDimExpr(3));
1515 inputExprs, rewriter.getContext());
1517 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1518 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1522 Value value = args[0];
1523 b.create<linalg::YieldOp>(loc, value);
1534 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1538 auto input = op.getInput();
1539 auto inputTy = cast<ShapedType>(input.getType());
1540 auto resultTy = cast<ShapedType>(op.getType());
1541 auto resultETy = resultTy.getElementType();
1543 bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1544 auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1546 auto imageH = inputTy.getShape()[1];
1547 auto imageW = inputTy.getShape()[2];
1549 auto dynamicDimsOr =
1551 if (!dynamicDimsOr.has_value())
1552 return rewriter.notifyMatchFailure(
1553 op,
"unable to get dynamic dimensions of tosa.resize");
1555 if (op.getMode() !=
"NEAREST_NEIGHBOR" && op.getMode() !=
"BILINEAR")
1556 return rewriter.notifyMatchFailure(
1557 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1560 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1561 auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1563 auto genericOp = b.create<linalg::GenericOp>(
1566 Value resize = genericOp.getResult(0);
1570 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1572 Value batch = b.create<linalg::IndexOp>(0);
1573 Value y = b.create<linalg::IndexOp>(1);
1574 Value x = b.create<linalg::IndexOp>(2);
1575 Value channel = b.create<linalg::IndexOp>(3);
1578 b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1579 Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1580 Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1581 Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1583 Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1584 Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1590 Value yScaleN, yScaleD, xScaleN, xScaleD;
1591 yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1592 yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1593 xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1594 xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1596 Value yOffset, xOffset, yBorder, xBorder;
1597 yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1598 xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1599 yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1600 xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1613 Value val = b.create<arith::MulIOp>(in, scaleD);
1614 val = b.create<arith::AddIOp>(val, offset);
1615 index = b.create<arith::FloorDivSIOp>(val, scaleN);
1619 Value r = b.create<arith::RemSIOp>(val, scaleN);
1620 Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1621 Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1622 delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1637 Value val = b.create<arith::MulIOp>(in, scaleD);
1638 val = b.create<arith::AddIOp>(val, offset);
1639 index = b.create<arith::DivSIOp>(val, scaleN);
1640 delta = b.create<arith::MulIOp>(index, scaleN);
1641 delta = b.create<arith::SubIOp>(val, delta);
1644 Value ix, iy, dx, dy;
1645 if (floatingPointMode) {
1646 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1647 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1649 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1650 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1653 if (op.getMode() ==
"NEAREST_NEIGHBOR") {
1654 auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1656 auto getNearestIndexAndClamp = [&](
Value val,
Value dval,
Value scale,
1660 return b.create<arith::ConstantIndexOp>(0);
1664 if (floatingPointMode) {
1665 auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1666 pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1668 Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1669 pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1673 auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1674 val = b.create<arith::AddIOp>(val, offset);
1676 return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1679 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1680 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1682 Value result = b.create<tensor::ExtractOp>(
1685 b.create<linalg::YieldOp>(result);
1688 assert(op.getMode() ==
"BILINEAR");
1690 auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1692 auto getClampedIdxs = [&](
Value &val0,
Value &val1,
int size,
Value in,
1695 val1 = b.create<arith::AddIOp>(val0, oneVal);
1700 val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1701 val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1709 Value x0, x1, y0, y1;
1710 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1711 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1713 Value y0x0 = b.create<tensor::ExtractOp>(
1715 Value y0x1 = b.create<tensor::ExtractOp>(
1717 Value y1x0 = b.create<tensor::ExtractOp>(
1719 Value y1x1 = b.create<tensor::ExtractOp>(
1722 if (floatingPointMode) {
1724 b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1730 Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1731 Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1732 Value mul1 = b.create<arith::MulFOp>(val1, delta);
1733 return b.create<arith::AddFOp>(mul0, mul1);
1739 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1744 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1748 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1749 b.create<linalg::YieldOp>(result);
1752 y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1753 y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1754 y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1755 y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1758 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1759 dx = b.create<arith::ExtSIOp>(resultETy, dx);
1760 dy = b.create<arith::ExtSIOp>(resultETy, dy);
1763 Value yScaleNExt = yScaleN;
1764 Value xScaleNExt = xScaleN;
1766 const int64_t scaleBitwidth =
1768 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
1769 yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
1770 xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
1774 Value scale,
int inputSize,
1777 return b.create<arith::MulIOp>(val0, scale);
1778 Value weight0 = b.create<arith::SubIOp>(scale, weight1);
1779 Value mul0 = b.create<arith::MulIOp>(val0, weight0);
1780 Value mul1 = b.create<arith::MulIOp>(val1, weight1);
1781 return b.create<arith::AddIOp>(mul0, mul1);
1784 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
1785 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
1787 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
1788 b.create<linalg::YieldOp>(result);
1793 rewriter.replaceOp(op, resize);
1801 template <
typename SrcOp>
1806 LogicalResult matchAndRewrite(SrcOp op,
1808 rewriter.replaceOp(op, op.getOperation()->getOperands());
1813 template <
typename SrcOp>
1818 LogicalResult matchAndRewrite(SrcOp reduceOp,
1828 LogicalResult matchAndRewrite(tosa::ReverseOp op,
1830 auto loc = op.getLoc();
1831 Value input = op.getInput1();
1832 auto inputTy = cast<ShapedType>(input.
getType());
1833 auto resultTy = cast<ShapedType>(op.getType());
1834 auto axis = op.getAxis();
1837 for (
int i = 0; i < inputTy.getRank(); i++) {
1838 if (inputTy.isDynamicDim(i)) {
1839 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1843 Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1846 auto emptyTensor = rewriter
1847 .create<tensor::EmptyOp>(loc, inputTy.getShape(),
1848 inputTy.getElementType(),
1852 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1854 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1859 for (
unsigned int i = 0; i < inputTy.getRank(); i++) {
1861 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1863 auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1865 rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1866 index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1870 indices.push_back(index);
1873 auto extract = nestedBuilder.create<tensor::ExtractOp>(
1874 nestedLoc, input, indices);
1875 nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
1876 extract.getResult());
1890 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1892 auto loc = op.getLoc();
1893 auto input = op.getInput1();
1894 auto inputTy = cast<ShapedType>(input.getType());
1895 auto inputShape = inputTy.getShape();
1896 auto resultTy = cast<ShapedType>(op.getType());
1897 auto elementTy = inputTy.getElementType();
1898 int64_t rank = inputTy.getRank();
1904 for (
int i = 0; i < rank; i++) {
1905 int64_t dim = multiples[i];
1906 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
1907 genericShape.push_back(inputShape[i]);
1911 for (
int i = 0; i < inputTy.getRank(); i++) {
1912 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1913 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1917 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
1918 op.getLoc(), genericShape, elementTy, dynDims);
1922 dimExprs.reserve(rank);
1923 for (
unsigned i = 0; i < rank; ++i)
1926 auto readAffineMap =
1933 auto genericOp = rewriter.
create<linalg::GenericOp>(
1938 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1942 op, resultTy, genericOp.getResult(0),
1965 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
1967 auto loc = argmaxOp.getLoc();
1968 Value input = argmaxOp.getInput();
1969 auto inputTy = cast<ShapedType>(input.
getType());
1970 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
1971 auto inElementTy = inputTy.getElementType();
1972 auto outElementTy = resultTy.getElementType();
1973 int axis = argmaxOp.getAxis();
1976 if (!isa<IntegerType>(outElementTy))
1979 "tosa.arg_max to linalg.* requires integer-like result type");
1982 for (
int i = 0; i < inputTy.getRank(); i++) {
1983 if (inputTy.isDynamicDim(i) && i != axis) {
1984 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1989 auto emptyTensorIdx = rewriter
1990 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
1991 outElementTy, dynDims)
1993 auto fillValueIdx = rewriter.
create<arith::ConstantOp>(
1995 auto filledTensorIdx =
2002 auto emptyTensorMax = rewriter
2003 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2004 inElementTy, dynDims)
2006 auto fillValueMaxAttr =
2009 if (!fillValueMaxAttr)
2011 argmaxOp,
"unsupported tosa.argmax element type");
2014 rewriter.
create<arith::ConstantOp>(loc, fillValueMaxAttr);
2015 auto filledTensorMax =
2024 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2025 iteratorTypes[axis] = utils::IteratorType::reduction;
2029 for (
int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2035 bool didEncounterError =
false;
2038 auto linalgOp = rewriter.
create<linalg::GenericOp>(
2040 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2043 auto newValue = blockArgs[0];
2044 auto oldIndex = blockArgs[1];
2045 auto oldValue = blockArgs[2];
2047 Value newIndex = rewriter.
create<arith::IndexCastOp>(
2048 nestedLoc, oldIndex.getType(),
2049 rewriter.
create<linalg::IndexOp>(loc, axis));
2052 if (isa<FloatType>(inElementTy)) {
2053 predicate = rewriter.
create<arith::CmpFOp>(
2054 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2055 }
else if (isa<IntegerType>(inElementTy)) {
2056 predicate = rewriter.
create<arith::CmpIOp>(
2057 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2059 didEncounterError =
true;
2063 auto resultMax = rewriter.
create<arith::SelectOp>(
2064 nestedLoc, predicate, newValue, oldValue);
2065 auto resultIndex = rewriter.
create<arith::SelectOp>(
2066 nestedLoc, predicate, newIndex, oldIndex);
2067 nestedBuilder.
create<linalg::YieldOp>(
2068 nestedLoc,
ValueRange({resultIndex, resultMax}));
2071 if (didEncounterError)
2073 argmaxOp,
"unsupported tosa.argmax element type");
2075 rewriter.
replaceOp(argmaxOp, linalgOp.getResult(0));
2084 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2086 auto input = adaptor.getOperands()[0];
2087 auto indices = adaptor.getOperands()[1];
2090 dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
2091 auto resultTy = cast<ShapedType>(op.getType());
2096 auto dynamicDims = inferDynamicDimsForGather(
2097 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2099 auto resultElementTy = resultTy.getElementType();
2101 auto loc = op.getLoc();
2104 .
create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2110 resultTy.getRank(), 0,
2111 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2115 auto genericOp = rewriter.
create<linalg::GenericOp>(
2120 auto indexValue = args[0];
2121 auto index0 = rewriter.
create<linalg::IndexOp>(loc, 0);
2122 Value index1 = rewriter.
create<arith::IndexCastOp>(
2124 auto index2 = rewriter.
create<linalg::IndexOp>(loc, 2);
2125 Value extract = rewriter.
create<tensor::ExtractOp>(
2126 loc, input,
ValueRange{index0, index1, index2});
2127 rewriter.
create<linalg::YieldOp>(loc, extract);
2129 rewriter.
replaceOp(op, genericOp.getResult(0));
2139 auto addDynamicDimension = [&](
Value source, int64_t dim) {
2141 if (
auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2142 results.push_back(dimValue);
2145 addDynamicDimension(values, 0);
2146 addDynamicDimension(indices, 1);
2147 addDynamicDimension(values, 2);
2159 LogicalResult matchAndRewrite(tosa::TableOp op,
2161 auto loc = op.getLoc();
2162 Value input = op.getInput1();
2164 auto inputTy = cast<ShapedType>(input.
getType());
2165 auto tableTy = cast<ShapedType>(
table.getType());
2166 auto resultTy = cast<ShapedType>(op.getType());
2168 auto inputElementTy = inputTy.getElementType();
2169 auto tableElementTy = tableTy.getElementType();
2170 auto resultElementTy = resultTy.getElementType();
2173 for (
int i = 0; i < resultTy.getRank(); ++i) {
2174 if (inputTy.isDynamicDim(i)) {
2176 rewriter.
create<tensor::DimOp>(loc, op.getOperand(0), i));
2180 auto emptyTensor = rewriter
2181 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2182 resultElementTy, dynDims)
2189 auto genericOp = rewriter.
create<linalg::GenericOp>(
2192 rewriter.
replaceOp(op, genericOp.getResult(0));
2197 &genericOp.getRegion(), genericOp.getRegion().end(),
2198 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2202 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2203 resultElementTy.isInteger(8)) {
2204 Value index = rewriter.
create<arith::IndexCastOp>(
2206 Value offset = rewriter.
create<arith::ConstantIndexOp>(loc, 128);
2211 rewriter.
create<linalg::YieldOp>(loc, extract);
2215 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2216 resultElementTy.isInteger(32)) {
2220 auto offset = rewriter.
create<arith::ConstantOp>(
2222 auto seven = rewriter.
create<arith::ConstantOp>(
2224 auto one = rewriter.
create<arith::ConstantOp>(
2226 auto b1111111 = rewriter.
create<arith::ConstantOp>(
2233 auto extendAdd = rewriter.
create<arith::AddIOp>(loc, extend, offset);
2234 Value index = rewriter.
create<arith::ShRUIOp>(loc, extendAdd, seven);
2236 rewriter.
create<arith::AndIOp>(loc, extendAdd, b1111111);
2241 Value indexPlusOne = rewriter.
create<arith::AddIOp>(loc, index, one);
2243 index = rewriter.
create<arith::IndexCastOp>(
2245 indexPlusOne = rewriter.
create<arith::IndexCastOp>(
2260 Value baseScaled = rewriter.
create<arith::ShLIOp>(loc, base, seven);
2261 Value diff = rewriter.
create<arith::SubIOp>(loc, next, base);
2262 Value diffScaled = rewriter.
create<arith::MulIOp>(loc, diff, fraction);
2264 rewriter.
create<arith::AddIOp>(loc, baseScaled, diffScaled);
2266 rewriter.
create<linalg::YieldOp>(loc, result);
2273 op,
"unable to create body for tosa.table op");
2280 static bool isRankedTensor(
Type type) {
return isa<RankedTensorType>(type); }
2284 auto one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2285 auto two = builder.
create<arith::ConstantIndexOp>(loc, 2);
2288 auto divBy2 = builder.
createOrFold<arith::DivUIOp>(loc, value, two);
2289 auto plusOne = builder.
createOrFold<arith::AddIOp>(loc, divBy2, one);
2293 static RankedTensorType
2301 dims[2] = halfPlusOne(builder, loc, dims[2]);
2306 auto elementType = cast<RankedTensorType>(input.
getType()).getElementType();
2311 RankedTensorType type,
2314 rewriter.
create<tensor::EmptyOp>(loc, type, dynamicSizes);
2315 auto fillValueAttr = rewriter.
getZeroAttr(type.getElementType());
2316 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
2317 auto filledTensor = rewriter
2321 return filledTensor;
2326 auto integerVal = builder.
create<arith::IndexCastUIOp>(
2332 return builder.
create<arith::UIToFPOp>(loc, type, integerVal);
2337 auto indexVal = builder.
create<linalg::IndexOp>(loc, index);
2338 return castIndexToFloat(builder, loc, type, indexVal);
2341 template <
typename... Args>
2347 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2349 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2350 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2352 "only supports ranked tensors");
2355 auto loc = rfft2d.getLoc();
2356 auto input = rfft2d.getInput();
2358 dyn_cast<FloatType>(cast<ShapedType>(input.
getType()).getElementType());
2361 "only supports float element types");
2365 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2369 utils::IteratorType::parallel, utils::IteratorType::parallel,
2370 utils::IteratorType::parallel, utils::IteratorType::reduction,
2371 utils::IteratorType::reduction};
2376 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2377 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2382 affineDimsExpr(rewriter, 0, 1, 2),
2383 affineDimsExpr(rewriter, 0, 1, 2)},
2387 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input, 1);
2388 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input, 2);
2391 auto twoPiAttr = rewriter.
getFloatAttr(elementType, 6.283185307179586);
2392 auto twoPi = rewriter.
create<arith::ConstantOp>(loc, twoPiAttr);
2393 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2394 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2397 Value valReal = args[0];
2398 Value sumReal = args[1];
2399 Value sumImag = args[2];
2402 Value oy = builder.
create<linalg::IndexOp>(loc, 1);
2403 Value ox = builder.
create<linalg::IndexOp>(loc, 2);
2404 Value iy = builder.
create<linalg::IndexOp>(loc, 3);
2405 Value ix = builder.
create<linalg::IndexOp>(loc, 4);
2410 auto iyXoy = builder.
create<index::MulOp>(loc, iy, oy);
2411 auto ixXox = builder.
create<index::MulOp>(loc, ix, ox);
2413 auto iyRem = builder.
create<index::RemUOp>(loc, iyXoy, dimH);
2414 auto ixRem = builder.
create<index::RemUOp>(loc, ixXox, dimW);
2416 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2417 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2419 auto yComponent = builder.
create<arith::DivFOp>(loc, iyRemFloat, constH);
2420 auto xComponent = builder.
create<arith::DivFOp>(loc, ixRemFloat, constW);
2421 auto sumXY = builder.
create<arith::AddFOp>(loc, yComponent, xComponent);
2422 auto angle = builder.
create<arith::MulFOp>(loc, twoPi, sumXY);
2426 auto cosAngle = builder.
create<math::CosOp>(loc, angle);
2427 auto sinAngle = builder.
create<math::SinOp>(loc, angle);
2428 auto realComponent =
2429 builder.
create<arith::MulFOp>(loc, valReal, cosAngle);
2430 auto imagComponent =
2431 builder.
create<arith::MulFOp>(loc, valReal, sinAngle);
2435 auto outReal = builder.
create<arith::AddFOp>(loc, sumReal, realComponent);
2436 auto outImag = builder.
create<arith::SubFOp>(loc, sumImag, imagComponent);
2442 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2443 indexingMaps, iteratorTypes, buildBody);
2452 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2454 if (!llvm::all_of(fft2d->getOperandTypes(),
2455 RFFT2dConverter::isRankedTensor) ||
2456 !llvm::all_of(fft2d->getResultTypes(),
2457 RFFT2dConverter::isRankedTensor)) {
2462 Value input_real = fft2d.getInputReal();
2463 Value input_imag = fft2d.getInputImag();
2464 BoolAttr inverse = fft2d.getInverseAttr();
2466 auto real_el_ty = cast<FloatType>(
2467 cast<ShapedType>(input_real.
getType()).getElementType());
2468 [[maybe_unused]]
auto imag_el_ty = cast<FloatType>(
2469 cast<ShapedType>(input_imag.
getType()).getElementType());
2471 assert(real_el_ty == imag_el_ty);
2486 utils::IteratorType::parallel, utils::IteratorType::parallel,
2487 utils::IteratorType::parallel, utils::IteratorType::reduction,
2488 utils::IteratorType::reduction};
2493 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2495 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2500 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2501 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2502 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2503 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2507 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 1);
2508 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 2);
2511 auto twoPiAttr = rewriter.
getFloatAttr(real_el_ty, 6.283185307179586);
2512 auto twoPi = rewriter.
create<arith::ConstantOp>(loc, twoPiAttr);
2514 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2516 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2519 Value valReal = args[0];
2520 Value valImag = args[1];
2521 Value sumReal = args[2];
2522 Value sumImag = args[3];
2525 Value oy = builder.
create<linalg::IndexOp>(loc, 1);
2526 Value ox = builder.
create<linalg::IndexOp>(loc, 2);
2527 Value iy = builder.
create<linalg::IndexOp>(loc, 3);
2528 Value ix = builder.
create<linalg::IndexOp>(loc, 4);
2532 auto iyXoy = builder.
create<index::MulOp>(loc, iy, oy);
2533 auto ixXox = builder.
create<index::MulOp>(loc, ix, ox);
2535 auto iyRem = builder.
create<index::RemUOp>(loc, iyXoy, dimH);
2536 auto ixRem = builder.
create<index::RemUOp>(loc, ixXox, dimW);
2539 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2541 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2543 auto yComponent = builder.
create<arith::DivFOp>(loc, iyRemFloat, constH);
2544 auto xComponent = builder.
create<arith::DivFOp>(loc, ixRemFloat, constW);
2546 auto sumXY = builder.
create<arith::AddFOp>(loc, yComponent, xComponent);
2547 auto angle = builder.
create<arith::MulFOp>(loc, twoPi, sumXY);
2550 angle = builder.
create<arith::MulFOp>(
2552 rewriter.
create<arith::ConstantOp>(
2558 auto cosAngle = builder.
create<math::CosOp>(loc, angle);
2559 auto sinAngle = builder.
create<math::SinOp>(loc, angle);
2561 auto rcos = builder.
create<arith::MulFOp>(loc, valReal, cosAngle);
2562 auto rsin = builder.
create<arith::MulFOp>(loc, valImag, sinAngle);
2563 auto realComponent = builder.
create<arith::AddFOp>(loc, rcos, rsin);
2565 auto icos = builder.
create<arith::MulFOp>(loc, valImag, cosAngle);
2566 auto isin = builder.
create<arith::MulFOp>(loc, valReal, sinAngle);
2568 auto imagComponent = builder.
create<arith::SubFOp>(loc, icos, isin);
2572 auto outReal = builder.
create<arith::AddFOp>(loc, sumReal, realComponent);
2573 auto outImag = builder.
create<arith::AddFOp>(loc, sumImag, imagComponent);
2579 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2580 indexingMaps, iteratorTypes, buildBody);
2592 patterns->
add<GenericResizeConverter>(patterns->
getContext(),
2596 patterns->
add<MaterializeResizeBroadcast>(patterns->
getContext(),
2601 PointwiseConverter<tosa::AddOp>,
2602 PointwiseConverter<tosa::SubOp>,
2603 PointwiseConverter<tosa::MulOp>,
2604 PointwiseConverter<tosa::IntDivOp>,
2605 PointwiseConverter<tosa::NegateOp>,
2606 PointwiseConverter<tosa::PowOp>,
2607 PointwiseConverter<tosa::ReciprocalOp>,
2608 PointwiseConverter<tosa::RsqrtOp>,
2609 PointwiseConverter<tosa::LogOp>,
2610 PointwiseConverter<tosa::ExpOp>,
2611 PointwiseConverter<tosa::AbsOp>,
2612 PointwiseConverter<tosa::SinOp>,
2613 PointwiseConverter<tosa::CosOp>,
2614 PointwiseConverter<tosa::TanhOp>,
2615 PointwiseConverter<tosa::ErfOp>,
2616 PointwiseConverter<tosa::BitwiseAndOp>,
2617 PointwiseConverter<tosa::BitwiseOrOp>,
2618 PointwiseConverter<tosa::BitwiseNotOp>,
2619 PointwiseConverter<tosa::BitwiseXorOp>,
2620 PointwiseConverter<tosa::LogicalAndOp>,
2621 PointwiseConverter<tosa::LogicalNotOp>,
2622 PointwiseConverter<tosa::LogicalOrOp>,
2623 PointwiseConverter<tosa::LogicalXorOp>,
2624 PointwiseConverter<tosa::CastOp>,
2625 PointwiseConverter<tosa::LogicalLeftShiftOp>,
2626 PointwiseConverter<tosa::LogicalRightShiftOp>,
2627 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2628 PointwiseConverter<tosa::ClzOp>,
2629 PointwiseConverter<tosa::SelectOp>,
2630 PointwiseConverter<tosa::GreaterOp>,
2631 PointwiseConverter<tosa::GreaterEqualOp>,
2632 PointwiseConverter<tosa::EqualOp>,
2633 PointwiseConverter<tosa::MaximumOp>,
2634 PointwiseConverter<tosa::MinimumOp>,
2635 PointwiseConverter<tosa::CeilOp>,
2636 PointwiseConverter<tosa::FloorOp>,
2637 PointwiseConverter<tosa::ClampOp>,
2638 PointwiseConverter<tosa::SigmoidOp>
2642 IdentityNConverter<tosa::IdentityOp>,
2643 ReduceConverter<tosa::ReduceAllOp>,
2644 ReduceConverter<tosa::ReduceAnyOp>,
2645 ReduceConverter<tosa::ReduceMinOp>,
2646 ReduceConverter<tosa::ReduceMaxOp>,
2647 ReduceConverter<tosa::ReduceSumOp>,
2648 ReduceConverter<tosa::ReduceProdOp>,
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, int64_t rank)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter)
static SmallVector< Value > expandInputRanks(PatternRewriter &rewriter, Location loc, ValueRange operands, int64_t rank)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static bool operandsAndResultsRanked(Operation *operation)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...