29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/Sequence.h"
31#include "llvm/ADT/SmallVectorExtras.h"
60template <
typename OpTy>
68 auto nanMode = op.getNanMode();
69 if (nanMode == NanPropagationMode::PROPAGATE)
73 Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
74 arith::CmpFPredicate::UNO,
lhs,
lhs);
75 Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
76 arith::CmpFPredicate::UNO,
rhs,
rhs);
78 arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN,
rhs,
result);
79 return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN,
lhs,
85 ConversionPatternRewriter &rewriter) {
91 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
92 return math::AbsFOp::create(rewriter, loc, resultTypes, args);
94 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
95 auto zero = arith::ConstantOp::create(rewriter, loc,
96 rewriter.getZeroAttr(elementTy));
97 auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]);
98 return arith::MaxSIOp::create(rewriter, loc, args[0], neg);
102 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
103 return arith::AddFOp::create(rewriter, loc, resultTypes, args);
105 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
106 return arith::AddIOp::create(rewriter, loc, resultTypes, args);
109 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
110 return arith::SubFOp::create(rewriter, loc, resultTypes, args);
112 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
113 return arith::SubIOp::create(rewriter, loc, resultTypes, args);
116 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
117 return arith::DivSIOp::create(rewriter, loc, resultTypes, args);
120 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
122 arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
123 return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
127 if (isa<tosa::MulOp>(op)) {
128 auto shiftVal = cast<tosa::MulOp>(op).getShift();
130 bool shiftIsConstant =
true;
133 shift = shiftElem.
getValues<IntegerAttr>()[0].getInt();
135 shiftIsConstant =
false;
137 if (isa<FloatType>(elementTy)) {
139 (
void)rewriter.notifyMatchFailure(op,
140 "Cannot have shift value for float");
143 return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
147 if (isa<IntegerType>(elementTy)) {
151 if (shift > 0 || !shiftIsConstant) {
158 a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
160 if (!
b.getType().isInteger(32))
161 b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(),
b);
163 auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
164 auto roundingAttr = RoundingModeAttr::get(rewriter.getContext(),
165 RoundingMode::SINGLE_ROUND);
167 tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
168 b, shiftAmount, roundingAttr);
174 int bWidth =
b.getType().getIntOrFloatBitWidth();
175 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
178 a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a);
180 b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0],
b);
182 return arith::MulIOp::create(rewriter, loc, resultTypes, a,
b);
187 if (isa<tosa::NegateOp>(op)) {
188 auto negate = cast<tosa::NegateOp>(op);
191 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
192 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
193 bool hasInZp = !failed(maybeInZp);
194 bool hasOutZp = !failed(maybeOutZp);
200 if (isa<FloatType>(elementTy))
201 return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
203 if (isa<IntegerType>(elementTy)) {
205 Type intermediateType;
208 int intermediateBitWidth = 64;
210 if (hasInZp && hasOutZp) {
212 const int64_t zpAdd = inZp + outZp;
214 APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
219 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
220 intermediateBitWidth = 16;
221 }
else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
222 intermediateBitWidth = 32;
225 intermediateType = rewriter.getIntegerType(intermediateBitWidth);
226 zpAddValue = arith::ConstantOp::create(
227 rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
229 intermediateType = rewriter.getIntegerType(intermediateBitWidth);
230 Value arg1 = args[1];
231 Value arg2 = args[2];
233 if (arg1.
getType() != intermediateType)
234 arg1 = arith::ExtSIOp::create(rewriter, loc, intermediateType, arg1);
235 if (arg2.
getType() != intermediateType)
236 arg2 = arith::ExtSIOp::create(rewriter, loc, intermediateType, arg2);
238 arith::AddIOp::create(rewriter, loc, intermediateType, arg1, arg2);
244 if (ext.
getType() != intermediateType)
245 ext = arith::ExtSIOp::create(rewriter, loc, intermediateType, ext);
246 auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext);
250 rewriter, loc, intermediateType,
251 APInt::getSignedMinValue(inputBitWidth).getSExtValue());
253 rewriter, loc, intermediateType,
254 APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
258 if (
clamp.getType() == elementTy)
260 return arith::TruncIOp::create(rewriter, loc, elementTy,
clamp);
265 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
266 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
269 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
270 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
273 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
274 auto allOnesAttr = rewriter.getIntegerAttr(
275 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
276 auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr);
277 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes);
281 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
282 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
285 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
286 return arith::ShLIOp::create(rewriter, loc, resultTypes, args);
289 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
290 return arith::ShRUIOp::create(rewriter, loc, resultTypes, args);
293 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
294 auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args);
295 auto round = cast<BoolAttr>(op->
getAttr(
"round")).getValue();
300 Type i1Ty = IntegerType::get(rewriter.getContext(), 1);
301 auto one = arith::ConstantOp::create(rewriter, loc,
302 IntegerAttr::get(elementTy, 1));
303 auto zero = arith::ConstantOp::create(rewriter, loc,
304 IntegerAttr::get(elementTy, 0));
306 arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 0));
308 arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1));
311 auto shiftValueGreaterThanZero = arith::CmpIOp::create(
312 rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero);
316 arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one);
318 arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract)
320 auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted,
323 arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
325 auto shouldRound = arith::SelectOp::create(
326 rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd, i1zero);
328 arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
329 return arith::AddIOp::create(rewriter, loc, resultTypes,
result, extended);
333 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
334 return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]);
338 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
339 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
342 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
343 auto one = arith::ConstantOp::create(rewriter, loc,
344 rewriter.getIntegerAttr(elementTy, 1));
345 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one);
349 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
350 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
353 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
354 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
357 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
358 return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args);
361 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
362 return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args);
365 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
366 return mlir::math::LogOp::create(rewriter, loc, resultTypes, args);
369 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
370 return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args);
373 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
374 return mlir::math::SinOp::create(rewriter, loc, resultTypes, args);
377 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
378 return mlir::math::CosOp::create(rewriter, loc, resultTypes, args);
381 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
382 return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args);
385 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
386 return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args);
389 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
390 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT,
393 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
394 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt,
398 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
399 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE,
402 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
403 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
407 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
408 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ,
411 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
412 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
416 if (isa<tosa::SelectOp>(op)) {
418 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
419 return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]);
423 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
424 auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
426 rewriter, args[0], args[1],
max);
429 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
430 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
434 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
435 auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
437 rewriter, args[0], args[1],
min);
440 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
441 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
445 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
446 return math::CeilOp::create(rewriter, loc, resultTypes, args);
449 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
450 return math::FloorOp::create(rewriter, loc, resultTypes, args);
453 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
454 bool losesInfo =
false;
455 APFloat minApf = cast<FloatAttr>(op->
getAttr(
"min_val")).getValue();
456 APFloat maxApf = cast<FloatAttr>(op->
getAttr(
"max_val")).getValue();
457 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
458 APFloat::rmNearestTiesToEven, &losesInfo);
459 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
460 APFloat::rmNearestTiesToEven, &losesInfo);
461 auto min = arith::ConstantOp::create(
462 rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
463 auto max = arith::ConstantOp::create(
464 rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
467 auto clampOp = llvm::cast<tosa::ClampOp>(op);
468 const auto nanMode = clampOp.getNanMode();
471 if (!isa<FloatType>(elementTy))
476 if (nanMode == NanPropagationMode::PROPAGATE)
490 Value isNaN = arith::CmpFOp::create(
491 rewriter, op->
getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
494 return arith::SelectOp::create(rewriter, op->
getLoc(), isNaN,
min,
result);
497 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
498 auto intTy = cast<IntegerType>(elementTy);
500 cast<IntegerAttr>(op->
getAttr(
"min_val")).getValue().getSExtValue();
502 cast<IntegerAttr>(op->
getAttr(
"max_val")).getValue().getSExtValue();
504 int64_t minRepresentable = std::numeric_limits<int64_t>::min();
505 int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
506 if (intTy.isUnsignedInteger()) {
507 minRepresentable = 0;
508 if (intTy.getIntOrFloatBitWidth() <= 63) {
510 (
int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
513 }
else if (intTy.getIntOrFloatBitWidth() <= 64) {
515 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
517 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
522 min = std::max(
min, minRepresentable);
523 max = std::max(
max, minRepresentable);
524 min = std::min(
min, maxRepresentable);
525 max = std::min(
max, maxRepresentable);
528 intTy.getIntOrFloatBitWidth());
530 intTy.getIntOrFloatBitWidth());
532 intTy.isUnsignedInteger());
536 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
538 arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
539 auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
540 auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
541 auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
542 return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
546 if (isa<tosa::CastOp>(op)) {
547 Type srcTy = elementTy;
548 Type dstTy = resultTypes.front();
550 (
void)rewriter.notifyMatchFailure(op,
"unsupported type");
560 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
561 return arith::ExtFOp::create(rewriter, loc, resultTypes, args,
564 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
565 return arith::TruncFOp::create(rewriter, loc, resultTypes, args,
569 if (srcTy.
isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
570 return arith::UIToFPOp::create(rewriter, loc, resultTypes, args,
573 if (srcTy.
isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
574 return arith::ExtUIOp::create(rewriter, loc, resultTypes, args,
580 auto unrealizedCast =
581 UnrealizedConversionCastOp::create(
585 return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
590 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
591 return arith::SIToFPOp::create(rewriter, loc, resultTypes, args,
595 if (isa<FloatType>(srcTy) && dstTy.
isInteger(1)) {
596 Value zero = arith::ConstantOp::create(rewriter, loc,
597 rewriter.getFloatAttr(srcTy, 0.0));
598 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE,
602 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
603 auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]);
605 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
609 APFloat::semanticsMaxExponent(fltSemantics)) {
612 auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded);
613 auto posInf = arith::ConstantOp::create(
616 APFloat::getInf(fltSemantics)));
617 auto negInf = arith::ConstantOp::create(
619 rewriter.getFloatAttr(
621 APFloat::getInf(fltSemantics,
true)));
622 auto overflow = arith::CmpFOp::create(
623 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf);
624 auto underflow = arith::CmpFOp::create(
625 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf);
626 auto intMin = arith::ConstantOp::create(
628 rewriter.getIntegerAttr(
631 auto intMax = arith::ConstantOp::create(
633 rewriter.getIntegerAttr(
637 arith::SelectOp::create(rewriter, loc, overflow, intMax, conv);
638 return arith::SelectOp::create(rewriter, loc, underflow, intMin,
642 auto intMinFP = arith::ConstantOp::create(
644 rewriter.getFloatAttr(
650 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
656 auto intMaxFP = arith::ConstantOp::create(
658 rewriter.getFloatAttr(
665 return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped);
672 auto intMaxPlusOneFP = arith::ConstantOp::create(
674 rewriter.getFloatAttr(
681 auto intMax = arith::ConstantOp::create(
683 rewriter.getIntegerAttr(
687 arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP);
689 arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP);
690 auto overflow = arith::CmpFOp::create(
691 rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
692 return arith::SelectOp::create(rewriter, loc, overflow, intMax,
698 if (isa<IntegerType>(srcTy) && dstTy.
isInteger(1)) {
701 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne,
705 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
706 return arith::ExtSIOp::create(rewriter, loc, resultTypes, args,
709 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
710 return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]);
714 (
void)rewriter.notifyMatchFailure(
715 op,
"unhandled op for linalg body calculation for elementwise op");
736 return tensor::DimOp::create(rewriter, loc,
tensor, indexValue).getResult();
742 auto shapedType = dyn_cast<ShapedType>(
tensor.getType());
743 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
744 assert(
index >= 0 &&
index < shapedType.getRank() &&
"index out of bounds");
745 if (shapedType.isDynamicDim(
index))
751 auto isRanked = [](
Value value) {
752 return isa<RankedTensorType>(value.getType());
754 return llvm::all_of(operation->
getOperands(), isRanked) &&
755 llvm::all_of(operation->
getResults(), isRanked);
768static std::pair<OpFoldResult, Value>
774 for (
auto operand : operands) {
775 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
776 if (ShapedType::isStatic(size) && size > 1)
781 auto operandsWithDynamicDim =
782 llvm::filter_to_vector(operands, [&](
Value operand) {
783 return cast<RankedTensorType>(operand.
getType()).isDynamicDim(dim);
787 if (operandsWithDynamicDim.empty())
794 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
795 if (operandsWithDynamicDim.size() == 1)
796 return {targetSize, operandsWithDynamicDim[0]};
799 for (
size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
801 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
802 targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize);
804 return {targetSize,
nullptr};
812 assert(!operands.empty());
813 auto rank = cast<RankedTensorType>(operands.front().
getType()).getRank();
816 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
817 auto [targetSize, masterOperand] =
819 targetShape.push_back(targetSize);
820 masterOperands.push_back(masterOperand);
822 return {targetShape, masterOperands};
828 Value masterOperand) {
830 auto rankedTensorType = cast<RankedTensorType>(operand.
getType());
831 if (!rankedTensorType.isDynamicDim(dim))
838 if (operand == masterOperand)
842 auto rank = rankedTensorType.getRank();
844 for (
auto index : llvm::seq<int64_t>(0, rank)) {
847 affineExprs.push_back(affineExpr);
849 auto broadcastAffineMap =
855 auto one =
createIndex(rewriter, loc, indexPool, 1);
856 auto runtimeSize =
getTensorDim(rewriter, loc, indexPool, operand, dim);
857 auto broadcastNecessary = arith::CmpIOp::create(
858 rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one);
868 for (
auto index : llvm::seq<int64_t>(0, rank)) {
869 auto size =
index == dim ? targetSize
872 outputTensorShape.push_back(size);
874 Value outputTensor = tensor::EmptyOp::create(
875 opBuilder, loc, outputTensorShape, rankedTensorType.getElementType());
879 linalg::GenericOp::create(
880 opBuilder, loc, outputTensor.
getType(), operand, outputTensor,
884 linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
889 auto castResultTensor = rewriter.
createOrFold<tensor::CastOp>(
890 loc, operand.
getType(), resultTensor);
893 scf::YieldOp::create(opBuilder, loc, castResultTensor);
898 scf::YieldOp::create(opBuilder, loc, operand);
902 auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary,
903 emitThenRegion, emitElseRegion);
904 return ifOp.getResult(0);
911 int64_t rank = cast<RankedTensorType>(operand.
getType()).getRank();
912 assert((
int64_t)targetShape.size() == rank);
913 assert((
int64_t)masterOperands.size() == rank);
914 for (
auto index : llvm::seq<int64_t>(0, rank))
927 if (operands.size() == 1)
931 bool hasDynamic =
false;
932 for (
auto op : operands) {
933 const auto tType = dyn_cast<RankedTensorType>(op.getType());
934 if (tType && !tType.hasStaticShape()) {
943 return llvm::map_to_vector(operands, [&](
Value operand) {
945 targetShape, masterOperands);
955 auto resultType = cast_or_null<RankedTensorType>(
958 return rewriter.notifyMatchFailure(operation,
"failed to convert type");
960 Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape,
961 resultType.getElementType());
966 auto rank = resultType.getRank();
967 auto affineMaps = llvm::map_to_vector(operands, [&](
Value operand) {
968 auto shape = cast<ShapedType>(operand.
getType()).getShape();
970 for (
auto it : llvm::enumerate(
shape)) {
974 bool requiresBroadcast =
975 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
976 auto affineExpr = requiresBroadcast
977 ? rewriter.getAffineConstantExpr(0)
978 : rewriter.getAffineDimExpr(it.index());
979 affineExprs.push_back(affineExpr);
981 return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
983 affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
986 bool encounteredError =
false;
987 auto linalgOp = linalg::GenericOp::create(
988 rewriter, loc, outputTensor.
getType(), operands, outputTensor, affineMaps,
993 {resultType.getElementType()}, rewriter);
995 encounteredError =
true;
998 linalg::YieldOp::create(opBuilder, loc, opResult);
1000 if (encounteredError)
1001 return rewriter.notifyMatchFailure(
1002 operation,
"unable to create linalg.generic body for elementwise op");
1005 auto castResult = rewriter.createOrFold<tensor::CastOp>(
1006 loc, resultType, linalgOp->getResult(0));
1007 rewriter.replaceOp(operation, castResult);
1014 if (isa<tosa::MulOp>(operation)) {
1018 return operands.take_front(2);
1020 return operands.take_front(3);
1022 if (
auto negate = dyn_cast<tosa::NegateOp>(operation)) {
1023 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
1024 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
1025 if (failed(maybeOutZp) && failed(maybeInZp))
1028 return operands.take_front(1);
1035 ConversionPatternRewriter &rewriter,
1039 assert(operation->
getNumResults() == 1 &&
"elementwise op expects 1 result");
1041 "elementwise op expects at least 1 operand");
1043 return rewriter.notifyMatchFailure(operation,
1044 "Unranked tensors not supported");
1048 auto loc = operation->
getLoc();
1050 auto [targetShape, masterOperands] =
1052 auto broadcastOperands =
1054 targetShape, masterOperands);
1056 targetShape, converter);
1063 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1066 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1069 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1072 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1075 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1077 elementTy, APFloat::getLargest(
1078 cast<FloatType>(elementTy).getFloatSemantics(),
false));
1080 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1084 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1086 elementTy, APFloat::getLargest(
1087 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1089 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1093 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1096 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1099 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1101 elementTy, APFloat::getLargest(
1102 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1104 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1118 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1119 return arith::AddFOp::create(rewriter, loc, args);
1122 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1123 return arith::AddIOp::create(rewriter, loc, args);
1126 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1127 return arith::MulFOp::create(rewriter, loc, args);
1130 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1131 return arith::MulIOp::create(rewriter, loc, args);
1134 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1135 return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
1138 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1139 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
1142 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1143 return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
1146 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1147 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
1150 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1151 return arith::AndIOp::create(rewriter, loc, args);
1153 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1154 return arith::OrIOp::create(rewriter, loc, args);
1162template <
typename OpTy>
1165 auto loc = op->getLoc();
1166 auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1167 auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1168 if (!inputTy || !resultTy)
1171 auto elementTy = resultTy.getElementType();
1172 Value input = op->getOperand(0);
1175 bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
1176 isa<FloatType>(elementTy) &&
1177 cast<FloatType>(elementTy).isBF16();
1182 for (
unsigned i = 0; i < inputTy.getRank(); i++) {
1184 reduceShape.push_back(inputTy.getDimSize(i));
1185 if (inputTy.isDynamicDim(i))
1186 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1191 inputs.push_back(input);
1195 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1201 op,
"No initial value found for reduction operation");
1203 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
1205 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValue},
1208 outputs.push_back(filledTensor);
1210 bool isNanIgnoreMode =
false;
1211 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1212 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1214 if (isa<FloatType>(elementTy) &&
1215 op.getNanMode() == NanPropagationMode::IGNORE) {
1216 isNanIgnoreMode =
true;
1222 auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
1223 auto emptyBoolTensor =
1224 tensor::EmptyOp::create(rewriter, loc, reduceShape,
1225 trueValue.getType(), dynDims)
1227 auto allResultsNaNTensor =
1228 linalg::FillOp::create(rewriter, loc,
ValueRange{trueValue},
1240 inputs.push_back(input);
1241 outputs.push_back(allResultsNaNTensor);
1245 bool didEncounterError =
false;
1246 linalg::LinalgOp linalgOp = linalg::ReduceOp::create(
1247 rewriter, loc, inputs, outputs, axis,
1249 std::array<Value, 2> binaryArgs{
1250 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1253 if (binaryArgs[0].
getType() != accTy)
1254 binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
1260 didEncounterError =
true;
1263 if (isNanIgnoreMode) {
1264 auto inputValue = blockArgs[0];
1265 auto initialValue = blockArgs[2];
1266 auto oldAllResultsNanFlagValue = blockArgs[3];
1269 Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(),
1270 arith::CmpFPredicate::UNO,
1271 inputValue, inputValue);
1273 auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(),
1274 isNaN, initialValue,
result);
1277 auto newAllResultsNanFlagValue = arith::AndIOp::create(
1278 nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1279 resultsToYield.push_back(selectOp);
1280 resultsToYield.push_back(newAllResultsNanFlagValue);
1282 resultsToYield.push_back(
result);
1284 linalg::YieldOp::create(nestedBuilder, loc, resultsToYield);
1287 if (!didEncounterError)
1289 op,
"unable to create linalg.generic body for reduce op");
1291 if (isNanIgnoreMode) {
1300 APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(),
false));
1301 auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
1302 auto emptyNanTensor =
1303 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1305 auto nanFilledTensor =
1306 linalg::FillOp::create(rewriter, loc,
ValueRange{nanValue},
1312 auto finalEmptyTensor =
1313 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1319 ins.push_back(linalgOp->getOpResult(1));
1320 ins.push_back(nanFilledTensor);
1321 ins.push_back(linalgOp->getResult(0));
1322 outs.push_back(finalEmptyTensor);
1324 linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs);
1325 linalgOp = linalgSelect;
1329 Value reducedRes = linalgOp->getResult(0);
1332 tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
1335 const unsigned reducedRank =
1336 cast<ShapedType>(reducedRes.
getType()).getRank();
1339 linalg::GenericOp::create(
1345 Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc,
1346 elementTy, args[0]);
1347 linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf);
1353 uint64_t expandInputRank = cast<ShapedType>(reducedRes.
getType()).getRank();
1354 reassociationMap.resize(expandInputRank);
1356 for (uint64_t i = 0; i < expandInputRank; i++) {
1357 int32_t dimToPush = i > axis ? i + 1 : i;
1361 if (expandInputRank != 0) {
1362 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1363 reassociationMap[expandedDim].push_back(
1378template <
typename SrcOp>
1379class PointwiseConverter :
public OpConversionPattern<SrcOp> {
1381 using OpConversionPattern<SrcOp>::OpConversionPattern;
1382 using typename OpConversionPattern<SrcOp>::OpAdaptor;
1385 matchAndRewrite(SrcOp op, OpAdaptor operands,
1386 ConversionPatternRewriter &rewriter)
const final {
1388 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1398 auto inputType = cast<RankedTensorType>(input.
getType());
1399 auto elemType = inputType.getElementType();
1400 auto collapsedType = RankedTensorType::get({}, elemType);
1402 return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input,
1409 output.reserve(input.size());
1411 for (
auto v : llvm::map_range(
1412 input, [](int32_t val) {
return static_cast<int8_t
>(val); })) {
1413 output.push_back(v);
1425static void setupLinalgGenericOpInputAndIndexingMap(
1428 bool isConstant, tosa::RescaleOp op,
Value &constant,
int64_t &arg,
1429 bool isShift =
false) {
1431 auto loc = op.getLoc();
1432 auto inputTy = cast<ShapedType>(op.getInput().getType());
1433 unsigned rank = inputTy.getRank();
1439 if (values.size() == 1) {
1440 IntegerAttr intAttr = isShift
1443 constant = arith::ConstantOp::create(rewriter, loc, intAttr);
1447 auto tensorType = RankedTensorType::get(
1448 {
static_cast<int64_t>(values.size())}, elementType);
1454 genericInputs.push_back(
1455 arith::ConstantOp::create(rewriter, loc, EltAttr));
1463 auto operand = isShift ? op.getShift() : op.getMultiplier();
1464 auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
1465 if (tensorType && tensorType.hasStaticShape() &&
1466 tensorType.getShape()[0] == 1) {
1471 genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
1472 indexingMaps.push_back(broadcastMap);
1474 genericInputs.push_back(operand);
1480 arg = indexingMaps.size() - 1;
1485 FailureOr<int64_t> maybeZp,
Location loc,
1487 bool isOutputZp =
false) {
1490 const uint32_t attrBitwidth =
1491 isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
1498 result = blockArgs[zpArg];
1499 auto zpTy =
result.getType();
1500 if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
1503 if (zpTy.isUnsignedInteger()) {
1505 UnrealizedConversionCastOp::create(
1510 if (zpTy.isUnsignedInteger()) {
1511 return arith::ExtUIOp::create(builder, loc, extendType,
result);
1513 return arith::ExtSIOp::create(builder, loc, extendType,
result);
1517 return arith::ConstantOp::create(builder, loc,
1518 IntegerAttr::get(extendType, *maybeZp));
1525 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1527 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1528 PatternRewriter &rewriter)
const final {
1529 auto loc = op.getLoc();
1530 auto input = op.getInput();
1531 auto inputTy = cast<ShapedType>(op.getInput().getType());
1532 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1533 unsigned rank = inputTy.getRank();
1536 if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND)
1538 op,
"tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1539 "currently supported");
1540 if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32())
1542 op,
"tosa.rescale requires scale32 for double_round to be true");
1544 if (!isa<IntegerType>(inputTy.getElementType()))
1547 SmallVector<Value> dynDims;
1548 for (
int i = 0; i < outputTy.getRank(); i++) {
1549 if (outputTy.isDynamicDim(i)) {
1550 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1554 DenseElementsAttr shiftElems;
1555 bool isShiftConstant =
false;
1557 isShiftConstant =
true;
1559 DenseElementsAttr multiplierElems;
1560 bool isMultiplierConstant =
false;
1562 isMultiplierConstant =
true;
1564 llvm::SmallVector<int32_t> shiftValues;
1565 llvm::SmallVector<int32_t> multiplierValues;
1568 if (isMultiplierConstant && isShiftConstant) {
1570 shiftValues = llvm::map_to_vector(
1571 shiftElems.
getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
1572 return static_cast<int32_t>(attr.getInt());
1575 llvm::map_to_vector(multiplierElems.
getValues<IntegerAttr>(),
1576 [](IntegerAttr attr) -> int32_t {
1577 return static_cast<int32_t>(attr.getInt());
1581 for (
int i = 0, s = multiplierValues.size(); i < s; i++) {
1582 if (shiftValues[i] > 63) {
1584 multiplierValues[i] = 0;
1589 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1590 llvm::any_of(shiftValues, [](int32_t v) {
return v > 31; });
1592 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
1594 RoundingMode roundingMode =
1595 doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
1597 SmallVector<AffineMap> indexingMaps = {
1599 SmallVector<Value, 4> genericInputs = {input};
1603 Value multiplierConstant;
1604 int64_t multiplierArg = 0;
1605 setupLinalgGenericOpInputAndIndexingMap(
1606 rewriter, multiplierValues, genericInputs, indexingMaps,
1607 isMultiplierConstant, op, multiplierConstant, multiplierArg);
1611 Value shiftConstant;
1612 int64_t shiftArg = 0;
1613 setupLinalgGenericOpInputAndIndexingMap(
1614 rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1615 shiftConstant, shiftArg,
true);
1620 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1621 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1631 genericInputs.push_back(
1632 collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
1633 indexingMaps.push_back(broadcastMap);
1634 iZpArg = indexingMaps.size() - 1;
1638 genericInputs.push_back(
1639 collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
1640 indexingMaps.push_back(broadcastMap);
1641 oZpArg = indexingMaps.size() - 1;
1648 Value emptyTensor = tensor::EmptyOp::create(
1649 rewriter, loc, outputTy.getShape(), outputTy.getElementType(),
1650 ArrayRef<Value>({dynDims}));
1652 auto linalgOp = linalg::GenericOp::create(
1653 rewriter, loc, outputTy, genericInputs,
ValueRange{emptyTensor},
1655 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1657 Value value = blockArgs[0];
1658 Type valueTy = value.
getType();
1660 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1661 auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
1662 nestedLoc, blockArgs, iZpArg);
1664 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1665 auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
1666 nestedLoc, blockArgs, oZpArg,
true);
1668 IntegerType outIntType =
1669 cast<IntegerType>(blockArgs.back().
getType());
1670 unsigned outBitWidth = outIntType.getWidth();
1671 assert(outBitWidth <= 32 &&
"Unexpected output zeropoint bitwidth");
1673 Value multiplier = multiplierConstant ? multiplierConstant
1674 : blockArgs[multiplierArg];
1675 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1678 value = UnrealizedConversionCastOp::create(
1679 nestedBuilder, nestedLoc,
1680 nestedBuilder.getIntegerType(
1686 if (op.getInputUnsigned()) {
1687 value = arith::ExtUIOp::create(nestedBuilder, nestedLoc,
1688 nestedBuilder.getI32Type(), value);
1690 value = arith::ExtSIOp::create(nestedBuilder, nestedLoc,
1691 nestedBuilder.getI32Type(), value);
1696 arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp);
1698 value = tosa::ApplyScaleOp::create(nestedBuilder, loc,
1699 nestedBuilder.getI32Type(), value,
1700 multiplier, shift, roundingMode);
1704 arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp);
1707 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1708 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1711 if (op.getOutputUnsigned()) {
1713 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1716 auto intMinVal = arith::ConstantOp::create(
1717 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin));
1718 auto intMaxVal = arith::ConstantOp::create(
1719 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax));
1722 nestedBuilder,
false);
1724 if (outIntType.getWidth() < 32) {
1725 value = arith::TruncIOp::create(
1726 nestedBuilder, nestedLoc,
1730 if (outIntType.isUnsignedInteger()) {
1731 value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
1735 linalg::YieldOp::create(nestedBuilder, loc, value);
1738 rewriter.
replaceOp(op, linalgOp->getResults());
1748 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1750 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1751 PatternRewriter &rewriter)
const final {
1752 Location loc = op.getLoc();
1753 ImplicitLocOpBuilder builder(loc, rewriter);
1754 auto input = op.getInput();
1755 auto inputTy = cast<RankedTensorType>(input.getType());
1756 auto resultTy = cast<RankedTensorType>(op.getType());
1757 const bool isBilinear = op.getMode() == ResizeMode::BILINEAR;
1759 auto inputH = inputTy.getDimSize(1);
1760 auto inputW = inputTy.getDimSize(2);
1761 auto outputH = resultTy.getDimSize(1);
1762 auto outputW = resultTy.getDimSize(2);
1764 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1766 op,
"tosa.resize is not a pure 1x1->1x1 image operation");
1768 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1769 op.getMode() != ResizeMode::BILINEAR)
1771 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1773 if (inputTy == resultTy) {
1778 SmallVector<int64_t> scale;
1784 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1791 RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1792 inputTy.getElementType());
1793 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input,
1797 llvm::SmallVector<Value> outputDynSize;
1798 if (inputTy.isDynamicDim(0))
1799 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1800 if (inputTy.isDynamicDim(3))
1801 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1804 auto genericTy = collapseTy.clone(resultTy.getElementType());
1806 tensor::EmptyOp::create(builder, genericTy.getShape(),
1807 resultTy.getElementType(), outputDynSize);
1809 SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1810 utils::IteratorType::parallel);
1812 auto generic = linalg::GenericOp::create(
1814 ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1815 [=](OpBuilder &
b, Location loc,
ValueRange args) {
1816 Value value = args[0];
1818 if (inputTy.getElementType() != resultTy.getElementType()) {
1819 value = arith::ExtSIOp::create(
b, loc, resultTy.getElementType(),
1822 if (isBilinear && scale[0] != 0) {
1823 Value scaleY = arith::ConstantOp::create(
1824 b, loc,
b.getI32IntegerAttr(scale[0]));
1825 value = arith::MulIOp::create(
b, loc, value, scaleY);
1828 if (isBilinear && scale[2] != 0) {
1829 Value scaleX = arith::ConstantOp::create(
1830 b, loc,
b.getI32IntegerAttr(scale[2]));
1831 value = arith::MulIOp::create(
b, loc, value, scaleX);
1835 linalg::YieldOp::create(
b, loc, value);
1839 op, resultTy,
generic.getResults()[0], reassociationMap);
1851 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1855 auto input = op.getInput();
1856 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1857 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1859 if (!inputTy || !resultTy)
1861 "requires ranked input/output types");
1863 auto batch = inputTy.getDimSize(0);
1864 auto channels = inputTy.getDimSize(3);
1865 auto inputH = inputTy.getDimSize(1);
1866 auto inputW = inputTy.getDimSize(2);
1867 auto outputH = resultTy.getDimSize(1);
1868 auto outputW = resultTy.getDimSize(2);
1870 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1872 op,
"tosa.resize has no broadcasting behavior");
1877 resizeShape.push_back(batch);
1878 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1879 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1880 resizeShape.push_back(channels);
1882 auto resizeTy = resultTy.clone(resizeShape);
1884 tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(),
1885 op.getOffset(), op.getBorder(), op.getMode());
1892 reassociationMap.push_back({});
1895 reassociationMap.push_back({});
1900 collapseShape.push_back(outputH);
1902 collapseShape.push_back(outputW);
1903 collapseShape.push_back(channels);
1905 auto collapseTy = resultTy.clone(collapseShape);
1906 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy,
1907 resize, reassociationMap);
1911 if (inputTy.isDynamicDim(0))
1912 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1913 if (inputTy.isDynamicDim(3))
1914 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1917 utils::IteratorType::parallel);
1918 Value empty = tensor::EmptyOp::create(
1919 builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1936 Value value = args[0];
1937 linalg::YieldOp::create(
b, loc, value);
1946 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1948 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1949 PatternRewriter &rewriter)
const final {
1950 Location loc = op.getLoc();
1951 ImplicitLocOpBuilder
b(loc, rewriter);
1952 auto input = op.getInput();
1953 auto inputTy = cast<ShapedType>(input.getType());
1954 auto resultTy = cast<ShapedType>(op.getType());
1955 auto resultETy = resultTy.getElementType();
1957 bool floatingPointMode = isa<FloatType>(resultETy);
1958 auto floatTy = resultETy;
1960 auto imageH = inputTy.getShape()[1];
1961 auto imageW = inputTy.getShape()[2];
1963 auto dynamicDimsOr =
1965 if (!dynamicDimsOr.has_value())
1967 op,
"unable to get dynamic dimensions of tosa.resize");
1969 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1970 op.getMode() != ResizeMode::BILINEAR)
1972 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1974 SmallVector<AffineMap, 2> affineMaps = {
1976 auto emptyTensor = tensor::EmptyOp::create(
b, resultTy.getShape(),
1977 resultETy, *dynamicDimsOr);
1978 auto genericOp = linalg::GenericOp::create(
1981 Value resize = genericOp.getResult(0);
1984 OpBuilder::InsertionGuard regionGuard(
b);
1985 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1987 Value batch = linalg::IndexOp::create(
b, 0);
1988 Value y = linalg::IndexOp::create(
b, 1);
1989 Value x = linalg::IndexOp::create(
b, 2);
1990 Value channel = linalg::IndexOp::create(
b, 3);
1993 arith::ConstantOp::create(
b,
b.getZeroAttr(
b.getI32Type()));
1994 Value zeroFp = arith::ConstantOp::create(
b,
b.getZeroAttr(floatTy));
1996 arith::ConstantOp::create(
b,
b.getI32IntegerAttr(imageH - 1));
1998 arith::ConstantOp::create(
b,
b.getI32IntegerAttr(imageW - 1));
2000 Value inY = arith::IndexCastOp::create(
b,
b.getI32Type(), y);
2001 Value inX = arith::IndexCastOp::create(
b,
b.getI32Type(), x);
2003 SmallVector<int64_t> scale, offset, border;
2008 op,
"tosa.resize scale/offset/border should have compile time "
2009 "constant values.");
2012 Value yScaleN, yScaleD, xScaleN, xScaleD;
2013 yScaleN = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[0]));
2014 yScaleD = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[1]));
2015 xScaleN = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[2]));
2016 xScaleD = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[3]));
2018 Value yOffset, xOffset, yBorder, xBorder;
2019 yOffset = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(offset[0]));
2020 xOffset = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(offset[1]));
2021 yBorder = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(border[0]));
2022 xBorder = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(border[1]));
2025 auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
2026 Value scaleN, Value scaleD, Value offset,
2027 int size, ImplicitLocOpBuilder &
b) {
2035 Value val = arith::MulIOp::create(
b, in, scaleD);
2036 val = arith::AddIOp::create(
b, val, offset);
2037 index = arith::FloorDivSIOp::create(
b, val, scaleN);
2041 Value r = arith::RemSIOp::create(
b, val, scaleN);
2042 Value rFp = arith::SIToFPOp::create(
b, floatTy, r);
2043 Value scaleNfp = arith::UIToFPOp::create(
b, floatTy, scaleN);
2044 delta = arith::DivFOp::create(
b, rFp, scaleNfp);
2048 auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
2049 Value scaleN, Value scaleD, Value offset,
2050 int size, ImplicitLocOpBuilder &
b) {
2059 Value val = arith::MulIOp::create(
b, in, scaleD);
2060 val = arith::AddIOp::create(
b, val, offset);
2061 index = arith::DivSIOp::create(
b, val, scaleN);
2062 delta = arith::MulIOp::create(
b, index, scaleN);
2063 delta = arith::SubIOp::create(
b, val, delta);
2066 Value ix, iy, dx, dy;
2067 if (floatingPointMode) {
2068 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH,
b);
2069 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW,
b);
2071 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH,
b);
2072 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW,
b);
2075 if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) {
2076 auto one = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(1));
2078 auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
2079 Value
max,
int size,
2080 ImplicitLocOpBuilder &
b) -> Value {
2086 if (floatingPointMode) {
2088 arith::ConstantOp::create(
b,
b.getFloatAttr(floatTy, 0.5f));
2089 pred = arith::CmpFOp::create(
b, arith::CmpFPredicate::OGE, dval, h);
2091 Value dvalDouble = arith::ShLIOp::create(
b, dval, one);
2092 pred = arith::CmpIOp::create(
b, arith::CmpIPredicate::sge,
2096 auto offset = arith::SelectOp::create(
b, pred, one, zeroI32);
2097 val = arith::AddIOp::create(
b, val, offset);
2099 return arith::IndexCastOp::create(
b,
b.getIndexType(), val);
2102 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH,
b);
2103 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW,
b);
2105 Value
result = tensor::ExtractOp::create(
2108 linalg::YieldOp::create(
b,
result);
2111 assert(op.getMode() == ResizeMode::BILINEAR);
2113 auto oneVal = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(1));
2115 auto getClampedIdxs = [&](Value &val0, Value &val1,
int size, Value in,
2116 Value
max, ImplicitLocOpBuilder &
b) {
2118 val1 = arith::AddIOp::create(
b, val0, oneVal);
2123 val0 = arith::IndexCastOp::create(
b,
b.getIndexType(), val0);
2124 val1 = arith::IndexCastOp::create(
b,
b.getIndexType(), val1);
2132 Value x0, x1, y0, y1;
2133 getClampedIdxs(y0, y1, imageH, iy, hMax,
b);
2134 getClampedIdxs(x0, x1, imageW, ix, wMax,
b);
2136 Value y0x0 = tensor::ExtractOp::create(
2138 Value y0x1 = tensor::ExtractOp::create(
2140 Value y1x0 = tensor::ExtractOp::create(
2142 Value y1x1 = tensor::ExtractOp::create(
2145 if (floatingPointMode) {
2147 arith::ConstantOp::create(
b,
b.getFloatAttr(floatTy, 1.0f));
2148 auto interpolate = [&](Value val0, Value val1, Value delta,
2150 ImplicitLocOpBuilder &
b) -> Value {
2153 Value oneMinusDelta = arith::SubFOp::create(
b, oneVal, delta);
2154 Value mul0 = arith::MulFOp::create(
b, val0, oneMinusDelta);
2155 Value mul1 = arith::MulFOp::create(
b, val1, delta);
2156 return arith::AddFOp::create(
b, mul0, mul1);
2162 Value topAcc = interpolate(y0x0, y0x1, dx, imageW,
b);
2167 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW,
b);
2171 Value
result = interpolate(topAcc, bottomAcc, dy, imageH,
b);
2172 linalg::YieldOp::create(
b,
result);
2175 y0x0 = arith::ExtSIOp::create(
b, resultETy, y0x0);
2176 y0x1 = arith::ExtSIOp::create(
b, resultETy, y0x1);
2177 y1x0 = arith::ExtSIOp::create(
b, resultETy, y1x0);
2178 y1x1 = arith::ExtSIOp::create(
b, resultETy, y1x1);
2181 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
2182 dx = arith::ExtSIOp::create(
b, resultETy, dx);
2183 dy = arith::ExtSIOp::create(
b, resultETy, dy);
2186 Value yScaleNExt = yScaleN;
2187 Value xScaleNExt = xScaleN;
2189 const int64_t scaleBitwidth =
2191 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2192 yScaleNExt = arith::ExtSIOp::create(
b, resultETy, yScaleN);
2193 xScaleNExt = arith::ExtSIOp::create(
b, resultETy, xScaleN);
2196 auto interpolate = [](Value val0, Value val1, Value weight1,
2197 Value scale,
int inputSize,
2198 ImplicitLocOpBuilder &
b) -> Value {
2200 return arith::MulIOp::create(
b, val0, scale);
2201 Value weight0 = arith::SubIOp::create(
b, scale, weight1);
2202 Value mul0 = arith::MulIOp::create(
b, val0, weight0);
2203 Value mul1 = arith::MulIOp::create(
b, val1, weight1);
2204 return arith::AddIOp::create(
b, mul0, mul1);
2207 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW,
b);
2208 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW,
b);
2210 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH,
b);
2211 linalg::YieldOp::create(
b,
result);
2224template <
typename SrcOp>
2227 using OpRewritePattern<SrcOp>::OpRewritePattern;
2229 LogicalResult matchAndRewrite(SrcOp op,
2230 PatternRewriter &rewriter)
const final {
2231 rewriter.
replaceOp(op, op.getOperation()->getOperands());
2236template <
typename SrcOp>
2239 using OpRewritePattern<SrcOp>::OpRewritePattern;
2241 LogicalResult matchAndRewrite(SrcOp reduceOp,
2242 PatternRewriter &rewriter)
const final {
2249 using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
2251 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2252 PatternRewriter &rewriter)
const final {
2253 auto loc = op.getLoc();
2254 Value input = op.getInput1();
2255 auto inputTy = cast<ShapedType>(input.
getType());
2256 auto resultTy = cast<ShapedType>(op.getType());
2257 auto axis = op.getAxis();
2259 SmallVector<Value> dynDims;
2260 for (
int i = 0; i < inputTy.getRank(); i++) {
2261 if (inputTy.isDynamicDim(i)) {
2262 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2266 Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
2269 auto emptyTensor = tensor::EmptyOp::create(
2270 rewriter, loc, inputTy.getShape(),
2271 inputTy.getElementType(), ArrayRef<Value>({dynDims}))
2273 SmallVector<AffineMap, 2> affineMaps = {
2277 op, resultTy, ArrayRef<Value>({}),
ValueRange{emptyTensor}, affineMaps,
2279 [&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
2280 llvm::SmallVector<Value>
indices;
2281 for (
unsigned int i = 0; i < inputTy.getRank(); i++) {
2283 linalg::IndexOp::create(rewriter, nestedLoc, i).getResult();
2287 arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one);
2288 index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne,
2295 auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc,
2297 linalg::YieldOp::create(nestedBuilder, op.getLoc(),
2298 extract.getResult());
2308struct TileConverter :
public OpConversionPattern<tosa::TileOp> {
2309 using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
2312 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2313 ConversionPatternRewriter &rewriter)
const override {
2314 auto loc = op.getLoc();
2315 auto input = op.getInput1();
2316 auto inputTy = cast<ShapedType>(input.
getType());
2317 auto inputShape = inputTy.getShape();
2318 auto resultTy = cast<ShapedType>(op.getType());
2319 auto elementTy = inputTy.getElementType();
2320 int64_t rank = inputTy.getRank();
2322 SmallVector<int64_t> multiples;
2323 if (
failed(op.getConstantMultiples(multiples)))
2327 SmallVector<int64_t, 2> genericShape;
2328 for (
int i = 0; i < rank; i++) {
2329 int64_t dim = multiples[i];
2330 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2331 genericShape.push_back(inputShape[i]);
2334 SmallVector<Value> dynDims;
2335 for (
int i = 0; i < inputTy.getRank(); i++) {
2336 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2337 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2341 auto emptyTensor = tensor::EmptyOp::create(
2342 rewriter, op.getLoc(), genericShape, elementTy, dynDims);
2345 SmallVector<AffineExpr, 4> dimExprs;
2346 dimExprs.reserve(rank);
2347 for (
unsigned i = 0; i < rank; ++i)
2348 dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
2350 auto readAffineMap =
2352 rewriter.getContext());
2354 SmallVector<AffineMap, 2> affineMaps = {
2355 readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
2357 auto genericOp = linalg::GenericOp::create(
2358 rewriter, loc, RankedTensorType::get(genericShape, elementTy), input,
2361 [&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
2362 linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin());
2367 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2368 op, resultTy, genericOp.getResult(0), shapeValue);
2388 using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
2390 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2391 PatternRewriter &rewriter)
const final {
2392 auto loc = argmaxOp.getLoc();
2393 Value input = argmaxOp.getInput();
2394 auto inputTy = cast<ShapedType>(input.
getType());
2395 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2396 auto inElementTy = inputTy.getElementType();
2397 auto outElementTy = resultTy.getElementType();
2398 int axis = argmaxOp.getAxis();
2399 auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2401 if (!isa<IntegerType>(outElementTy))
2402 return rewriter.notifyMatchFailure(
2404 "tosa.arg_max to linalg.* requires integer-like result type");
2406 SmallVector<Value> dynDims;
2407 for (
int i = 0; i < inputTy.getRank(); i++) {
2408 if (inputTy.isDynamicDim(i) && i != axis) {
2409 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2414 auto emptyTensorIdx =
2415 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2416 outElementTy, dynDims)
2418 auto fillValueIdx = arith::ConstantOp::create(
2419 rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
2420 auto filledTensorIdx =
2421 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValueIdx},
2426 auto emptyTensorMax =
2427 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
2430 auto fillValueMaxAttr =
2433 if (!fillValueMaxAttr)
2434 return rewriter.notifyMatchFailure(
2435 argmaxOp,
"unsupported tosa.argmax element type");
2438 arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
2439 auto filledTensorMax =
2440 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValueMax},
2446 SmallVector<utils::IteratorType, 4> iteratorTypes;
2447 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2448 iteratorTypes[axis] = utils::IteratorType::reduction;
2450 SmallVector<AffineExpr, 2> srcExprs;
2451 SmallVector<AffineExpr, 2> dstExprs;
2452 for (
int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2458 bool didEncounterError =
false;
2460 rewriter.getContext());
2461 auto linalgOp = linalg::GenericOp::create(
2462 rewriter, loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2463 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2464 [&](OpBuilder &nestedBuilder, Location nestedLoc,
2466 auto newValue = blockArgs[0];
2467 auto oldIndex = blockArgs[1];
2468 auto oldValue = blockArgs[2];
2470 Value newIndex = arith::IndexCastOp::create(
2471 rewriter, nestedLoc, oldIndex.getType(),
2472 linalg::IndexOp::create(rewriter, loc, axis));
2475 if (isa<FloatType>(inElementTy)) {
2476 if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) {
2479 predicate = arith::CmpFOp::create(rewriter, nestedLoc,
2480 arith::CmpFPredicate::OGT,
2481 newValue, oldValue);
2486 Value gt = arith::CmpFOp::create(rewriter, nestedLoc,
2487 arith::CmpFPredicate::UGT,
2488 newValue, oldValue);
2489 Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc,
2490 arith::CmpFPredicate::ORD,
2491 oldValue, oldValue);
2492 predicate = arith::AndIOp::create(
2493 rewriter, nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
2495 }
else if (isa<IntegerType>(inElementTy)) {
2496 predicate = arith::CmpIOp::create(rewriter, nestedLoc,
2497 arith::CmpIPredicate::sgt,
2498 newValue, oldValue);
2500 didEncounterError =
true;
2504 auto resultMax = arith::SelectOp::create(
2505 rewriter, nestedLoc, predicate, newValue, oldValue);
2506 auto resultIndex = arith::SelectOp::create(
2507 rewriter, nestedLoc, predicate, newIndex, oldIndex);
2508 linalg::YieldOp::create(nestedBuilder, nestedLoc,
2512 if (didEncounterError)
2513 return rewriter.notifyMatchFailure(
2514 argmaxOp,
"unsupported tosa.argmax element type");
2516 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2521class GatherConverter :
public OpConversionPattern<tosa::GatherOp> {
2523 using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2525 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2526 ConversionPatternRewriter &rewriter)
const final {
2527 auto input = adaptor.getOperands()[0];
2528 auto indices = adaptor.getOperands()[1];
2530 auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2531 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2532 if (!valuesTy || !resultTy)
2533 return rewriter.notifyMatchFailure(op,
"unranked tensors not supported");
2535 auto dynamicDims = inferDynamicDimsForGather(
2536 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2538 auto resultElementTy = resultTy.getElementType();
2540 auto loc = op.getLoc();
2542 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2543 resultElementTy, dynamicDims)
2546 SmallVector<AffineMap, 2> affineMaps = {
2548 resultTy.getRank(), 0,
2549 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2550 rewriter.getContext()),
2551 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2553 auto genericOp = linalg::GenericOp::create(
2557 [&](OpBuilder &
b, Location loc,
ValueRange args) {
2558 auto indexValue = args[0];
2559 auto index0 = linalg::IndexOp::create(rewriter, loc, 0);
2560 Value index1 = arith::IndexCastOp::create(
2561 rewriter, loc, rewriter.getIndexType(), indexValue);
2562 auto index2 = linalg::IndexOp::create(rewriter, loc, 2);
2563 Value extract = tensor::ExtractOp::create(
2564 rewriter, loc, input,
ValueRange{index0, index1, index2});
2565 linalg::YieldOp::create(rewriter, loc, extract);
2567 rewriter.replaceOp(op, genericOp.getResult(0));
2571 static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2575 llvm::SmallVector<Value> results;
2577 auto addDynamicDimension = [&](Value source, int64_t dim) {
2579 if (
auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2580 results.push_back(dimValue);
2583 addDynamicDimension(values, 0);
2584 addDynamicDimension(
indices, 1);
2585 addDynamicDimension(values, 2);
2595 using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2597 LogicalResult matchAndRewrite(tosa::TableOp op,
2598 PatternRewriter &rewriter)
const final {
2599 auto loc = op.getLoc();
2600 Value input = op.getInput1();
2601 Value table = op.getTable();
2602 auto inputTy = cast<ShapedType>(input.
getType());
2603 auto tableTy = cast<ShapedType>(table.
getType());
2604 auto resultTy = cast<ShapedType>(op.getType());
2606 auto inputElementTy = inputTy.getElementType();
2607 auto tableElementTy = tableTy.getElementType();
2608 auto resultElementTy = resultTy.getElementType();
2610 SmallVector<Value> dynDims;
2611 for (
int i = 0; i < resultTy.getRank(); ++i) {
2612 if (inputTy.isDynamicDim(i)) {
2614 tensor::DimOp::create(rewriter, loc, op.getOperand(0), i));
2619 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2620 resultElementTy, dynDims)
2623 SmallVector<AffineMap, 2> affineMaps = {
2624 rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2625 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2627 auto genericOp = linalg::GenericOp::create(
2630 rewriter.replaceOp(op, genericOp.getResult(0));
2633 OpBuilder::InsertionGuard regionGuard(rewriter);
2634 Block *block = rewriter.createBlock(
2635 &genericOp.getRegion(), genericOp.getRegion().end(),
2636 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2639 rewriter.setInsertionPointToStart(block);
2640 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2641 resultElementTy.isInteger(8)) {
2642 Value index = arith::IndexCastOp::create(
2643 rewriter, loc, rewriter.getIndexType(), inputValue);
2645 index = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
2648 tensor::ExtractOp::create(rewriter, loc, table,
ValueRange{index});
2649 linalg::YieldOp::create(rewriter, loc, extract);
2653 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2654 resultElementTy.isInteger(32)) {
2655 Value extend = arith::ExtSIOp::create(
2656 rewriter, loc, rewriter.getI32Type(), inputValue);
2658 auto offset = arith::ConstantOp::create(
2659 rewriter, loc, rewriter.getI32IntegerAttr(32768));
2660 auto seven = arith::ConstantOp::create(rewriter, loc,
2661 rewriter.getI32IntegerAttr(7));
2662 auto one = arith::ConstantOp::create(rewriter, loc,
2663 rewriter.getI32IntegerAttr(1));
2664 auto b1111111 = arith::ConstantOp::create(
2665 rewriter, loc, rewriter.getI32IntegerAttr(127));
2671 auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset);
2672 Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven);
2674 arith::AndIOp::create(rewriter, loc, extendAdd, b1111111);
2679 Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one);
2681 index = arith::IndexCastOp::create(rewriter, loc,
2682 rewriter.getIndexType(), index);
2683 indexPlusOne = arith::IndexCastOp::create(
2684 rewriter, loc, rewriter.getIndexType(), indexPlusOne);
2687 tensor::ExtractOp::create(rewriter, loc, table,
ValueRange{index});
2688 Value next = tensor::ExtractOp::create(rewriter, loc, table,
2692 arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), base);
2694 arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), next);
2698 Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven);
2699 Value diff = arith::SubIOp::create(rewriter, loc, next, base);
2700 Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction);
2702 arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled);
2704 linalg::YieldOp::create(rewriter, loc,
result);
2710 return rewriter.notifyMatchFailure(
2711 op,
"unable to create body for tosa.table op");
2716 using OpRewritePattern<RFFT2dOp>::OpRewritePattern;
2718 static bool isRankedTensor(Type type) {
return isa<RankedTensorType>(type); }
2720 static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2726 auto divBy2 = builder.
createOrFold<arith::DivUIOp>(loc, value, two);
2727 auto plusOne = builder.
createOrFold<arith::AddIOp>(loc, divBy2, one);
2731 static RankedTensorType
2732 computeOutputShape(OpBuilder &builder, Location loc, Value input,
2733 llvm::SmallVectorImpl<Value> &dynamicSizes) {
2739 dims[2] = halfPlusOne(builder, loc, dims[2]);
2741 llvm::SmallVector<int64_t, 3> staticSizes;
2744 auto elementType = cast<RankedTensorType>(input.
getType()).getElementType();
2745 return RankedTensorType::get(staticSizes, elementType);
2748 static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2749 RankedTensorType type,
2750 llvm::ArrayRef<Value> dynamicSizes) {
2752 tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
2753 auto fillValueAttr = rewriter.
getZeroAttr(type.getElementType());
2754 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
2756 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValue},
2759 return filledTensor;
2762 static Value castIndexToFloat(OpBuilder &builder, Location loc,
2763 FloatType type, Value value) {
2764 auto integerVal = arith::IndexCastUIOp::create(
2766 type.getIntOrFloatBitWidth() > 32 ? builder.
getI64Type()
2770 return arith::UIToFPOp::create(builder, loc, type, integerVal);
2773 static Value createLinalgIndex(OpBuilder &builder, Location loc,
2774 FloatType type, int64_t index) {
2775 auto indexVal = linalg::IndexOp::create(builder, loc, index);
2776 return castIndexToFloat(builder, loc, type, indexVal);
2779 template <
typename... Args>
2780 static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2785 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2786 PatternRewriter &rewriter)
const override {
2787 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2788 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2790 "only supports ranked tensors");
2793 auto loc = rfft2d.getLoc();
2794 auto input = rfft2d.getInputReal();
2796 dyn_cast<FloatType>(cast<ShapedType>(input.
getType()).getElementType());
2799 "only supports float element types");
2802 llvm::SmallVector<Value> dynamicSizes;
2803 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2806 llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
2807 utils::IteratorType::parallel, utils::IteratorType::parallel,
2808 utils::IteratorType::parallel, utils::IteratorType::reduction,
2809 utils::IteratorType::reduction};
2812 llvm::SmallVector<Value> genericOpInputs = {input};
2813 llvm::SmallVector<Value> genericOpOutputs = {
2814 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2815 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2819 llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2820 affineDimsExpr(rewriter, 0, 1, 2),
2821 affineDimsExpr(rewriter, 0, 1, 2)},
2825 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input, 1);
2826 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input, 2);
2829 auto twoPiAttr = rewriter.
getFloatAttr(elementType, 6.283185307179586);
2830 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2831 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2832 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2834 auto buildBody = [&](OpBuilder &builder, Location loc,
ValueRange args) {
2835 Value valReal = args[0];
2836 Value sumReal = args[1];
2837 Value sumImag = args[2];
2840 Value oy = linalg::IndexOp::create(builder, loc, 1);
2841 Value ox = linalg::IndexOp::create(builder, loc, 2);
2842 Value iy = linalg::IndexOp::create(builder, loc, 3);
2843 Value ix = linalg::IndexOp::create(builder, loc, 4);
2848 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2849 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2851 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2852 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2854 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2855 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2857 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2858 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2859 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2860 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2864 auto cosAngle = math::CosOp::create(builder, loc, angle);
2865 auto sinAngle = math::SinOp::create(builder, loc, angle);
2866 auto realComponent =
2867 arith::MulFOp::create(builder, loc, valReal, cosAngle);
2868 auto imagComponent =
2869 arith::MulFOp::create(builder, loc, valReal, sinAngle);
2874 arith::AddFOp::create(builder, loc, sumReal, realComponent);
2876 arith::SubFOp::create(builder, loc, sumImag, imagComponent);
2878 linalg::YieldOp::create(builder, loc,
ValueRange{outReal, outImag});
2882 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2883 indexingMaps, iteratorTypes, buildBody);
2892 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2893 PatternRewriter &rewriter)
const override {
2894 if (!llvm::all_of(fft2d->getOperandTypes(),
2895 RFFT2dConverter::isRankedTensor) ||
2896 !llvm::all_of(fft2d->getResultTypes(),
2897 RFFT2dConverter::isRankedTensor)) {
2901 Location loc = fft2d.getLoc();
2902 Value input_real = fft2d.getInputReal();
2903 Value input_imag = fft2d.getInputImag();
2904 BoolAttr inverse = fft2d.getInverseAttr();
2906 auto real_el_ty = cast<FloatType>(
2907 cast<ShapedType>(input_real.
getType()).getElementType());
2908 [[maybe_unused]]
auto imag_el_ty = cast<FloatType>(
2909 cast<ShapedType>(input_imag.
getType()).getElementType());
2911 assert(real_el_ty == imag_el_ty);
2914 SmallVector<Value> dynamicSizes;
2919 SmallVector<int64_t, 3> staticSizes;
2922 auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2925 SmallVector<utils::IteratorType, 5> iteratorTypes = {
2926 utils::IteratorType::parallel, utils::IteratorType::parallel,
2927 utils::IteratorType::parallel, utils::IteratorType::reduction,
2928 utils::IteratorType::reduction};
2931 SmallVector<Value> genericOpInputs = {input_real, input_imag};
2932 SmallVector<Value> genericOpOutputs = {
2933 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2935 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2940 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2941 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2942 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2943 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2947 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 1);
2948 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 2);
2951 auto twoPiAttr = rewriter.
getFloatAttr(real_el_ty, 6.283185307179586);
2952 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2954 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2956 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2958 auto buildBody = [&](OpBuilder &builder, Location loc,
ValueRange args) {
2959 Value valReal = args[0];
2960 Value valImag = args[1];
2961 Value sumReal = args[2];
2962 Value sumImag = args[3];
2965 Value oy = linalg::IndexOp::create(builder, loc, 1);
2966 Value ox = linalg::IndexOp::create(builder, loc, 2);
2967 Value iy = linalg::IndexOp::create(builder, loc, 3);
2968 Value ix = linalg::IndexOp::create(builder, loc, 4);
2972 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2973 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2975 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2976 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2979 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2981 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2983 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2984 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2986 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2987 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2990 angle = arith::MulFOp::create(
2991 builder, loc, angle,
2992 arith::ConstantOp::create(rewriter, loc,
2998 auto cosAngle = math::CosOp::create(builder, loc, angle);
2999 auto sinAngle = math::SinOp::create(builder, loc, angle);
3001 auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle);
3002 auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle);
3003 auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin);
3005 auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle);
3006 auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle);
3008 auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin);
3013 arith::AddFOp::create(builder, loc, sumReal, realComponent);
3015 arith::AddFOp::create(builder, loc, sumImag, imagComponent);
3017 linalg::YieldOp::create(builder, loc,
ValueRange{outReal, outImag});
3021 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
3022 indexingMaps, iteratorTypes, buildBody);
3043 PointwiseConverter<tosa::AddOp>,
3044 PointwiseConverter<tosa::SubOp>,
3045 PointwiseConverter<tosa::MulOp>,
3046 PointwiseConverter<tosa::IntDivOp>,
3047 PointwiseConverter<tosa::NegateOp>,
3048 PointwiseConverter<tosa::PowOp>,
3049 PointwiseConverter<tosa::ReciprocalOp>,
3050 PointwiseConverter<tosa::RsqrtOp>,
3051 PointwiseConverter<tosa::LogOp>,
3052 PointwiseConverter<tosa::ExpOp>,
3053 PointwiseConverter<tosa::AbsOp>,
3054 PointwiseConverter<tosa::SinOp>,
3055 PointwiseConverter<tosa::CosOp>,
3056 PointwiseConverter<tosa::TanhOp>,
3057 PointwiseConverter<tosa::ErfOp>,
3058 PointwiseConverter<tosa::BitwiseAndOp>,
3059 PointwiseConverter<tosa::BitwiseOrOp>,
3060 PointwiseConverter<tosa::BitwiseNotOp>,
3061 PointwiseConverter<tosa::BitwiseXorOp>,
3062 PointwiseConverter<tosa::LogicalAndOp>,
3063 PointwiseConverter<tosa::LogicalNotOp>,
3064 PointwiseConverter<tosa::LogicalOrOp>,
3065 PointwiseConverter<tosa::LogicalXorOp>,
3066 PointwiseConverter<tosa::CastOp>,
3067 PointwiseConverter<tosa::LogicalLeftShiftOp>,
3068 PointwiseConverter<tosa::LogicalRightShiftOp>,
3069 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
3070 PointwiseConverter<tosa::ClzOp>,
3071 PointwiseConverter<tosa::SelectOp>,
3072 PointwiseConverter<tosa::GreaterOp>,
3073 PointwiseConverter<tosa::GreaterEqualOp>,
3074 PointwiseConverter<tosa::EqualOp>,
3075 PointwiseConverter<tosa::MaximumOp>,
3076 PointwiseConverter<tosa::MinimumOp>,
3077 PointwiseConverter<tosa::CeilOp>,
3078 PointwiseConverter<tosa::FloorOp>,
3079 PointwiseConverter<tosa::ClampOp>,
3080 PointwiseConverter<tosa::SigmoidOp>
3081 >(converter,
patterns->getContext());
3084 IdentityNConverter<tosa::IdentityOp>,
3085 ReduceConverter<tosa::ReduceAllOp>,
3086 ReduceConverter<tosa::ReduceAnyOp>,
3087 ReduceConverter<tosa::ReduceMinOp>,
3088 ReduceConverter<tosa::ReduceMaxOp>,
3089 ReduceConverter<tosa::ReduceSumOp>,
3090 ReduceConverter<tosa::ReduceProductOp>,
3098 TileConverter>(
patterns->getContext());
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
DenseMap< int64_t, Value > IndexPool
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static bool operandsAndResultsRanked(Operation *operation)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
BlockArgument getArgument(unsigned i)
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
BoolAttr getBoolAttr(bool value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
IntegerAttr getI8IntegerAttr(int8_t value)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...