29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/Sequence.h"
31#include "llvm/ADT/SmallVectorExtras.h"
60template <
typename OpTy>
68 auto nanMode = op.getNanMode();
69 if (nanMode == NanPropagationMode::PROPAGATE)
73 Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
74 arith::CmpFPredicate::UNO,
lhs,
lhs);
75 Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
76 arith::CmpFPredicate::UNO,
rhs,
rhs);
78 arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN,
rhs,
result);
79 return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN,
lhs,
85 ConversionPatternRewriter &rewriter) {
91 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
92 return math::AbsFOp::create(rewriter, loc, resultTypes, args);
94 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
95 auto zero = arith::ConstantOp::create(rewriter, loc,
96 rewriter.getZeroAttr(elementTy));
97 auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]);
98 return arith::MaxSIOp::create(rewriter, loc, args[0], neg);
102 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
103 return arith::AddFOp::create(rewriter, loc, resultTypes, args);
105 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
106 return arith::AddIOp::create(rewriter, loc, resultTypes, args);
109 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
110 return arith::SubFOp::create(rewriter, loc, resultTypes, args);
112 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
113 return arith::SubIOp::create(rewriter, loc, resultTypes, args);
116 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
117 return arith::DivSIOp::create(rewriter, loc, resultTypes, args);
120 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
122 arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
123 return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
127 if (isa<tosa::MulOp>(op)) {
128 auto shiftVal = cast<tosa::MulOp>(op).getShift();
130 bool shiftIsConstant =
true;
133 shift = shiftElem.
getValues<IntegerAttr>()[0].getInt();
135 shiftIsConstant =
false;
137 if (isa<FloatType>(elementTy)) {
139 (
void)rewriter.notifyMatchFailure(op,
140 "Cannot have shift value for float");
143 return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
147 if (isa<IntegerType>(elementTy)) {
151 if (shift > 0 || !shiftIsConstant) {
158 a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
160 if (!
b.getType().isInteger(32))
161 b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(),
b);
163 auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
164 auto roundingAttr = RoundingModeAttr::get(rewriter.getContext(),
165 RoundingMode::SINGLE_ROUND);
167 tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
168 b, shiftAmount, roundingAttr);
174 int bWidth =
b.getType().getIntOrFloatBitWidth();
175 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
178 a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a);
180 b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0],
b);
182 return arith::MulIOp::create(rewriter, loc, resultTypes, a,
b);
187 if (isa<tosa::NegateOp>(op)) {
188 auto negate = cast<tosa::NegateOp>(op);
191 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
192 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
193 bool hasInZp = !failed(maybeInZp);
194 bool hasOutZp = !failed(maybeOutZp);
200 if (isa<FloatType>(elementTy))
201 return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
203 if (isa<IntegerType>(elementTy)) {
205 Type intermediateType;
208 int intermediateBitWidth = 64;
210 if (hasInZp && hasOutZp) {
212 const int64_t zpAdd = inZp + outZp;
214 APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
219 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
220 intermediateBitWidth = 16;
221 }
else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
222 intermediateBitWidth = 32;
225 intermediateType = rewriter.getIntegerType(intermediateBitWidth);
226 zpAddValue = arith::ConstantOp::create(
227 rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
229 intermediateType = rewriter.getIntegerType(intermediateBitWidth);
230 Value arg1 = args[1];
231 Value arg2 = args[2];
233 if (arg1.
getType() != intermediateType)
234 arg1 = arith::ExtSIOp::create(rewriter, loc, intermediateType, arg1);
235 if (arg2.
getType() != intermediateType)
236 arg2 = arith::ExtSIOp::create(rewriter, loc, intermediateType, arg2);
238 arith::AddIOp::create(rewriter, loc, intermediateType, arg1, arg2);
244 if (ext.
getType() != intermediateType)
245 ext = arith::ExtSIOp::create(rewriter, loc, intermediateType, ext);
246 auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext);
250 rewriter, loc, intermediateType,
251 APInt::getSignedMinValue(inputBitWidth).getSExtValue());
253 rewriter, loc, intermediateType,
254 APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
258 if (
clamp.getType() == elementTy)
260 return arith::TruncIOp::create(rewriter, loc, elementTy,
clamp);
265 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
266 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
269 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
270 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
273 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
274 auto allOnesAttr = rewriter.getIntegerAttr(
275 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
276 auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr);
277 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes);
281 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
282 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
285 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
286 return arith::ShLIOp::create(rewriter, loc, resultTypes, args);
289 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
290 return arith::ShRUIOp::create(rewriter, loc, resultTypes, args);
293 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
294 auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args);
295 auto round = cast<BoolAttr>(op->
getAttr(
"round")).getValue();
300 Type i1Ty = IntegerType::get(rewriter.getContext(), 1);
301 auto one = arith::ConstantOp::create(rewriter, loc,
302 IntegerAttr::get(elementTy, 1));
303 auto zero = arith::ConstantOp::create(rewriter, loc,
304 IntegerAttr::get(elementTy, 0));
306 arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 0));
308 arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1));
311 auto shiftValueGreaterThanZero = arith::CmpIOp::create(
312 rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero);
316 arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one);
318 arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract)
320 auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted,
323 arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
325 auto shouldRound = arith::SelectOp::create(
326 rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd, i1zero);
328 arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
329 return arith::AddIOp::create(rewriter, loc, resultTypes,
result, extended);
333 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
334 return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]);
338 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
339 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
342 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
343 auto one = arith::ConstantOp::create(rewriter, loc,
344 rewriter.getIntegerAttr(elementTy, 1));
345 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one);
349 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
350 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
353 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
354 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
357 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
358 return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args);
361 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
362 return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args);
365 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
366 return mlir::math::LogOp::create(rewriter, loc, resultTypes, args);
369 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
370 return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args);
373 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
374 return mlir::math::SinOp::create(rewriter, loc, resultTypes, args);
377 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
378 return mlir::math::CosOp::create(rewriter, loc, resultTypes, args);
381 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
382 return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args);
385 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
386 return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args);
389 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
390 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT,
393 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
394 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt,
398 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
399 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE,
402 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
403 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
407 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
408 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ,
411 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
412 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
416 if (isa<tosa::SelectOp>(op)) {
418 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
419 return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]);
423 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
424 auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
426 rewriter, args[0], args[1],
max);
429 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
430 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
434 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
435 auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
437 rewriter, args[0], args[1],
min);
440 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
441 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
445 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
446 return math::CeilOp::create(rewriter, loc, resultTypes, args);
449 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
450 return math::FloorOp::create(rewriter, loc, resultTypes, args);
453 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
454 bool losesInfo =
false;
455 APFloat minApf = cast<FloatAttr>(op->
getAttr(
"min_val")).getValue();
456 APFloat maxApf = cast<FloatAttr>(op->
getAttr(
"max_val")).getValue();
457 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
458 APFloat::rmNearestTiesToEven, &losesInfo);
459 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
460 APFloat::rmNearestTiesToEven, &losesInfo);
461 auto min = arith::ConstantOp::create(
462 rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
463 auto max = arith::ConstantOp::create(
464 rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
467 auto clampOp = llvm::cast<tosa::ClampOp>(op);
468 const auto nanMode = clampOp.getNanMode();
471 if (!isa<FloatType>(elementTy))
476 if (nanMode == NanPropagationMode::PROPAGATE)
490 Value isNaN = arith::CmpFOp::create(
491 rewriter, op->
getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
494 return arith::SelectOp::create(rewriter, op->
getLoc(), isNaN,
min,
result);
497 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
498 auto intTy = cast<IntegerType>(elementTy);
500 cast<IntegerAttr>(op->
getAttr(
"min_val")).getValue().getSExtValue();
502 cast<IntegerAttr>(op->
getAttr(
"max_val")).getValue().getSExtValue();
504 int64_t minRepresentable = std::numeric_limits<int64_t>::min();
505 int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
506 if (intTy.isUnsignedInteger()) {
507 minRepresentable = 0;
508 if (intTy.getIntOrFloatBitWidth() <= 63) {
510 (
int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
513 }
else if (intTy.getIntOrFloatBitWidth() <= 64) {
515 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
517 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
522 min = std::max(
min, minRepresentable);
523 max = std::max(
max, minRepresentable);
524 min = std::min(
min, maxRepresentable);
525 max = std::min(
max, maxRepresentable);
528 intTy.getIntOrFloatBitWidth());
530 intTy.getIntOrFloatBitWidth());
532 intTy.isUnsignedInteger());
536 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
538 arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
539 auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
540 auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
541 auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
542 return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
546 if (isa<tosa::CastOp>(op)) {
547 Type srcTy = elementTy;
548 Type dstTy = resultTypes.front();
550 (
void)rewriter.notifyMatchFailure(op,
"unsupported type");
560 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
561 return arith::ExtFOp::create(rewriter, loc, resultTypes, args,
564 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
565 return arith::TruncFOp::create(rewriter, loc, resultTypes, args,
569 if (srcTy.
isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
570 return arith::UIToFPOp::create(rewriter, loc, resultTypes, args,
573 if (srcTy.
isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
574 return arith::ExtUIOp::create(rewriter, loc, resultTypes, args,
580 auto unrealizedCast =
581 UnrealizedConversionCastOp::create(
585 return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
590 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
591 return arith::SIToFPOp::create(rewriter, loc, resultTypes, args,
595 if (isa<FloatType>(srcTy) && dstTy.
isInteger(1)) {
596 Value zero = arith::ConstantOp::create(rewriter, loc,
597 rewriter.getFloatAttr(srcTy, 0.0));
598 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE,
602 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
603 auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]);
605 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
609 APFloat::semanticsMaxExponent(fltSemantics)) {
612 auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded);
613 auto posInf = arith::ConstantOp::create(
616 APFloat::getInf(fltSemantics)));
617 auto negInf = arith::ConstantOp::create(
619 rewriter.getFloatAttr(
621 APFloat::getInf(fltSemantics,
true)));
622 auto overflow = arith::CmpFOp::create(
623 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf);
624 auto underflow = arith::CmpFOp::create(
625 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf);
626 auto intMin = arith::ConstantOp::create(
628 rewriter.getIntegerAttr(
631 auto intMax = arith::ConstantOp::create(
633 rewriter.getIntegerAttr(
637 arith::SelectOp::create(rewriter, loc, overflow, intMax, conv);
638 return arith::SelectOp::create(rewriter, loc, underflow, intMin,
642 auto intMinFP = arith::ConstantOp::create(
644 rewriter.getFloatAttr(
650 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
656 auto intMaxFP = arith::ConstantOp::create(
658 rewriter.getFloatAttr(
665 return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped);
672 auto intMaxPlusOneFP = arith::ConstantOp::create(
674 rewriter.getFloatAttr(
681 auto intMax = arith::ConstantOp::create(
683 rewriter.getIntegerAttr(
687 arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP);
689 arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP);
690 auto overflow = arith::CmpFOp::create(
691 rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
692 return arith::SelectOp::create(rewriter, loc, overflow, intMax,
698 if (isa<IntegerType>(srcTy) && dstTy.
isInteger(1)) {
701 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne,
705 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
706 return arith::ExtSIOp::create(rewriter, loc, resultTypes, args,
709 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
710 return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]);
714 (
void)rewriter.notifyMatchFailure(
715 op,
"unhandled op for linalg body calculation for elementwise op");
736 return tensor::DimOp::create(rewriter, loc,
tensor, indexValue).getResult();
742 auto shapedType = dyn_cast<ShapedType>(
tensor.getType());
743 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
744 assert(
index >= 0 &&
index < shapedType.getRank() &&
"index out of bounds");
745 if (shapedType.isDynamicDim(
index))
751 auto isRanked = [](
Value value) {
752 return isa<RankedTensorType>(value.getType());
754 return llvm::all_of(operation->
getOperands(), isRanked) &&
755 llvm::all_of(operation->
getResults(), isRanked);
768static std::pair<OpFoldResult, Value>
774 for (
auto operand : operands) {
775 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
776 if (ShapedType::isStatic(size) && size > 1)
781 auto operandsWithDynamicDim =
782 llvm::filter_to_vector(operands, [&](
Value operand) {
783 return cast<RankedTensorType>(operand.
getType()).isDynamicDim(dim);
787 if (operandsWithDynamicDim.empty())
794 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
795 if (operandsWithDynamicDim.size() == 1)
796 return {targetSize, operandsWithDynamicDim[0]};
799 for (
size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
801 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
802 targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize);
804 return {targetSize,
nullptr};
812 assert(!operands.empty());
813 auto rank = cast<RankedTensorType>(operands.front().
getType()).getRank();
816 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
817 auto [targetSize, masterOperand] =
819 targetShape.push_back(targetSize);
820 masterOperands.push_back(masterOperand);
822 return {targetShape, masterOperands};
828 Value masterOperand) {
830 auto rankedTensorType = cast<RankedTensorType>(operand.
getType());
831 if (!rankedTensorType.isDynamicDim(dim))
838 if (operand == masterOperand)
842 auto rank = rankedTensorType.getRank();
844 for (
auto index : llvm::seq<int64_t>(0, rank)) {
847 affineExprs.push_back(affineExpr);
849 auto broadcastAffineMap =
855 auto one =
createIndex(rewriter, loc, indexPool, 1);
856 auto runtimeSize =
getTensorDim(rewriter, loc, indexPool, operand, dim);
857 auto broadcastNecessary = arith::CmpIOp::create(
858 rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one);
868 for (
auto index : llvm::seq<int64_t>(0, rank)) {
869 auto size =
index == dim ? targetSize
872 outputTensorShape.push_back(size);
874 Value outputTensor = tensor::EmptyOp::create(
875 opBuilder, loc, outputTensorShape, rankedTensorType.getElementType());
879 linalg::GenericOp::create(
880 opBuilder, loc, outputTensor.
getType(), operand, outputTensor,
884 linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
889 auto castResultTensor = rewriter.
createOrFold<tensor::CastOp>(
890 loc, operand.
getType(), resultTensor);
893 scf::YieldOp::create(opBuilder, loc, castResultTensor);
898 scf::YieldOp::create(opBuilder, loc, operand);
902 auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary,
903 emitThenRegion, emitElseRegion);
904 return ifOp.getResult(0);
911 int64_t rank = cast<RankedTensorType>(operand.
getType()).getRank();
912 assert((
int64_t)targetShape.size() == rank);
913 assert((
int64_t)masterOperands.size() == rank);
914 for (
auto index : llvm::seq<int64_t>(0, rank))
927 if (operands.size() == 1)
931 bool hasDynamic =
false;
932 for (
auto op : operands) {
933 const auto tType = dyn_cast<RankedTensorType>(op.getType());
934 if (tType && !tType.hasStaticShape()) {
943 return llvm::map_to_vector(operands, [&](
Value operand) {
945 targetShape, masterOperands);
955 auto resultType = cast_or_null<RankedTensorType>(
958 return rewriter.notifyMatchFailure(operation,
"failed to convert type");
960 Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape,
961 resultType.getElementType());
966 auto rank = resultType.getRank();
967 auto affineMaps = llvm::map_to_vector(operands, [&](
Value operand) {
968 auto shape = cast<ShapedType>(operand.
getType()).getShape();
970 for (
auto it : llvm::enumerate(
shape)) {
974 bool requiresBroadcast =
975 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
976 auto affineExpr = requiresBroadcast
977 ? rewriter.getAffineConstantExpr(0)
978 : rewriter.getAffineDimExpr(it.index());
979 affineExprs.push_back(affineExpr);
981 return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
983 affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
986 bool encounteredError =
false;
987 auto linalgOp = linalg::GenericOp::create(
988 rewriter, loc, outputTensor.
getType(), operands, outputTensor, affineMaps,
993 {resultType.getElementType()}, rewriter);
995 encounteredError =
true;
998 linalg::YieldOp::create(opBuilder, loc, opResult);
1000 if (encounteredError)
1001 return rewriter.notifyMatchFailure(
1002 operation,
"unable to create linalg.generic body for elementwise op");
1005 auto castResult = rewriter.createOrFold<tensor::CastOp>(
1006 loc, resultType, linalgOp->getResult(0));
1007 rewriter.replaceOp(operation, castResult);
1014 if (isa<tosa::MulOp>(operation)) {
1018 return operands.take_front(2);
1020 return operands.take_front(3);
1022 if (
auto negate = dyn_cast<tosa::NegateOp>(operation)) {
1023 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
1024 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
1025 if (failed(maybeOutZp) && failed(maybeInZp))
1028 return operands.take_front(1);
1035 ConversionPatternRewriter &rewriter,
1039 assert(operation->
getNumResults() == 1 &&
"elementwise op expects 1 result");
1041 "elementwise op expects at least 1 operand");
1043 return rewriter.notifyMatchFailure(operation,
1044 "Unranked tensors not supported");
1048 auto loc = operation->
getLoc();
1050 auto [targetShape, masterOperands] =
1052 auto broadcastOperands =
1054 targetShape, masterOperands);
1056 targetShape, converter);
1063 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1066 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1069 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1072 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1075 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1077 elementTy, APFloat::getLargest(
1078 cast<FloatType>(elementTy).getFloatSemantics(),
false));
1080 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1084 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1086 elementTy, APFloat::getLargest(
1087 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1089 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1093 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1096 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1099 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1101 elementTy, APFloat::getLargest(
1102 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1104 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1118 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1119 return arith::AddFOp::create(rewriter, loc, args);
1122 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1123 return arith::AddIOp::create(rewriter, loc, args);
1126 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1127 return arith::MulFOp::create(rewriter, loc, args);
1130 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1131 return arith::MulIOp::create(rewriter, loc, args);
1134 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1135 return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
1138 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1139 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
1142 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1143 return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
1146 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1147 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
1150 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1151 return arith::AndIOp::create(rewriter, loc, args);
1153 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1154 return arith::OrIOp::create(rewriter, loc, args);
1162template <
typename OpTy>
1165 auto loc = op->getLoc();
1166 auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1167 auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1168 if (!inputTy || !resultTy)
1171 auto elementTy = resultTy.getElementType();
1172 Value input = op->getOperand(0);
1175 bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
1176 isa<FloatType>(elementTy) &&
1177 cast<FloatType>(elementTy).isBF16();
1182 for (
unsigned i = 0; i < inputTy.getRank(); i++) {
1184 reduceShape.push_back(inputTy.getDimSize(i));
1185 if (inputTy.isDynamicDim(i))
1186 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1191 inputs.push_back(input);
1195 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1201 op,
"No initial value found for reduction operation");
1203 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
1205 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValue},
1208 outputs.push_back(filledTensor);
1210 bool isNanIgnoreMode =
false;
1211 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1212 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1214 if (isa<FloatType>(elementTy) &&
1215 op.getNanMode() == NanPropagationMode::IGNORE) {
1216 isNanIgnoreMode =
true;
1222 auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
1223 auto emptyBoolTensor =
1224 tensor::EmptyOp::create(rewriter, loc, reduceShape,
1225 trueValue.getType(), dynDims)
1227 auto allResultsNaNTensor =
1228 linalg::FillOp::create(rewriter, loc,
ValueRange{trueValue},
1240 inputs.push_back(input);
1241 outputs.push_back(allResultsNaNTensor);
1245 bool didEncounterError =
false;
1246 linalg::LinalgOp linalgOp = linalg::ReduceOp::create(
1247 rewriter, loc, inputs, outputs, axis,
1249 std::array<Value, 2> binaryArgs{
1250 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1253 if (binaryArgs[0].
getType() != accTy)
1254 binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
1260 didEncounterError =
true;
1263 if (isNanIgnoreMode) {
1264 auto inputValue = blockArgs[0];
1265 auto initialValue = blockArgs[2];
1266 auto oldAllResultsNanFlagValue = blockArgs[3];
1269 Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(),
1270 arith::CmpFPredicate::UNO,
1271 inputValue, inputValue);
1273 auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(),
1274 isNaN, initialValue,
result);
1277 auto newAllResultsNanFlagValue = arith::AndIOp::create(
1278 nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1279 resultsToYield.push_back(selectOp);
1280 resultsToYield.push_back(newAllResultsNanFlagValue);
1282 resultsToYield.push_back(
result);
1284 linalg::YieldOp::create(nestedBuilder, loc, resultsToYield);
1287 if (!didEncounterError)
1289 op,
"unable to create linalg.generic body for reduce op");
1291 if (isNanIgnoreMode) {
1300 APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(),
false));
1301 auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
1302 auto emptyNanTensor =
1303 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1305 auto nanFilledTensor =
1306 linalg::FillOp::create(rewriter, loc,
ValueRange{nanValue},
1312 auto finalEmptyTensor =
1313 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1319 ins.push_back(linalgOp->getOpResult(1));
1320 ins.push_back(nanFilledTensor);
1321 ins.push_back(linalgOp->getResult(0));
1322 outs.push_back(finalEmptyTensor);
1324 linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs);
1325 linalgOp = linalgSelect;
1329 Value reducedRes = linalgOp->getResult(0);
1332 tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
1335 const unsigned reducedRank =
1336 cast<ShapedType>(reducedRes.
getType()).getRank();
1339 linalg::GenericOp::create(
1345 Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc,
1346 elementTy, args[0]);
1347 linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf);
1353 uint64_t expandInputRank = cast<ShapedType>(reducedRes.
getType()).getRank();
1354 reassociationMap.resize(expandInputRank);
1356 for (uint64_t i = 0; i < expandInputRank; i++) {
1357 int32_t dimToPush = i > axis ? i + 1 : i;
1361 if (expandInputRank != 0) {
1362 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1363 reassociationMap[expandedDim].push_back(
1378template <
typename SrcOp>
1379class PointwiseConverter :
public OpConversionPattern<SrcOp> {
1381 using OpConversionPattern<SrcOp>::OpConversionPattern;
1382 using typename OpConversionPattern<SrcOp>::OpAdaptor;
1385 matchAndRewrite(SrcOp op, OpAdaptor operands,
1386 ConversionPatternRewriter &rewriter)
const final {
1388 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1398 auto inputType = cast<RankedTensorType>(input.
getType());
1399 auto elemType = inputType.getElementType();
1400 auto collapsedType = RankedTensorType::get({}, elemType);
1402 return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input,
1409 output.reserve(input.size());
1411 for (
auto v : llvm::map_range(
1412 input, [](int32_t val) {
return static_cast<int8_t
>(val); })) {
1413 output.push_back(v);
1425static void setupLinalgGenericOpInputAndIndexingMap(
1428 bool isConstant, tosa::RescaleOp op,
Value &constant,
int64_t &arg,
1429 bool isShift =
false) {
1431 auto loc = op.getLoc();
1432 auto inputTy = cast<ShapedType>(op.getInput().getType());
1433 unsigned rank = inputTy.getRank();
1439 if (values.size() == 1) {
1440 IntegerAttr intAttr = isShift
1443 constant = arith::ConstantOp::create(rewriter, loc, intAttr);
1447 auto tensorType = RankedTensorType::get(
1448 {
static_cast<int64_t>(values.size())}, elementType);
1454 genericInputs.push_back(
1455 arith::ConstantOp::create(rewriter, loc, EltAttr));
1463 auto operand = isShift ? op.getShift() : op.getMultiplier();
1464 auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
1465 if (tensorType && tensorType.hasStaticShape() &&
1466 tensorType.getShape()[0] == 1) {
1471 genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
1472 indexingMaps.push_back(broadcastMap);
1474 genericInputs.push_back(operand);
1480 arg = indexingMaps.size() - 1;
1485 FailureOr<int64_t> maybeZp,
Location loc,
1487 bool isOutputZp =
false) {
1490 const uint32_t attrBitwidth =
1491 isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
1498 result = blockArgs[zpArg];
1499 auto zpTy =
result.getType();
1500 if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
1503 if (zpTy.isUnsignedInteger()) {
1505 UnrealizedConversionCastOp::create(
1510 if (zpTy.isUnsignedInteger()) {
1511 return arith::ExtUIOp::create(builder, loc, extendType,
result);
1513 return arith::ExtSIOp::create(builder, loc, extendType,
result);
1517 return arith::ConstantOp::create(builder, loc,
1518 IntegerAttr::get(extendType, *maybeZp));
1525 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1527 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1528 PatternRewriter &rewriter)
const final {
1529 auto loc = op.getLoc();
1530 auto input = op.getInput();
1531 auto inputTy = cast<ShapedType>(op.getInput().getType());
1532 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1533 unsigned rank = inputTy.getRank();
1536 if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND)
1538 op,
"tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1539 "currently supported");
1540 if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32())
1542 op,
"tosa.rescale requires scale32 for double_round to be true");
1544 if (!isa<IntegerType>(inputTy.getElementType()))
1547 SmallVector<Value> dynDims;
1548 for (
int i = 0; i < outputTy.getRank(); i++) {
1549 if (outputTy.isDynamicDim(i)) {
1550 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1554 DenseElementsAttr shiftElems;
1555 bool isShiftConstant =
false;
1557 isShiftConstant =
true;
1559 DenseElementsAttr multiplierElems;
1560 bool isMultiplierConstant =
false;
1562 isMultiplierConstant =
true;
1564 llvm::SmallVector<int32_t> shiftValues;
1565 llvm::SmallVector<int32_t> multiplierValues;
1568 if (isMultiplierConstant && isShiftConstant) {
1570 shiftValues = llvm::map_to_vector(
1571 shiftElems.
getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
1572 return static_cast<int32_t>(attr.getInt());
1575 llvm::map_to_vector(multiplierElems.
getValues<IntegerAttr>(),
1576 [](IntegerAttr attr) -> int32_t {
1577 return static_cast<int32_t>(attr.getInt());
1581 for (
int i = 0, s = multiplierValues.size(); i < s; i++) {
1582 if (shiftValues[i] > 63) {
1584 multiplierValues[i] = 0;
1589 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1590 llvm::any_of(shiftValues, [](int32_t v) {
return v > 31; });
1592 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
1594 RoundingMode roundingMode =
1595 doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
1597 SmallVector<AffineMap> indexingMaps = {
1599 SmallVector<Value, 4> genericInputs = {input};
1603 Value multiplierConstant;
1604 int64_t multiplierArg = 0;
1605 setupLinalgGenericOpInputAndIndexingMap(
1606 rewriter, multiplierValues, genericInputs, indexingMaps,
1607 isMultiplierConstant, op, multiplierConstant, multiplierArg);
1611 Value shiftConstant;
1612 int64_t shiftArg = 0;
1613 setupLinalgGenericOpInputAndIndexingMap(
1614 rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1615 shiftConstant, shiftArg,
true);
1620 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1621 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1631 genericInputs.push_back(
1632 collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
1633 indexingMaps.push_back(broadcastMap);
1634 iZpArg = indexingMaps.size() - 1;
1638 genericInputs.push_back(
1639 collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
1640 indexingMaps.push_back(broadcastMap);
1641 oZpArg = indexingMaps.size() - 1;
1648 Value emptyTensor = tensor::EmptyOp::create(
1649 rewriter, loc, outputTy.getShape(), outputTy.getElementType(),
1650 ArrayRef<Value>({dynDims}));
1652 auto linalgOp = linalg::GenericOp::create(
1653 rewriter, loc, outputTy, genericInputs,
ValueRange{emptyTensor},
1655 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1657 Value value = blockArgs[0];
1658 Type valueTy = value.
getType();
1660 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1661 auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
1662 nestedLoc, blockArgs, iZpArg);
1664 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1665 auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
1666 nestedLoc, blockArgs, oZpArg,
true);
1668 IntegerType outIntType =
1669 cast<IntegerType>(blockArgs.back().
getType());
1670 unsigned outBitWidth = outIntType.getWidth();
1671 assert(outBitWidth <= 32 &&
"Unexpected output zeropoint bitwidth");
1673 Value multiplier = multiplierConstant ? multiplierConstant
1674 : blockArgs[multiplierArg];
1675 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1678 value = UnrealizedConversionCastOp::create(
1679 nestedBuilder, nestedLoc,
1680 nestedBuilder.getIntegerType(
1686 if (op.getInputUnsigned()) {
1687 value = arith::ExtUIOp::create(nestedBuilder, nestedLoc,
1688 nestedBuilder.getI32Type(), value);
1690 value = arith::ExtSIOp::create(nestedBuilder, nestedLoc,
1691 nestedBuilder.getI32Type(), value);
1696 arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp);
1698 value = tosa::ApplyScaleOp::create(nestedBuilder, loc,
1699 nestedBuilder.getI32Type(), value,
1700 multiplier, shift, roundingMode);
1704 arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp);
1707 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1708 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1711 if (op.getOutputUnsigned()) {
1713 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1716 auto intMinVal = arith::ConstantOp::create(
1717 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin));
1718 auto intMaxVal = arith::ConstantOp::create(
1719 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax));
1722 nestedBuilder,
false);
1724 if (outIntType.getWidth() < 32) {
1725 value = arith::TruncIOp::create(
1726 nestedBuilder, nestedLoc,
1730 if (outIntType.isUnsignedInteger()) {
1731 value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
1735 linalg::YieldOp::create(nestedBuilder, loc, value);
1738 rewriter.
replaceOp(op, linalgOp->getResults());
1748 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1750 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1751 PatternRewriter &rewriter)
const final {
1752 Location loc = op.getLoc();
1753 ImplicitLocOpBuilder builder(loc, rewriter);
1754 auto input = op.getInput();
1755 auto inputTy = cast<RankedTensorType>(input.getType());
1756 auto resultTy = cast<RankedTensorType>(op.getType());
1757 const bool isBilinear = op.getMode() == ResizeMode::BILINEAR;
1759 auto inputH = inputTy.getDimSize(1);
1760 auto inputW = inputTy.getDimSize(2);
1761 auto outputH = resultTy.getDimSize(1);
1762 auto outputW = resultTy.getDimSize(2);
1764 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1766 op,
"tosa.resize is not a pure 1x1->1x1 image operation");
1768 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1769 op.getMode() != ResizeMode::BILINEAR)
1771 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1773 if (inputTy == resultTy) {
1778 SmallVector<int64_t> scale;
1784 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1791 RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1792 inputTy.getElementType());
1793 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input,
1797 llvm::SmallVector<Value> outputDynSize;
1798 if (inputTy.isDynamicDim(0))
1799 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1800 if (inputTy.isDynamicDim(3))
1801 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1804 auto genericTy = collapseTy.clone(resultTy.getElementType());
1806 tensor::EmptyOp::create(builder, genericTy.getShape(),
1807 resultTy.getElementType(), outputDynSize);
1809 SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1810 utils::IteratorType::parallel);
1812 auto generic = linalg::GenericOp::create(
1814 ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1815 [=](OpBuilder &
b, Location loc,
ValueRange args) {
1816 Value value = args[0];
1818 if (inputTy.getElementType() != resultTy.getElementType()) {
1819 value = arith::ExtSIOp::create(
b, loc, resultTy.getElementType(),
1822 if (isBilinear && scale[0] != 0) {
1823 Value scaleY = arith::ConstantOp::create(
1824 b, loc,
b.getI32IntegerAttr(scale[0]));
1825 value = arith::MulIOp::create(
b, loc, value, scaleY);
1828 if (isBilinear && scale[2] != 0) {
1829 Value scaleX = arith::ConstantOp::create(
1830 b, loc,
b.getI32IntegerAttr(scale[2]));
1831 value = arith::MulIOp::create(
b, loc, value, scaleX);
1835 linalg::YieldOp::create(
b, loc, value);
1839 op, resultTy,
generic.getResults()[0], reassociationMap);
1851 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1855 auto input = op.getInput();
1856 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1857 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1859 if (!inputTy || !resultTy)
1861 "requires ranked input/output types");
1863 auto batch = inputTy.getDimSize(0);
1864 auto channels = inputTy.getDimSize(3);
1865 auto inputH = inputTy.getDimSize(1);
1866 auto inputW = inputTy.getDimSize(2);
1867 auto outputH = resultTy.getDimSize(1);
1868 auto outputW = resultTy.getDimSize(2);
1870 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1872 op,
"tosa.resize has no broadcasting behavior");
1877 resizeShape.push_back(batch);
1878 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1879 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1880 resizeShape.push_back(channels);
1882 auto resizeTy = resultTy.clone(resizeShape);
1884 tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(),
1885 op.getOffset(), op.getBorder(), op.getMode());
1892 reassociationMap.push_back({});
1895 reassociationMap.push_back({});
1900 collapseShape.push_back(outputH);
1902 collapseShape.push_back(outputW);
1903 collapseShape.push_back(channels);
1905 auto collapseTy = resultTy.clone(collapseShape);
1906 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy,
1907 resize, reassociationMap);
1911 if (inputTy.isDynamicDim(0))
1912 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1913 if (inputTy.isDynamicDim(3))
1914 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1917 utils::IteratorType::parallel);
1918 Value empty = tensor::EmptyOp::create(
1919 builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1936 Value value = args[0];
1937 linalg::YieldOp::create(
b, loc, value);
1946 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1948 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1949 PatternRewriter &rewriter)
const final {
1950 Location loc = op.getLoc();
1951 ImplicitLocOpBuilder
b(loc, rewriter);
1952 auto input = op.getInput();
1953 auto inputTy = cast<ShapedType>(input.getType());
1954 auto resultTy = cast<ShapedType>(op.getType());
1955 auto resultETy = resultTy.getElementType();
1957 bool floatingPointMode = isa<FloatType>(resultETy);
1958 auto floatTy = resultETy;
1960 auto imageH = inputTy.getShape()[1];
1961 auto imageW = inputTy.getShape()[2];
1963 auto dynamicDimsOr =
1965 if (!dynamicDimsOr.has_value())
1967 op,
"unable to get dynamic dimensions of tosa.resize");
1969 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1970 op.getMode() != ResizeMode::BILINEAR)
1972 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1974 SmallVector<AffineMap, 2> affineMaps = {
1976 auto emptyTensor = tensor::EmptyOp::create(
b, resultTy.getShape(),
1977 resultETy, *dynamicDimsOr);
1978 auto genericOp = linalg::GenericOp::create(
1981 Value resize = genericOp.getResult(0);
1984 OpBuilder::InsertionGuard regionGuard(
b);
1985 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1987 Value batch = linalg::IndexOp::create(
b, 0);
1988 Value y = linalg::IndexOp::create(
b, 1);
1989 Value x = linalg::IndexOp::create(
b, 2);
1990 Value channel = linalg::IndexOp::create(
b, 3);
1993 arith::ConstantOp::create(
b,
b.getZeroAttr(
b.getI32Type()));
1994 Value zeroFp = arith::ConstantOp::create(
b,
b.getZeroAttr(floatTy));
1996 arith::ConstantOp::create(
b,
b.getI32IntegerAttr(imageH - 1));
1998 arith::ConstantOp::create(
b,
b.getI32IntegerAttr(imageW - 1));
2000 Value inY = arith::IndexCastOp::create(
b,
b.getI32Type(), y);
2001 Value inX = arith::IndexCastOp::create(
b,
b.getI32Type(), x);
2003 SmallVector<int64_t> scale, offset, border;
2008 op,
"tosa.resize scale/offset/border should have compile time "
2009 "constant values.");
2012 Value yScaleN, yScaleD, xScaleN, xScaleD;
2013 yScaleN = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[0]));
2014 yScaleD = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[1]));
2015 xScaleN = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[2]));
2016 xScaleD = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[3]));
2018 Value yOffset, xOffset, yBorder, xBorder;
2019 yOffset = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(offset[0]));
2020 xOffset = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(offset[1]));
2021 yBorder = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(border[0]));
2022 xBorder = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(border[1]));
2025 auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
2026 Value scaleN, Value scaleD, Value offset,
2027 int size, ImplicitLocOpBuilder &
b) {
2035 Value val = arith::MulIOp::create(
b, in, scaleD);
2036 val = arith::AddIOp::create(
b, val, offset);
2037 index = arith::FloorDivSIOp::create(
b, val, scaleN);
2040 Value scaledIndex = arith::MulIOp::create(
b, index, scaleN);
2041 Value r = arith::SubIOp::create(
b, val, scaledIndex);
2042 Value rFp = arith::SIToFPOp::create(
b, floatTy, r);
2045 Value scaleNfp = arith::UIToFPOp::create(
b, floatTy, scaleN);
2046 delta = arith::DivFOp::create(
b, rFp, scaleNfp);
2050 auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
2051 Value scaleN, Value scaleD, Value offset,
2052 int size, ImplicitLocOpBuilder &
b) {
2061 Value val = arith::MulIOp::create(
b, in, scaleD);
2062 val = arith::AddIOp::create(
b, val, offset);
2063 index = arith::DivSIOp::create(
b, val, scaleN);
2064 delta = arith::MulIOp::create(
b, index, scaleN);
2065 delta = arith::SubIOp::create(
b, val, delta);
2068 Value ix, iy, dx, dy;
2069 if (floatingPointMode) {
2070 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH,
b);
2071 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW,
b);
2073 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH,
b);
2074 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW,
b);
2077 if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) {
2078 auto one = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(1));
2080 auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
2081 Value
max,
int size,
2082 ImplicitLocOpBuilder &
b) -> Value {
2088 if (floatingPointMode) {
2090 arith::ConstantOp::create(
b,
b.getFloatAttr(floatTy, 0.5f));
2091 pred = arith::CmpFOp::create(
b, arith::CmpFPredicate::OGE, dval, h);
2093 Value dvalDouble = arith::ShLIOp::create(
b, dval, one);
2094 pred = arith::CmpIOp::create(
b, arith::CmpIPredicate::sge,
2098 auto offset = arith::SelectOp::create(
b, pred, one, zeroI32);
2099 val = arith::AddIOp::create(
b, val, offset);
2101 return arith::IndexCastOp::create(
b,
b.getIndexType(), val);
2104 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH,
b);
2105 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW,
b);
2107 Value
result = tensor::ExtractOp::create(
2110 linalg::YieldOp::create(
b,
result);
2113 assert(op.getMode() == ResizeMode::BILINEAR);
2115 auto oneVal = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(1));
2117 auto getClampedIdxs = [&](Value &val0, Value &val1,
int size, Value in,
2118 Value
max, ImplicitLocOpBuilder &
b) {
2120 val1 = arith::AddIOp::create(
b, val0, oneVal);
2125 val0 = arith::IndexCastOp::create(
b,
b.getIndexType(), val0);
2126 val1 = arith::IndexCastOp::create(
b,
b.getIndexType(), val1);
2134 Value x0, x1, y0, y1;
2135 getClampedIdxs(y0, y1, imageH, iy, hMax,
b);
2136 getClampedIdxs(x0, x1, imageW, ix, wMax,
b);
2138 Value y0x0 = tensor::ExtractOp::create(
2140 Value y0x1 = tensor::ExtractOp::create(
2142 Value y1x0 = tensor::ExtractOp::create(
2144 Value y1x1 = tensor::ExtractOp::create(
2147 if (floatingPointMode) {
2149 arith::ConstantOp::create(
b,
b.getFloatAttr(floatTy, 1.0f));
2150 auto interpolate = [&](Value val0, Value val1, Value delta,
2152 ImplicitLocOpBuilder &
b) -> Value {
2155 Value oneMinusDelta = arith::SubFOp::create(
b, oneVal, delta);
2156 Value mul0 = arith::MulFOp::create(
b, val0, oneMinusDelta);
2157 Value mul1 = arith::MulFOp::create(
b, val1, delta);
2158 return arith::AddFOp::create(
b, mul0, mul1);
2164 Value topAcc = interpolate(y0x0, y0x1, dx, imageW,
b);
2169 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW,
b);
2173 Value
result = interpolate(topAcc, bottomAcc, dy, imageH,
b);
2174 linalg::YieldOp::create(
b,
result);
2177 y0x0 = arith::ExtSIOp::create(
b, resultETy, y0x0);
2178 y0x1 = arith::ExtSIOp::create(
b, resultETy, y0x1);
2179 y1x0 = arith::ExtSIOp::create(
b, resultETy, y1x0);
2180 y1x1 = arith::ExtSIOp::create(
b, resultETy, y1x1);
2183 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
2184 dx = arith::ExtSIOp::create(
b, resultETy, dx);
2185 dy = arith::ExtSIOp::create(
b, resultETy, dy);
2188 Value yScaleNExt = yScaleN;
2189 Value xScaleNExt = xScaleN;
2191 const int64_t scaleBitwidth =
2193 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2194 yScaleNExt = arith::ExtSIOp::create(
b, resultETy, yScaleN);
2195 xScaleNExt = arith::ExtSIOp::create(
b, resultETy, xScaleN);
2198 auto interpolate = [](Value val0, Value val1, Value weight1,
2199 Value scale,
int inputSize,
2200 ImplicitLocOpBuilder &
b) -> Value {
2202 return arith::MulIOp::create(
b, val0, scale);
2203 Value weight0 = arith::SubIOp::create(
b, scale, weight1);
2204 Value mul0 = arith::MulIOp::create(
b, val0, weight0);
2205 Value mul1 = arith::MulIOp::create(
b, val1, weight1);
2206 return arith::AddIOp::create(
b, mul0, mul1);
2209 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW,
b);
2210 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW,
b);
2212 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH,
b);
2213 linalg::YieldOp::create(
b,
result);
2226template <
typename SrcOp>
2229 using OpRewritePattern<SrcOp>::OpRewritePattern;
2231 LogicalResult matchAndRewrite(SrcOp op,
2232 PatternRewriter &rewriter)
const final {
2233 rewriter.
replaceOp(op, op.getOperation()->getOperands());
2238template <
typename SrcOp>
2241 using OpRewritePattern<SrcOp>::OpRewritePattern;
2243 LogicalResult matchAndRewrite(SrcOp reduceOp,
2244 PatternRewriter &rewriter)
const final {
2251 using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
2253 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2254 PatternRewriter &rewriter)
const final {
2255 auto loc = op.getLoc();
2256 Value input = op.getInput1();
2257 auto inputTy = cast<ShapedType>(input.
getType());
2258 auto resultTy = cast<ShapedType>(op.getType());
2259 auto axis = op.getAxis();
2261 SmallVector<Value> dynDims;
2262 for (
int i = 0; i < inputTy.getRank(); i++) {
2263 if (inputTy.isDynamicDim(i)) {
2264 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2268 Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
2271 auto emptyTensor = tensor::EmptyOp::create(
2272 rewriter, loc, inputTy.getShape(),
2273 inputTy.getElementType(), ArrayRef<Value>({dynDims}))
2275 SmallVector<AffineMap, 2> affineMaps = {
2279 op, resultTy, ArrayRef<Value>({}),
ValueRange{emptyTensor}, affineMaps,
2281 [&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
2282 llvm::SmallVector<Value>
indices;
2283 for (
unsigned int i = 0; i < inputTy.getRank(); i++) {
2285 linalg::IndexOp::create(rewriter, nestedLoc, i).getResult();
2289 arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one);
2290 index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne,
2297 auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc,
2299 linalg::YieldOp::create(nestedBuilder, op.getLoc(),
2300 extract.getResult());
2310struct TileConverter :
public OpConversionPattern<tosa::TileOp> {
2311 using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
2314 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2315 ConversionPatternRewriter &rewriter)
const override {
2316 auto loc = op.getLoc();
2317 auto input = op.getInput1();
2318 auto inputTy = cast<ShapedType>(input.
getType());
2319 auto inputShape = inputTy.getShape();
2320 auto resultTy = cast<ShapedType>(op.getType());
2321 auto elementTy = inputTy.getElementType();
2322 int64_t rank = inputTy.getRank();
2324 SmallVector<int64_t> multiples;
2325 if (
failed(op.getConstantMultiples(multiples)))
2329 SmallVector<int64_t, 2> genericShape;
2330 for (
int i = 0; i < rank; i++) {
2331 int64_t dim = multiples[i];
2332 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2333 genericShape.push_back(inputShape[i]);
2336 SmallVector<Value> dynDims;
2337 for (
int i = 0; i < inputTy.getRank(); i++) {
2338 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2339 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2343 auto emptyTensor = tensor::EmptyOp::create(
2344 rewriter, op.getLoc(), genericShape, elementTy, dynDims);
2347 SmallVector<AffineExpr, 4> dimExprs;
2348 dimExprs.reserve(rank);
2349 for (
unsigned i = 0; i < rank; ++i)
2350 dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
2352 auto readAffineMap =
2354 rewriter.getContext());
2356 SmallVector<AffineMap, 2> affineMaps = {
2357 readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
2359 auto genericOp = linalg::GenericOp::create(
2360 rewriter, loc, RankedTensorType::get(genericShape, elementTy), input,
2363 [&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
2364 linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin());
2369 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2370 op, resultTy, genericOp.getResult(0), shapeValue);
2390 using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
2392 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2393 PatternRewriter &rewriter)
const final {
2394 auto loc = argmaxOp.getLoc();
2395 Value input = argmaxOp.getInput();
2396 auto inputTy = cast<ShapedType>(input.
getType());
2397 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2398 auto inElementTy = inputTy.getElementType();
2399 auto outElementTy = resultTy.getElementType();
2400 int axis = argmaxOp.getAxis();
2401 auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2403 if (!isa<IntegerType>(outElementTy))
2404 return rewriter.notifyMatchFailure(
2406 "tosa.arg_max to linalg.* requires integer-like result type");
2408 SmallVector<Value> dynDims;
2409 for (
int i = 0; i < inputTy.getRank(); i++) {
2410 if (inputTy.isDynamicDim(i) && i != axis) {
2411 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2416 auto emptyTensorIdx =
2417 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2418 outElementTy, dynDims)
2420 auto fillValueIdx = arith::ConstantOp::create(
2421 rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
2422 auto filledTensorIdx =
2423 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValueIdx},
2428 auto emptyTensorMax =
2429 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
2432 auto fillValueMaxAttr =
2435 if (!fillValueMaxAttr)
2436 return rewriter.notifyMatchFailure(
2437 argmaxOp,
"unsupported tosa.argmax element type");
2440 arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
2441 auto filledTensorMax =
2442 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValueMax},
2448 SmallVector<utils::IteratorType, 4> iteratorTypes;
2449 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2450 iteratorTypes[axis] = utils::IteratorType::reduction;
2452 SmallVector<AffineExpr, 2> srcExprs;
2453 SmallVector<AffineExpr, 2> dstExprs;
2454 for (
int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2460 bool didEncounterError =
false;
2462 rewriter.getContext());
2463 auto linalgOp = linalg::GenericOp::create(
2464 rewriter, loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2465 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2466 [&](OpBuilder &nestedBuilder, Location nestedLoc,
2468 auto newValue = blockArgs[0];
2469 auto oldIndex = blockArgs[1];
2470 auto oldValue = blockArgs[2];
2472 Value newIndex = arith::IndexCastOp::create(
2473 rewriter, nestedLoc, oldIndex.getType(),
2474 linalg::IndexOp::create(rewriter, loc, axis));
2477 if (isa<FloatType>(inElementTy)) {
2478 if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) {
2481 predicate = arith::CmpFOp::create(rewriter, nestedLoc,
2482 arith::CmpFPredicate::OGT,
2483 newValue, oldValue);
2488 Value gt = arith::CmpFOp::create(rewriter, nestedLoc,
2489 arith::CmpFPredicate::UGT,
2490 newValue, oldValue);
2491 Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc,
2492 arith::CmpFPredicate::ORD,
2493 oldValue, oldValue);
2494 predicate = arith::AndIOp::create(
2495 rewriter, nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
2497 }
else if (isa<IntegerType>(inElementTy)) {
2498 predicate = arith::CmpIOp::create(rewriter, nestedLoc,
2499 arith::CmpIPredicate::sgt,
2500 newValue, oldValue);
2502 didEncounterError =
true;
2506 auto resultMax = arith::SelectOp::create(
2507 rewriter, nestedLoc, predicate, newValue, oldValue);
2508 auto resultIndex = arith::SelectOp::create(
2509 rewriter, nestedLoc, predicate, newIndex, oldIndex);
2510 linalg::YieldOp::create(nestedBuilder, nestedLoc,
2514 if (didEncounterError)
2515 return rewriter.notifyMatchFailure(
2516 argmaxOp,
"unsupported tosa.argmax element type");
2518 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2523class GatherConverter :
public OpConversionPattern<tosa::GatherOp> {
2525 using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2527 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2528 ConversionPatternRewriter &rewriter)
const final {
2529 auto input = adaptor.getOperands()[0];
2530 auto indices = adaptor.getOperands()[1];
2532 auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2533 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2534 if (!valuesTy || !resultTy)
2535 return rewriter.notifyMatchFailure(op,
"unranked tensors not supported");
2537 auto dynamicDims = inferDynamicDimsForGather(
2538 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2540 auto resultElementTy = resultTy.getElementType();
2542 auto loc = op.getLoc();
2544 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2545 resultElementTy, dynamicDims)
2548 SmallVector<AffineMap, 2> affineMaps = {
2550 resultTy.getRank(), 0,
2551 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2552 rewriter.getContext()),
2553 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2555 auto genericOp = linalg::GenericOp::create(
2559 [&](OpBuilder &
b, Location loc,
ValueRange args) {
2560 auto indexValue = args[0];
2561 auto index0 = linalg::IndexOp::create(rewriter, loc, 0);
2562 Value index1 = arith::IndexCastOp::create(
2563 rewriter, loc, rewriter.getIndexType(), indexValue);
2564 auto index2 = linalg::IndexOp::create(rewriter, loc, 2);
2565 Value extract = tensor::ExtractOp::create(
2566 rewriter, loc, input,
ValueRange{index0, index1, index2});
2567 linalg::YieldOp::create(rewriter, loc, extract);
2569 rewriter.replaceOp(op, genericOp.getResult(0));
2573 static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2577 llvm::SmallVector<Value> results;
2579 auto addDynamicDimension = [&](Value source, int64_t dim) {
2581 if (
auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2582 results.push_back(dimValue);
2585 addDynamicDimension(values, 0);
2586 addDynamicDimension(
indices, 1);
2587 addDynamicDimension(values, 2);
2597 using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2599 LogicalResult matchAndRewrite(tosa::TableOp op,
2600 PatternRewriter &rewriter)
const final {
2601 auto loc = op.getLoc();
2602 Value input = op.getInput1();
2603 Value table = op.getTable();
2604 auto inputTy = cast<ShapedType>(input.
getType());
2605 auto tableTy = cast<ShapedType>(table.
getType());
2606 auto resultTy = cast<ShapedType>(op.getType());
2608 auto inputElementTy = inputTy.getElementType();
2609 auto tableElementTy = tableTy.getElementType();
2610 auto resultElementTy = resultTy.getElementType();
2612 SmallVector<Value> dynDims;
2613 for (
int i = 0; i < resultTy.getRank(); ++i) {
2614 if (inputTy.isDynamicDim(i)) {
2616 tensor::DimOp::create(rewriter, loc, op.getOperand(0), i));
2621 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2622 resultElementTy, dynDims)
2625 SmallVector<AffineMap, 2> affineMaps = {
2626 rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2627 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2629 auto genericOp = linalg::GenericOp::create(
2632 rewriter.replaceOp(op, genericOp.getResult(0));
2635 OpBuilder::InsertionGuard regionGuard(rewriter);
2636 Block *block = rewriter.createBlock(
2637 &genericOp.getRegion(), genericOp.getRegion().end(),
2638 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2641 rewriter.setInsertionPointToStart(block);
2642 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2643 resultElementTy.isInteger(8)) {
2644 Value index = arith::IndexCastOp::create(
2645 rewriter, loc, rewriter.getIndexType(), inputValue);
2647 index = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
2650 tensor::ExtractOp::create(rewriter, loc, table,
ValueRange{index});
2651 linalg::YieldOp::create(rewriter, loc, extract);
2655 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2656 resultElementTy.isInteger(32)) {
2657 Value extend = arith::ExtSIOp::create(
2658 rewriter, loc, rewriter.getI32Type(), inputValue);
2660 auto offset = arith::ConstantOp::create(
2661 rewriter, loc, rewriter.getI32IntegerAttr(32768));
2662 auto seven = arith::ConstantOp::create(rewriter, loc,
2663 rewriter.getI32IntegerAttr(7));
2664 auto one = arith::ConstantOp::create(rewriter, loc,
2665 rewriter.getI32IntegerAttr(1));
2666 auto b1111111 = arith::ConstantOp::create(
2667 rewriter, loc, rewriter.getI32IntegerAttr(127));
2673 auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset);
2674 Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven);
2676 arith::AndIOp::create(rewriter, loc, extendAdd, b1111111);
2681 Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one);
2683 index = arith::IndexCastOp::create(rewriter, loc,
2684 rewriter.getIndexType(), index);
2685 indexPlusOne = arith::IndexCastOp::create(
2686 rewriter, loc, rewriter.getIndexType(), indexPlusOne);
2689 tensor::ExtractOp::create(rewriter, loc, table,
ValueRange{index});
2690 Value next = tensor::ExtractOp::create(rewriter, loc, table,
2694 arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), base);
2696 arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), next);
2700 Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven);
2701 Value diff = arith::SubIOp::create(rewriter, loc, next, base);
2702 Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction);
2704 arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled);
2706 linalg::YieldOp::create(rewriter, loc,
result);
2712 return rewriter.notifyMatchFailure(
2713 op,
"unable to create body for tosa.table op");
2718 using OpRewritePattern<RFFT2dOp>::OpRewritePattern;
2720 static bool isRankedTensor(Type type) {
return isa<RankedTensorType>(type); }
2722 static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2728 auto divBy2 = builder.
createOrFold<arith::DivUIOp>(loc, value, two);
2729 auto plusOne = builder.
createOrFold<arith::AddIOp>(loc, divBy2, one);
2733 static RankedTensorType
2734 computeOutputShape(OpBuilder &builder, Location loc, Value input,
2735 llvm::SmallVectorImpl<Value> &dynamicSizes) {
2741 dims[2] = halfPlusOne(builder, loc, dims[2]);
2743 llvm::SmallVector<int64_t, 3> staticSizes;
2746 auto elementType = cast<RankedTensorType>(input.
getType()).getElementType();
2747 return RankedTensorType::get(staticSizes, elementType);
2750 static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2751 RankedTensorType type,
2752 llvm::ArrayRef<Value> dynamicSizes) {
2754 tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
2755 auto fillValueAttr = rewriter.
getZeroAttr(type.getElementType());
2756 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
2758 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValue},
2761 return filledTensor;
2764 static Value castIndexToFloat(OpBuilder &builder, Location loc,
2765 FloatType type, Value value) {
2766 auto integerVal = arith::IndexCastUIOp::create(
2768 type.getIntOrFloatBitWidth() > 32 ? builder.
getI64Type()
2772 return arith::UIToFPOp::create(builder, loc, type, integerVal);
2775 static Value createLinalgIndex(OpBuilder &builder, Location loc,
2776 FloatType type, int64_t index) {
2777 auto indexVal = linalg::IndexOp::create(builder, loc, index);
2778 return castIndexToFloat(builder, loc, type, indexVal);
2781 template <
typename... Args>
2782 static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2787 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2788 PatternRewriter &rewriter)
const override {
2789 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2790 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2792 "only supports ranked tensors");
2795 auto loc = rfft2d.getLoc();
2796 auto input = rfft2d.getInputReal();
2798 dyn_cast<FloatType>(cast<ShapedType>(input.
getType()).getElementType());
2801 "only supports float element types");
2804 llvm::SmallVector<Value> dynamicSizes;
2805 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2808 llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
2809 utils::IteratorType::parallel, utils::IteratorType::parallel,
2810 utils::IteratorType::parallel, utils::IteratorType::reduction,
2811 utils::IteratorType::reduction};
2814 llvm::SmallVector<Value> genericOpInputs = {input};
2815 llvm::SmallVector<Value> genericOpOutputs = {
2816 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2817 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2821 llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2822 affineDimsExpr(rewriter, 0, 1, 2),
2823 affineDimsExpr(rewriter, 0, 1, 2)},
2827 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input, 1);
2828 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input, 2);
2831 auto twoPiAttr = rewriter.
getFloatAttr(elementType, 6.283185307179586);
2832 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2833 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2834 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2836 auto buildBody = [&](OpBuilder &builder, Location loc,
ValueRange args) {
2837 Value valReal = args[0];
2838 Value sumReal = args[1];
2839 Value sumImag = args[2];
2842 Value oy = linalg::IndexOp::create(builder, loc, 1);
2843 Value ox = linalg::IndexOp::create(builder, loc, 2);
2844 Value iy = linalg::IndexOp::create(builder, loc, 3);
2845 Value ix = linalg::IndexOp::create(builder, loc, 4);
2850 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2851 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2853 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2854 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2856 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2857 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2859 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2860 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2861 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2862 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2866 auto cosAngle = math::CosOp::create(builder, loc, angle);
2867 auto sinAngle = math::SinOp::create(builder, loc, angle);
2868 auto realComponent =
2869 arith::MulFOp::create(builder, loc, valReal, cosAngle);
2870 auto imagComponent =
2871 arith::MulFOp::create(builder, loc, valReal, sinAngle);
2876 arith::AddFOp::create(builder, loc, sumReal, realComponent);
2878 arith::SubFOp::create(builder, loc, sumImag, imagComponent);
2880 linalg::YieldOp::create(builder, loc,
ValueRange{outReal, outImag});
2884 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2885 indexingMaps, iteratorTypes, buildBody);
2894 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2895 PatternRewriter &rewriter)
const override {
2896 if (!llvm::all_of(fft2d->getOperandTypes(),
2897 RFFT2dConverter::isRankedTensor) ||
2898 !llvm::all_of(fft2d->getResultTypes(),
2899 RFFT2dConverter::isRankedTensor)) {
2903 Location loc = fft2d.getLoc();
2904 Value input_real = fft2d.getInputReal();
2905 Value input_imag = fft2d.getInputImag();
2906 BoolAttr inverse = fft2d.getInverseAttr();
2908 auto real_el_ty = cast<FloatType>(
2909 cast<ShapedType>(input_real.
getType()).getElementType());
2910 [[maybe_unused]]
auto imag_el_ty = cast<FloatType>(
2911 cast<ShapedType>(input_imag.
getType()).getElementType());
2913 assert(real_el_ty == imag_el_ty);
2916 SmallVector<Value> dynamicSizes;
2921 SmallVector<int64_t, 3> staticSizes;
2924 auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2927 SmallVector<utils::IteratorType, 5> iteratorTypes = {
2928 utils::IteratorType::parallel, utils::IteratorType::parallel,
2929 utils::IteratorType::parallel, utils::IteratorType::reduction,
2930 utils::IteratorType::reduction};
2933 SmallVector<Value> genericOpInputs = {input_real, input_imag};
2934 SmallVector<Value> genericOpOutputs = {
2935 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2937 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2942 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2943 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2944 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2945 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2949 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 1);
2950 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 2);
2953 auto twoPiAttr = rewriter.
getFloatAttr(real_el_ty, 6.283185307179586);
2954 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2956 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2958 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2960 auto buildBody = [&](OpBuilder &builder, Location loc,
ValueRange args) {
2961 Value valReal = args[0];
2962 Value valImag = args[1];
2963 Value sumReal = args[2];
2964 Value sumImag = args[3];
2967 Value oy = linalg::IndexOp::create(builder, loc, 1);
2968 Value ox = linalg::IndexOp::create(builder, loc, 2);
2969 Value iy = linalg::IndexOp::create(builder, loc, 3);
2970 Value ix = linalg::IndexOp::create(builder, loc, 4);
2974 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2975 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2977 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2978 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2981 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2983 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2985 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2986 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2988 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2989 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2992 angle = arith::MulFOp::create(
2993 builder, loc, angle,
2994 arith::ConstantOp::create(rewriter, loc,
3000 auto cosAngle = math::CosOp::create(builder, loc, angle);
3001 auto sinAngle = math::SinOp::create(builder, loc, angle);
3003 auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle);
3004 auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle);
3005 auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin);
3007 auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle);
3008 auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle);
3010 auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin);
3015 arith::AddFOp::create(builder, loc, sumReal, realComponent);
3017 arith::AddFOp::create(builder, loc, sumImag, imagComponent);
3019 linalg::YieldOp::create(builder, loc,
ValueRange{outReal, outImag});
3023 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
3024 indexingMaps, iteratorTypes, buildBody);
3036 patterns->
add<GenericResizeConverter>(patterns->
getContext(),
3040 patterns->
add<MaterializeResizeBroadcast>(patterns->
getContext(),
3045 PointwiseConverter<tosa::AddOp>,
3046 PointwiseConverter<tosa::SubOp>,
3047 PointwiseConverter<tosa::MulOp>,
3048 PointwiseConverter<tosa::IntDivOp>,
3049 PointwiseConverter<tosa::NegateOp>,
3050 PointwiseConverter<tosa::PowOp>,
3051 PointwiseConverter<tosa::ReciprocalOp>,
3052 PointwiseConverter<tosa::RsqrtOp>,
3053 PointwiseConverter<tosa::LogOp>,
3054 PointwiseConverter<tosa::ExpOp>,
3055 PointwiseConverter<tosa::AbsOp>,
3056 PointwiseConverter<tosa::SinOp>,
3057 PointwiseConverter<tosa::CosOp>,
3058 PointwiseConverter<tosa::TanhOp>,
3059 PointwiseConverter<tosa::ErfOp>,
3060 PointwiseConverter<tosa::BitwiseAndOp>,
3061 PointwiseConverter<tosa::BitwiseOrOp>,
3062 PointwiseConverter<tosa::BitwiseNotOp>,
3063 PointwiseConverter<tosa::BitwiseXorOp>,
3064 PointwiseConverter<tosa::LogicalAndOp>,
3065 PointwiseConverter<tosa::LogicalNotOp>,
3066 PointwiseConverter<tosa::LogicalOrOp>,
3067 PointwiseConverter<tosa::LogicalXorOp>,
3068 PointwiseConverter<tosa::CastOp>,
3069 PointwiseConverter<tosa::LogicalLeftShiftOp>,
3070 PointwiseConverter<tosa::LogicalRightShiftOp>,
3071 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
3072 PointwiseConverter<tosa::ClzOp>,
3073 PointwiseConverter<tosa::SelectOp>,
3074 PointwiseConverter<tosa::GreaterOp>,
3075 PointwiseConverter<tosa::GreaterEqualOp>,
3076 PointwiseConverter<tosa::EqualOp>,
3077 PointwiseConverter<tosa::MaximumOp>,
3078 PointwiseConverter<tosa::MinimumOp>,
3079 PointwiseConverter<tosa::CeilOp>,
3080 PointwiseConverter<tosa::FloorOp>,
3081 PointwiseConverter<tosa::ClampOp>,
3082 PointwiseConverter<tosa::SigmoidOp>
3086 IdentityNConverter<tosa::IdentityOp>,
3087 ReduceConverter<tosa::ReduceAllOp>,
3088 ReduceConverter<tosa::ReduceAnyOp>,
3089 ReduceConverter<tosa::ReduceMinOp>,
3090 ReduceConverter<tosa::ReduceMaxOp>,
3091 ReduceConverter<tosa::ReduceSumOp>,
3092 ReduceConverter<tosa::ReduceProductOp>,
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
DenseMap< int64_t, Value > IndexPool
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static bool operandsAndResultsRanked(Operation *operation)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
BlockArgument getArgument(unsigned i)
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
BoolAttr getBoolAttr(bool value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
IntegerAttr getI8IntegerAttr(int8_t value)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...