31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
40 static arith::ConstantOp
43 auto castedN =
static_cast<T
>(
44 cast<IntegerAttr>(op->
getAttr(attrName)).getValue().getSExtValue());
45 return rewriter.
create<arith::ConstantOp>(
57 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
58 return rewriter.
create<math::AbsFOp>(loc, resultTypes, args);
60 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
61 auto zero = rewriter.
create<arith::ConstantOp>(
63 auto neg = rewriter.
create<arith::SubIOp>(loc, zero, args[0]);
64 return rewriter.
create<arith::MaxSIOp>(loc, args[0], neg);
68 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
69 return rewriter.
create<arith::AddFOp>(loc, resultTypes, args);
71 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
72 return rewriter.
create<arith::AddIOp>(loc, resultTypes, args);
75 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
76 return rewriter.
create<arith::SubFOp>(loc, resultTypes, args);
78 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
79 return rewriter.
create<arith::SubIOp>(loc, resultTypes, args);
82 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
83 return rewriter.
create<arith::DivSIOp>(loc, resultTypes, args);
86 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
89 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, args[0]);
93 if (isa<tosa::MulOp>(op)) {
94 auto shift_val = cast<tosa::MulOp>(op).getShift();
95 ElementsAttr shift_elem;
96 if (!shift_val.getImpl() ||
101 int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
103 if (isa<FloatType>(elementTy)) {
106 "Cannot have shift value for float");
109 return rewriter.
create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
112 if (isa<IntegerType>(elementTy)) {
118 rewriter.
create<arith::ConstantIntOp>(loc, shift, 8);
125 auto result = rewriter.
create<tosa::ApplyScaleOp>(
129 if (elementTy.isInteger(32))
132 return rewriter.
create<arith::TruncIOp>(loc, elementTy, result);
137 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
140 a = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], a);
142 b = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], b);
144 return rewriter.
create<arith::MulIOp>(loc, resultTypes, a, b);
149 if (isa<tosa::NegateOp>(op)) {
150 if (isa<FloatType>(elementTy))
151 return rewriter.
create<arith::NegFOp>(loc, resultTypes, args);
153 if (isa<IntegerType>(elementTy)) {
154 auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1ZpAttr();
155 auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZpAttr();
158 inputZpAttr ? inputZpAttr.getValue().getSExtValue() : 0;
159 const int64_t outZp =
160 outputZpAttr ? outputZpAttr.getValue().getSExtValue() : 0;
162 if (!inZp && !outZp) {
163 auto constant = rewriter.
create<arith::ConstantOp>(
165 return rewriter.
create<arith::SubIOp>(loc, resultTypes, constant,
170 const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
171 const int64_t zpAdd = inZp + outZp;
172 const int64_t maxValue =
173 APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
179 int intermediateBitWidth = 64;
180 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
181 intermediateBitWidth = 16;
182 }
else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
183 intermediateBitWidth = 32;
184 }
else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
185 intermediateBitWidth = 48;
189 Value zpAddValue = rewriter.
create<arith::ConstantOp>(
195 rewriter.
create<arith::ExtSIOp>(loc, intermediateType, args[0]);
196 auto sub = rewriter.
create<arith::SubIOp>(loc, zpAddValue, ext);
200 loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
203 loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
208 return rewriter.
create<arith::TruncIOp>(loc, elementTy,
clamp);
213 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
214 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
217 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
218 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
221 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
223 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
224 auto allOnes = rewriter.
create<arith::ConstantOp>(loc, allOnesAttr);
225 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
229 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
230 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
233 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
234 return rewriter.
create<arith::ShLIOp>(loc, resultTypes, args);
237 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
238 return rewriter.
create<arith::ShRUIOp>(loc, resultTypes, args);
241 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
242 auto result = rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args);
243 auto round = cast<BoolAttr>(op->
getAttr(
"round")).getValue();
257 auto shiftValueGreaterThanZero = rewriter.
create<arith::CmpIOp>(
258 loc, arith::CmpIPredicate::sgt, args[1], zero);
262 rewriter.
create<arith::SubIOp>(loc, resultTypes, args[1], one);
264 rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
267 rewriter.
create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
269 rewriter.
create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
271 auto shouldRound = rewriter.
create<arith::AndIOp>(
272 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
274 rewriter.
create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
275 return rewriter.
create<arith::AddIOp>(loc, resultTypes, result, extended);
279 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
280 return rewriter.
create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
284 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
285 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
288 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
289 auto one = rewriter.
create<arith::ConstantOp>(
291 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], one);
295 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
296 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
299 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
300 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
303 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
304 return rewriter.
create<mlir::math::PowFOp>(loc, resultTypes, args);
307 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
308 return rewriter.
create<mlir::math::RsqrtOp>(loc, resultTypes, args);
311 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
312 return rewriter.
create<mlir::math::LogOp>(loc, resultTypes, args);
315 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
316 return rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, args);
319 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
320 return rewriter.
create<mlir::math::SinOp>(loc, resultTypes, args);
323 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
324 return rewriter.
create<mlir::math::CosOp>(loc, resultTypes, args);
327 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
328 return rewriter.
create<mlir::math::TanhOp>(loc, resultTypes, args);
331 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
332 return rewriter.
create<mlir::math::ErfOp>(loc, resultTypes, args);
335 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
336 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
339 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
340 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
344 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
345 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
348 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
349 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
353 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
354 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
357 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
358 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
362 if (isa<tosa::SelectOp>(op)) {
364 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
365 return rewriter.
create<arith::SelectOp>(loc, args[0], args[1], args[2]);
369 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
370 return rewriter.
create<arith::MaximumFOp>(loc, args[0], args[1]);
373 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
374 return rewriter.
create<arith::MaxSIOp>(loc, args[0], args[1]);
378 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
379 return rewriter.
create<arith::MinimumFOp>(loc, args[0], args[1]);
382 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
383 return rewriter.
create<arith::MinSIOp>(loc, args[0], args[1]);
387 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
388 return rewriter.
create<math::CeilOp>(loc, resultTypes, args);
391 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
392 return rewriter.
create<math::FloorOp>(loc, resultTypes, args);
395 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
396 bool losesInfo =
false;
397 APFloat minApf = cast<FloatAttr>(op->
getAttr(
"min_val")).getValue();
398 APFloat maxApf = cast<FloatAttr>(op->
getAttr(
"max_val")).getValue();
399 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
400 APFloat::rmNearestTiesToEven, &losesInfo);
401 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
402 APFloat::rmNearestTiesToEven, &losesInfo);
403 auto min = rewriter.
create<arith::ConstantOp>(
404 loc, elementTy, rewriter.
getFloatAttr(elementTy, minApf));
405 auto max = rewriter.
create<arith::ConstantOp>(
406 loc, elementTy, rewriter.
getFloatAttr(elementTy, maxApf));
410 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
411 auto intTy = cast<IntegerType>(elementTy);
413 cast<IntegerAttr>(op->
getAttr(
"min_val")).getValue().getSExtValue();
415 cast<IntegerAttr>(op->
getAttr(
"max_val")).getValue().getSExtValue();
419 if (intTy.isUnsignedInteger()) {
420 minRepresentable = 0;
421 if (intTy.getIntOrFloatBitWidth() <= 63) {
423 (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
426 }
else if (intTy.getIntOrFloatBitWidth() <= 64) {
428 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
430 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
440 auto minVal = rewriter.
create<arith::ConstantIntOp>(
441 loc,
min, intTy.getIntOrFloatBitWidth());
442 auto maxVal = rewriter.
create<arith::ConstantIntOp>(
443 loc,
max, intTy.getIntOrFloatBitWidth());
445 intTy.isUnsignedInteger());
449 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
452 auto negate = rewriter.
create<arith::NegFOp>(loc, resultTypes, args[0]);
453 auto exp = rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, negate);
454 auto added = rewriter.
create<arith::AddFOp>(loc, resultTypes, exp, one);
455 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, added);
459 if (isa<tosa::CastOp>(op)) {
460 Type srcTy = elementTy;
461 Type dstTy = resultTypes.front();
468 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
469 return rewriter.
create<arith::ExtFOp>(loc, resultTypes, args,
472 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
473 return rewriter.
create<arith::TruncFOp>(loc, resultTypes, args,
477 if (srcTy.
isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
478 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes, args,
481 if (srcTy.
isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
482 return rewriter.
create<arith::ExtUIOp>(loc, resultTypes, args,
488 auto unrealizedCast =
490 .
create<UnrealizedConversionCastOp>(
494 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes[0],
499 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
500 return rewriter.
create<arith::SIToFPOp>(loc, resultTypes, args,
504 if (isa<FloatType>(srcTy) && dstTy.
isInteger(1)) {
507 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
511 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
512 auto rounded = rewriter.
create<math::RoundEvenOp>(loc, args[0]);
514 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
518 APFloat::semanticsMaxExponent(fltSemantics)) {
521 auto conv = rewriter.
create<arith::FPToSIOp>(loc, dstTy, rounded);
522 auto posInf = rewriter.
create<arith::ConstantOp>(
524 APFloat::getInf(fltSemantics)));
525 auto negInf = rewriter.
create<arith::ConstantOp>(
528 APFloat::getInf(fltSemantics,
true)));
529 auto overflow = rewriter.
create<arith::CmpFOp>(
530 loc, arith::CmpFPredicate::UEQ, rounded, posInf);
531 auto underflow = rewriter.
create<arith::CmpFOp>(
532 loc, arith::CmpFPredicate::UEQ, rounded, negInf);
533 auto intMin = rewriter.
create<arith::ConstantOp>(
536 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
537 auto intMax = rewriter.
create<arith::ConstantOp>(
540 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
542 rewriter.
create<arith::SelectOp>(loc, overflow, intMax, conv);
543 return rewriter.
create<arith::SelectOp>(loc, underflow, intMin,
547 auto intMinFP = rewriter.
create<arith::ConstantOp>(
554 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
560 auto intMaxFP = rewriter.
create<arith::ConstantOp>(
568 return rewriter.
create<arith::FPToSIOp>(loc, dstTy, clamped);
575 auto intMaxPlusOneFP = rewriter.
create<arith::ConstantOp>(
583 auto intMax = rewriter.
create<arith::ConstantOp>(
588 rewriter.
create<arith::MaximumFOp>(loc, rounded, intMinFP);
590 rewriter.
create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
591 auto overflow = rewriter.
create<arith::CmpFOp>(
592 loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
593 return rewriter.
create<arith::SelectOp>(loc, overflow, intMax,
599 if (isa<IntegerType>(srcTy) && dstTy.
isInteger(1)) {
600 Value zero = rewriter.
create<arith::ConstantIntOp>(
602 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
606 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
607 return rewriter.
create<arith::ExtSIOp>(loc, resultTypes, args,
610 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
611 return rewriter.
create<arith::TruncIOp>(loc, dstTy, args[0]);
616 op,
"unhandled op for linalg body calculation for elementwise op");
623 auto tensorType = dyn_cast<RankedTensorType>(tensor.
getType());
624 assert(tensorType &&
"expected a ranked tensor type");
625 int64_t tensorRank = tensorType.getRank();
626 int64_t numExtraDims = rank - tensorRank;
627 assert(numExtraDims >= 0 &&
"cannot expand tensor to a lower rank");
634 if (tensorRank != 0) {
635 for (index = 0; index <= numExtraDims; index++)
636 reassociationIndices[0].push_back(index);
637 for (
size_t position = 1; position < reassociationIndices.size();
639 reassociationIndices[position].push_back(index++);
644 for (index = 0; index < numExtraDims; index++)
645 resultShape.push_back(1);
646 for (
auto size : tensorType.getShape())
647 resultShape.push_back(size);
652 return rewriter.
create<tensor::ExpandShapeOp>(loc, resultType, tensor,
653 reassociationIndices);
659 return llvm::map_to_vector(operands, [&](
Value operand) {
660 return expandRank(rewriter, loc, operand, rank);
671 auto [it, inserted] = indexPool.try_emplace(index);
680 auto indexValue =
createIndex(rewriter, loc, indexPool, index);
681 return rewriter.
create<tensor::DimOp>(loc, tensor, indexValue).getResult();
687 auto shapedType = dyn_cast<ShapedType>(tensor.
getType());
688 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
689 assert(index >= 0 && index < shapedType.getRank() &&
"index out of bounds");
690 if (shapedType.isDynamicDim(index))
691 return getTensorDim(rewriter, loc, indexPool, tensor, index);
692 return rewriter.
getIndexAttr(shapedType.getDimSize(index));
696 auto isRanked = [](
Value value) {
697 return isa<RankedTensorType>(value.getType());
699 return llvm::all_of(operation->
getOperands(), isRanked) &&
700 llvm::all_of(operation->
getResults(), isRanked);
713 static std::pair<OpFoldResult, Value>
719 for (
auto operand : operands) {
720 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
721 if (!ShapedType::isDynamic(size) && size > 1)
726 auto operandsWithDynamicDim =
727 llvm::filter_to_vector(operands, [&](
Value operand) {
728 return cast<RankedTensorType>(operand.
getType()).isDynamicDim(dim);
732 if (operandsWithDynamicDim.empty())
739 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
740 if (operandsWithDynamicDim.size() == 1)
741 return {targetSize, operandsWithDynamicDim[0]};
744 for (
size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
746 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
747 targetSize = rewriter.
create<arith::MaxUIOp>(loc, targetSize, nextSize);
749 return {targetSize,
nullptr};
757 assert(!operands.empty());
758 auto rank = cast<RankedTensorType>(operands.front().
getType()).getRank();
761 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
762 auto [targetSize, masterOperand] =
764 targetShape.push_back(targetSize);
765 masterOperands.push_back(masterOperand);
767 return {targetShape, masterOperands};
773 Value masterOperand) {
775 auto rankedTensorType = cast<RankedTensorType>(operand.
getType());
776 if (!rankedTensorType.isDynamicDim(dim))
783 if (operand == masterOperand)
787 auto rank = rankedTensorType.getRank();
789 for (
auto index : llvm::seq<int64_t>(0, rank)) {
792 affineExprs.push_back(affineExpr);
794 auto broadcastAffineMap =
800 auto one =
createIndex(rewriter, loc, indexPool, 1);
801 auto runtimeSize =
getTensorDim(rewriter, loc, indexPool, operand, dim);
802 auto broadcastNecessary = rewriter.
create<arith::CmpIOp>(
803 loc, arith::CmpIPredicate::eq, runtimeSize, one);
813 for (
auto index : llvm::seq<int64_t>(0, rank)) {
814 auto size = index == dim ? targetSize
817 outputTensorShape.push_back(size);
819 Value outputTensor = opBuilder.
create<tensor::EmptyOp>(
820 loc, outputTensorShape, rankedTensorType.getElementType());
825 .
create<linalg::GenericOp>(
826 loc, outputTensor.
getType(), operand, outputTensor, affineMaps,
830 opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
835 auto castResultTensor = rewriter.
createOrFold<tensor::CastOp>(
836 loc, operand.
getType(), resultTensor);
839 opBuilder.create<scf::YieldOp>(loc, castResultTensor);
844 opBuilder.
create<scf::YieldOp>(loc, operand);
848 auto ifOp = rewriter.
create<scf::IfOp>(loc, broadcastNecessary,
849 emitThenRegion, emitElseRegion);
857 int64_t rank = cast<RankedTensorType>(operand.
getType()).getRank();
858 assert((int64_t)targetShape.size() == rank);
859 assert((int64_t)masterOperands.size() == rank);
860 for (
auto index : llvm::seq<int64_t>(0, rank))
863 targetShape[index], masterOperands[index]);
873 if (operands.size() == 1)
877 return llvm::map_to_vector(operands, [&](
Value operand) {
879 targetShape, masterOperands);
889 auto resultType = cast_or_null<RankedTensorType>(
894 Value outputTensor = rewriter.
create<tensor::EmptyOp>(
895 loc, targetShape, resultType.getElementType());
900 auto rank = resultType.getRank();
901 auto affineMaps = llvm::map_to_vector(operands, [&](
Value operand) {
902 auto shape = cast<ShapedType>(operand.
getType()).getShape();
908 bool requiresBroadcast =
909 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
910 auto affineExpr = requiresBroadcast
913 affineExprs.push_back(affineExpr);
920 bool encounteredError =
false;
921 auto linalgOp = rewriter.
create<linalg::GenericOp>(
922 loc, outputTensor.
getType(), operands, outputTensor, affineMaps,
927 {resultType.getElementType()}, rewriter);
929 encounteredError =
true;
932 opBuilder.create<linalg::YieldOp>(loc, opResult);
934 if (encounteredError)
936 operation,
"unable to create linalg.generic body for elementwise op");
939 auto castResult = rewriter.
createOrFold<tensor::CastOp>(
940 loc, resultType, linalgOp->getResult(0));
941 rewriter.
replaceOp(operation, castResult);
951 assert(operation->
getNumResults() == 1 &&
"elementwise op expects 1 result");
953 "elementwise op expects at least 1 operand");
956 "Unranked tensors not supported");
960 auto loc = operation->
getLoc();
965 auto operandsToExpand =
966 isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands;
968 auto expandedOperands =
970 auto [targetShape, masterOperands] =
973 rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
975 targetShape, converter);
982 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
985 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
988 if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy))
991 if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy))
994 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
996 elementTy, APFloat::getLargest(
997 cast<FloatType>(elementTy).getFloatSemantics(),
false));
999 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1003 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1005 elementTy, APFloat::getLargest(
1006 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1008 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1012 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1015 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1018 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1020 elementTy, APFloat::getLargest(
1021 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1023 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1037 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1038 return rewriter.
create<arith::AddFOp>(loc, args);
1041 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1042 return rewriter.
create<arith::AddIOp>(loc, args);
1045 if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) {
1046 return rewriter.
create<arith::MulFOp>(loc, args);
1049 if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) {
1050 return rewriter.
create<arith::MulIOp>(loc, args);
1053 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1054 return rewriter.
create<arith::MinimumFOp>(loc, args[0], args[1]);
1057 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1058 return rewriter.
create<arith::MinSIOp>(loc, args[0], args[1]);
1061 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1062 return rewriter.
create<arith::MaximumFOp>(loc, args[0], args[1]);
1065 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1066 return rewriter.
create<arith::MaxSIOp>(loc, args[0], args[1]);
1069 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1070 return rewriter.
create<arith::AndIOp>(loc, args);
1072 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1073 return rewriter.
create<arith::OrIOp>(loc, args);
1086 auto elementTy = resultTy.getElementType();
1091 for (
unsigned i = 0; i < inputTy.getRank(); i++) {
1093 reduceShape.push_back(inputTy.getDimSize(i));
1094 if (inputTy.isDynamicDim(i))
1095 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1102 .
create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1109 op,
"No initial value found for reduction operation");
1111 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
1112 auto filledTensor = rewriter
1117 bool didEncounterError =
false;
1118 auto linalgOp = rewriter.
create<linalg::ReduceOp>(
1119 loc, input, filledTensor, axis,
1122 op, blockArgs, elementTy, rewriter);
1124 didEncounterError =
true;
1126 nestedBuilder.create<linalg::YieldOp>(loc, result);
1129 if (!didEncounterError)
1131 op,
"unable to create linalg.generic body for reduce op");
1134 uint64_t expandInputRank =
1135 cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
1136 reassociationMap.resize(expandInputRank);
1138 for (uint64_t i = 0; i < expandInputRank; i++) {
1139 int32_t dimToPush = i > axis ? i + 1 : i;
1143 if (expandInputRank != 0) {
1144 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1145 reassociationMap[expandedDim].push_back(
1154 op, resultTy, linalgOp.
getResults()[0], reassociationMap);
1160 template <
typename SrcOp>
1167 matchAndRewrite(SrcOp op, OpAdaptor operands,
1170 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1178 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1180 auto loc = op.getLoc();
1181 auto input = op.getInput();
1182 auto inputTy = cast<ShapedType>(op.getInput().getType());
1183 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1184 unsigned rank = inputTy.getRank();
1187 if (op.getDoubleRound() && !op.getScale32())
1188 return rewriter.notifyMatchFailure(
1189 op,
"tosa.rescale requires scale32 for double_round to be true");
1191 if (!isa<IntegerType>(inputTy.getElementType()))
1192 return rewriter.notifyMatchFailure(op,
"only support integer type");
1195 for (
int i = 0; i < outputTy.getRank(); i++) {
1196 if (outputTy.isDynamicDim(i)) {
1197 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1206 for (
int i = 0, s = multiplierValues.size(); i < s; i++) {
1207 if (shiftValues[i] > 63) {
1209 multiplierValues[i] = 0;
1216 op.getDoubleRound() &&
1217 llvm::any_of(shiftValues, [](int32_t v) {
return v > 31; });
1220 rewriter.getMultiDimIdentityMap(rank)};
1225 Value multiplierConstant;
1226 int64_t multiplierArg = 0;
1227 if (multiplierValues.size() == 1) {
1228 multiplierConstant = rewriter.create<arith::ConstantOp>(
1229 loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1232 rewriter.getAffineDimExpr(rank - 1)};
1233 auto multiplierType =
1235 rewriter.getI32Type());
1236 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1241 rewriter.getContext()));
1243 multiplierArg = indexingMaps.size() - 1;
1248 Value shiftConstant;
1249 int64_t shiftArg = 0;
1250 if (shiftValues.size() == 1) {
1251 shiftConstant = rewriter.create<arith::ConstantOp>(
1252 loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1255 rewriter.getAffineDimExpr(rank - 1)};
1258 rewriter.getIntegerType(8));
1259 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1263 rewriter.getContext()));
1264 shiftArg = indexingMaps.size() - 1;
1268 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1271 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1272 loc, outputTy.getShape(), outputTy.getElementType(),
1275 auto linalgOp = rewriter.create<linalg::GenericOp>(
1276 loc, outputTy, genericInputs,
ValueRange{emptyTensor}, indexingMaps,
1280 Value value = blockArgs[0];
1288 auto inputZp = createConstFromIntAttribute<int32_t>(
1289 op,
"input_zp", nestedBuilder.getIntegerType(inBitwidth),
1291 auto outputZp = createConstFromIntAttribute<int32_t>(
1292 op,
"output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1294 Value multiplier = multiplierConstant ? multiplierConstant
1295 : blockArgs[multiplierArg];
1296 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1299 if (op.getInputUnsigned()) {
1300 value = nestedBuilder.create<arith::ExtUIOp>(
1301 nestedLoc, nestedBuilder.getI32Type(), value);
1303 value = nestedBuilder.create<arith::ExtSIOp>(
1304 nestedLoc, nestedBuilder.getI32Type(), value);
1309 nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1311 value = nestedBuilder.create<tosa::ApplyScaleOp>(
1312 loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1313 nestedBuilder.getBoolAttr(doubleRound));
1317 nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1320 IntegerType outIntType =
1321 cast<IntegerType>(blockArgs.back().getType());
1322 unsigned outBitWidth = outIntType.getWidth();
1324 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1325 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1328 if (op.getOutputUnsigned()) {
1330 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1333 auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1334 loc, nestedBuilder.getI32IntegerAttr(intMin));
1335 auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1336 loc, nestedBuilder.getI32IntegerAttr(intMax));
1339 nestedBuilder,
false);
1341 if (outIntType.getWidth() < 32) {
1342 value = nestedBuilder.create<arith::TruncIOp>(
1343 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1347 nestedBuilder.create<linalg::YieldOp>(loc, value);
1350 rewriter.replaceOp(op, linalgOp->getResults());
1362 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1366 auto input = op.getInput();
1367 auto inputTy = cast<RankedTensorType>(input.getType());
1368 auto resultTy = cast<RankedTensorType>(op.getType());
1369 const bool isBilinear = op.getMode() ==
"BILINEAR";
1371 auto inputH = inputTy.getDimSize(1);
1372 auto inputW = inputTy.getDimSize(2);
1373 auto outputH = resultTy.getDimSize(1);
1374 auto outputW = resultTy.getDimSize(2);
1376 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1377 return rewriter.notifyMatchFailure(
1378 op,
"tosa.resize is not a pure 1x1->1x1 image operation");
1381 if (op.getMode() !=
"NEAREST_NEIGHBOR" && op.getMode() !=
"BILINEAR")
1382 return rewriter.notifyMatchFailure(
1383 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1385 if (inputTy == resultTy) {
1386 rewriter.replaceOp(op, input);
1394 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1395 reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1396 reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1397 reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1401 inputTy.getElementType());
1402 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1407 if (inputTy.isDynamicDim(0))
1408 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1409 if (inputTy.isDynamicDim(3))
1410 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1413 auto genericTy = collapseTy.clone(resultTy.getElementType());
1414 Value empty = builder.create<tensor::EmptyOp>(
1415 genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1416 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1418 utils::IteratorType::parallel);
1420 auto generic = builder.create<linalg::GenericOp>(
1424 Value value = args[0];
1426 if (inputTy.getElementType() != resultTy.getElementType()) {
1428 b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1430 if (isBilinear && scale[0] != 0) {
1431 Value scaleY = b.create<arith::ConstantOp>(
1432 loc, b.getI32IntegerAttr(scale[0]));
1433 value = b.create<arith::MulIOp>(loc, value, scaleY);
1436 if (isBilinear && scale[2] != 0) {
1437 Value scaleX = b.create<arith::ConstantOp>(
1438 loc, b.getI32IntegerAttr(scale[2]));
1439 value = b.create<arith::MulIOp>(loc, value, scaleX);
1443 b.create<linalg::YieldOp>(loc, value);
1446 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1447 op, resultTy,
generic.getResults()[0], reassociationMap);
1459 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1463 auto input = op.getInput();
1464 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1465 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1467 if (!inputTy || !resultTy)
1468 return rewriter.notifyMatchFailure(op,
1469 "requires ranked input/output types");
1471 auto batch = inputTy.getDimSize(0);
1472 auto channels = inputTy.getDimSize(3);
1473 auto inputH = inputTy.getDimSize(1);
1474 auto inputW = inputTy.getDimSize(2);
1475 auto outputH = resultTy.getDimSize(1);
1476 auto outputW = resultTy.getDimSize(2);
1478 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1479 return rewriter.notifyMatchFailure(
1480 op,
"tosa.resize has no broadcasting behavior");
1485 resizeShape.push_back(batch);
1486 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1487 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1488 resizeShape.push_back(channels);
1490 auto resizeTy = resultTy.clone(resizeShape);
1492 builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs());
1496 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1497 reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1499 reassociationMap.push_back({});
1500 reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1502 reassociationMap.push_back({});
1503 reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1507 collapseShape.push_back(outputH);
1509 collapseShape.push_back(outputW);
1510 collapseShape.push_back(channels);
1512 auto collapseTy = resultTy.clone(collapseShape);
1513 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1518 if (inputTy.isDynamicDim(0))
1519 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1520 if (inputTy.isDynamicDim(3))
1521 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1524 utils::IteratorType::parallel);
1525 Value empty = builder.create<tensor::EmptyOp>(
1526 resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1530 inputExprs.push_back(rewriter.getAffineDimExpr(1));
1532 inputExprs.push_back(rewriter.getAffineDimExpr(2));
1533 inputExprs.push_back(rewriter.getAffineDimExpr(3));
1536 inputExprs, rewriter.getContext());
1538 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1539 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1543 Value value = args[0];
1544 b.create<linalg::YieldOp>(loc, value);
1555 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1559 auto input = op.getInput();
1560 auto inputTy = cast<ShapedType>(input.getType());
1561 auto resultTy = cast<ShapedType>(op.getType());
1562 auto resultETy = resultTy.getElementType();
1564 bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1565 auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1567 auto imageH = inputTy.getShape()[1];
1568 auto imageW = inputTy.getShape()[2];
1570 auto dynamicDimsOr =
1572 if (!dynamicDimsOr.has_value())
1573 return rewriter.notifyMatchFailure(
1574 op,
"unable to get dynamic dimensions of tosa.resize");
1576 if (op.getMode() !=
"NEAREST_NEIGHBOR" && op.getMode() !=
"BILINEAR")
1577 return rewriter.notifyMatchFailure(
1578 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1581 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1582 auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1584 auto genericOp = b.create<linalg::GenericOp>(
1587 Value resize = genericOp.getResult(0);
1591 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1593 Value batch = b.create<linalg::IndexOp>(0);
1594 Value y = b.create<linalg::IndexOp>(1);
1595 Value x = b.create<linalg::IndexOp>(2);
1596 Value channel = b.create<linalg::IndexOp>(3);
1599 b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1600 Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1601 Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1602 Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1604 Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1605 Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1611 Value yScaleN, yScaleD, xScaleN, xScaleD;
1612 yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1613 yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1614 xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1615 xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1617 Value yOffset, xOffset, yBorder, xBorder;
1618 yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1619 xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1620 yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1621 xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1634 Value val = b.create<arith::MulIOp>(in, scaleD);
1635 val = b.create<arith::AddIOp>(val, offset);
1636 index = b.create<arith::FloorDivSIOp>(val, scaleN);
1640 Value r = b.create<arith::RemSIOp>(val, scaleN);
1641 Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1642 Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1643 delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1658 Value val = b.create<arith::MulIOp>(in, scaleD);
1659 val = b.create<arith::AddIOp>(val, offset);
1660 index = b.create<arith::DivSIOp>(val, scaleN);
1661 delta = b.create<arith::MulIOp>(index, scaleN);
1662 delta = b.create<arith::SubIOp>(val, delta);
1665 Value ix, iy, dx, dy;
1666 if (floatingPointMode) {
1667 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1668 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1670 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1671 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1674 if (op.getMode() ==
"NEAREST_NEIGHBOR") {
1675 auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1677 auto getNearestIndexAndClamp = [&](
Value val,
Value dval,
Value scale,
1681 return b.create<arith::ConstantIndexOp>(0);
1685 if (floatingPointMode) {
1686 auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1687 pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1689 Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1690 pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1694 auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1695 val = b.create<arith::AddIOp>(val, offset);
1697 return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1700 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1701 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1703 Value result = b.create<tensor::ExtractOp>(
1706 b.create<linalg::YieldOp>(result);
1709 assert(op.getMode() ==
"BILINEAR");
1711 auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1713 auto getClampedIdxs = [&](
Value &val0,
Value &val1,
int size,
Value in,
1716 val1 = b.create<arith::AddIOp>(val0, oneVal);
1721 val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1722 val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1730 Value x0, x1, y0, y1;
1731 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1732 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1734 Value y0x0 = b.create<tensor::ExtractOp>(
1736 Value y0x1 = b.create<tensor::ExtractOp>(
1738 Value y1x0 = b.create<tensor::ExtractOp>(
1740 Value y1x1 = b.create<tensor::ExtractOp>(
1743 if (floatingPointMode) {
1745 b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1751 Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1752 Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1753 Value mul1 = b.create<arith::MulFOp>(val1, delta);
1754 return b.create<arith::AddFOp>(mul0, mul1);
1760 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1765 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1769 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1770 b.create<linalg::YieldOp>(result);
1773 y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1774 y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1775 y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1776 y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1779 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1780 dx = b.create<arith::ExtSIOp>(resultETy, dx);
1781 dy = b.create<arith::ExtSIOp>(resultETy, dy);
1784 Value yScaleNExt = yScaleN;
1785 Value xScaleNExt = xScaleN;
1787 const int64_t scaleBitwidth =
1789 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
1790 yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
1791 xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
1795 Value scale,
int inputSize,
1798 return b.create<arith::MulIOp>(val0, scale);
1799 Value weight0 = b.create<arith::SubIOp>(scale, weight1);
1800 Value mul0 = b.create<arith::MulIOp>(val0, weight0);
1801 Value mul1 = b.create<arith::MulIOp>(val1, weight1);
1802 return b.create<arith::AddIOp>(mul0, mul1);
1805 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
1806 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
1808 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
1809 b.create<linalg::YieldOp>(result);
1814 rewriter.replaceOp(op, resize);
1822 template <
typename SrcOp>
1827 LogicalResult matchAndRewrite(SrcOp op,
1829 rewriter.replaceOp(op, op.getOperation()->getOperands());
1834 template <
typename SrcOp>
1839 LogicalResult matchAndRewrite(SrcOp reduceOp,
1849 LogicalResult matchAndRewrite(tosa::ReverseOp op,
1851 auto loc = op.getLoc();
1852 Value input = op.getInput1();
1853 auto inputTy = cast<ShapedType>(input.
getType());
1854 auto resultTy = cast<ShapedType>(op.getType());
1855 auto axis = op.getAxis();
1858 for (
int i = 0; i < inputTy.getRank(); i++) {
1859 if (inputTy.isDynamicDim(i)) {
1860 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1864 Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1867 auto emptyTensor = rewriter
1868 .create<tensor::EmptyOp>(loc, inputTy.getShape(),
1869 inputTy.getElementType(),
1873 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1875 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1880 for (
unsigned int i = 0; i < inputTy.getRank(); i++) {
1882 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1884 auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1886 rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1887 index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1891 indices.push_back(index);
1894 auto extract = nestedBuilder.create<tensor::ExtractOp>(
1895 nestedLoc, input, indices);
1896 nestedBuilder.create<linalg::YieldOp>(op.getLoc(),
1897 extract.getResult());
1911 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1913 auto loc = op.getLoc();
1914 auto input = op.getInput1();
1915 auto inputTy = cast<ShapedType>(input.getType());
1916 auto inputShape = inputTy.getShape();
1917 auto resultTy = cast<ShapedType>(op.getType());
1918 auto elementTy = inputTy.getElementType();
1919 int64_t rank = inputTy.getRank();
1922 if (failed(op.getConstantMultiples(multiples)))
1927 for (
int i = 0; i < rank; i++) {
1928 int64_t dim = multiples[i];
1929 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
1930 genericShape.push_back(inputShape[i]);
1934 for (
int i = 0; i < inputTy.getRank(); i++) {
1935 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1936 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1940 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
1941 op.getLoc(), genericShape, elementTy, dynDims);
1945 dimExprs.reserve(rank);
1946 for (
unsigned i = 0; i < rank; ++i)
1949 auto readAffineMap =
1956 auto genericOp = rewriter.
create<linalg::GenericOp>(
1961 nestedBuilder.create<linalg::YieldOp>(op.getLoc(), *args.begin());
1967 op, resultTy, genericOp.getResult(0), shapeValue);
1989 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
1991 auto loc = argmaxOp.getLoc();
1992 Value input = argmaxOp.getInput();
1993 auto inputTy = cast<ShapedType>(input.
getType());
1994 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
1995 auto inElementTy = inputTy.getElementType();
1996 auto outElementTy = resultTy.getElementType();
1997 int axis = argmaxOp.getAxis();
2000 if (!isa<IntegerType>(outElementTy))
2003 "tosa.arg_max to linalg.* requires integer-like result type");
2006 for (
int i = 0; i < inputTy.getRank(); i++) {
2007 if (inputTy.isDynamicDim(i) && i != axis) {
2008 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
2013 auto emptyTensorIdx = rewriter
2014 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2015 outElementTy, dynDims)
2017 auto fillValueIdx = rewriter.
create<arith::ConstantOp>(
2019 auto filledTensorIdx =
2026 auto emptyTensorMax = rewriter
2027 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2028 inElementTy, dynDims)
2030 auto fillValueMaxAttr =
2033 if (!fillValueMaxAttr)
2035 argmaxOp,
"unsupported tosa.argmax element type");
2038 rewriter.
create<arith::ConstantOp>(loc, fillValueMaxAttr);
2039 auto filledTensorMax =
2048 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2049 iteratorTypes[axis] = utils::IteratorType::reduction;
2053 for (
int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2059 bool didEncounterError =
false;
2062 auto linalgOp = rewriter.
create<linalg::GenericOp>(
2064 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2067 auto newValue = blockArgs[0];
2068 auto oldIndex = blockArgs[1];
2069 auto oldValue = blockArgs[2];
2071 Value newIndex = rewriter.
create<arith::IndexCastOp>(
2072 nestedLoc, oldIndex.getType(),
2073 rewriter.
create<linalg::IndexOp>(loc, axis));
2076 if (isa<FloatType>(inElementTy)) {
2077 predicate = rewriter.
create<arith::CmpFOp>(
2078 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2079 }
else if (isa<IntegerType>(inElementTy)) {
2080 predicate = rewriter.
create<arith::CmpIOp>(
2081 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2083 didEncounterError =
true;
2087 auto resultMax = rewriter.
create<arith::SelectOp>(
2088 nestedLoc, predicate, newValue, oldValue);
2089 auto resultIndex = rewriter.
create<arith::SelectOp>(
2090 nestedLoc, predicate, newIndex, oldIndex);
2091 nestedBuilder.
create<linalg::YieldOp>(
2092 nestedLoc,
ValueRange({resultIndex, resultMax}));
2095 if (didEncounterError)
2097 argmaxOp,
"unsupported tosa.argmax element type");
2099 rewriter.
replaceOp(argmaxOp, linalgOp.getResult(0));
2108 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2110 auto input = adaptor.getOperands()[0];
2111 auto indices = adaptor.getOperands()[1];
2114 dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
2115 auto resultTy = cast<ShapedType>(op.getType());
2120 auto dynamicDims = inferDynamicDimsForGather(
2121 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2123 auto resultElementTy = resultTy.getElementType();
2125 auto loc = op.getLoc();
2128 .
create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2134 resultTy.getRank(), 0,
2135 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2139 auto genericOp = rewriter.
create<linalg::GenericOp>(
2144 auto indexValue = args[0];
2145 auto index0 = rewriter.
create<linalg::IndexOp>(loc, 0);
2146 Value index1 = rewriter.
create<arith::IndexCastOp>(
2148 auto index2 = rewriter.
create<linalg::IndexOp>(loc, 2);
2149 Value extract = rewriter.
create<tensor::ExtractOp>(
2150 loc, input,
ValueRange{index0, index1, index2});
2151 rewriter.
create<linalg::YieldOp>(loc, extract);
2153 rewriter.
replaceOp(op, genericOp.getResult(0));
2163 auto addDynamicDimension = [&](
Value source, int64_t dim) {
2165 if (
auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2166 results.push_back(dimValue);
2169 addDynamicDimension(values, 0);
2170 addDynamicDimension(indices, 1);
2171 addDynamicDimension(values, 2);
2183 LogicalResult matchAndRewrite(tosa::TableOp op,
2185 auto loc = op.getLoc();
2186 Value input = op.getInput1();
2188 auto inputTy = cast<ShapedType>(input.
getType());
2189 auto tableTy = cast<ShapedType>(
table.getType());
2190 auto resultTy = cast<ShapedType>(op.getType());
2192 auto inputElementTy = inputTy.getElementType();
2193 auto tableElementTy = tableTy.getElementType();
2194 auto resultElementTy = resultTy.getElementType();
2197 for (
int i = 0; i < resultTy.getRank(); ++i) {
2198 if (inputTy.isDynamicDim(i)) {
2200 rewriter.
create<tensor::DimOp>(loc, op.getOperand(0), i));
2204 auto emptyTensor = rewriter
2205 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2206 resultElementTy, dynDims)
2213 auto genericOp = rewriter.
create<linalg::GenericOp>(
2216 rewriter.
replaceOp(op, genericOp.getResult(0));
2221 &genericOp.getRegion(), genericOp.getRegion().end(),
2222 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2226 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2227 resultElementTy.isInteger(8)) {
2228 Value index = rewriter.
create<arith::IndexCastOp>(
2230 Value offset = rewriter.
create<arith::ConstantIndexOp>(loc, 128);
2235 rewriter.
create<linalg::YieldOp>(loc, extract);
2239 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2240 resultElementTy.isInteger(32)) {
2244 auto offset = rewriter.
create<arith::ConstantOp>(
2246 auto seven = rewriter.
create<arith::ConstantOp>(
2248 auto one = rewriter.
create<arith::ConstantOp>(
2250 auto b1111111 = rewriter.
create<arith::ConstantOp>(
2257 auto extendAdd = rewriter.
create<arith::AddIOp>(loc, extend, offset);
2258 Value index = rewriter.
create<arith::ShRUIOp>(loc, extendAdd, seven);
2260 rewriter.
create<arith::AndIOp>(loc, extendAdd, b1111111);
2265 Value indexPlusOne = rewriter.
create<arith::AddIOp>(loc, index, one);
2267 index = rewriter.
create<arith::IndexCastOp>(
2269 indexPlusOne = rewriter.
create<arith::IndexCastOp>(
2284 Value baseScaled = rewriter.
create<arith::ShLIOp>(loc, base, seven);
2285 Value diff = rewriter.
create<arith::SubIOp>(loc, next, base);
2286 Value diffScaled = rewriter.
create<arith::MulIOp>(loc, diff, fraction);
2288 rewriter.
create<arith::AddIOp>(loc, baseScaled, diffScaled);
2290 rewriter.
create<linalg::YieldOp>(loc, result);
2297 op,
"unable to create body for tosa.table op");
2304 static bool isRankedTensor(
Type type) {
return isa<RankedTensorType>(type); }
2308 auto one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2309 auto two = builder.
create<arith::ConstantIndexOp>(loc, 2);
2312 auto divBy2 = builder.
createOrFold<arith::DivUIOp>(loc, value, two);
2313 auto plusOne = builder.
createOrFold<arith::AddIOp>(loc, divBy2, one);
2317 static RankedTensorType
2325 dims[2] = halfPlusOne(builder, loc, dims[2]);
2330 auto elementType = cast<RankedTensorType>(input.
getType()).getElementType();
2335 RankedTensorType type,
2338 rewriter.
create<tensor::EmptyOp>(loc, type, dynamicSizes);
2339 auto fillValueAttr = rewriter.
getZeroAttr(type.getElementType());
2340 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
2341 auto filledTensor = rewriter
2345 return filledTensor;
2349 FloatType type,
Value value) {
2350 auto integerVal = builder.
create<arith::IndexCastUIOp>(
2352 type.getIntOrFloatBitWidth() > 32 ? builder.
getI64Type()
2356 return builder.
create<arith::UIToFPOp>(loc, type, integerVal);
2360 FloatType type, int64_t index) {
2361 auto indexVal = builder.
create<linalg::IndexOp>(loc, index);
2362 return castIndexToFloat(builder, loc, type, indexVal);
2365 template <
typename... Args>
2371 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2373 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2374 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2376 "only supports ranked tensors");
2379 auto loc = rfft2d.getLoc();
2380 auto input = rfft2d.getInput();
2382 dyn_cast<FloatType>(cast<ShapedType>(input.
getType()).getElementType());
2385 "only supports float element types");
2389 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2393 utils::IteratorType::parallel, utils::IteratorType::parallel,
2394 utils::IteratorType::parallel, utils::IteratorType::reduction,
2395 utils::IteratorType::reduction};
2400 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2401 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2406 affineDimsExpr(rewriter, 0, 1, 2),
2407 affineDimsExpr(rewriter, 0, 1, 2)},
2411 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input, 1);
2412 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input, 2);
2415 auto twoPiAttr = rewriter.
getFloatAttr(elementType, 6.283185307179586);
2416 auto twoPi = rewriter.
create<arith::ConstantOp>(loc, twoPiAttr);
2417 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2418 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2421 Value valReal = args[0];
2422 Value sumReal = args[1];
2423 Value sumImag = args[2];
2426 Value oy = builder.
create<linalg::IndexOp>(loc, 1);
2427 Value ox = builder.
create<linalg::IndexOp>(loc, 2);
2428 Value iy = builder.
create<linalg::IndexOp>(loc, 3);
2429 Value ix = builder.
create<linalg::IndexOp>(loc, 4);
2434 auto iyXoy = builder.
create<index::MulOp>(loc, iy, oy);
2435 auto ixXox = builder.
create<index::MulOp>(loc, ix, ox);
2437 auto iyRem = builder.
create<index::RemUOp>(loc, iyXoy, dimH);
2438 auto ixRem = builder.
create<index::RemUOp>(loc, ixXox, dimW);
2440 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2441 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2443 auto yComponent = builder.
create<arith::DivFOp>(loc, iyRemFloat, constH);
2444 auto xComponent = builder.
create<arith::DivFOp>(loc, ixRemFloat, constW);
2445 auto sumXY = builder.
create<arith::AddFOp>(loc, yComponent, xComponent);
2446 auto angle = builder.
create<arith::MulFOp>(loc, twoPi, sumXY);
2450 auto cosAngle = builder.
create<math::CosOp>(loc, angle);
2451 auto sinAngle = builder.
create<math::SinOp>(loc, angle);
2452 auto realComponent =
2453 builder.
create<arith::MulFOp>(loc, valReal, cosAngle);
2454 auto imagComponent =
2455 builder.
create<arith::MulFOp>(loc, valReal, sinAngle);
2459 auto outReal = builder.
create<arith::AddFOp>(loc, sumReal, realComponent);
2460 auto outImag = builder.
create<arith::SubFOp>(loc, sumImag, imagComponent);
2466 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2467 indexingMaps, iteratorTypes, buildBody);
2476 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2478 if (!llvm::all_of(fft2d->getOperandTypes(),
2479 RFFT2dConverter::isRankedTensor) ||
2480 !llvm::all_of(fft2d->getResultTypes(),
2481 RFFT2dConverter::isRankedTensor)) {
2486 Value input_real = fft2d.getInputReal();
2487 Value input_imag = fft2d.getInputImag();
2488 BoolAttr inverse = fft2d.getInverseAttr();
2490 auto real_el_ty = cast<FloatType>(
2491 cast<ShapedType>(input_real.
getType()).getElementType());
2492 [[maybe_unused]]
auto imag_el_ty = cast<FloatType>(
2493 cast<ShapedType>(input_imag.
getType()).getElementType());
2495 assert(real_el_ty == imag_el_ty);
2510 utils::IteratorType::parallel, utils::IteratorType::parallel,
2511 utils::IteratorType::parallel, utils::IteratorType::reduction,
2512 utils::IteratorType::reduction};
2517 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2519 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2524 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2525 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2526 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2527 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2531 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 1);
2532 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 2);
2535 auto twoPiAttr = rewriter.
getFloatAttr(real_el_ty, 6.283185307179586);
2536 auto twoPi = rewriter.
create<arith::ConstantOp>(loc, twoPiAttr);
2538 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2540 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2543 Value valReal = args[0];
2544 Value valImag = args[1];
2545 Value sumReal = args[2];
2546 Value sumImag = args[3];
2549 Value oy = builder.
create<linalg::IndexOp>(loc, 1);
2550 Value ox = builder.
create<linalg::IndexOp>(loc, 2);
2551 Value iy = builder.
create<linalg::IndexOp>(loc, 3);
2552 Value ix = builder.
create<linalg::IndexOp>(loc, 4);
2556 auto iyXoy = builder.
create<index::MulOp>(loc, iy, oy);
2557 auto ixXox = builder.
create<index::MulOp>(loc, ix, ox);
2559 auto iyRem = builder.
create<index::RemUOp>(loc, iyXoy, dimH);
2560 auto ixRem = builder.
create<index::RemUOp>(loc, ixXox, dimW);
2563 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2565 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2567 auto yComponent = builder.
create<arith::DivFOp>(loc, iyRemFloat, constH);
2568 auto xComponent = builder.
create<arith::DivFOp>(loc, ixRemFloat, constW);
2570 auto sumXY = builder.
create<arith::AddFOp>(loc, yComponent, xComponent);
2571 auto angle = builder.
create<arith::MulFOp>(loc, twoPi, sumXY);
2574 angle = builder.
create<arith::MulFOp>(
2576 rewriter.
create<arith::ConstantOp>(
2582 auto cosAngle = builder.
create<math::CosOp>(loc, angle);
2583 auto sinAngle = builder.
create<math::SinOp>(loc, angle);
2585 auto rcos = builder.
create<arith::MulFOp>(loc, valReal, cosAngle);
2586 auto rsin = builder.
create<arith::MulFOp>(loc, valImag, sinAngle);
2587 auto realComponent = builder.
create<arith::AddFOp>(loc, rcos, rsin);
2589 auto icos = builder.
create<arith::MulFOp>(loc, valImag, cosAngle);
2590 auto isin = builder.
create<arith::MulFOp>(loc, valReal, sinAngle);
2592 auto imagComponent = builder.
create<arith::SubFOp>(loc, icos, isin);
2596 auto outReal = builder.
create<arith::AddFOp>(loc, sumReal, realComponent);
2597 auto outImag = builder.
create<arith::AddFOp>(loc, sumImag, imagComponent);
2603 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2604 indexingMaps, iteratorTypes, buildBody);
2625 PointwiseConverter<tosa::AddOp>,
2626 PointwiseConverter<tosa::SubOp>,
2627 PointwiseConverter<tosa::MulOp>,
2628 PointwiseConverter<tosa::IntDivOp>,
2629 PointwiseConverter<tosa::NegateOp>,
2630 PointwiseConverter<tosa::PowOp>,
2631 PointwiseConverter<tosa::ReciprocalOp>,
2632 PointwiseConverter<tosa::RsqrtOp>,
2633 PointwiseConverter<tosa::LogOp>,
2634 PointwiseConverter<tosa::ExpOp>,
2635 PointwiseConverter<tosa::AbsOp>,
2636 PointwiseConverter<tosa::SinOp>,
2637 PointwiseConverter<tosa::CosOp>,
2638 PointwiseConverter<tosa::TanhOp>,
2639 PointwiseConverter<tosa::ErfOp>,
2640 PointwiseConverter<tosa::BitwiseAndOp>,
2641 PointwiseConverter<tosa::BitwiseOrOp>,
2642 PointwiseConverter<tosa::BitwiseNotOp>,
2643 PointwiseConverter<tosa::BitwiseXorOp>,
2644 PointwiseConverter<tosa::LogicalAndOp>,
2645 PointwiseConverter<tosa::LogicalNotOp>,
2646 PointwiseConverter<tosa::LogicalOrOp>,
2647 PointwiseConverter<tosa::LogicalXorOp>,
2648 PointwiseConverter<tosa::CastOp>,
2649 PointwiseConverter<tosa::LogicalLeftShiftOp>,
2650 PointwiseConverter<tosa::LogicalRightShiftOp>,
2651 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2652 PointwiseConverter<tosa::ClzOp>,
2653 PointwiseConverter<tosa::SelectOp>,
2654 PointwiseConverter<tosa::GreaterOp>,
2655 PointwiseConverter<tosa::GreaterEqualOp>,
2656 PointwiseConverter<tosa::EqualOp>,
2657 PointwiseConverter<tosa::MaximumOp>,
2658 PointwiseConverter<tosa::MinimumOp>,
2659 PointwiseConverter<tosa::CeilOp>,
2660 PointwiseConverter<tosa::FloorOp>,
2661 PointwiseConverter<tosa::ClampOp>,
2662 PointwiseConverter<tosa::SigmoidOp>
2663 >(converter,
patterns->getContext());
2666 IdentityNConverter<tosa::IdentityOp>,
2667 ReduceConverter<tosa::ReduceAllOp>,
2668 ReduceConverter<tosa::ReduceAnyOp>,
2669 ReduceConverter<tosa::ReduceMinOp>,
2670 ReduceConverter<tosa::ReduceMaxOp>,
2671 ReduceConverter<tosa::ReduceSumOp>,
2672 ReduceConverter<tosa::ReduceProdOp>,
2680 TileConverter>(
patterns->getContext());
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, int64_t rank)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter)
static SmallVector< Value > expandInputRanks(PatternRewriter &rewriter, Location loc, ValueRange operands, int64_t rank)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static bool operandsAndResultsRanked(Operation *operation)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...