31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
40 static arith::ConstantOp
43 auto castedN =
static_cast<T
>(
44 cast<IntegerAttr>(op->
getAttr(attrName)).getValue().getSExtValue());
45 return rewriter.
create<arith::ConstantOp>(
57 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
58 return rewriter.
create<math::AbsFOp>(loc, resultTypes, args);
60 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
61 auto zero = rewriter.
create<arith::ConstantOp>(
63 auto neg = rewriter.
create<arith::SubIOp>(loc, zero, args[0]);
64 return rewriter.
create<arith::MaxSIOp>(loc, args[0], neg);
68 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
69 return rewriter.
create<arith::AddFOp>(loc, resultTypes, args);
71 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
72 return rewriter.
create<arith::AddIOp>(loc, resultTypes, args);
75 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
76 return rewriter.
create<arith::SubFOp>(loc, resultTypes, args);
78 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
79 return rewriter.
create<arith::SubIOp>(loc, resultTypes, args);
82 if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) {
83 if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
85 "Cannot have shift value for float");
88 return rewriter.
create<arith::MulFOp>(loc, resultTypes, args);
92 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
93 return rewriter.
create<arith::DivSIOp>(loc, resultTypes, args);
96 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
99 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, args[0]);
102 if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
106 cast<IntegerAttr>(op->
getAttr(
"shift")).getValue().getSExtValue();
109 rewriter.
create<arith::ConstantIntOp>(loc, shift, 8);
116 auto result = rewriter.
create<tosa::ApplyScaleOp>(
120 if (elementTy.isInteger(32))
123 return rewriter.
create<arith::TruncIOp>(loc, elementTy, result);
128 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
131 a = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], a);
133 b = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], b);
135 return rewriter.
create<arith::MulIOp>(loc, resultTypes, a, b);
139 if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
140 return rewriter.
create<arith::NegFOp>(loc, resultTypes, args);
142 if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
143 !cast<tosa::NegateOp>(op).getQuantizationInfo()) {
146 return rewriter.
create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
149 if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
150 cast<tosa::NegateOp>(op).getQuantizationInfo()) {
151 auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
152 int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
153 int64_t inZp = quantizationInfo.value().getInputZp();
154 int64_t outZp = quantizationInfo.value().getOutputZp();
157 int64_t zpAdd = inZp + outZp;
158 int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
163 int intermediateBitWidth = 64;
164 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
165 intermediateBitWidth = 16;
166 }
else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
167 intermediateBitWidth = 32;
168 }
else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
169 intermediateBitWidth = 48;
173 Value zpAddValue = rewriter.
create<arith::ConstantOp>(
178 auto ext = rewriter.
create<arith::ExtSIOp>(loc, intermediateType, args[0]);
179 auto sub = rewriter.
create<arith::SubIOp>(loc, zpAddValue, ext);
183 loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
186 loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
192 return rewriter.
create<arith::TruncIOp>(loc, elementTy,
clamp);
196 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
197 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
200 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
201 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
204 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
206 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
207 auto allOnes = rewriter.
create<arith::ConstantOp>(loc, allOnesAttr);
208 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
212 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
213 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
216 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
217 return rewriter.
create<arith::ShLIOp>(loc, resultTypes, args);
220 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
221 return rewriter.
create<arith::ShRUIOp>(loc, resultTypes, args);
224 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
225 auto result = rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args);
226 auto round = cast<BoolAttr>(op->
getAttr(
"round")).getValue();
240 auto shiftValueGreaterThanZero = rewriter.
create<arith::CmpIOp>(
241 loc, arith::CmpIPredicate::sgt, args[1], zero);
245 rewriter.
create<arith::SubIOp>(loc, resultTypes, args[1], one);
247 rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
250 rewriter.
create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
252 rewriter.
create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
254 auto shouldRound = rewriter.
create<arith::AndIOp>(
255 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
257 rewriter.
create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
258 return rewriter.
create<arith::AddIOp>(loc, resultTypes, result, extended);
262 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
263 return rewriter.
create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
267 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
268 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
271 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
272 auto one = rewriter.
create<arith::ConstantOp>(
274 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], one);
278 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
279 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
282 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
283 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
286 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
287 return rewriter.
create<mlir::math::PowFOp>(loc, resultTypes, args);
290 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
291 return rewriter.
create<mlir::math::RsqrtOp>(loc, resultTypes, args);
294 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
295 return rewriter.
create<mlir::math::LogOp>(loc, resultTypes, args);
298 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
299 return rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, args);
302 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
303 return rewriter.
create<mlir::math::SinOp>(loc, resultTypes, args);
306 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
307 return rewriter.
create<mlir::math::CosOp>(loc, resultTypes, args);
310 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
311 return rewriter.
create<mlir::math::TanhOp>(loc, resultTypes, args);
314 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
315 return rewriter.
create<mlir::math::ErfOp>(loc, resultTypes, args);
318 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
319 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
322 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
323 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
327 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
328 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
331 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
332 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
336 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
337 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
340 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
341 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
345 if (isa<tosa::SelectOp>(op)) {
347 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
348 return rewriter.
create<arith::SelectOp>(loc, args[0], args[1], args[2]);
352 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
353 return rewriter.
create<arith::MaximumFOp>(loc, args[0], args[1]);
356 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
357 return rewriter.
create<arith::MaxSIOp>(loc, args[0], args[1]);
361 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
362 return rewriter.
create<arith::MinimumFOp>(loc, args[0], args[1]);
365 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
366 return rewriter.
create<arith::MinSIOp>(loc, args[0], args[1]);
370 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
371 return rewriter.
create<math::CeilOp>(loc, resultTypes, args);
374 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
375 return rewriter.
create<math::FloorOp>(loc, resultTypes, args);
378 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
379 bool losesInfo =
false;
380 APFloat minApf = cast<FloatAttr>(op->
getAttr(
"min_fp")).getValue();
381 APFloat maxApf = cast<FloatAttr>(op->
getAttr(
"max_fp")).getValue();
382 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
383 APFloat::rmNearestTiesToEven, &losesInfo);
384 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
385 APFloat::rmNearestTiesToEven, &losesInfo);
386 auto min = rewriter.
create<arith::ConstantOp>(
387 loc, elementTy, rewriter.
getFloatAttr(elementTy, minApf));
388 auto max = rewriter.
create<arith::ConstantOp>(
389 loc, elementTy, rewriter.
getFloatAttr(elementTy, maxApf));
393 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
394 auto intTy = cast<IntegerType>(elementTy);
396 cast<IntegerAttr>(op->
getAttr(
"min_int")).getValue().getSExtValue();
398 cast<IntegerAttr>(op->
getAttr(
"max_int")).getValue().getSExtValue();
402 if (intTy.isUnsignedInteger()) {
403 minRepresentable = 0;
404 if (intTy.getIntOrFloatBitWidth() <= 63) {
405 maxRepresentable = (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
408 }
else if(intTy.getIntOrFloatBitWidth() <= 64) {
410 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
412 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
421 auto minVal = rewriter.
create<arith::ConstantIntOp>(
422 loc,
min, intTy.getIntOrFloatBitWidth());
423 auto maxVal = rewriter.
create<arith::ConstantIntOp>(
424 loc,
max, intTy.getIntOrFloatBitWidth());
426 intTy.isUnsignedInteger());
430 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
433 auto negate = rewriter.
create<arith::NegFOp>(loc, resultTypes, args[0]);
434 auto exp = rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, negate);
435 auto added = rewriter.
create<arith::AddFOp>(loc, resultTypes, exp, one);
436 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, added);
440 if (isa<tosa::CastOp>(op)) {
441 Type srcTy = elementTy;
442 Type dstTy = resultTypes.front();
449 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
450 return rewriter.
create<arith::ExtFOp>(loc, resultTypes, args,
453 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
454 return rewriter.
create<arith::TruncFOp>(loc, resultTypes, args,
458 if (srcTy.
isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
459 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes, args,
462 if (srcTy.
isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
463 return rewriter.
create<arith::ExtUIOp>(loc, resultTypes, args,
469 auto unrealizedCast =
471 .
create<UnrealizedConversionCastOp>(
475 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes[0],
480 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
481 return rewriter.
create<arith::SIToFPOp>(loc, resultTypes, args,
485 if (isa<FloatType>(srcTy) && dstTy.
isInteger(1)) {
488 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
492 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
493 auto rounded = rewriter.
create<math::RoundEvenOp>(loc, args[0]);
495 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
499 APFloat::semanticsMaxExponent(fltSemantics)) {
502 auto conv = rewriter.
create<arith::FPToSIOp>(loc, dstTy, rounded);
503 auto posInf = rewriter.
create<arith::ConstantOp>(
505 APFloat::getInf(fltSemantics)));
506 auto negInf = rewriter.
create<arith::ConstantOp>(
509 APFloat::getInf(fltSemantics,
true)));
510 auto overflow = rewriter.
create<arith::CmpFOp>(
511 loc, arith::CmpFPredicate::UEQ, rounded, posInf);
512 auto underflow = rewriter.
create<arith::CmpFOp>(
513 loc, arith::CmpFPredicate::UEQ, rounded, negInf);
514 auto intMin = rewriter.
create<arith::ConstantOp>(
517 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
518 auto intMax = rewriter.
create<arith::ConstantOp>(
521 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
523 rewriter.
create<arith::SelectOp>(loc, overflow, intMax, conv);
524 return rewriter.
create<arith::SelectOp>(loc, underflow, intMin,
528 auto intMinFP = rewriter.
create<arith::ConstantOp>(
535 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
541 auto intMaxFP = rewriter.
create<arith::ConstantOp>(
549 return rewriter.
create<arith::FPToSIOp>(loc, dstTy, clamped);
556 auto intMaxPlusOneFP = rewriter.
create<arith::ConstantOp>(
563 auto intMax = rewriter.
create<arith::ConstantOp>(
568 rewriter.
create<arith::MaximumFOp>(loc, rounded, intMinFP);
570 rewriter.
create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
571 auto overflow = rewriter.
create<arith::CmpFOp>(
572 loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
573 return rewriter.
create<arith::SelectOp>(loc, overflow, intMax,
579 if (isa<IntegerType>(srcTy) && dstTy.
isInteger(1)) {
580 Value zero = rewriter.
create<arith::ConstantIntOp>(
582 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
586 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
587 return rewriter.
create<arith::ExtSIOp>(loc, resultTypes, args,
590 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
591 return rewriter.
create<arith::TruncIOp>(loc, dstTy, args[0]);
596 op,
"unhandled op for linalg body calculation for elementwise op");
603 auto shapedType = dyn_cast<ShapedType>(tensor.
getType());
604 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
605 int64_t numExtraDims = rank - shapedType.getRank();
606 assert(numExtraDims >= 0 &&
"cannot expand tensor to a lower rank");
612 shapedType.getRank());
614 for (index = 0; index <= numExtraDims; index++)
615 reassociationIndices[0].push_back(index);
616 for (
size_t position = 1; position < reassociationIndices.size(); position++)
617 reassociationIndices[position].push_back(index++);
621 for (index = 0; index < numExtraDims; index++)
622 resultShape.push_back(1);
623 for (
auto size : shapedType.getShape())
624 resultShape.push_back(size);
629 return rewriter.
create<tensor::ExpandShapeOp>(loc, resultType, tensor,
630 reassociationIndices);
636 return llvm::map_to_vector(operands, [&](
Value operand) {
637 return expandRank(rewriter, loc, operand, rank);
648 auto [it, inserted] = indexPool.try_emplace(index);
657 auto indexValue =
createIndex(rewriter, loc, indexPool, index);
658 return rewriter.
create<tensor::DimOp>(loc, tensor, indexValue).getResult();
664 auto shapedType = dyn_cast<ShapedType>(tensor.
getType());
665 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
666 assert(index >= 0 && index < shapedType.getRank() &&
"index out of bounds");
667 if (shapedType.isDynamicDim(index))
668 return getTensorDim(rewriter, loc, indexPool, tensor, index);
669 return rewriter.
getIndexAttr(shapedType.getDimSize(index));
673 auto isRanked = [](
Value value) {
674 return isa<RankedTensorType>(value.getType());
676 return llvm::all_of(operation->
getOperands(), isRanked) &&
677 llvm::all_of(operation->
getResults(), isRanked);
690 static std::pair<OpFoldResult, Value>
696 for (
auto operand : operands) {
697 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
698 if (!ShapedType::isDynamic(size) && size > 1)
703 auto operandsWithDynamicDim =
704 llvm::to_vector(llvm::make_filter_range(operands, [&](
Value operand) {
705 return cast<RankedTensorType>(operand.
getType()).isDynamicDim(dim);
709 if (operandsWithDynamicDim.empty())
716 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
717 if (operandsWithDynamicDim.size() == 1)
718 return {targetSize, operandsWithDynamicDim[0]};
721 for (
size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
723 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
724 targetSize = rewriter.
create<arith::MaxUIOp>(loc, targetSize, nextSize);
726 return {targetSize,
nullptr};
734 assert(!operands.empty());
735 auto rank = cast<RankedTensorType>(operands.front().
getType()).getRank();
738 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
739 auto [targetSize, masterOperand] =
741 targetShape.push_back(targetSize);
742 masterOperands.push_back(masterOperand);
744 return {targetShape, masterOperands};
750 Value masterOperand) {
752 auto rankedTensorType = cast<RankedTensorType>(operand.
getType());
753 if (!rankedTensorType.isDynamicDim(dim))
760 if (operand == masterOperand)
764 auto rank = rankedTensorType.getRank();
766 for (
auto index : llvm::seq<int64_t>(0, rank)) {
769 affineExprs.push_back(affineExpr);
771 auto broadcastAffineMap =
777 auto one =
createIndex(rewriter, loc, indexPool, 1);
778 auto runtimeSize =
getTensorDim(rewriter, loc, indexPool, operand, dim);
779 auto broadcastNecessary = rewriter.
create<arith::CmpIOp>(
780 loc, arith::CmpIPredicate::eq, runtimeSize, one);
790 for (
auto index : llvm::seq<int64_t>(0, rank)) {
791 auto size = index == dim ? targetSize
794 outputTensorShape.push_back(size);
796 Value outputTensor = opBuilder.
create<tensor::EmptyOp>(
797 loc, outputTensorShape, rankedTensorType.getElementType());
802 .
create<linalg::GenericOp>(
803 loc, outputTensor.
getType(), operand, outputTensor, affineMaps,
807 opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
812 auto castResultTensor = rewriter.
createOrFold<tensor::CastOp>(
813 loc, operand.
getType(), resultTensor);
816 opBuilder.create<scf::YieldOp>(loc, castResultTensor);
821 opBuilder.
create<scf::YieldOp>(loc, operand);
825 auto ifOp = rewriter.
create<scf::IfOp>(loc, broadcastNecessary,
826 emitThenRegion, emitElseRegion);
834 int64_t rank = cast<RankedTensorType>(operand.
getType()).getRank();
835 assert((int64_t)targetShape.size() == rank);
836 assert((int64_t)masterOperands.size() == rank);
837 for (
auto index : llvm::seq<int64_t>(0, rank))
840 targetShape[index], masterOperands[index]);
850 if (operands.size() == 1)
854 return llvm::map_to_vector(operands, [&](
Value operand) {
856 targetShape, masterOperands);
866 auto resultType = cast_or_null<RankedTensorType>(
871 Value outputTensor = rewriter.
create<tensor::EmptyOp>(
872 loc, targetShape, resultType.getElementType());
877 auto rank = resultType.getRank();
878 auto affineMaps = llvm::map_to_vector(operands, [&](
Value operand) {
879 auto shape = cast<ShapedType>(operand.
getType()).getShape();
884 affineExprs.push_back(affineExpr);
891 bool encounteredError =
false;
892 auto linalgOp = rewriter.
create<linalg::GenericOp>(
893 loc, outputTensor.
getType(), operands, outputTensor, affineMaps,
898 {resultType.getElementType()}, rewriter);
900 encounteredError =
true;
903 opBuilder.create<linalg::YieldOp>(loc, opResult);
905 if (encounteredError)
907 operation,
"unable to create linalg.generic body for elementwise op");
910 auto castResult = rewriter.
createOrFold<tensor::CastOp>(
911 loc, resultType, linalgOp->getResult(0));
912 rewriter.
replaceOp(operation, castResult);
922 assert(operation->
getNumResults() == 1 &&
"elementwise op expects 1 result");
924 "elementwise op expects at least 1 operand");
927 "Unranked tensors not supported");
931 auto loc = operation->
getLoc();
935 auto [targetShape, masterOperands] =
938 rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
940 targetShape, converter);
947 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
950 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
953 if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy))
956 if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy))
959 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
961 elementTy, APFloat::getLargest(
962 cast<FloatType>(elementTy).getFloatSemantics(),
false));
964 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
968 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
970 elementTy, APFloat::getLargest(
971 cast<FloatType>(elementTy).getFloatSemantics(),
true));
973 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
977 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
980 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
983 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
985 elementTy, APFloat::getLargest(
986 cast<FloatType>(elementTy).getFloatSemantics(),
true));
988 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1002 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1003 return rewriter.
create<arith::AddFOp>(loc, args);
1006 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1007 return rewriter.
create<arith::AddIOp>(loc, args);
1010 if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) {
1011 return rewriter.
create<arith::MulFOp>(loc, args);
1014 if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) {
1015 return rewriter.
create<arith::MulIOp>(loc, args);
1018 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1019 return rewriter.
create<arith::MinimumFOp>(loc, args[0], args[1]);
1022 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1023 return rewriter.
create<arith::MinSIOp>(loc, args[0], args[1]);
1026 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1027 return rewriter.
create<arith::MaximumFOp>(loc, args[0], args[1]);
1030 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1031 return rewriter.
create<arith::MaxSIOp>(loc, args[0], args[1]);
1034 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1035 return rewriter.
create<arith::AndIOp>(loc, args);
1037 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1038 return rewriter.
create<arith::OrIOp>(loc, args);
1051 auto elementTy = resultTy.getElementType();
1056 for (
unsigned i = 0; i < inputTy.getRank(); i++) {
1058 reduceShape.push_back(inputTy.getDimSize(i));
1059 if (inputTy.isDynamicDim(i))
1060 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1067 .
create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1074 op,
"No initial value found for reduction operation");
1076 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
1077 auto filledTensor = rewriter
1082 bool didEncounterError =
false;
1083 auto linalgOp = rewriter.
create<linalg::ReduceOp>(
1084 loc, input, filledTensor, axis,
1087 op, blockArgs, elementTy, rewriter);
1089 didEncounterError =
true;
1091 nestedBuilder.create<linalg::YieldOp>(loc, result);
1094 if (!didEncounterError)
1096 op,
"unable to create linalg.generic body for reduce op");
1099 uint64_t expandInputRank =
1100 cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
1101 reassociationMap.resize(expandInputRank);
1103 for (uint64_t i = 0; i < expandInputRank; i++) {
1104 int32_t dimToPush = i > axis ? i + 1 : i;
1108 if (expandInputRank != 0) {
1109 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1110 reassociationMap[expandedDim].push_back(
1119 op, resultTy, linalgOp.
getResults()[0], reassociationMap);
1125 template <
typename SrcOp>
1132 matchAndRewrite(SrcOp op, OpAdaptor operands,
1135 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1143 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1146 auto input = op.getInput();
1147 auto inputTy = cast<ShapedType>(op.getInput().getType());
1148 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1149 unsigned rank = inputTy.getRank();
1152 if (op.getDoubleRound() && !op.getScale32())
1153 return rewriter.notifyMatchFailure(
1154 op,
"tosa.rescale requires scale32 for double_round to be true");
1157 for (
int i = 0; i < outputTy.getRank(); i++) {
1158 if (outputTy.isDynamicDim(i)) {
1159 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1168 for (
int i = 0, s = multiplierValues.size(); i < s; i++) {
1169 if (shiftValues[i] > 63) {
1171 multiplierValues[i] = 0;
1178 op.getDoubleRound() &&
1179 llvm::any_of(shiftValues, [](int32_t v) {
return v > 31; });
1182 rewriter.getMultiDimIdentityMap(rank)};
1187 Value multiplierConstant;
1188 int64_t multiplierArg = 0;
1189 if (multiplierValues.size() == 1) {
1190 multiplierConstant = rewriter.create<arith::ConstantOp>(
1191 loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1194 rewriter.getAffineDimExpr(rank - 1)};
1195 auto multiplierType =
1197 rewriter.getI32Type());
1198 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1203 rewriter.getContext()));
1205 multiplierArg = indexingMaps.size() - 1;
1210 Value shiftConstant;
1211 int64_t shiftArg = 0;
1212 if (shiftValues.size() == 1) {
1213 shiftConstant = rewriter.create<arith::ConstantOp>(
1214 loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1217 rewriter.getAffineDimExpr(rank - 1)};
1220 rewriter.getIntegerType(8));
1221 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1225 rewriter.getContext()));
1226 shiftArg = indexingMaps.size() - 1;
1230 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1233 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1234 loc, outputTy.getShape(), outputTy.getElementType(),
1237 auto linalgOp = rewriter.create<linalg::GenericOp>(
1238 loc, outputTy, genericInputs,
ValueRange{emptyTensor}, indexingMaps,
1242 Value value = blockArgs[0];
1250 auto inputZp = createConstFromIntAttribute<int32_t>(
1251 op,
"input_zp", nestedBuilder.getIntegerType(inBitwidth),
1253 auto outputZp = createConstFromIntAttribute<int32_t>(
1254 op,
"output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1256 Value multiplier = multiplierConstant ? multiplierConstant
1257 : blockArgs[multiplierArg];
1258 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1262 value = nestedBuilder
1263 .create<UnrealizedConversionCastOp>(
1265 nestedBuilder.getIntegerType(
1269 value = nestedBuilder.create<arith::ExtUIOp>(
1270 nestedLoc, nestedBuilder.getI32Type(), value);
1272 value = nestedBuilder.create<arith::ExtSIOp>(
1273 nestedLoc, nestedBuilder.getI32Type(), value);
1278 nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1280 value = nestedBuilder.create<tosa::ApplyScaleOp>(
1281 loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1282 nestedBuilder.getBoolAttr(doubleRound));
1286 nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1289 IntegerType outIntType =
1290 cast<IntegerType>(blockArgs.back().getType());
1291 unsigned outBitWidth = outIntType.getWidth();
1293 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1294 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1297 if (outIntType.isUnsignedInteger()) {
1299 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1302 auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1303 loc, nestedBuilder.getI32IntegerAttr(intMin));
1304 auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1305 loc, nestedBuilder.getI32IntegerAttr(intMax));
1308 nestedBuilder,
false);
1310 if (outIntType.getWidth() < 32) {
1311 value = nestedBuilder.create<arith::TruncIOp>(
1312 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1315 if (outIntType.isUnsignedInteger()) {
1316 value = nestedBuilder
1317 .create<UnrealizedConversionCastOp>(nestedLoc,
1323 nestedBuilder.create<linalg::YieldOp>(loc, value);
1326 rewriter.replaceOp(op, linalgOp->getResults());
1338 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1342 auto input = op.getInput();
1343 auto inputTy = cast<RankedTensorType>(input.getType());
1344 auto resultTy = cast<RankedTensorType>(op.getType());
1345 const bool isBilinear = op.getMode() ==
"BILINEAR";
1347 auto inputH = inputTy.getDimSize(1);
1348 auto inputW = inputTy.getDimSize(2);
1349 auto outputH = resultTy.getDimSize(1);
1350 auto outputW = resultTy.getDimSize(2);
1352 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1353 return rewriter.notifyMatchFailure(
1354 op,
"tosa.resize is not a pure 1x1->1x1 image operation");
1357 if (op.getMode() !=
"NEAREST_NEIGHBOR" && op.getMode() !=
"BILINEAR")
1358 return rewriter.notifyMatchFailure(
1359 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1361 if (inputTy == resultTy) {
1362 rewriter.replaceOp(op, input);
1370 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1371 reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1372 reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1373 reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1377 inputTy.getElementType());
1378 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1383 if (inputTy.isDynamicDim(0))
1384 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1385 if (inputTy.isDynamicDim(3))
1386 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1389 auto genericTy = collapseTy.clone(resultTy.getElementType());
1390 Value empty = builder.create<tensor::EmptyOp>(
1391 genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1392 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1394 utils::IteratorType::parallel);
1396 auto generic = builder.create<linalg::GenericOp>(
1400 Value value = args[0];
1402 if (inputTy.getElementType() != resultTy.getElementType()) {
1404 b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1406 if (isBilinear && scale[0] != 0) {
1407 Value scaleY = b.create<arith::ConstantOp>(
1408 loc, b.getI32IntegerAttr(scale[0]));
1409 value = b.create<arith::MulIOp>(loc, value, scaleY);
1412 if (isBilinear && scale[2] != 0) {
1413 Value scaleX = b.create<arith::ConstantOp>(
1414 loc, b.getI32IntegerAttr(scale[2]));
1415 value = b.create<arith::MulIOp>(loc, value, scaleX);
1419 b.create<linalg::YieldOp>(loc, value);
1422 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1423 op, resultTy,
generic.
getResults()[0], reassociationMap);
1435 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1439 auto input = op.getInput();
1440 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1441 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1443 if (!inputTy || !resultTy)
1444 return rewriter.notifyMatchFailure(op,
1445 "requires ranked input/output types");
1447 auto batch = inputTy.getDimSize(0);
1448 auto channels = inputTy.getDimSize(3);
1449 auto inputH = inputTy.getDimSize(1);
1450 auto inputW = inputTy.getDimSize(2);
1451 auto outputH = resultTy.getDimSize(1);
1452 auto outputW = resultTy.getDimSize(2);
1454 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1455 return rewriter.notifyMatchFailure(
1456 op,
"tosa.resize has no broadcasting behavior");
1461 resizeShape.push_back(batch);
1462 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1463 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1464 resizeShape.push_back(channels);
1466 auto resizeTy = resultTy.clone(resizeShape);
1468 builder.create<tosa::ResizeOp>(resizeTy, input, op->
getAttrs());
1472 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1473 reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1475 reassociationMap.push_back({});
1476 reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1478 reassociationMap.push_back({});
1479 reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1483 collapseShape.push_back(outputH);
1485 collapseShape.push_back(outputW);
1486 collapseShape.push_back(channels);
1488 auto collapseTy = resultTy.clone(collapseShape);
1489 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1494 if (inputTy.isDynamicDim(0))
1495 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1496 if (inputTy.isDynamicDim(3))
1497 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1500 utils::IteratorType::parallel);
1501 Value empty = builder.create<tensor::EmptyOp>(
1502 resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1506 inputExprs.push_back(rewriter.getAffineDimExpr(1));
1508 inputExprs.push_back(rewriter.getAffineDimExpr(2));
1509 inputExprs.push_back(rewriter.getAffineDimExpr(3));
1512 inputExprs, rewriter.getContext());
1514 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1515 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1519 Value value = args[0];
1520 b.create<linalg::YieldOp>(loc, value);
1531 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1535 auto input = op.getInput();
1536 auto inputTy = cast<ShapedType>(input.getType());
1537 auto resultTy = cast<ShapedType>(op.getType());
1538 auto resultETy = resultTy.getElementType();
1540 bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1541 auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1543 auto imageH = inputTy.getShape()[1];
1544 auto imageW = inputTy.getShape()[2];
1546 auto dynamicDimsOr =
1548 if (!dynamicDimsOr.has_value())
1549 return rewriter.notifyMatchFailure(
1550 op,
"unable to get dynamic dimensions of tosa.resize");
1552 if (op.getMode() !=
"NEAREST_NEIGHBOR" && op.getMode() !=
"BILINEAR")
1553 return rewriter.notifyMatchFailure(
1554 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1557 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1558 auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1560 auto genericOp = b.create<linalg::GenericOp>(
1563 Value resize = genericOp.getResult(0);
1567 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1569 Value batch = b.create<linalg::IndexOp>(0);
1570 Value y = b.create<linalg::IndexOp>(1);
1571 Value x = b.create<linalg::IndexOp>(2);
1572 Value channel = b.create<linalg::IndexOp>(3);
1575 b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1576 Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1577 Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1578 Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1580 Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1581 Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1587 Value yScaleN, yScaleD, xScaleN, xScaleD;
1588 yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1589 yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1590 xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1591 xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1593 Value yOffset, xOffset, yBorder, xBorder;
1594 yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1595 xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1596 yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1597 xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1610 Value val = b.create<arith::MulIOp>(in, scaleD);
1611 val = b.create<arith::AddIOp>(val, offset);
1612 index = b.create<arith::FloorDivSIOp>(val, scaleN);
1616 Value r = b.create<arith::RemSIOp>(val, scaleN);
1617 Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1618 Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1619 delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1634 Value val = b.create<arith::MulIOp>(in, scaleD);
1635 val = b.create<arith::AddIOp>(val, offset);
1636 index = b.create<arith::DivSIOp>(val, scaleN);
1637 delta = b.create<arith::MulIOp>(index, scaleN);
1638 delta = b.create<arith::SubIOp>(val, delta);
1641 Value ix, iy, dx, dy;
1642 if (floatingPointMode) {
1643 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1644 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1646 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1647 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1650 if (op.getMode() ==
"NEAREST_NEIGHBOR") {
1651 auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1653 auto getNearestIndexAndClamp = [&](
Value val,
Value dval,
Value scale,
1657 return b.create<arith::ConstantIndexOp>(0);
1661 if (floatingPointMode) {
1662 auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1663 pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1665 Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1666 pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1670 auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1671 val = b.create<arith::AddIOp>(val, offset);
1673 return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1676 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1677 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1679 Value result = b.create<tensor::ExtractOp>(
1682 b.create<linalg::YieldOp>(result);
1685 assert(op.getMode() ==
"BILINEAR");
1687 auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1689 auto getClampedIdxs = [&](
Value &val0,
Value &val1,
int size,
Value in,
1692 val1 = b.create<arith::AddIOp>(val0, oneVal);
1697 val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1698 val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1706 Value x0, x1, y0, y1;
1707 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1708 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1710 Value y0x0 = b.create<tensor::ExtractOp>(
1712 Value y0x1 = b.create<tensor::ExtractOp>(
1714 Value y1x0 = b.create<tensor::ExtractOp>(
1716 Value y1x1 = b.create<tensor::ExtractOp>(
1719 if (floatingPointMode) {
1721 b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1727 Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1728 Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1729 Value mul1 = b.create<arith::MulFOp>(val1, delta);
1730 return b.create<arith::AddFOp>(mul0, mul1);
1736 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1741 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1745 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1746 b.create<linalg::YieldOp>(result);
1749 y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1750 y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1751 y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1752 y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1755 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1756 dx = b.create<arith::ExtSIOp>(resultETy, dx);
1757 dy = b.create<arith::ExtSIOp>(resultETy, dy);
1760 Value yScaleNExt = yScaleN;
1761 Value xScaleNExt = xScaleN;
1763 const int64_t scaleBitwidth =
1765 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
1766 yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
1767 xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
1771 Value scale,
int inputSize,
1774 return b.create<arith::MulIOp>(val0, scale);
1775 Value weight0 = b.create<arith::SubIOp>(scale, weight1);
1776 Value mul0 = b.create<arith::MulIOp>(val0, weight0);
1777 Value mul1 = b.create<arith::MulIOp>(val1, weight1);
1778 return b.create<arith::AddIOp>(mul0, mul1);
1781 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
1782 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
1784 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
1785 b.create<linalg::YieldOp>(result);
1790 rewriter.replaceOp(op, resize);
1798 template <
typename SrcOp>
1803 LogicalResult matchAndRewrite(SrcOp op,
1805 rewriter.replaceOp(op, op.getOperation()->
getOperands());
1810 template <
typename SrcOp>
1815 LogicalResult matchAndRewrite(SrcOp reduceOp,
1825 LogicalResult matchAndRewrite(tosa::ReverseOp op,
1828 Value input = op.getInput();
1829 auto inputTy = cast<ShapedType>(input.
getType());
1830 auto resultTy = cast<ShapedType>(op.getType());
1831 auto axis = op.getAxis();
1834 for (
int i = 0; i < inputTy.getRank(); i++) {
1835 if (inputTy.isDynamicDim(i)) {
1836 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1840 Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1843 auto emptyTensor = rewriter
1844 .create<tensor::EmptyOp>(loc, inputTy.getShape(),
1845 inputTy.getElementType(),
1849 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1851 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1856 for (
unsigned int i = 0; i < inputTy.getRank(); i++) {
1858 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1860 auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1862 rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1863 index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1867 indices.push_back(index);
1870 auto extract = nestedBuilder.create<tensor::ExtractOp>(
1871 nestedLoc, input, indices);
1872 nestedBuilder.create<linalg::YieldOp>(op.
getLoc(),
1873 extract.getResult());
1887 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1890 auto input = op.getInput1();
1891 auto inputTy = cast<ShapedType>(input.getType());
1892 auto inputShape = inputTy.getShape();
1893 auto resultTy = cast<ShapedType>(op.getType());
1894 auto elementTy = inputTy.getElementType();
1895 int64_t rank = inputTy.getRank();
1901 for (
int i = 0; i < rank; i++) {
1902 int64_t dim = multiples[i];
1903 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
1904 genericShape.push_back(inputShape[i]);
1908 for (
int i = 0; i < inputTy.getRank(); i++) {
1909 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1910 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1914 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
1915 op.
getLoc(), genericShape, elementTy, dynDims);
1919 dimExprs.reserve(rank);
1920 for (
unsigned i = 0; i < rank; ++i)
1923 auto readAffineMap =
1930 auto genericOp = rewriter.
create<linalg::GenericOp>(
1935 nestedBuilder.create<linalg::YieldOp>(op.
getLoc(), *args.begin());
1962 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
1964 auto loc = argmaxOp.getLoc();
1965 Value input = argmaxOp.getInput();
1966 auto inputTy = cast<ShapedType>(input.
getType());
1967 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
1968 auto inElementTy = inputTy.getElementType();
1969 auto outElementTy = resultTy.getElementType();
1970 int axis = argmaxOp.getAxis();
1973 if (!isa<IntegerType>(outElementTy))
1976 "tosa.arg_max to linalg.* requires integer-like result type");
1979 for (
int i = 0; i < inputTy.getRank(); i++) {
1980 if (inputTy.isDynamicDim(i) && i != axis) {
1981 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1986 auto emptyTensorIdx = rewriter
1987 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
1988 outElementTy, dynDims)
1990 auto fillValueIdx = rewriter.
create<arith::ConstantOp>(
1992 auto filledTensorIdx =
1999 auto emptyTensorMax = rewriter
2000 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2001 inElementTy, dynDims)
2003 auto fillValueMaxAttr =
2006 if (!fillValueMaxAttr)
2008 argmaxOp,
"unsupported tosa.argmax element type");
2011 rewriter.
create<arith::ConstantOp>(loc, fillValueMaxAttr);
2012 auto filledTensorMax =
2021 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2022 iteratorTypes[axis] = utils::IteratorType::reduction;
2026 for (
int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2032 bool didEncounterError =
false;
2035 auto linalgOp = rewriter.
create<linalg::GenericOp>(
2037 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2040 auto newValue = blockArgs[0];
2041 auto oldIndex = blockArgs[1];
2042 auto oldValue = blockArgs[2];
2044 Value newIndex = rewriter.
create<arith::IndexCastOp>(
2045 nestedLoc, oldIndex.getType(),
2046 rewriter.
create<linalg::IndexOp>(loc, axis));
2049 if (isa<FloatType>(inElementTy)) {
2050 predicate = rewriter.
create<arith::CmpFOp>(
2051 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2052 }
else if (isa<IntegerType>(inElementTy)) {
2053 predicate = rewriter.
create<arith::CmpIOp>(
2054 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2056 didEncounterError =
true;
2060 auto resultMax = rewriter.
create<arith::SelectOp>(
2061 nestedLoc, predicate, newValue, oldValue);
2062 auto resultIndex = rewriter.
create<arith::SelectOp>(
2063 nestedLoc, predicate, newIndex, oldIndex);
2064 nestedBuilder.
create<linalg::YieldOp>(
2065 nestedLoc,
ValueRange({resultIndex, resultMax}));
2068 if (didEncounterError)
2070 argmaxOp,
"unsupported tosa.argmax element type");
2072 rewriter.
replaceOp(argmaxOp, linalgOp.getResult(0));
2081 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2083 auto input = adaptor.getOperands()[0];
2084 auto indices = adaptor.getOperands()[1];
2087 dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
2088 auto resultTy = cast<ShapedType>(op.getType());
2093 auto dynamicDims = inferDynamicDimsForGather(
2094 rewriter, op.
getLoc(), adaptor.getValues(), adaptor.getIndices());
2096 auto resultElementTy = resultTy.getElementType();
2101 .
create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2107 resultTy.getRank(), 0,
2108 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2112 auto genericOp = rewriter.
create<linalg::GenericOp>(
2117 auto indexValue = args[0];
2118 auto index0 = rewriter.
create<linalg::IndexOp>(loc, 0);
2119 Value index1 = rewriter.
create<arith::IndexCastOp>(
2121 auto index2 = rewriter.
create<linalg::IndexOp>(loc, 2);
2122 Value extract = rewriter.
create<tensor::ExtractOp>(
2123 loc, input,
ValueRange{index0, index1, index2});
2124 rewriter.
create<linalg::YieldOp>(loc, extract);
2126 rewriter.
replaceOp(op, genericOp.getResult(0));
2136 auto addDynamicDimension = [&](
Value source, int64_t dim) {
2138 if (
auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2139 results.push_back(dimValue);
2142 addDynamicDimension(values, 0);
2143 addDynamicDimension(indices, 1);
2144 addDynamicDimension(values, 2);
2156 LogicalResult matchAndRewrite(tosa::TableOp op,
2159 Value input = op.getInput();
2161 auto inputTy = cast<ShapedType>(input.
getType());
2162 auto tableTy = cast<ShapedType>(
table.getType());
2163 auto resultTy = cast<ShapedType>(op.getType());
2165 auto inputElementTy = inputTy.getElementType();
2166 auto tableElementTy = tableTy.getElementType();
2167 auto resultElementTy = resultTy.getElementType();
2170 for (
int i = 0; i < resultTy.getRank(); ++i) {
2171 if (inputTy.isDynamicDim(i)) {
2177 auto emptyTensor = rewriter
2178 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2179 resultElementTy, dynDims)
2186 auto genericOp = rewriter.
create<linalg::GenericOp>(
2189 rewriter.
replaceOp(op, genericOp.getResult(0));
2194 &genericOp.getRegion(), genericOp.getRegion().end(),
2195 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2199 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2200 resultElementTy.isInteger(8)) {
2201 Value index = rewriter.
create<arith::IndexCastOp>(
2203 Value offset = rewriter.
create<arith::ConstantIndexOp>(loc, 128);
2208 rewriter.
create<linalg::YieldOp>(loc, extract);
2212 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2213 resultElementTy.isInteger(32)) {
2217 auto offset = rewriter.
create<arith::ConstantOp>(
2219 auto seven = rewriter.
create<arith::ConstantOp>(
2221 auto one = rewriter.
create<arith::ConstantOp>(
2223 auto b1111111 = rewriter.
create<arith::ConstantOp>(
2230 auto extendAdd = rewriter.
create<arith::AddIOp>(loc, extend, offset);
2231 Value index = rewriter.
create<arith::ShRUIOp>(loc, extendAdd, seven);
2233 rewriter.
create<arith::AndIOp>(loc, extendAdd, b1111111);
2238 Value indexPlusOne = rewriter.
create<arith::AddIOp>(loc, index, one);
2240 index = rewriter.
create<arith::IndexCastOp>(
2242 indexPlusOne = rewriter.
create<arith::IndexCastOp>(
2257 Value baseScaled = rewriter.
create<arith::ShLIOp>(loc, base, seven);
2258 Value diff = rewriter.
create<arith::SubIOp>(loc, next, base);
2259 Value diffScaled = rewriter.
create<arith::MulIOp>(loc, diff, fraction);
2261 rewriter.
create<arith::AddIOp>(loc, baseScaled, diffScaled);
2263 rewriter.
create<linalg::YieldOp>(loc, result);
2270 op,
"unable to create body for tosa.table op");
2277 static bool isRankedTensor(
Type type) {
return isa<RankedTensorType>(type); }
2281 auto one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2282 auto two = builder.
create<arith::ConstantIndexOp>(loc, 2);
2285 auto divBy2 = builder.
createOrFold<arith::DivUIOp>(loc, value, two);
2286 auto plusOne = builder.
createOrFold<arith::AddIOp>(loc, divBy2, one);
2290 static RankedTensorType
2298 dims[2] = halfPlusOne(builder, loc, dims[2]);
2303 auto elementType = cast<RankedTensorType>(input.
getType()).getElementType();
2308 RankedTensorType type,
2311 rewriter.
create<tensor::EmptyOp>(loc, type, dynamicSizes);
2312 auto fillValueAttr = rewriter.
getZeroAttr(type.getElementType());
2313 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
2314 auto filledTensor = rewriter
2318 return filledTensor;
2323 auto integerVal = builder.
create<arith::IndexCastUIOp>(
2329 return builder.
create<arith::UIToFPOp>(loc, type, integerVal);
2334 auto indexVal = builder.
create<linalg::IndexOp>(loc, index);
2335 return castIndexToFloat(builder, loc, type, indexVal);
2338 template <
typename... Args>
2344 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2346 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2347 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2349 "only supports ranked tensors");
2352 auto loc = rfft2d.getLoc();
2353 auto input = rfft2d.getInput();
2355 dyn_cast<FloatType>(cast<ShapedType>(input.
getType()).getElementType());
2358 "only supports float element types");
2362 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2366 utils::IteratorType::parallel, utils::IteratorType::parallel,
2367 utils::IteratorType::parallel, utils::IteratorType::reduction,
2368 utils::IteratorType::reduction};
2373 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2374 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2379 affineDimsExpr(rewriter, 0, 1, 2),
2380 affineDimsExpr(rewriter, 0, 1, 2)},
2384 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input, 1);
2385 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input, 2);
2388 auto twoPiAttr = rewriter.
getFloatAttr(elementType, 6.283185307179586);
2389 auto twoPi = rewriter.
create<arith::ConstantOp>(loc, twoPiAttr);
2390 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2391 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2394 Value valReal = args[0];
2395 Value sumReal = args[1];
2396 Value sumImag = args[2];
2399 Value oy = builder.
create<linalg::IndexOp>(loc, 1);
2400 Value ox = builder.
create<linalg::IndexOp>(loc, 2);
2401 Value iy = builder.
create<linalg::IndexOp>(loc, 3);
2402 Value ix = builder.
create<linalg::IndexOp>(loc, 4);
2407 auto iyXoy = builder.
create<index::MulOp>(loc, iy, oy);
2408 auto ixXox = builder.
create<index::MulOp>(loc, ix, ox);
2410 auto iyRem = builder.
create<index::RemUOp>(loc, iyXoy, dimH);
2411 auto ixRem = builder.
create<index::RemUOp>(loc, ixXox, dimW);
2413 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2414 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2416 auto yComponent = builder.
create<arith::DivFOp>(loc, iyRemFloat, constH);
2417 auto xComponent = builder.
create<arith::DivFOp>(loc, ixRemFloat, constW);
2418 auto sumXY = builder.
create<arith::AddFOp>(loc, yComponent, xComponent);
2419 auto angle = builder.
create<arith::MulFOp>(loc, twoPi, sumXY);
2423 auto cosAngle = builder.
create<math::CosOp>(loc, angle);
2424 auto sinAngle = builder.
create<math::SinOp>(loc, angle);
2425 auto realComponent =
2426 builder.
create<arith::MulFOp>(loc, valReal, cosAngle);
2427 auto imagComponent =
2428 builder.
create<arith::MulFOp>(loc, valReal, sinAngle);
2432 auto outReal = builder.
create<arith::AddFOp>(loc, sumReal, realComponent);
2433 auto outImag = builder.
create<arith::SubFOp>(loc, sumImag, imagComponent);
2439 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2440 indexingMaps, iteratorTypes, buildBody);
2449 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2451 if (!llvm::all_of(fft2d->getOperandTypes(),
2452 RFFT2dConverter::isRankedTensor) ||
2453 !llvm::all_of(fft2d->getResultTypes(),
2454 RFFT2dConverter::isRankedTensor)) {
2459 Value input_real = fft2d.getInputReal();
2460 Value input_imag = fft2d.getInputImag();
2461 BoolAttr inverse = fft2d.getInverseAttr();
2463 auto real_el_ty = cast<FloatType>(
2464 cast<ShapedType>(input_real.
getType()).getElementType());
2465 [[maybe_unused]]
auto imag_el_ty = cast<FloatType>(
2466 cast<ShapedType>(input_imag.
getType()).getElementType());
2468 assert(real_el_ty == imag_el_ty);
2483 utils::IteratorType::parallel, utils::IteratorType::parallel,
2484 utils::IteratorType::parallel, utils::IteratorType::reduction,
2485 utils::IteratorType::reduction};
2490 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2492 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2497 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2498 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2499 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2500 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2504 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 1);
2505 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 2);
2508 auto twoPiAttr = rewriter.
getFloatAttr(real_el_ty, 6.283185307179586);
2509 auto twoPi = rewriter.
create<arith::ConstantOp>(loc, twoPiAttr);
2511 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2513 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2516 Value valReal = args[0];
2517 Value valImag = args[1];
2518 Value sumReal = args[2];
2519 Value sumImag = args[3];
2522 Value oy = builder.
create<linalg::IndexOp>(loc, 1);
2523 Value ox = builder.
create<linalg::IndexOp>(loc, 2);
2524 Value iy = builder.
create<linalg::IndexOp>(loc, 3);
2525 Value ix = builder.
create<linalg::IndexOp>(loc, 4);
2529 auto iyXoy = builder.
create<index::MulOp>(loc, iy, oy);
2530 auto ixXox = builder.
create<index::MulOp>(loc, ix, ox);
2532 auto iyRem = builder.
create<index::RemUOp>(loc, iyXoy, dimH);
2533 auto ixRem = builder.
create<index::RemUOp>(loc, ixXox, dimW);
2536 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2538 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2540 auto yComponent = builder.
create<arith::DivFOp>(loc, iyRemFloat, constH);
2541 auto xComponent = builder.
create<arith::DivFOp>(loc, ixRemFloat, constW);
2543 auto sumXY = builder.
create<arith::AddFOp>(loc, yComponent, xComponent);
2544 auto angle = builder.
create<arith::MulFOp>(loc, twoPi, sumXY);
2547 angle = builder.
create<arith::MulFOp>(
2549 rewriter.
create<arith::ConstantOp>(
2555 auto cosAngle = builder.
create<math::CosOp>(loc, angle);
2556 auto sinAngle = builder.
create<math::SinOp>(loc, angle);
2558 auto rcos = builder.
create<arith::MulFOp>(loc, valReal, cosAngle);
2559 auto rsin = builder.
create<arith::MulFOp>(loc, valImag, sinAngle);
2560 auto realComponent = builder.
create<arith::AddFOp>(loc, rcos, rsin);
2562 auto icos = builder.
create<arith::MulFOp>(loc, valImag, cosAngle);
2563 auto isin = builder.
create<arith::MulFOp>(loc, valReal, sinAngle);
2565 auto imagComponent = builder.
create<arith::SubFOp>(loc, icos, isin);
2569 auto outReal = builder.
create<arith::AddFOp>(loc, sumReal, realComponent);
2570 auto outImag = builder.
create<arith::AddFOp>(loc, sumImag, imagComponent);
2576 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2577 indexingMaps, iteratorTypes, buildBody);
2589 patterns->
add<GenericResizeConverter>(patterns->
getContext(),
2593 patterns->
add<MaterializeResizeBroadcast>(patterns->
getContext(),
2598 PointwiseConverter<tosa::AddOp>,
2599 PointwiseConverter<tosa::SubOp>,
2600 PointwiseConverter<tosa::MulOp>,
2601 PointwiseConverter<tosa::IntDivOp>,
2602 PointwiseConverter<tosa::NegateOp>,
2603 PointwiseConverter<tosa::PowOp>,
2604 PointwiseConverter<tosa::ReciprocalOp>,
2605 PointwiseConverter<tosa::RsqrtOp>,
2606 PointwiseConverter<tosa::LogOp>,
2607 PointwiseConverter<tosa::ExpOp>,
2608 PointwiseConverter<tosa::AbsOp>,
2609 PointwiseConverter<tosa::SinOp>,
2610 PointwiseConverter<tosa::CosOp>,
2611 PointwiseConverter<tosa::TanhOp>,
2612 PointwiseConverter<tosa::ErfOp>,
2613 PointwiseConverter<tosa::BitwiseAndOp>,
2614 PointwiseConverter<tosa::BitwiseOrOp>,
2615 PointwiseConverter<tosa::BitwiseNotOp>,
2616 PointwiseConverter<tosa::BitwiseXorOp>,
2617 PointwiseConverter<tosa::LogicalAndOp>,
2618 PointwiseConverter<tosa::LogicalNotOp>,
2619 PointwiseConverter<tosa::LogicalOrOp>,
2620 PointwiseConverter<tosa::LogicalXorOp>,
2621 PointwiseConverter<tosa::CastOp>,
2622 PointwiseConverter<tosa::LogicalLeftShiftOp>,
2623 PointwiseConverter<tosa::LogicalRightShiftOp>,
2624 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2625 PointwiseConverter<tosa::ClzOp>,
2626 PointwiseConverter<tosa::SelectOp>,
2627 PointwiseConverter<tosa::GreaterOp>,
2628 PointwiseConverter<tosa::GreaterEqualOp>,
2629 PointwiseConverter<tosa::EqualOp>,
2630 PointwiseConverter<tosa::MaximumOp>,
2631 PointwiseConverter<tosa::MinimumOp>,
2632 PointwiseConverter<tosa::CeilOp>,
2633 PointwiseConverter<tosa::FloorOp>,
2634 PointwiseConverter<tosa::ClampOp>,
2635 PointwiseConverter<tosa::SigmoidOp>
2639 IdentityNConverter<tosa::IdentityOp>,
2640 ReduceConverter<tosa::ReduceAllOp>,
2641 ReduceConverter<tosa::ReduceAnyOp>,
2642 ReduceConverter<tosa::ReduceMinOp>,
2643 ReduceConverter<tosa::ReduceMaxOp>,
2644 ReduceConverter<tosa::ReduceSumOp>,
2645 ReduceConverter<tosa::ReduceProdOp>,
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, int64_t rank)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter)
static SmallVector< Value > expandInputRanks(PatternRewriter &rewriter, Location loc, ValueRange operands, int64_t rank)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static bool operandsAndResultsRanked(Operation *operation)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
typename SourceOp::Adaptor OpAdaptor
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...