29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/Sequence.h"
59template <
typename OpTy>
67 auto nanMode = op.getNanMode();
68 if (nanMode == NanPropagationMode::PROPAGATE)
72 Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
73 arith::CmpFPredicate::UNO,
lhs,
lhs);
74 Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
75 arith::CmpFPredicate::UNO,
rhs,
rhs);
77 arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN,
rhs,
result);
78 return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN,
lhs,
84 ConversionPatternRewriter &rewriter) {
90 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
91 return math::AbsFOp::create(rewriter, loc, resultTypes, args);
93 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
94 auto zero = arith::ConstantOp::create(rewriter, loc,
95 rewriter.getZeroAttr(elementTy));
96 auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]);
97 return arith::MaxSIOp::create(rewriter, loc, args[0], neg);
101 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
102 return arith::AddFOp::create(rewriter, loc, resultTypes, args);
104 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
105 return arith::AddIOp::create(rewriter, loc, resultTypes, args);
108 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
109 return arith::SubFOp::create(rewriter, loc, resultTypes, args);
111 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
112 return arith::SubIOp::create(rewriter, loc, resultTypes, args);
115 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
116 return arith::DivSIOp::create(rewriter, loc, resultTypes, args);
119 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
121 arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
122 return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
126 if (isa<tosa::MulOp>(op)) {
127 auto shiftVal = cast<tosa::MulOp>(op).getShift();
129 bool shiftIsConstant =
true;
132 shift = shiftElem.
getValues<IntegerAttr>()[0].getInt();
134 shiftIsConstant =
false;
136 if (isa<FloatType>(elementTy)) {
138 (
void)rewriter.notifyMatchFailure(op,
139 "Cannot have shift value for float");
142 return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
146 if (isa<IntegerType>(elementTy)) {
150 if (shift > 0 || !shiftIsConstant) {
157 a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
159 if (!
b.getType().isInteger(32))
160 b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(),
b);
162 auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
163 auto roundingAttr = RoundingModeAttr::get(rewriter.getContext(),
164 RoundingMode::SINGLE_ROUND);
166 tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
167 b, shiftAmount, roundingAttr);
173 int bWidth =
b.getType().getIntOrFloatBitWidth();
174 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
177 a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a);
179 b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0],
b);
181 return arith::MulIOp::create(rewriter, loc, resultTypes, a,
b);
186 if (isa<tosa::NegateOp>(op)) {
187 auto negate = cast<tosa::NegateOp>(op);
190 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
191 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
192 bool hasInZp = !failed(maybeInZp);
193 bool hasOutZp = !failed(maybeOutZp);
199 if (isa<FloatType>(elementTy))
200 return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
202 if (isa<IntegerType>(elementTy)) {
204 Type intermediateType;
207 int intermediateBitWidth = 64;
209 if (hasInZp && hasOutZp) {
211 const int64_t zpAdd = inZp + outZp;
213 APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
218 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
219 intermediateBitWidth = 16;
220 }
else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
221 intermediateBitWidth = 32;
224 intermediateType = rewriter.getIntegerType(intermediateBitWidth);
225 zpAddValue = arith::ConstantOp::create(
226 rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
228 intermediateType = rewriter.getIntegerType(intermediateBitWidth);
229 Value arg1 = args[1];
230 Value arg2 = args[2];
232 if (arg1.
getType() != intermediateType)
233 arg1 = arith::ExtSIOp::create(rewriter, loc, intermediateType, arg1);
234 if (arg2.
getType() != intermediateType)
235 arg2 = arith::ExtSIOp::create(rewriter, loc, intermediateType, arg2);
237 arith::AddIOp::create(rewriter, loc, intermediateType, arg1, arg2);
243 if (ext.
getType() != intermediateType)
244 ext = arith::ExtSIOp::create(rewriter, loc, intermediateType, ext);
245 auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext);
249 rewriter, loc, intermediateType,
250 APInt::getSignedMinValue(inputBitWidth).getSExtValue());
252 rewriter, loc, intermediateType,
253 APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
257 if (
clamp.getType() == elementTy)
259 return arith::TruncIOp::create(rewriter, loc, elementTy,
clamp);
264 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
265 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
268 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
269 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
272 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
273 auto allOnesAttr = rewriter.getIntegerAttr(
274 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
275 auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr);
276 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes);
280 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
281 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
284 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
285 return arith::ShLIOp::create(rewriter, loc, resultTypes, args);
288 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
289 return arith::ShRUIOp::create(rewriter, loc, resultTypes, args);
292 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
293 auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args);
294 auto round = cast<BoolAttr>(op->
getAttr(
"round")).getValue();
299 Type i1Ty = IntegerType::get(rewriter.getContext(), 1);
300 auto one = arith::ConstantOp::create(rewriter, loc,
301 IntegerAttr::get(elementTy, 1));
302 auto zero = arith::ConstantOp::create(rewriter, loc,
303 IntegerAttr::get(elementTy, 0));
305 arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 0));
307 arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1));
310 auto shiftValueGreaterThanZero = arith::CmpIOp::create(
311 rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero);
315 arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one);
317 arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract)
319 auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted,
322 arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
324 auto shouldRound = arith::SelectOp::create(
325 rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd, i1zero);
327 arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
328 return arith::AddIOp::create(rewriter, loc, resultTypes,
result, extended);
332 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
333 return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]);
337 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
338 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
341 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
342 auto one = arith::ConstantOp::create(rewriter, loc,
343 rewriter.getIntegerAttr(elementTy, 1));
344 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one);
348 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
349 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
352 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
353 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
356 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
357 return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args);
360 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
361 return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args);
364 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
365 return mlir::math::LogOp::create(rewriter, loc, resultTypes, args);
368 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
369 return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args);
372 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
373 return mlir::math::SinOp::create(rewriter, loc, resultTypes, args);
376 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
377 return mlir::math::CosOp::create(rewriter, loc, resultTypes, args);
380 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
381 return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args);
384 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
385 return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args);
388 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
389 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT,
392 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
393 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt,
397 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
398 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE,
401 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
402 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
406 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
407 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ,
410 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
411 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
415 if (isa<tosa::SelectOp>(op)) {
417 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
418 return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]);
422 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
423 auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
425 rewriter, args[0], args[1],
max);
428 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
429 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
433 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
434 auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
436 rewriter, args[0], args[1],
min);
439 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
440 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
444 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
445 return math::CeilOp::create(rewriter, loc, resultTypes, args);
448 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
449 return math::FloorOp::create(rewriter, loc, resultTypes, args);
452 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
453 bool losesInfo =
false;
454 APFloat minApf = cast<FloatAttr>(op->
getAttr(
"min_val")).getValue();
455 APFloat maxApf = cast<FloatAttr>(op->
getAttr(
"max_val")).getValue();
456 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
457 APFloat::rmNearestTiesToEven, &losesInfo);
458 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
459 APFloat::rmNearestTiesToEven, &losesInfo);
460 auto min = arith::ConstantOp::create(
461 rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
462 auto max = arith::ConstantOp::create(
463 rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
466 auto clampOp = llvm::cast<tosa::ClampOp>(op);
467 const auto nanMode = clampOp.getNanMode();
470 if (!isa<FloatType>(elementTy))
475 if (nanMode == NanPropagationMode::PROPAGATE)
489 Value isNaN = arith::CmpFOp::create(
490 rewriter, op->
getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
493 return arith::SelectOp::create(rewriter, op->
getLoc(), isNaN,
min,
result);
496 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
497 auto intTy = cast<IntegerType>(elementTy);
499 cast<IntegerAttr>(op->
getAttr(
"min_val")).getValue().getSExtValue();
501 cast<IntegerAttr>(op->
getAttr(
"max_val")).getValue().getSExtValue();
503 int64_t minRepresentable = std::numeric_limits<int64_t>::min();
504 int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
505 if (intTy.isUnsignedInteger()) {
506 minRepresentable = 0;
507 if (intTy.getIntOrFloatBitWidth() <= 63) {
509 (
int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
512 }
else if (intTy.getIntOrFloatBitWidth() <= 64) {
514 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
516 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
521 min = std::max(
min, minRepresentable);
522 max = std::max(
max, minRepresentable);
523 min = std::min(
min, maxRepresentable);
524 max = std::min(
max, maxRepresentable);
527 intTy.getIntOrFloatBitWidth());
529 intTy.getIntOrFloatBitWidth());
531 intTy.isUnsignedInteger());
535 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
537 arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
538 auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
539 auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
540 auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
541 return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
545 if (isa<tosa::CastOp>(op)) {
546 Type srcTy = elementTy;
547 Type dstTy = resultTypes.front();
549 (
void)rewriter.notifyMatchFailure(op,
"unsupported type");
559 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
560 return arith::ExtFOp::create(rewriter, loc, resultTypes, args,
563 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
564 return arith::TruncFOp::create(rewriter, loc, resultTypes, args,
568 if (srcTy.
isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
569 return arith::UIToFPOp::create(rewriter, loc, resultTypes, args,
572 if (srcTy.
isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
573 return arith::ExtUIOp::create(rewriter, loc, resultTypes, args,
579 auto unrealizedCast =
580 UnrealizedConversionCastOp::create(
584 return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
589 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
590 return arith::SIToFPOp::create(rewriter, loc, resultTypes, args,
594 if (isa<FloatType>(srcTy) && dstTy.
isInteger(1)) {
595 Value zero = arith::ConstantOp::create(rewriter, loc,
596 rewriter.getFloatAttr(srcTy, 0.0));
597 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE,
601 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
602 auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]);
604 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
608 APFloat::semanticsMaxExponent(fltSemantics)) {
611 auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded);
612 auto posInf = arith::ConstantOp::create(
615 APFloat::getInf(fltSemantics)));
616 auto negInf = arith::ConstantOp::create(
618 rewriter.getFloatAttr(
620 APFloat::getInf(fltSemantics,
true)));
621 auto overflow = arith::CmpFOp::create(
622 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf);
623 auto underflow = arith::CmpFOp::create(
624 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf);
625 auto intMin = arith::ConstantOp::create(
627 rewriter.getIntegerAttr(
630 auto intMax = arith::ConstantOp::create(
632 rewriter.getIntegerAttr(
636 arith::SelectOp::create(rewriter, loc, overflow, intMax, conv);
637 return arith::SelectOp::create(rewriter, loc, underflow, intMin,
641 auto intMinFP = arith::ConstantOp::create(
643 rewriter.getFloatAttr(
649 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
655 auto intMaxFP = arith::ConstantOp::create(
657 rewriter.getFloatAttr(
664 return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped);
671 auto intMaxPlusOneFP = arith::ConstantOp::create(
673 rewriter.getFloatAttr(
680 auto intMax = arith::ConstantOp::create(
682 rewriter.getIntegerAttr(
686 arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP);
688 arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP);
689 auto overflow = arith::CmpFOp::create(
690 rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
691 return arith::SelectOp::create(rewriter, loc, overflow, intMax,
697 if (isa<IntegerType>(srcTy) && dstTy.
isInteger(1)) {
700 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne,
704 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
705 return arith::ExtSIOp::create(rewriter, loc, resultTypes, args,
708 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
709 return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]);
713 (
void)rewriter.notifyMatchFailure(
714 op,
"unhandled op for linalg body calculation for elementwise op");
735 return tensor::DimOp::create(rewriter, loc,
tensor, indexValue).getResult();
741 auto shapedType = dyn_cast<ShapedType>(
tensor.getType());
742 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
743 assert(
index >= 0 &&
index < shapedType.getRank() &&
"index out of bounds");
744 if (shapedType.isDynamicDim(
index))
750 auto isRanked = [](
Value value) {
751 return isa<RankedTensorType>(value.getType());
753 return llvm::all_of(operation->
getOperands(), isRanked) &&
754 llvm::all_of(operation->
getResults(), isRanked);
767static std::pair<OpFoldResult, Value>
773 for (
auto operand : operands) {
774 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
775 if (ShapedType::isStatic(size) && size > 1)
780 auto operandsWithDynamicDim =
781 llvm::filter_to_vector(operands, [&](
Value operand) {
782 return cast<RankedTensorType>(operand.
getType()).isDynamicDim(dim);
786 if (operandsWithDynamicDim.empty())
793 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
794 if (operandsWithDynamicDim.size() == 1)
795 return {targetSize, operandsWithDynamicDim[0]};
798 for (
size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
800 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
801 targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize);
803 return {targetSize,
nullptr};
811 assert(!operands.empty());
812 auto rank = cast<RankedTensorType>(operands.front().
getType()).getRank();
815 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
816 auto [targetSize, masterOperand] =
818 targetShape.push_back(targetSize);
819 masterOperands.push_back(masterOperand);
821 return {targetShape, masterOperands};
827 Value masterOperand) {
829 auto rankedTensorType = cast<RankedTensorType>(operand.
getType());
830 if (!rankedTensorType.isDynamicDim(dim))
837 if (operand == masterOperand)
841 auto rank = rankedTensorType.getRank();
843 for (
auto index : llvm::seq<int64_t>(0, rank)) {
846 affineExprs.push_back(affineExpr);
848 auto broadcastAffineMap =
854 auto one =
createIndex(rewriter, loc, indexPool, 1);
855 auto runtimeSize =
getTensorDim(rewriter, loc, indexPool, operand, dim);
856 auto broadcastNecessary = arith::CmpIOp::create(
857 rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one);
867 for (
auto index : llvm::seq<int64_t>(0, rank)) {
868 auto size =
index == dim ? targetSize
871 outputTensorShape.push_back(size);
873 Value outputTensor = tensor::EmptyOp::create(
874 opBuilder, loc, outputTensorShape, rankedTensorType.getElementType());
878 linalg::GenericOp::create(
879 opBuilder, loc, outputTensor.
getType(), operand, outputTensor,
883 linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
888 auto castResultTensor = rewriter.
createOrFold<tensor::CastOp>(
889 loc, operand.
getType(), resultTensor);
892 scf::YieldOp::create(opBuilder, loc, castResultTensor);
897 scf::YieldOp::create(opBuilder, loc, operand);
901 auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary,
902 emitThenRegion, emitElseRegion);
903 return ifOp.getResult(0);
910 int64_t rank = cast<RankedTensorType>(operand.
getType()).getRank();
911 assert((
int64_t)targetShape.size() == rank);
912 assert((
int64_t)masterOperands.size() == rank);
913 for (
auto index : llvm::seq<int64_t>(0, rank))
926 if (operands.size() == 1)
930 bool hasDynamic =
false;
931 for (
auto op : operands) {
932 const auto tType = dyn_cast<RankedTensorType>(op.getType());
933 if (tType && !tType.hasStaticShape()) {
942 return llvm::map_to_vector(operands, [&](
Value operand) {
944 targetShape, masterOperands);
954 auto resultType = cast_or_null<RankedTensorType>(
957 return rewriter.notifyMatchFailure(operation,
"failed to convert type");
959 Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape,
960 resultType.getElementType());
965 auto rank = resultType.getRank();
966 auto affineMaps = llvm::map_to_vector(operands, [&](
Value operand) {
967 auto shape = cast<ShapedType>(operand.
getType()).getShape();
969 for (
auto it : llvm::enumerate(
shape)) {
973 bool requiresBroadcast =
974 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
975 auto affineExpr = requiresBroadcast
976 ? rewriter.getAffineConstantExpr(0)
977 : rewriter.getAffineDimExpr(it.index());
978 affineExprs.push_back(affineExpr);
980 return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
982 affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
985 bool encounteredError =
false;
986 auto linalgOp = linalg::GenericOp::create(
987 rewriter, loc, outputTensor.
getType(), operands, outputTensor, affineMaps,
992 {resultType.getElementType()}, rewriter);
994 encounteredError =
true;
997 linalg::YieldOp::create(opBuilder, loc, opResult);
999 if (encounteredError)
1000 return rewriter.notifyMatchFailure(
1001 operation,
"unable to create linalg.generic body for elementwise op");
1004 auto castResult = rewriter.createOrFold<tensor::CastOp>(
1005 loc, resultType, linalgOp->getResult(0));
1006 rewriter.replaceOp(operation, castResult);
1013 if (isa<tosa::MulOp>(operation)) {
1017 return operands.take_front(2);
1019 return operands.take_front(3);
1021 if (
auto negate = dyn_cast<tosa::NegateOp>(operation)) {
1022 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
1023 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
1024 if (failed(maybeOutZp) && failed(maybeInZp))
1027 return operands.take_front(1);
1034 ConversionPatternRewriter &rewriter,
1038 assert(operation->
getNumResults() == 1 &&
"elementwise op expects 1 result");
1040 "elementwise op expects at least 1 operand");
1042 return rewriter.notifyMatchFailure(operation,
1043 "Unranked tensors not supported");
1047 auto loc = operation->
getLoc();
1049 auto [targetShape, masterOperands] =
1051 auto broadcastOperands =
1053 targetShape, masterOperands);
1055 targetShape, converter);
1062 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1065 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1068 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1071 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1074 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1076 elementTy, APFloat::getLargest(
1077 cast<FloatType>(elementTy).getFloatSemantics(),
false));
1079 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1083 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1085 elementTy, APFloat::getLargest(
1086 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1088 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1092 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1095 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1098 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1100 elementTy, APFloat::getLargest(
1101 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1103 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1117 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1118 return arith::AddFOp::create(rewriter, loc, args);
1121 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1122 return arith::AddIOp::create(rewriter, loc, args);
1125 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1126 return arith::MulFOp::create(rewriter, loc, args);
1129 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1130 return arith::MulIOp::create(rewriter, loc, args);
1133 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1134 return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
1137 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1138 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
1141 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1142 return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
1145 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1146 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
1149 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1150 return arith::AndIOp::create(rewriter, loc, args);
1152 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1153 return arith::OrIOp::create(rewriter, loc, args);
1161template <
typename OpTy>
1164 auto loc = op->getLoc();
1165 auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1166 auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1167 if (!inputTy || !resultTy)
1170 auto elementTy = resultTy.getElementType();
1171 Value input = op->getOperand(0);
1174 bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
1175 isa<FloatType>(elementTy) &&
1176 cast<FloatType>(elementTy).isBF16();
1181 for (
unsigned i = 0; i < inputTy.getRank(); i++) {
1183 reduceShape.push_back(inputTy.getDimSize(i));
1184 if (inputTy.isDynamicDim(i))
1185 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1190 inputs.push_back(input);
1194 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1200 op,
"No initial value found for reduction operation");
1202 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
1204 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValue},
1207 outputs.push_back(filledTensor);
1209 bool isNanIgnoreMode =
false;
1210 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1211 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1213 if (isa<FloatType>(elementTy) &&
1214 op.getNanMode() == NanPropagationMode::IGNORE) {
1215 isNanIgnoreMode =
true;
1221 auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
1222 auto emptyBoolTensor =
1223 tensor::EmptyOp::create(rewriter, loc, reduceShape,
1224 trueValue.getType(), dynDims)
1226 auto allResultsNaNTensor =
1227 linalg::FillOp::create(rewriter, loc,
ValueRange{trueValue},
1239 inputs.push_back(input);
1240 outputs.push_back(allResultsNaNTensor);
1244 bool didEncounterError =
false;
1245 linalg::LinalgOp linalgOp = linalg::ReduceOp::create(
1246 rewriter, loc, inputs, outputs, axis,
1248 std::array<Value, 2> binaryArgs{
1249 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1252 if (binaryArgs[0].
getType() != accTy)
1253 binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
1259 didEncounterError =
true;
1262 if (isNanIgnoreMode) {
1263 auto inputValue = blockArgs[0];
1264 auto initialValue = blockArgs[2];
1265 auto oldAllResultsNanFlagValue = blockArgs[3];
1268 Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(),
1269 arith::CmpFPredicate::UNO,
1270 inputValue, inputValue);
1272 auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(),
1273 isNaN, initialValue,
result);
1276 auto newAllResultsNanFlagValue = arith::AndIOp::create(
1277 nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1278 resultsToYield.push_back(selectOp);
1279 resultsToYield.push_back(newAllResultsNanFlagValue);
1281 resultsToYield.push_back(
result);
1283 linalg::YieldOp::create(nestedBuilder, loc, resultsToYield);
1286 if (!didEncounterError)
1288 op,
"unable to create linalg.generic body for reduce op");
1290 if (isNanIgnoreMode) {
1299 APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(),
false));
1300 auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
1301 auto emptyNanTensor =
1302 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1304 auto nanFilledTensor =
1305 linalg::FillOp::create(rewriter, loc,
ValueRange{nanValue},
1311 auto finalEmptyTensor =
1312 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1318 ins.push_back(linalgOp->getOpResult(1));
1319 ins.push_back(nanFilledTensor);
1320 ins.push_back(linalgOp->getResult(0));
1321 outs.push_back(finalEmptyTensor);
1323 linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs);
1324 linalgOp = linalgSelect;
1328 Value reducedRes = linalgOp->getResult(0);
1331 tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
1334 const unsigned reducedRank =
1335 cast<ShapedType>(reducedRes.
getType()).getRank();
1338 linalg::GenericOp::create(
1344 Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc,
1345 elementTy, args[0]);
1346 linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf);
1352 uint64_t expandInputRank = cast<ShapedType>(reducedRes.
getType()).getRank();
1353 reassociationMap.resize(expandInputRank);
1355 for (uint64_t i = 0; i < expandInputRank; i++) {
1356 int32_t dimToPush = i > axis ? i + 1 : i;
1360 if (expandInputRank != 0) {
1361 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1362 reassociationMap[expandedDim].push_back(
1377template <
typename SrcOp>
1378class PointwiseConverter :
public OpConversionPattern<SrcOp> {
1380 using OpConversionPattern<SrcOp>::OpConversionPattern;
1381 using typename OpConversionPattern<SrcOp>::OpAdaptor;
1384 matchAndRewrite(SrcOp op, OpAdaptor operands,
1385 ConversionPatternRewriter &rewriter)
const final {
1387 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1397 auto inputType = cast<RankedTensorType>(input.
getType());
1398 auto elemType = inputType.getElementType();
1399 auto collapsedType = RankedTensorType::get({}, elemType);
1401 return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input,
1408 output.reserve(input.size());
1410 for (
auto v : llvm::map_range(
1411 input, [](int32_t val) {
return static_cast<int8_t
>(val); })) {
1412 output.push_back(v);
1424static void setupLinalgGenericOpInputAndIndexingMap(
1427 bool isConstant, tosa::RescaleOp op,
Value &constant,
int64_t &arg,
1428 bool isShift =
false) {
1430 auto loc = op.getLoc();
1431 auto inputTy = cast<ShapedType>(op.getInput().getType());
1432 unsigned rank = inputTy.getRank();
1438 if (values.size() == 1) {
1439 IntegerAttr intAttr = isShift
1442 constant = arith::ConstantOp::create(rewriter, loc, intAttr);
1446 auto tensorType = RankedTensorType::get(
1447 {
static_cast<int64_t>(values.size())}, elementType);
1453 genericInputs.push_back(
1454 arith::ConstantOp::create(rewriter, loc, EltAttr));
1462 auto operand = isShift ? op.getShift() : op.getMultiplier();
1463 auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
1464 if (tensorType && tensorType.hasStaticShape() &&
1465 tensorType.getShape()[0] == 1) {
1470 genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
1471 indexingMaps.push_back(broadcastMap);
1473 genericInputs.push_back(operand);
1479 arg = indexingMaps.size() - 1;
1484 FailureOr<int64_t> maybeZp,
Location loc,
1486 bool isOutputZp =
false) {
1489 const uint32_t attrBitwidth =
1490 isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
1497 result = blockArgs[zpArg];
1498 auto zpTy =
result.getType();
1499 if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
1502 if (zpTy.isUnsignedInteger()) {
1504 UnrealizedConversionCastOp::create(
1509 if (zpTy.isUnsignedInteger()) {
1510 return arith::ExtUIOp::create(builder, loc, extendType,
result);
1512 return arith::ExtSIOp::create(builder, loc, extendType,
result);
1516 return arith::ConstantOp::create(builder, loc,
1517 IntegerAttr::get(extendType, *maybeZp));
1524 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1526 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1527 PatternRewriter &rewriter)
const final {
1528 auto loc = op.getLoc();
1529 auto input = op.getInput();
1530 auto inputTy = cast<ShapedType>(op.getInput().getType());
1531 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1532 unsigned rank = inputTy.getRank();
1535 if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND)
1537 op,
"tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1538 "currently supported");
1539 if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32())
1541 op,
"tosa.rescale requires scale32 for double_round to be true");
1543 if (!isa<IntegerType>(inputTy.getElementType()))
1546 SmallVector<Value> dynDims;
1547 for (
int i = 0; i < outputTy.getRank(); i++) {
1548 if (outputTy.isDynamicDim(i)) {
1549 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1553 DenseElementsAttr shiftElems;
1554 bool isShiftConstant =
false;
1556 isShiftConstant =
true;
1558 DenseElementsAttr multiplierElems;
1559 bool isMultiplierConstant =
false;
1561 isMultiplierConstant =
true;
1563 llvm::SmallVector<int32_t> shiftValues;
1564 llvm::SmallVector<int32_t> multiplierValues;
1567 if (isMultiplierConstant && isShiftConstant) {
1569 shiftValues = llvm::to_vector(llvm::map_range(
1570 shiftElems.
getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
1571 return static_cast<int32_t>(attr.getInt());
1573 multiplierValues = llvm::to_vector(
1574 llvm::map_range(multiplierElems.
getValues<IntegerAttr>(),
1575 [](IntegerAttr attr) -> int32_t {
1576 return static_cast<int32_t>(attr.getInt());
1580 for (
int i = 0, s = multiplierValues.size(); i < s; i++) {
1581 if (shiftValues[i] > 63) {
1583 multiplierValues[i] = 0;
1588 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1589 llvm::any_of(shiftValues, [](int32_t v) {
return v > 31; });
1591 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
1593 RoundingMode roundingMode =
1594 doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
1596 SmallVector<AffineMap> indexingMaps = {
1598 SmallVector<Value, 4> genericInputs = {input};
1602 Value multiplierConstant;
1603 int64_t multiplierArg = 0;
1604 setupLinalgGenericOpInputAndIndexingMap(
1605 rewriter, multiplierValues, genericInputs, indexingMaps,
1606 isMultiplierConstant, op, multiplierConstant, multiplierArg);
1610 Value shiftConstant;
1611 int64_t shiftArg = 0;
1612 setupLinalgGenericOpInputAndIndexingMap(
1613 rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1614 shiftConstant, shiftArg,
true);
1619 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1620 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1630 genericInputs.push_back(
1631 collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
1632 indexingMaps.push_back(broadcastMap);
1633 iZpArg = indexingMaps.size() - 1;
1637 genericInputs.push_back(
1638 collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
1639 indexingMaps.push_back(broadcastMap);
1640 oZpArg = indexingMaps.size() - 1;
1647 Value emptyTensor = tensor::EmptyOp::create(
1648 rewriter, loc, outputTy.getShape(), outputTy.getElementType(),
1649 ArrayRef<Value>({dynDims}));
1651 auto linalgOp = linalg::GenericOp::create(
1652 rewriter, loc, outputTy, genericInputs,
ValueRange{emptyTensor},
1654 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1656 Value value = blockArgs[0];
1657 Type valueTy = value.
getType();
1659 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1660 auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
1661 nestedLoc, blockArgs, iZpArg);
1663 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1664 auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
1665 nestedLoc, blockArgs, oZpArg,
true);
1667 IntegerType outIntType =
1668 cast<IntegerType>(blockArgs.back().
getType());
1669 unsigned outBitWidth = outIntType.getWidth();
1670 assert(outBitWidth <= 32 &&
"Unexpected output zeropoint bitwidth");
1672 Value multiplier = multiplierConstant ? multiplierConstant
1673 : blockArgs[multiplierArg];
1674 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1677 value = UnrealizedConversionCastOp::create(
1678 nestedBuilder, nestedLoc,
1679 nestedBuilder.getIntegerType(
1685 if (op.getInputUnsigned()) {
1686 value = arith::ExtUIOp::create(nestedBuilder, nestedLoc,
1687 nestedBuilder.getI32Type(), value);
1689 value = arith::ExtSIOp::create(nestedBuilder, nestedLoc,
1690 nestedBuilder.getI32Type(), value);
1695 arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp);
1697 value = tosa::ApplyScaleOp::create(nestedBuilder, loc,
1698 nestedBuilder.getI32Type(), value,
1699 multiplier, shift, roundingMode);
1703 arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp);
1706 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1707 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1710 if (op.getOutputUnsigned()) {
1712 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1715 auto intMinVal = arith::ConstantOp::create(
1716 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin));
1717 auto intMaxVal = arith::ConstantOp::create(
1718 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax));
1721 nestedBuilder,
false);
1723 if (outIntType.getWidth() < 32) {
1724 value = arith::TruncIOp::create(
1725 nestedBuilder, nestedLoc,
1729 if (outIntType.isUnsignedInteger()) {
1730 value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
1734 linalg::YieldOp::create(nestedBuilder, loc, value);
1737 rewriter.
replaceOp(op, linalgOp->getResults());
1747 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1749 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1750 PatternRewriter &rewriter)
const final {
1751 Location loc = op.getLoc();
1752 ImplicitLocOpBuilder builder(loc, rewriter);
1753 auto input = op.getInput();
1754 auto inputTy = cast<RankedTensorType>(input.getType());
1755 auto resultTy = cast<RankedTensorType>(op.getType());
1756 const bool isBilinear = op.getMode() == ResizeMode::BILINEAR;
1758 auto inputH = inputTy.getDimSize(1);
1759 auto inputW = inputTy.getDimSize(2);
1760 auto outputH = resultTy.getDimSize(1);
1761 auto outputW = resultTy.getDimSize(2);
1763 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1765 op,
"tosa.resize is not a pure 1x1->1x1 image operation");
1767 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1768 op.getMode() != ResizeMode::BILINEAR)
1770 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1772 if (inputTy == resultTy) {
1777 SmallVector<int64_t> scale;
1783 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1790 RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1791 inputTy.getElementType());
1792 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input,
1796 llvm::SmallVector<Value> outputDynSize;
1797 if (inputTy.isDynamicDim(0))
1798 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1799 if (inputTy.isDynamicDim(3))
1800 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1803 auto genericTy = collapseTy.clone(resultTy.getElementType());
1805 tensor::EmptyOp::create(builder, genericTy.getShape(),
1806 resultTy.getElementType(), outputDynSize);
1808 SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1809 utils::IteratorType::parallel);
1811 auto generic = linalg::GenericOp::create(
1813 ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1814 [=](OpBuilder &
b, Location loc,
ValueRange args) {
1815 Value value = args[0];
1817 if (inputTy.getElementType() != resultTy.getElementType()) {
1818 value = arith::ExtSIOp::create(
b, loc, resultTy.getElementType(),
1821 if (isBilinear && scale[0] != 0) {
1822 Value scaleY = arith::ConstantOp::create(
1823 b, loc,
b.getI32IntegerAttr(scale[0]));
1824 value = arith::MulIOp::create(
b, loc, value, scaleY);
1827 if (isBilinear && scale[2] != 0) {
1828 Value scaleX = arith::ConstantOp::create(
1829 b, loc,
b.getI32IntegerAttr(scale[2]));
1830 value = arith::MulIOp::create(
b, loc, value, scaleX);
1834 linalg::YieldOp::create(
b, loc, value);
1838 op, resultTy,
generic.getResults()[0], reassociationMap);
1850 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1854 auto input = op.getInput();
1855 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1856 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1858 if (!inputTy || !resultTy)
1860 "requires ranked input/output types");
1862 auto batch = inputTy.getDimSize(0);
1863 auto channels = inputTy.getDimSize(3);
1864 auto inputH = inputTy.getDimSize(1);
1865 auto inputW = inputTy.getDimSize(2);
1866 auto outputH = resultTy.getDimSize(1);
1867 auto outputW = resultTy.getDimSize(2);
1869 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1871 op,
"tosa.resize has no broadcasting behavior");
1876 resizeShape.push_back(batch);
1877 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1878 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1879 resizeShape.push_back(channels);
1881 auto resizeTy = resultTy.clone(resizeShape);
1883 tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(),
1884 op.getOffset(), op.getBorder(), op.getMode());
1891 reassociationMap.push_back({});
1894 reassociationMap.push_back({});
1899 collapseShape.push_back(outputH);
1901 collapseShape.push_back(outputW);
1902 collapseShape.push_back(channels);
1904 auto collapseTy = resultTy.clone(collapseShape);
1905 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy,
1906 resize, reassociationMap);
1910 if (inputTy.isDynamicDim(0))
1911 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1912 if (inputTy.isDynamicDim(3))
1913 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1916 utils::IteratorType::parallel);
1917 Value empty = tensor::EmptyOp::create(
1918 builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1935 Value value = args[0];
1936 linalg::YieldOp::create(
b, loc, value);
1945 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1947 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1948 PatternRewriter &rewriter)
const final {
1949 Location loc = op.getLoc();
1950 ImplicitLocOpBuilder
b(loc, rewriter);
1951 auto input = op.getInput();
1952 auto inputTy = cast<ShapedType>(input.getType());
1953 auto resultTy = cast<ShapedType>(op.getType());
1954 auto resultETy = resultTy.getElementType();
1956 bool floatingPointMode = isa<FloatType>(resultETy);
1957 auto floatTy = resultETy;
1959 auto imageH = inputTy.getShape()[1];
1960 auto imageW = inputTy.getShape()[2];
1962 auto dynamicDimsOr =
1964 if (!dynamicDimsOr.has_value())
1966 op,
"unable to get dynamic dimensions of tosa.resize");
1968 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1969 op.getMode() != ResizeMode::BILINEAR)
1971 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1973 SmallVector<AffineMap, 2> affineMaps = {
1975 auto emptyTensor = tensor::EmptyOp::create(
b, resultTy.getShape(),
1976 resultETy, *dynamicDimsOr);
1977 auto genericOp = linalg::GenericOp::create(
1980 Value resize = genericOp.getResult(0);
1983 OpBuilder::InsertionGuard regionGuard(
b);
1984 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1986 Value batch = linalg::IndexOp::create(
b, 0);
1987 Value y = linalg::IndexOp::create(
b, 1);
1988 Value x = linalg::IndexOp::create(
b, 2);
1989 Value channel = linalg::IndexOp::create(
b, 3);
1992 arith::ConstantOp::create(
b,
b.getZeroAttr(
b.getI32Type()));
1993 Value zeroFp = arith::ConstantOp::create(
b,
b.getZeroAttr(floatTy));
1995 arith::ConstantOp::create(
b,
b.getI32IntegerAttr(imageH - 1));
1997 arith::ConstantOp::create(
b,
b.getI32IntegerAttr(imageW - 1));
1999 Value inY = arith::IndexCastOp::create(
b,
b.getI32Type(), y);
2000 Value inX = arith::IndexCastOp::create(
b,
b.getI32Type(), x);
2002 SmallVector<int64_t> scale, offset, border;
2007 op,
"tosa.resize scale/offset/border should have compile time "
2008 "constant values.");
2011 Value yScaleN, yScaleD, xScaleN, xScaleD;
2012 yScaleN = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[0]));
2013 yScaleD = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[1]));
2014 xScaleN = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[2]));
2015 xScaleD = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[3]));
2017 Value yOffset, xOffset, yBorder, xBorder;
2018 yOffset = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(offset[0]));
2019 xOffset = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(offset[1]));
2020 yBorder = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(border[0]));
2021 xBorder = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(border[1]));
2024 auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
2025 Value scaleN, Value scaleD, Value offset,
2026 int size, ImplicitLocOpBuilder &
b) {
2034 Value val = arith::MulIOp::create(
b, in, scaleD);
2035 val = arith::AddIOp::create(
b, val, offset);
2036 index = arith::FloorDivSIOp::create(
b, val, scaleN);
2040 Value r = arith::RemSIOp::create(
b, val, scaleN);
2041 Value rFp = arith::SIToFPOp::create(
b, floatTy, r);
2042 Value scaleNfp = arith::UIToFPOp::create(
b, floatTy, scaleN);
2043 delta = arith::DivFOp::create(
b, rFp, scaleNfp);
2047 auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
2048 Value scaleN, Value scaleD, Value offset,
2049 int size, ImplicitLocOpBuilder &
b) {
2058 Value val = arith::MulIOp::create(
b, in, scaleD);
2059 val = arith::AddIOp::create(
b, val, offset);
2060 index = arith::DivSIOp::create(
b, val, scaleN);
2061 delta = arith::MulIOp::create(
b, index, scaleN);
2062 delta = arith::SubIOp::create(
b, val, delta);
2065 Value ix, iy, dx, dy;
2066 if (floatingPointMode) {
2067 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH,
b);
2068 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW,
b);
2070 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH,
b);
2071 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW,
b);
2074 if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) {
2075 auto one = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(1));
2077 auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
2078 Value
max,
int size,
2079 ImplicitLocOpBuilder &
b) -> Value {
2085 if (floatingPointMode) {
2087 arith::ConstantOp::create(
b,
b.getFloatAttr(floatTy, 0.5f));
2088 pred = arith::CmpFOp::create(
b, arith::CmpFPredicate::OGE, dval, h);
2090 Value dvalDouble = arith::ShLIOp::create(
b, dval, one);
2091 pred = arith::CmpIOp::create(
b, arith::CmpIPredicate::sge,
2095 auto offset = arith::SelectOp::create(
b, pred, one, zeroI32);
2096 val = arith::AddIOp::create(
b, val, offset);
2098 return arith::IndexCastOp::create(
b,
b.getIndexType(), val);
2101 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH,
b);
2102 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW,
b);
2104 Value
result = tensor::ExtractOp::create(
2107 linalg::YieldOp::create(
b,
result);
2110 assert(op.getMode() == ResizeMode::BILINEAR);
2112 auto oneVal = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(1));
2114 auto getClampedIdxs = [&](Value &val0, Value &val1,
int size, Value in,
2115 Value
max, ImplicitLocOpBuilder &
b) {
2117 val1 = arith::AddIOp::create(
b, val0, oneVal);
2122 val0 = arith::IndexCastOp::create(
b,
b.getIndexType(), val0);
2123 val1 = arith::IndexCastOp::create(
b,
b.getIndexType(), val1);
2131 Value x0, x1, y0, y1;
2132 getClampedIdxs(y0, y1, imageH, iy, hMax,
b);
2133 getClampedIdxs(x0, x1, imageW, ix, wMax,
b);
2135 Value y0x0 = tensor::ExtractOp::create(
2137 Value y0x1 = tensor::ExtractOp::create(
2139 Value y1x0 = tensor::ExtractOp::create(
2141 Value y1x1 = tensor::ExtractOp::create(
2144 if (floatingPointMode) {
2146 arith::ConstantOp::create(
b,
b.getFloatAttr(floatTy, 1.0f));
2147 auto interpolate = [&](Value val0, Value val1, Value delta,
2149 ImplicitLocOpBuilder &
b) -> Value {
2152 Value oneMinusDelta = arith::SubFOp::create(
b, oneVal, delta);
2153 Value mul0 = arith::MulFOp::create(
b, val0, oneMinusDelta);
2154 Value mul1 = arith::MulFOp::create(
b, val1, delta);
2155 return arith::AddFOp::create(
b, mul0, mul1);
2161 Value topAcc = interpolate(y0x0, y0x1, dx, imageW,
b);
2166 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW,
b);
2170 Value
result = interpolate(topAcc, bottomAcc, dy, imageH,
b);
2171 linalg::YieldOp::create(
b,
result);
2174 y0x0 = arith::ExtSIOp::create(
b, resultETy, y0x0);
2175 y0x1 = arith::ExtSIOp::create(
b, resultETy, y0x1);
2176 y1x0 = arith::ExtSIOp::create(
b, resultETy, y1x0);
2177 y1x1 = arith::ExtSIOp::create(
b, resultETy, y1x1);
2180 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
2181 dx = arith::ExtSIOp::create(
b, resultETy, dx);
2182 dy = arith::ExtSIOp::create(
b, resultETy, dy);
2185 Value yScaleNExt = yScaleN;
2186 Value xScaleNExt = xScaleN;
2188 const int64_t scaleBitwidth =
2190 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2191 yScaleNExt = arith::ExtSIOp::create(
b, resultETy, yScaleN);
2192 xScaleNExt = arith::ExtSIOp::create(
b, resultETy, xScaleN);
2195 auto interpolate = [](Value val0, Value val1, Value weight1,
2196 Value scale,
int inputSize,
2197 ImplicitLocOpBuilder &
b) -> Value {
2199 return arith::MulIOp::create(
b, val0, scale);
2200 Value weight0 = arith::SubIOp::create(
b, scale, weight1);
2201 Value mul0 = arith::MulIOp::create(
b, val0, weight0);
2202 Value mul1 = arith::MulIOp::create(
b, val1, weight1);
2203 return arith::AddIOp::create(
b, mul0, mul1);
2206 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW,
b);
2207 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW,
b);
2209 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH,
b);
2210 linalg::YieldOp::create(
b,
result);
2223template <
typename SrcOp>
2226 using OpRewritePattern<SrcOp>::OpRewritePattern;
2228 LogicalResult matchAndRewrite(SrcOp op,
2229 PatternRewriter &rewriter)
const final {
2230 rewriter.
replaceOp(op, op.getOperation()->getOperands());
2235template <
typename SrcOp>
2238 using OpRewritePattern<SrcOp>::OpRewritePattern;
2240 LogicalResult matchAndRewrite(SrcOp reduceOp,
2241 PatternRewriter &rewriter)
const final {
2248 using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
2250 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2251 PatternRewriter &rewriter)
const final {
2252 auto loc = op.getLoc();
2253 Value input = op.getInput1();
2254 auto inputTy = cast<ShapedType>(input.
getType());
2255 auto resultTy = cast<ShapedType>(op.getType());
2256 auto axis = op.getAxis();
2258 SmallVector<Value> dynDims;
2259 for (
int i = 0; i < inputTy.getRank(); i++) {
2260 if (inputTy.isDynamicDim(i)) {
2261 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2265 Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
2268 auto emptyTensor = tensor::EmptyOp::create(
2269 rewriter, loc, inputTy.getShape(),
2270 inputTy.getElementType(), ArrayRef<Value>({dynDims}))
2272 SmallVector<AffineMap, 2> affineMaps = {
2276 op, resultTy, ArrayRef<Value>({}),
ValueRange{emptyTensor}, affineMaps,
2278 [&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
2279 llvm::SmallVector<Value>
indices;
2280 for (
unsigned int i = 0; i < inputTy.getRank(); i++) {
2282 linalg::IndexOp::create(rewriter, nestedLoc, i).getResult();
2286 arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one);
2287 index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne,
2294 auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc,
2296 linalg::YieldOp::create(nestedBuilder, op.getLoc(),
2297 extract.getResult());
2307struct TileConverter :
public OpConversionPattern<tosa::TileOp> {
2308 using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
2311 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2312 ConversionPatternRewriter &rewriter)
const override {
2313 auto loc = op.getLoc();
2314 auto input = op.getInput1();
2315 auto inputTy = cast<ShapedType>(input.
getType());
2316 auto inputShape = inputTy.getShape();
2317 auto resultTy = cast<ShapedType>(op.getType());
2318 auto elementTy = inputTy.getElementType();
2319 int64_t rank = inputTy.getRank();
2321 SmallVector<int64_t> multiples;
2322 if (
failed(op.getConstantMultiples(multiples)))
2326 SmallVector<int64_t, 2> genericShape;
2327 for (
int i = 0; i < rank; i++) {
2328 int64_t dim = multiples[i];
2329 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2330 genericShape.push_back(inputShape[i]);
2333 SmallVector<Value> dynDims;
2334 for (
int i = 0; i < inputTy.getRank(); i++) {
2335 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2336 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2340 auto emptyTensor = tensor::EmptyOp::create(
2341 rewriter, op.getLoc(), genericShape, elementTy, dynDims);
2344 SmallVector<AffineExpr, 4> dimExprs;
2345 dimExprs.reserve(rank);
2346 for (
unsigned i = 0; i < rank; ++i)
2347 dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
2349 auto readAffineMap =
2351 rewriter.getContext());
2353 SmallVector<AffineMap, 2> affineMaps = {
2354 readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
2356 auto genericOp = linalg::GenericOp::create(
2357 rewriter, loc, RankedTensorType::get(genericShape, elementTy), input,
2360 [&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
2361 linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin());
2366 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2367 op, resultTy, genericOp.getResult(0), shapeValue);
2387 using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
2389 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2390 PatternRewriter &rewriter)
const final {
2391 auto loc = argmaxOp.getLoc();
2392 Value input = argmaxOp.getInput();
2393 auto inputTy = cast<ShapedType>(input.
getType());
2394 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2395 auto inElementTy = inputTy.getElementType();
2396 auto outElementTy = resultTy.getElementType();
2397 int axis = argmaxOp.getAxis();
2398 auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2400 if (!isa<IntegerType>(outElementTy))
2401 return rewriter.notifyMatchFailure(
2403 "tosa.arg_max to linalg.* requires integer-like result type");
2405 SmallVector<Value> dynDims;
2406 for (
int i = 0; i < inputTy.getRank(); i++) {
2407 if (inputTy.isDynamicDim(i) && i != axis) {
2408 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2413 auto emptyTensorIdx =
2414 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2415 outElementTy, dynDims)
2417 auto fillValueIdx = arith::ConstantOp::create(
2418 rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
2419 auto filledTensorIdx =
2420 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValueIdx},
2425 auto emptyTensorMax =
2426 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
2429 auto fillValueMaxAttr =
2432 if (!fillValueMaxAttr)
2433 return rewriter.notifyMatchFailure(
2434 argmaxOp,
"unsupported tosa.argmax element type");
2437 arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
2438 auto filledTensorMax =
2439 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValueMax},
2445 SmallVector<utils::IteratorType, 4> iteratorTypes;
2446 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2447 iteratorTypes[axis] = utils::IteratorType::reduction;
2449 SmallVector<AffineExpr, 2> srcExprs;
2450 SmallVector<AffineExpr, 2> dstExprs;
2451 for (
int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2457 bool didEncounterError =
false;
2459 rewriter.getContext());
2460 auto linalgOp = linalg::GenericOp::create(
2461 rewriter, loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2462 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2463 [&](OpBuilder &nestedBuilder, Location nestedLoc,
2465 auto newValue = blockArgs[0];
2466 auto oldIndex = blockArgs[1];
2467 auto oldValue = blockArgs[2];
2469 Value newIndex = arith::IndexCastOp::create(
2470 rewriter, nestedLoc, oldIndex.getType(),
2471 linalg::IndexOp::create(rewriter, loc, axis));
2474 if (isa<FloatType>(inElementTy)) {
2475 if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) {
2478 predicate = arith::CmpFOp::create(rewriter, nestedLoc,
2479 arith::CmpFPredicate::OGT,
2480 newValue, oldValue);
2485 Value gt = arith::CmpFOp::create(rewriter, nestedLoc,
2486 arith::CmpFPredicate::UGT,
2487 newValue, oldValue);
2488 Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc,
2489 arith::CmpFPredicate::ORD,
2490 oldValue, oldValue);
2491 predicate = arith::AndIOp::create(
2492 rewriter, nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
2494 }
else if (isa<IntegerType>(inElementTy)) {
2495 predicate = arith::CmpIOp::create(rewriter, nestedLoc,
2496 arith::CmpIPredicate::sgt,
2497 newValue, oldValue);
2499 didEncounterError =
true;
2503 auto resultMax = arith::SelectOp::create(
2504 rewriter, nestedLoc, predicate, newValue, oldValue);
2505 auto resultIndex = arith::SelectOp::create(
2506 rewriter, nestedLoc, predicate, newIndex, oldIndex);
2507 linalg::YieldOp::create(nestedBuilder, nestedLoc,
2511 if (didEncounterError)
2512 return rewriter.notifyMatchFailure(
2513 argmaxOp,
"unsupported tosa.argmax element type");
2515 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2520class GatherConverter :
public OpConversionPattern<tosa::GatherOp> {
2522 using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2524 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2525 ConversionPatternRewriter &rewriter)
const final {
2526 auto input = adaptor.getOperands()[0];
2527 auto indices = adaptor.getOperands()[1];
2529 auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2530 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2531 if (!valuesTy || !resultTy)
2532 return rewriter.notifyMatchFailure(op,
"unranked tensors not supported");
2534 auto dynamicDims = inferDynamicDimsForGather(
2535 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2537 auto resultElementTy = resultTy.getElementType();
2539 auto loc = op.getLoc();
2541 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2542 resultElementTy, dynamicDims)
2545 SmallVector<AffineMap, 2> affineMaps = {
2547 resultTy.getRank(), 0,
2548 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2549 rewriter.getContext()),
2550 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2552 auto genericOp = linalg::GenericOp::create(
2556 [&](OpBuilder &
b, Location loc,
ValueRange args) {
2557 auto indexValue = args[0];
2558 auto index0 = linalg::IndexOp::create(rewriter, loc, 0);
2559 Value index1 = arith::IndexCastOp::create(
2560 rewriter, loc, rewriter.getIndexType(), indexValue);
2561 auto index2 = linalg::IndexOp::create(rewriter, loc, 2);
2562 Value extract = tensor::ExtractOp::create(
2563 rewriter, loc, input,
ValueRange{index0, index1, index2});
2564 linalg::YieldOp::create(rewriter, loc, extract);
2566 rewriter.replaceOp(op, genericOp.getResult(0));
2570 static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2574 llvm::SmallVector<Value> results;
2576 auto addDynamicDimension = [&](Value source, int64_t dim) {
2578 if (
auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2579 results.push_back(dimValue);
2582 addDynamicDimension(values, 0);
2583 addDynamicDimension(
indices, 1);
2584 addDynamicDimension(values, 2);
2594 using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2596 LogicalResult matchAndRewrite(tosa::TableOp op,
2597 PatternRewriter &rewriter)
const final {
2598 auto loc = op.getLoc();
2599 Value input = op.getInput1();
2600 Value table = op.getTable();
2601 auto inputTy = cast<ShapedType>(input.
getType());
2602 auto tableTy = cast<ShapedType>(table.
getType());
2603 auto resultTy = cast<ShapedType>(op.getType());
2605 auto inputElementTy = inputTy.getElementType();
2606 auto tableElementTy = tableTy.getElementType();
2607 auto resultElementTy = resultTy.getElementType();
2609 SmallVector<Value> dynDims;
2610 for (
int i = 0; i < resultTy.getRank(); ++i) {
2611 if (inputTy.isDynamicDim(i)) {
2613 tensor::DimOp::create(rewriter, loc, op.getOperand(0), i));
2618 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2619 resultElementTy, dynDims)
2622 SmallVector<AffineMap, 2> affineMaps = {
2623 rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2624 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2626 auto genericOp = linalg::GenericOp::create(
2629 rewriter.replaceOp(op, genericOp.getResult(0));
2632 OpBuilder::InsertionGuard regionGuard(rewriter);
2633 Block *block = rewriter.createBlock(
2634 &genericOp.getRegion(), genericOp.getRegion().end(),
2635 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2638 rewriter.setInsertionPointToStart(block);
2639 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2640 resultElementTy.isInteger(8)) {
2641 Value index = arith::IndexCastOp::create(
2642 rewriter, loc, rewriter.getIndexType(), inputValue);
2644 index = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
2647 tensor::ExtractOp::create(rewriter, loc, table,
ValueRange{index});
2648 linalg::YieldOp::create(rewriter, loc, extract);
2652 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2653 resultElementTy.isInteger(32)) {
2654 Value extend = arith::ExtSIOp::create(
2655 rewriter, loc, rewriter.getI32Type(), inputValue);
2657 auto offset = arith::ConstantOp::create(
2658 rewriter, loc, rewriter.getI32IntegerAttr(32768));
2659 auto seven = arith::ConstantOp::create(rewriter, loc,
2660 rewriter.getI32IntegerAttr(7));
2661 auto one = arith::ConstantOp::create(rewriter, loc,
2662 rewriter.getI32IntegerAttr(1));
2663 auto b1111111 = arith::ConstantOp::create(
2664 rewriter, loc, rewriter.getI32IntegerAttr(127));
2670 auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset);
2671 Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven);
2673 arith::AndIOp::create(rewriter, loc, extendAdd, b1111111);
2678 Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one);
2680 index = arith::IndexCastOp::create(rewriter, loc,
2681 rewriter.getIndexType(), index);
2682 indexPlusOne = arith::IndexCastOp::create(
2683 rewriter, loc, rewriter.getIndexType(), indexPlusOne);
2686 tensor::ExtractOp::create(rewriter, loc, table,
ValueRange{index});
2687 Value next = tensor::ExtractOp::create(rewriter, loc, table,
2691 arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), base);
2693 arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), next);
2697 Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven);
2698 Value diff = arith::SubIOp::create(rewriter, loc, next, base);
2699 Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction);
2701 arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled);
2703 linalg::YieldOp::create(rewriter, loc,
result);
2709 return rewriter.notifyMatchFailure(
2710 op,
"unable to create body for tosa.table op");
2715 using OpRewritePattern<RFFT2dOp>::OpRewritePattern;
2717 static bool isRankedTensor(Type type) {
return isa<RankedTensorType>(type); }
2719 static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2725 auto divBy2 = builder.
createOrFold<arith::DivUIOp>(loc, value, two);
2726 auto plusOne = builder.
createOrFold<arith::AddIOp>(loc, divBy2, one);
2730 static RankedTensorType
2731 computeOutputShape(OpBuilder &builder, Location loc, Value input,
2732 llvm::SmallVectorImpl<Value> &dynamicSizes) {
2738 dims[2] = halfPlusOne(builder, loc, dims[2]);
2740 llvm::SmallVector<int64_t, 3> staticSizes;
2743 auto elementType = cast<RankedTensorType>(input.
getType()).getElementType();
2744 return RankedTensorType::get(staticSizes, elementType);
2747 static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2748 RankedTensorType type,
2749 llvm::ArrayRef<Value> dynamicSizes) {
2751 tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
2752 auto fillValueAttr = rewriter.
getZeroAttr(type.getElementType());
2753 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
2755 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValue},
2758 return filledTensor;
2761 static Value castIndexToFloat(OpBuilder &builder, Location loc,
2762 FloatType type, Value value) {
2763 auto integerVal = arith::IndexCastUIOp::create(
2765 type.getIntOrFloatBitWidth() > 32 ? builder.
getI64Type()
2769 return arith::UIToFPOp::create(builder, loc, type, integerVal);
2772 static Value createLinalgIndex(OpBuilder &builder, Location loc,
2773 FloatType type, int64_t index) {
2774 auto indexVal = linalg::IndexOp::create(builder, loc, index);
2775 return castIndexToFloat(builder, loc, type, indexVal);
2778 template <
typename... Args>
2779 static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2784 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2785 PatternRewriter &rewriter)
const override {
2786 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2787 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2789 "only supports ranked tensors");
2792 auto loc = rfft2d.getLoc();
2793 auto input = rfft2d.getInputReal();
2795 dyn_cast<FloatType>(cast<ShapedType>(input.
getType()).getElementType());
2798 "only supports float element types");
2801 llvm::SmallVector<Value> dynamicSizes;
2802 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2805 llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
2806 utils::IteratorType::parallel, utils::IteratorType::parallel,
2807 utils::IteratorType::parallel, utils::IteratorType::reduction,
2808 utils::IteratorType::reduction};
2811 llvm::SmallVector<Value> genericOpInputs = {input};
2812 llvm::SmallVector<Value> genericOpOutputs = {
2813 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2814 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2818 llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2819 affineDimsExpr(rewriter, 0, 1, 2),
2820 affineDimsExpr(rewriter, 0, 1, 2)},
2824 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input, 1);
2825 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input, 2);
2828 auto twoPiAttr = rewriter.
getFloatAttr(elementType, 6.283185307179586);
2829 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2830 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2831 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2833 auto buildBody = [&](OpBuilder &builder, Location loc,
ValueRange args) {
2834 Value valReal = args[0];
2835 Value sumReal = args[1];
2836 Value sumImag = args[2];
2839 Value oy = linalg::IndexOp::create(builder, loc, 1);
2840 Value ox = linalg::IndexOp::create(builder, loc, 2);
2841 Value iy = linalg::IndexOp::create(builder, loc, 3);
2842 Value ix = linalg::IndexOp::create(builder, loc, 4);
2847 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2848 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2850 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2851 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2853 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2854 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2856 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2857 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2858 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2859 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2863 auto cosAngle = math::CosOp::create(builder, loc, angle);
2864 auto sinAngle = math::SinOp::create(builder, loc, angle);
2865 auto realComponent =
2866 arith::MulFOp::create(builder, loc, valReal, cosAngle);
2867 auto imagComponent =
2868 arith::MulFOp::create(builder, loc, valReal, sinAngle);
2873 arith::AddFOp::create(builder, loc, sumReal, realComponent);
2875 arith::SubFOp::create(builder, loc, sumImag, imagComponent);
2877 linalg::YieldOp::create(builder, loc,
ValueRange{outReal, outImag});
2881 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2882 indexingMaps, iteratorTypes, buildBody);
2891 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2892 PatternRewriter &rewriter)
const override {
2893 if (!llvm::all_of(fft2d->getOperandTypes(),
2894 RFFT2dConverter::isRankedTensor) ||
2895 !llvm::all_of(fft2d->getResultTypes(),
2896 RFFT2dConverter::isRankedTensor)) {
2900 Location loc = fft2d.getLoc();
2901 Value input_real = fft2d.getInputReal();
2902 Value input_imag = fft2d.getInputImag();
2903 BoolAttr inverse = fft2d.getInverseAttr();
2905 auto real_el_ty = cast<FloatType>(
2906 cast<ShapedType>(input_real.
getType()).getElementType());
2907 [[maybe_unused]]
auto imag_el_ty = cast<FloatType>(
2908 cast<ShapedType>(input_imag.
getType()).getElementType());
2910 assert(real_el_ty == imag_el_ty);
2913 SmallVector<Value> dynamicSizes;
2918 SmallVector<int64_t, 3> staticSizes;
2921 auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2924 SmallVector<utils::IteratorType, 5> iteratorTypes = {
2925 utils::IteratorType::parallel, utils::IteratorType::parallel,
2926 utils::IteratorType::parallel, utils::IteratorType::reduction,
2927 utils::IteratorType::reduction};
2930 SmallVector<Value> genericOpInputs = {input_real, input_imag};
2931 SmallVector<Value> genericOpOutputs = {
2932 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2934 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2939 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2940 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2941 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2942 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2946 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 1);
2947 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 2);
2950 auto twoPiAttr = rewriter.
getFloatAttr(real_el_ty, 6.283185307179586);
2951 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2953 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2955 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2957 auto buildBody = [&](OpBuilder &builder, Location loc,
ValueRange args) {
2958 Value valReal = args[0];
2959 Value valImag = args[1];
2960 Value sumReal = args[2];
2961 Value sumImag = args[3];
2964 Value oy = linalg::IndexOp::create(builder, loc, 1);
2965 Value ox = linalg::IndexOp::create(builder, loc, 2);
2966 Value iy = linalg::IndexOp::create(builder, loc, 3);
2967 Value ix = linalg::IndexOp::create(builder, loc, 4);
2971 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2972 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2974 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2975 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2978 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2980 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2982 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2983 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2985 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2986 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2989 angle = arith::MulFOp::create(
2990 builder, loc, angle,
2991 arith::ConstantOp::create(rewriter, loc,
2997 auto cosAngle = math::CosOp::create(builder, loc, angle);
2998 auto sinAngle = math::SinOp::create(builder, loc, angle);
3000 auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle);
3001 auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle);
3002 auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin);
3004 auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle);
3005 auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle);
3007 auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin);
3012 arith::AddFOp::create(builder, loc, sumReal, realComponent);
3014 arith::AddFOp::create(builder, loc, sumImag, imagComponent);
3016 linalg::YieldOp::create(builder, loc,
ValueRange{outReal, outImag});
3020 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
3021 indexingMaps, iteratorTypes, buildBody);
3042 PointwiseConverter<tosa::AddOp>,
3043 PointwiseConverter<tosa::SubOp>,
3044 PointwiseConverter<tosa::MulOp>,
3045 PointwiseConverter<tosa::IntDivOp>,
3046 PointwiseConverter<tosa::NegateOp>,
3047 PointwiseConverter<tosa::PowOp>,
3048 PointwiseConverter<tosa::ReciprocalOp>,
3049 PointwiseConverter<tosa::RsqrtOp>,
3050 PointwiseConverter<tosa::LogOp>,
3051 PointwiseConverter<tosa::ExpOp>,
3052 PointwiseConverter<tosa::AbsOp>,
3053 PointwiseConverter<tosa::SinOp>,
3054 PointwiseConverter<tosa::CosOp>,
3055 PointwiseConverter<tosa::TanhOp>,
3056 PointwiseConverter<tosa::ErfOp>,
3057 PointwiseConverter<tosa::BitwiseAndOp>,
3058 PointwiseConverter<tosa::BitwiseOrOp>,
3059 PointwiseConverter<tosa::BitwiseNotOp>,
3060 PointwiseConverter<tosa::BitwiseXorOp>,
3061 PointwiseConverter<tosa::LogicalAndOp>,
3062 PointwiseConverter<tosa::LogicalNotOp>,
3063 PointwiseConverter<tosa::LogicalOrOp>,
3064 PointwiseConverter<tosa::LogicalXorOp>,
3065 PointwiseConverter<tosa::CastOp>,
3066 PointwiseConverter<tosa::LogicalLeftShiftOp>,
3067 PointwiseConverter<tosa::LogicalRightShiftOp>,
3068 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
3069 PointwiseConverter<tosa::ClzOp>,
3070 PointwiseConverter<tosa::SelectOp>,
3071 PointwiseConverter<tosa::GreaterOp>,
3072 PointwiseConverter<tosa::GreaterEqualOp>,
3073 PointwiseConverter<tosa::EqualOp>,
3074 PointwiseConverter<tosa::MaximumOp>,
3075 PointwiseConverter<tosa::MinimumOp>,
3076 PointwiseConverter<tosa::CeilOp>,
3077 PointwiseConverter<tosa::FloorOp>,
3078 PointwiseConverter<tosa::ClampOp>,
3079 PointwiseConverter<tosa::SigmoidOp>
3080 >(converter,
patterns->getContext());
3083 IdentityNConverter<tosa::IdentityOp>,
3084 ReduceConverter<tosa::ReduceAllOp>,
3085 ReduceConverter<tosa::ReduceAnyOp>,
3086 ReduceConverter<tosa::ReduceMinOp>,
3087 ReduceConverter<tosa::ReduceMaxOp>,
3088 ReduceConverter<tosa::ReduceSumOp>,
3089 ReduceConverter<tosa::ReduceProductOp>,
3097 TileConverter>(
patterns->getContext());
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
DenseMap< int64_t, Value > IndexPool
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static bool operandsAndResultsRanked(Operation *operation)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
BlockArgument getArgument(unsigned i)
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
BoolAttr getBoolAttr(bool value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
IntegerAttr getI8IntegerAttr(int8_t value)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...