29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/Sequence.h"
31#include "llvm/ADT/SmallVectorExtras.h"
60template <
typename OpTy>
68 auto nanMode = op.getNanMode();
69 if (nanMode == NanPropagationMode::PROPAGATE)
73 Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
74 arith::CmpFPredicate::UNO,
lhs,
lhs);
75 Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
76 arith::CmpFPredicate::UNO,
rhs,
rhs);
78 arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN,
rhs,
result);
79 return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN,
lhs,
85 ConversionPatternRewriter &rewriter) {
91 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
92 return math::AbsFOp::create(rewriter, loc, resultTypes, args);
94 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
95 auto zero = arith::ConstantOp::create(rewriter, loc,
96 rewriter.getZeroAttr(elementTy));
97 auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]);
98 return arith::MaxSIOp::create(rewriter, loc, args[0], neg);
102 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
103 return arith::AddFOp::create(rewriter, loc, resultTypes, args);
105 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
106 return arith::AddIOp::create(rewriter, loc, resultTypes, args);
109 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
110 return arith::SubFOp::create(rewriter, loc, resultTypes, args);
112 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
113 return arith::SubIOp::create(rewriter, loc, resultTypes, args);
116 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
117 return arith::DivSIOp::create(rewriter, loc, resultTypes, args);
120 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
122 arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
123 return arith::DivFOp::create(rewriter, loc, one, args[0]);
127 if (isa<tosa::MulOp>(op)) {
128 auto shiftVal = cast<tosa::MulOp>(op).getShift();
130 bool shiftIsConstant =
true;
133 shift = shiftElem.
getValues<IntegerAttr>()[0].getInt();
135 shiftIsConstant =
false;
137 if (isa<FloatType>(elementTy)) {
139 (
void)rewriter.notifyMatchFailure(op,
140 "Cannot have shift value for float");
143 return arith::MulFOp::create(rewriter, loc, args[0], args[1]);
146 if (isa<IntegerType>(elementTy)) {
150 if (shift > 0 || !shiftIsConstant) {
157 a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
159 if (!
b.getType().isInteger(32))
160 b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(),
b);
162 auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
163 auto roundingAttr = RoundingModeAttr::get(rewriter.getContext(),
164 RoundingMode::SINGLE_ROUND);
166 tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
167 b, shiftAmount, roundingAttr);
173 int bWidth =
b.getType().getIntOrFloatBitWidth();
174 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
177 a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a);
179 b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0],
b);
181 return arith::MulIOp::create(rewriter, loc, resultTypes, a,
b);
186 if (isa<tosa::NegateOp>(op)) {
187 auto negate = cast<tosa::NegateOp>(op);
190 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
191 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
192 bool hasInZp = !failed(maybeInZp);
193 bool hasOutZp = !failed(maybeOutZp);
199 if (isa<FloatType>(elementTy))
200 return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
202 if (isa<IntegerType>(elementTy)) {
204 Type intermediateType;
207 int intermediateBitWidth = 64;
209 if (hasInZp && hasOutZp) {
211 const int64_t zpAdd = inZp + outZp;
213 APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
218 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
219 intermediateBitWidth = 16;
220 }
else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
221 intermediateBitWidth = 32;
224 intermediateType = rewriter.getIntegerType(intermediateBitWidth);
225 zpAddValue = arith::ConstantOp::create(
226 rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
228 intermediateType = rewriter.getIntegerType(intermediateBitWidth);
229 Value arg1 = args[1];
230 Value arg2 = args[2];
232 if (arg1.
getType() != intermediateType)
233 arg1 = arith::ExtSIOp::create(rewriter, loc, intermediateType, arg1);
234 if (arg2.
getType() != intermediateType)
235 arg2 = arith::ExtSIOp::create(rewriter, loc, intermediateType, arg2);
237 arith::AddIOp::create(rewriter, loc, intermediateType, arg1, arg2);
243 if (ext.
getType() != intermediateType)
244 ext = arith::ExtSIOp::create(rewriter, loc, intermediateType, ext);
245 auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext);
249 rewriter, loc, intermediateType,
250 APInt::getSignedMinValue(inputBitWidth).getSExtValue());
252 rewriter, loc, intermediateType,
253 APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
257 if (
clamp.getType() == elementTy)
259 return arith::TruncIOp::create(rewriter, loc, elementTy,
clamp);
264 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
265 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
268 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
269 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
272 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
273 auto allOnesAttr = rewriter.getIntegerAttr(
274 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
275 auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr);
276 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes);
280 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
281 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
284 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
285 return arith::ShLIOp::create(rewriter, loc, resultTypes, args);
288 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
289 return arith::ShRUIOp::create(rewriter, loc, resultTypes, args);
292 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
293 auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args);
294 auto round = cast<BoolAttr>(op->
getAttr(
"round")).getValue();
299 Type i1Ty = IntegerType::get(rewriter.getContext(), 1);
300 auto one = arith::ConstantOp::create(rewriter, loc,
301 IntegerAttr::get(elementTy, 1));
302 auto zero = arith::ConstantOp::create(rewriter, loc,
303 IntegerAttr::get(elementTy, 0));
305 arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 0));
307 arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1));
310 auto shiftValueGreaterThanZero = arith::CmpIOp::create(
311 rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero);
315 arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one);
317 arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract)
319 auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted,
322 arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
324 auto shouldRound = arith::SelectOp::create(
325 rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd, i1zero);
327 arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
328 return arith::AddIOp::create(rewriter, loc, resultTypes,
result, extended);
332 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
333 return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]);
337 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
338 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
341 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
342 auto one = arith::ConstantOp::create(rewriter, loc,
343 rewriter.getIntegerAttr(elementTy, 1));
344 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one);
348 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
349 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
352 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
353 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
356 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
357 return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args);
360 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
361 return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args);
364 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
365 return mlir::math::LogOp::create(rewriter, loc, resultTypes, args);
368 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
369 return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args);
372 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
373 return mlir::math::SinOp::create(rewriter, loc, resultTypes, args);
376 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
377 return mlir::math::CosOp::create(rewriter, loc, resultTypes, args);
380 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
381 return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args);
384 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
385 return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args);
388 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
389 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT,
392 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
393 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt,
397 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
398 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE,
401 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
402 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
406 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
407 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ,
410 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
411 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
415 if (isa<tosa::SelectOp>(op)) {
417 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
418 return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]);
422 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
423 auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
425 rewriter, args[0], args[1],
max);
428 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
429 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
433 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
434 auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
436 rewriter, args[0], args[1],
min);
439 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
440 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
444 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
445 return math::CeilOp::create(rewriter, loc, resultTypes, args);
448 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
449 return math::FloorOp::create(rewriter, loc, resultTypes, args);
452 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
453 bool losesInfo =
false;
454 APFloat minApf = cast<FloatAttr>(op->
getAttr(
"min_val")).getValue();
455 APFloat maxApf = cast<FloatAttr>(op->
getAttr(
"max_val")).getValue();
456 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
457 APFloat::rmNearestTiesToEven, &losesInfo);
458 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
459 APFloat::rmNearestTiesToEven, &losesInfo);
460 auto min = arith::ConstantOp::create(
461 rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
462 auto max = arith::ConstantOp::create(
463 rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
466 auto clampOp = llvm::cast<tosa::ClampOp>(op);
467 const auto nanMode = clampOp.getNanMode();
470 if (!isa<FloatType>(elementTy))
475 if (nanMode == NanPropagationMode::PROPAGATE)
489 Value isNaN = arith::CmpFOp::create(
490 rewriter, op->
getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
493 return arith::SelectOp::create(rewriter, op->
getLoc(), isNaN,
min,
result);
496 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
497 auto intTy = cast<IntegerType>(elementTy);
499 cast<IntegerAttr>(op->
getAttr(
"min_val")).getValue().getSExtValue();
501 cast<IntegerAttr>(op->
getAttr(
"max_val")).getValue().getSExtValue();
503 int64_t minRepresentable = std::numeric_limits<int64_t>::min();
504 int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
505 if (intTy.isUnsignedInteger()) {
506 minRepresentable = 0;
507 if (intTy.getIntOrFloatBitWidth() <= 63) {
509 (
int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
512 }
else if (intTy.getIntOrFloatBitWidth() <= 64) {
514 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
516 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
521 min = std::max(
min, minRepresentable);
522 max = std::max(
max, minRepresentable);
523 min = std::min(
min, maxRepresentable);
524 max = std::min(
max, maxRepresentable);
527 intTy.getIntOrFloatBitWidth());
529 intTy.getIntOrFloatBitWidth());
531 intTy.isUnsignedInteger());
535 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
537 arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
538 auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
539 auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
540 auto added = arith::AddFOp::create(rewriter, loc, exp, one);
541 return arith::DivFOp::create(rewriter, loc, one, added);
545 if (isa<tosa::CastOp>(op)) {
546 Type srcTy = elementTy;
547 Type dstTy = resultTypes.front();
549 (
void)rewriter.notifyMatchFailure(op,
"unsupported type");
559 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
560 return arith::ExtFOp::create(rewriter, loc, resultTypes, args,
563 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
564 return arith::TruncFOp::create(rewriter, loc, resultTypes, args,
568 if (srcTy.
isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
569 return arith::UIToFPOp::create(rewriter, loc, resultTypes, args,
572 if (srcTy.
isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
573 return arith::ExtUIOp::create(rewriter, loc, resultTypes, args,
579 auto unrealizedCast =
580 UnrealizedConversionCastOp::create(
584 return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
589 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
590 return arith::SIToFPOp::create(rewriter, loc, resultTypes, args,
594 if (isa<FloatType>(srcTy) && dstTy.
isInteger(1)) {
595 Value zero = arith::ConstantOp::create(rewriter, loc,
596 rewriter.getFloatAttr(srcTy, 0.0));
597 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE,
601 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
602 auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]);
604 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
608 APFloat::semanticsMaxExponent(fltSemantics)) {
611 auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded);
612 auto posInf = arith::ConstantOp::create(
615 APFloat::getInf(fltSemantics)));
616 auto negInf = arith::ConstantOp::create(
618 rewriter.getFloatAttr(
620 APFloat::getInf(fltSemantics,
true)));
621 auto overflow = arith::CmpFOp::create(
622 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf);
623 auto underflow = arith::CmpFOp::create(
624 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf);
625 auto intMin = arith::ConstantOp::create(
627 rewriter.getIntegerAttr(
630 auto intMax = arith::ConstantOp::create(
632 rewriter.getIntegerAttr(
636 arith::SelectOp::create(rewriter, loc, overflow, intMax, conv);
637 return arith::SelectOp::create(rewriter, loc, underflow, intMin,
641 auto intMinFP = arith::ConstantOp::create(
643 rewriter.getFloatAttr(
649 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
655 auto intMaxFP = arith::ConstantOp::create(
657 rewriter.getFloatAttr(
664 return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped);
671 auto intMaxPlusOneFP = arith::ConstantOp::create(
673 rewriter.getFloatAttr(
680 auto intMax = arith::ConstantOp::create(
682 rewriter.getIntegerAttr(
686 arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP);
688 arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP);
689 auto overflow = arith::CmpFOp::create(
690 rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
691 return arith::SelectOp::create(rewriter, loc, overflow, intMax,
697 if (isa<IntegerType>(srcTy) && dstTy.
isInteger(1)) {
700 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne,
704 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
705 return arith::ExtSIOp::create(rewriter, loc, resultTypes, args,
708 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
709 return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]);
713 (
void)rewriter.notifyMatchFailure(
714 op,
"unhandled op for linalg body calculation for elementwise op");
735 return tensor::DimOp::create(rewriter, loc,
tensor, indexValue).getResult();
741 auto shapedType = dyn_cast<ShapedType>(
tensor.getType());
742 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
743 assert(
index >= 0 &&
index < shapedType.getRank() &&
"index out of bounds");
744 if (shapedType.isDynamicDim(
index))
750 auto isRanked = [](
Value value) {
751 return isa<RankedTensorType>(value.getType());
753 return llvm::all_of(operation->
getOperands(), isRanked) &&
754 llvm::all_of(operation->
getResults(), isRanked);
767static std::pair<OpFoldResult, Value>
773 for (
auto operand : operands) {
774 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
775 if (ShapedType::isStatic(size) && size > 1)
780 auto operandsWithDynamicDim =
781 llvm::filter_to_vector(operands, [&](
Value operand) {
782 return cast<RankedTensorType>(operand.
getType()).isDynamicDim(dim);
786 if (operandsWithDynamicDim.empty())
793 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
794 if (operandsWithDynamicDim.size() == 1)
795 return {targetSize, operandsWithDynamicDim[0]};
798 for (
size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
800 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
801 targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize);
803 return {targetSize,
nullptr};
811 assert(!operands.empty());
812 auto rank = cast<RankedTensorType>(operands.front().
getType()).getRank();
815 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
816 auto [targetSize, masterOperand] =
818 targetShape.push_back(targetSize);
819 masterOperands.push_back(masterOperand);
821 return {targetShape, masterOperands};
827 Value masterOperand) {
829 auto rankedTensorType = cast<RankedTensorType>(operand.
getType());
830 if (!rankedTensorType.isDynamicDim(dim))
837 if (operand == masterOperand)
841 auto rank = rankedTensorType.getRank();
843 for (
auto index : llvm::seq<int64_t>(0, rank)) {
846 affineExprs.push_back(affineExpr);
848 auto broadcastAffineMap =
854 auto one =
createIndex(rewriter, loc, indexPool, 1);
855 auto runtimeSize =
getTensorDim(rewriter, loc, indexPool, operand, dim);
856 auto broadcastNecessary = arith::CmpIOp::create(
857 rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one);
867 for (
auto index : llvm::seq<int64_t>(0, rank)) {
868 auto size =
index == dim ? targetSize
871 outputTensorShape.push_back(size);
873 Value outputTensor = tensor::EmptyOp::create(
874 opBuilder, loc, outputTensorShape, rankedTensorType.getElementType());
878 linalg::GenericOp::create(
879 opBuilder, loc, outputTensor.
getType(), operand, outputTensor,
883 linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
888 auto castResultTensor = rewriter.
createOrFold<tensor::CastOp>(
889 loc, operand.
getType(), resultTensor);
892 scf::YieldOp::create(opBuilder, loc, castResultTensor);
897 scf::YieldOp::create(opBuilder, loc, operand);
901 auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary,
902 emitThenRegion, emitElseRegion);
903 return ifOp.getResult(0);
910 int64_t rank = cast<RankedTensorType>(operand.
getType()).getRank();
911 assert((
int64_t)targetShape.size() == rank);
912 assert((
int64_t)masterOperands.size() == rank);
913 for (
auto index : llvm::seq<int64_t>(0, rank))
926 if (operands.size() == 1)
930 bool hasDynamic =
false;
931 for (
auto op : operands) {
932 const auto tType = dyn_cast<RankedTensorType>(op.getType());
933 if (tType && !tType.hasStaticShape()) {
942 return llvm::map_to_vector(operands, [&](
Value operand) {
944 targetShape, masterOperands);
954 auto resultType = cast_or_null<RankedTensorType>(
957 return rewriter.notifyMatchFailure(operation,
"failed to convert type");
959 Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape,
960 resultType.getElementType());
965 auto rank = resultType.getRank();
966 auto affineMaps = llvm::map_to_vector(operands, [&](
Value operand) {
967 auto shape = cast<ShapedType>(operand.
getType()).getShape();
969 for (
auto it : llvm::enumerate(
shape)) {
973 bool requiresBroadcast =
974 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
975 auto affineExpr = requiresBroadcast
976 ? rewriter.getAffineConstantExpr(0)
977 : rewriter.getAffineDimExpr(it.index());
978 affineExprs.push_back(affineExpr);
980 return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
982 affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
985 bool encounteredError =
false;
986 auto linalgOp = linalg::GenericOp::create(
987 rewriter, loc, outputTensor.
getType(), operands, outputTensor, affineMaps,
992 {resultType.getElementType()}, rewriter);
994 encounteredError =
true;
997 linalg::YieldOp::create(opBuilder, loc, opResult);
999 if (encounteredError)
1000 return rewriter.notifyMatchFailure(
1001 operation,
"unable to create linalg.generic body for elementwise op");
1004 auto castResult = rewriter.createOrFold<tensor::CastOp>(
1005 loc, resultType, linalgOp->getResult(0));
1006 rewriter.replaceOp(operation, castResult);
1013 if (isa<tosa::MulOp>(operation)) {
1017 return operands.take_front(2);
1019 return operands.take_front(3);
1021 if (
auto negate = dyn_cast<tosa::NegateOp>(operation)) {
1022 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
1023 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
1024 if (failed(maybeOutZp) && failed(maybeInZp))
1027 return operands.take_front(1);
1034 ConversionPatternRewriter &rewriter,
1038 assert(operation->
getNumResults() == 1 &&
"elementwise op expects 1 result");
1040 "elementwise op expects at least 1 operand");
1042 return rewriter.notifyMatchFailure(operation,
1043 "Unranked tensors not supported");
1047 auto loc = operation->
getLoc();
1049 auto [targetShape, masterOperands] =
1051 auto broadcastOperands =
1053 targetShape, masterOperands);
1055 targetShape, converter);
1062 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1065 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1068 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1071 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1074 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1076 elementTy, APFloat::getLargest(
1077 cast<FloatType>(elementTy).getFloatSemantics(),
false));
1079 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1083 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1085 elementTy, APFloat::getLargest(
1086 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1088 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1092 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1095 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1098 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1100 elementTy, APFloat::getLargest(
1101 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1103 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1117 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1118 return arith::AddFOp::create(rewriter, loc, args);
1121 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1122 return arith::AddIOp::create(rewriter, loc, args);
1125 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1126 return arith::MulFOp::create(rewriter, loc, args);
1129 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1130 return arith::MulIOp::create(rewriter, loc, args);
1133 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1134 return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
1137 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1138 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
1141 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1142 return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
1145 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1146 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
1149 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1150 return arith::AndIOp::create(rewriter, loc, args);
1152 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1153 return arith::OrIOp::create(rewriter, loc, args);
1161template <
typename OpTy>
1164 auto loc = op->getLoc();
1165 auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1166 auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1167 if (!inputTy || !resultTy)
1170 auto elementTy = resultTy.getElementType();
1171 Value input = op->getOperand(0);
1174 bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
1175 isa<FloatType>(elementTy) &&
1176 cast<FloatType>(elementTy).isBF16();
1181 for (
unsigned i = 0; i < inputTy.getRank(); i++) {
1183 reduceShape.push_back(inputTy.getDimSize(i));
1184 if (inputTy.isDynamicDim(i))
1185 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1190 inputs.push_back(input);
1194 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1200 op,
"No initial value found for reduction operation");
1202 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
1204 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValue},
1207 outputs.push_back(filledTensor);
1209 bool isNanIgnoreMode =
false;
1210 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1211 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1213 if (isa<FloatType>(elementTy) &&
1214 op.getNanMode() == NanPropagationMode::IGNORE) {
1215 isNanIgnoreMode =
true;
1221 auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
1222 auto emptyBoolTensor =
1223 tensor::EmptyOp::create(rewriter, loc, reduceShape,
1224 trueValue.getType(), dynDims)
1226 auto allResultsNaNTensor =
1227 linalg::FillOp::create(rewriter, loc,
ValueRange{trueValue},
1239 inputs.push_back(input);
1240 outputs.push_back(allResultsNaNTensor);
1244 bool didEncounterError =
false;
1245 linalg::LinalgOp linalgOp = linalg::ReduceOp::create(
1246 rewriter, loc, inputs, outputs, axis,
1248 std::array<Value, 2> binaryArgs{
1249 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1252 if (binaryArgs[0].
getType() != accTy)
1253 binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
1259 didEncounterError =
true;
1262 if (isNanIgnoreMode) {
1263 auto inputValue = blockArgs[0];
1264 auto initialValue = blockArgs[2];
1265 auto oldAllResultsNanFlagValue = blockArgs[3];
1268 Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(),
1269 arith::CmpFPredicate::UNO,
1270 inputValue, inputValue);
1272 auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(),
1273 isNaN, initialValue,
result);
1276 auto newAllResultsNanFlagValue = arith::AndIOp::create(
1277 nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1278 resultsToYield.push_back(selectOp);
1279 resultsToYield.push_back(newAllResultsNanFlagValue);
1281 resultsToYield.push_back(
result);
1283 linalg::YieldOp::create(nestedBuilder, loc, resultsToYield);
1286 if (!didEncounterError)
1288 op,
"unable to create linalg.generic body for reduce op");
1290 if (isNanIgnoreMode) {
1299 APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(),
false));
1300 auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
1301 auto emptyNanTensor =
1302 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1304 auto nanFilledTensor =
1305 linalg::FillOp::create(rewriter, loc,
ValueRange{nanValue},
1311 auto finalEmptyTensor =
1312 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1318 ins.push_back(linalgOp->getOpResult(1));
1319 ins.push_back(nanFilledTensor);
1320 ins.push_back(linalgOp->getResult(0));
1321 outs.push_back(finalEmptyTensor);
1323 linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs);
1324 linalgOp = linalgSelect;
1328 Value reducedRes = linalgOp->getResult(0);
1331 tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
1334 const unsigned reducedRank =
1335 cast<ShapedType>(reducedRes.
getType()).getRank();
1338 linalg::GenericOp::create(
1344 Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc,
1345 elementTy, args[0]);
1346 linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf);
1352 uint64_t expandInputRank = cast<ShapedType>(reducedRes.
getType()).getRank();
1353 reassociationMap.resize(expandInputRank);
1355 for (uint64_t i = 0; i < expandInputRank; i++) {
1356 int32_t dimToPush = i > axis ? i + 1 : i;
1360 if (expandInputRank != 0) {
1361 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1362 reassociationMap[expandedDim].push_back(
1377template <
typename SrcOp>
1378class PointwiseConverter :
public OpConversionPattern<SrcOp> {
1380 using OpConversionPattern<SrcOp>::OpConversionPattern;
1381 using typename OpConversionPattern<SrcOp>::OpAdaptor;
1384 matchAndRewrite(SrcOp op, OpAdaptor operands,
1385 ConversionPatternRewriter &rewriter)
const final {
1387 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1397 auto inputType = cast<RankedTensorType>(input.
getType());
1398 auto elemType = inputType.getElementType();
1399 auto collapsedType = RankedTensorType::get({}, elemType);
1401 return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input,
1408 output.reserve(input.size());
1410 for (
auto v : llvm::map_range(
1411 input, [](int32_t val) {
return static_cast<int8_t
>(val); })) {
1412 output.push_back(v);
1424static void setupLinalgGenericOpInputAndIndexingMap(
1427 bool isConstant, tosa::RescaleOp op,
Value &constant,
int64_t &arg,
1428 bool isShift =
false) {
1430 auto loc = op.getLoc();
1431 auto inputTy = cast<ShapedType>(op.getInput().getType());
1432 unsigned rank = inputTy.getRank();
1438 if (values.size() == 1) {
1439 IntegerAttr intAttr = isShift
1442 constant = arith::ConstantOp::create(rewriter, loc, intAttr);
1446 auto tensorType = RankedTensorType::get(
1447 {
static_cast<int64_t>(values.size())}, elementType);
1453 genericInputs.push_back(
1454 arith::ConstantOp::create(rewriter, loc, EltAttr));
1462 auto operand = isShift ? op.getShift() : op.getMultiplier();
1463 auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
1464 if (tensorType && tensorType.hasStaticShape() &&
1465 tensorType.getShape()[0] == 1) {
1470 genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
1471 indexingMaps.push_back(broadcastMap);
1473 genericInputs.push_back(operand);
1479 arg = indexingMaps.size() - 1;
1484 FailureOr<int64_t> maybeZp,
Location loc,
1486 bool isOutputZp =
false) {
1489 const uint32_t attrBitwidth =
1490 isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
1497 result = blockArgs[zpArg];
1498 auto zpTy =
result.getType();
1499 if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
1502 if (zpTy.isUnsignedInteger()) {
1504 UnrealizedConversionCastOp::create(
1509 if (zpTy.isUnsignedInteger()) {
1510 return arith::ExtUIOp::create(builder, loc, extendType,
result);
1512 return arith::ExtSIOp::create(builder, loc, extendType,
result);
1516 return arith::ConstantOp::create(builder, loc,
1517 IntegerAttr::get(extendType, *maybeZp));
1524 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1526 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1527 PatternRewriter &rewriter)
const final {
1528 auto loc = op.getLoc();
1529 auto input = op.getInput();
1530 auto inputTy = cast<ShapedType>(op.getInput().getType());
1531 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1532 unsigned rank = inputTy.getRank();
1535 if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND)
1537 op,
"tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1538 "currently supported");
1539 if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32())
1541 op,
"tosa.rescale requires scale32 for double_round to be true");
1543 if (!isa<IntegerType>(inputTy.getElementType()))
1546 SmallVector<Value> dynDims;
1547 for (
int i = 0; i < outputTy.getRank(); i++) {
1548 if (outputTy.isDynamicDim(i)) {
1549 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1553 DenseElementsAttr shiftElems;
1554 bool isShiftConstant =
false;
1556 isShiftConstant =
true;
1558 DenseElementsAttr multiplierElems;
1559 bool isMultiplierConstant =
false;
1561 isMultiplierConstant =
true;
1563 llvm::SmallVector<int32_t> shiftValues;
1564 llvm::SmallVector<int32_t> multiplierValues;
1567 if (isMultiplierConstant && isShiftConstant) {
1569 shiftValues = llvm::map_to_vector(
1570 shiftElems.
getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
1571 return static_cast<int32_t>(attr.getInt());
1574 llvm::map_to_vector(multiplierElems.
getValues<IntegerAttr>(),
1575 [](IntegerAttr attr) -> int32_t {
1576 return static_cast<int32_t>(attr.getInt());
1580 for (
int i = 0, s = multiplierValues.size(); i < s; i++) {
1581 if (shiftValues[i] > 63) {
1583 multiplierValues[i] = 0;
1588 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1589 llvm::any_of(shiftValues, [](int32_t v) {
return v > 31; });
1591 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
1593 RoundingMode roundingMode =
1594 doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
1596 SmallVector<AffineMap> indexingMaps = {
1598 SmallVector<Value, 4> genericInputs = {input};
1602 Value multiplierConstant;
1603 int64_t multiplierArg = 0;
1604 setupLinalgGenericOpInputAndIndexingMap(
1605 rewriter, multiplierValues, genericInputs, indexingMaps,
1606 isMultiplierConstant, op, multiplierConstant, multiplierArg);
1610 Value shiftConstant;
1611 int64_t shiftArg = 0;
1612 setupLinalgGenericOpInputAndIndexingMap(
1613 rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1614 shiftConstant, shiftArg,
true);
1619 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1620 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1630 genericInputs.push_back(
1631 collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
1632 indexingMaps.push_back(broadcastMap);
1633 iZpArg = indexingMaps.size() - 1;
1637 genericInputs.push_back(
1638 collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
1639 indexingMaps.push_back(broadcastMap);
1640 oZpArg = indexingMaps.size() - 1;
1647 Value emptyTensor = tensor::EmptyOp::create(
1648 rewriter, loc, outputTy.getShape(), outputTy.getElementType(),
1649 ArrayRef<Value>({dynDims}));
1651 auto linalgOp = linalg::GenericOp::create(
1652 rewriter, loc, outputTy, genericInputs,
ValueRange{emptyTensor},
1654 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1656 Value value = blockArgs[0];
1657 Type valueTy = value.
getType();
1659 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1660 auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
1661 nestedLoc, blockArgs, iZpArg);
1663 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1664 auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
1665 nestedLoc, blockArgs, oZpArg,
true);
1667 IntegerType outIntType =
1668 cast<IntegerType>(blockArgs.back().
getType());
1669 unsigned outBitWidth = outIntType.getWidth();
1670 assert(outBitWidth <= 32 &&
"Unexpected output zeropoint bitwidth");
1672 Value multiplier = multiplierConstant ? multiplierConstant
1673 : blockArgs[multiplierArg];
1674 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1677 value = UnrealizedConversionCastOp::create(
1678 nestedBuilder, nestedLoc,
1679 nestedBuilder.getIntegerType(
1685 if (op.getInputUnsigned()) {
1686 value = arith::ExtUIOp::create(nestedBuilder, nestedLoc,
1687 nestedBuilder.getI32Type(), value);
1689 value = arith::ExtSIOp::create(nestedBuilder, nestedLoc,
1690 nestedBuilder.getI32Type(), value);
1695 arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp);
1697 value = tosa::ApplyScaleOp::create(nestedBuilder, loc,
1698 nestedBuilder.getI32Type(), value,
1699 multiplier, shift, roundingMode);
1703 arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp);
1706 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1707 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1710 if (op.getOutputUnsigned()) {
1712 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1715 auto intMinVal = arith::ConstantOp::create(
1716 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin));
1717 auto intMaxVal = arith::ConstantOp::create(
1718 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax));
1721 nestedBuilder,
false);
1723 if (outIntType.getWidth() < 32) {
1724 value = arith::TruncIOp::create(
1725 nestedBuilder, nestedLoc,
1729 if (outIntType.isUnsignedInteger()) {
1730 value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
1734 linalg::YieldOp::create(nestedBuilder, loc, value);
1737 rewriter.
replaceOp(op, linalgOp->getResults());
1747 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1749 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1750 PatternRewriter &rewriter)
const final {
1751 Location loc = op.getLoc();
1752 ImplicitLocOpBuilder builder(loc, rewriter);
1753 auto input = op.getInput();
1754 auto inputTy = cast<RankedTensorType>(input.getType());
1755 auto resultTy = cast<RankedTensorType>(op.getType());
1756 const bool isBilinear = op.getMode() == ResizeMode::BILINEAR;
1758 auto inputH = inputTy.getDimSize(1);
1759 auto inputW = inputTy.getDimSize(2);
1760 auto outputH = resultTy.getDimSize(1);
1761 auto outputW = resultTy.getDimSize(2);
1763 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1765 op,
"tosa.resize is not a pure 1x1->1x1 image operation");
1767 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1768 op.getMode() != ResizeMode::BILINEAR)
1770 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1772 if (inputTy == resultTy) {
1777 SmallVector<int64_t> scale;
1783 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1790 RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1791 inputTy.getElementType());
1792 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input,
1796 llvm::SmallVector<Value> outputDynSize;
1797 if (inputTy.isDynamicDim(0))
1798 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1799 if (inputTy.isDynamicDim(3))
1800 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1803 auto genericTy = collapseTy.clone(resultTy.getElementType());
1805 tensor::EmptyOp::create(builder, genericTy.getShape(),
1806 resultTy.getElementType(), outputDynSize);
1808 SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1809 utils::IteratorType::parallel);
1811 auto generic = linalg::GenericOp::create(
1813 ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1814 [=](OpBuilder &
b, Location loc,
ValueRange args) {
1815 Value value = args[0];
1817 if (inputTy.getElementType() != resultTy.getElementType()) {
1818 value = arith::ExtSIOp::create(
b, loc, resultTy.getElementType(),
1821 if (isBilinear && scale[0] != 0) {
1822 Value scaleY = arith::ConstantOp::create(
1823 b, loc,
b.getI32IntegerAttr(scale[0]));
1824 value = arith::MulIOp::create(
b, loc, value, scaleY);
1827 if (isBilinear && scale[2] != 0) {
1828 Value scaleX = arith::ConstantOp::create(
1829 b, loc,
b.getI32IntegerAttr(scale[2]));
1830 value = arith::MulIOp::create(
b, loc, value, scaleX);
1834 linalg::YieldOp::create(
b, loc, value);
1838 op, resultTy,
generic.getResults()[0], reassociationMap);
1850 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1854 auto input = op.getInput();
1855 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1856 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1858 if (!inputTy || !resultTy)
1860 "requires ranked input/output types");
1862 auto batch = inputTy.getDimSize(0);
1863 auto channels = inputTy.getDimSize(3);
1864 auto inputH = inputTy.getDimSize(1);
1865 auto inputW = inputTy.getDimSize(2);
1866 auto outputH = resultTy.getDimSize(1);
1867 auto outputW = resultTy.getDimSize(2);
1869 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1871 op,
"tosa.resize has no broadcasting behavior");
1876 resizeShape.push_back(batch);
1877 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1878 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1879 resizeShape.push_back(channels);
1881 auto resizeTy = resultTy.clone(resizeShape);
1883 tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(),
1884 op.getOffset(), op.getBorder(), op.getMode());
1891 reassociationMap.push_back({});
1894 reassociationMap.push_back({});
1899 collapseShape.push_back(outputH);
1901 collapseShape.push_back(outputW);
1902 collapseShape.push_back(channels);
1904 auto collapseTy = resultTy.clone(collapseShape);
1905 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy,
1906 resize, reassociationMap);
1910 if (inputTy.isDynamicDim(0))
1911 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1912 if (inputTy.isDynamicDim(3))
1913 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1916 utils::IteratorType::parallel);
1917 Value empty = tensor::EmptyOp::create(
1918 builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1935 Value value = args[0];
1936 linalg::YieldOp::create(
b, loc, value);
1945 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1947 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1948 PatternRewriter &rewriter)
const final {
1949 Location loc = op.getLoc();
1950 ImplicitLocOpBuilder
b(loc, rewriter);
1951 auto input = op.getInput();
1952 auto inputTy = cast<ShapedType>(input.getType());
1953 auto resultTy = cast<ShapedType>(op.getType());
1954 auto resultETy = resultTy.getElementType();
1956 bool floatingPointMode = isa<FloatType>(resultETy);
1957 auto floatTy = resultETy;
1959 auto imageH = inputTy.getShape()[1];
1960 auto imageW = inputTy.getShape()[2];
1962 auto dynamicDimsOr =
1964 if (!dynamicDimsOr.has_value())
1966 op,
"unable to get dynamic dimensions of tosa.resize");
1968 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1969 op.getMode() != ResizeMode::BILINEAR)
1971 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1973 SmallVector<AffineMap, 2> affineMaps = {
1975 auto emptyTensor = tensor::EmptyOp::create(
b, resultTy.getShape(),
1976 resultETy, *dynamicDimsOr);
1977 auto genericOp = linalg::GenericOp::create(
1980 Value resize = genericOp.getResult(0);
1983 OpBuilder::InsertionGuard regionGuard(
b);
1984 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1986 Value batch = linalg::IndexOp::create(
b, 0);
1987 Value y = linalg::IndexOp::create(
b, 1);
1988 Value x = linalg::IndexOp::create(
b, 2);
1989 Value channel = linalg::IndexOp::create(
b, 3);
1992 arith::ConstantOp::create(
b,
b.getZeroAttr(
b.getI32Type()));
1993 Value zeroFp = arith::ConstantOp::create(
b,
b.getZeroAttr(floatTy));
1995 arith::ConstantOp::create(
b,
b.getI32IntegerAttr(imageH - 1));
1997 arith::ConstantOp::create(
b,
b.getI32IntegerAttr(imageW - 1));
1999 Value inY = arith::IndexCastOp::create(
b,
b.getI32Type(), y);
2000 Value inX = arith::IndexCastOp::create(
b,
b.getI32Type(), x);
2002 SmallVector<int64_t> scale, offset, border;
2007 op,
"tosa.resize scale/offset/border should have compile time "
2008 "constant values.");
2011 Value yScaleN, yScaleD, xScaleN, xScaleD;
2012 yScaleN = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[0]));
2013 yScaleD = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[1]));
2014 xScaleN = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[2]));
2015 xScaleD = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[3]));
2017 Value yOffset, xOffset, yBorder, xBorder;
2018 yOffset = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(offset[0]));
2019 xOffset = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(offset[1]));
2020 yBorder = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(border[0]));
2021 xBorder = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(border[1]));
2024 auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
2025 Value scaleN, Value scaleD, Value offset,
2026 int size, ImplicitLocOpBuilder &
b) {
2034 Value val = arith::MulIOp::create(
b, in, scaleD);
2035 val = arith::AddIOp::create(
b, val, offset);
2036 index = arith::FloorDivSIOp::create(
b, val, scaleN);
2039 Value scaledIndex = arith::MulIOp::create(
b, index, scaleN);
2040 Value r = arith::SubIOp::create(
b, val, scaledIndex);
2041 Value rFp = arith::SIToFPOp::create(
b, floatTy, r);
2044 Value scaleNfp = arith::UIToFPOp::create(
b, floatTy, scaleN);
2045 delta = arith::DivFOp::create(
b, rFp, scaleNfp);
2049 auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
2050 Value scaleN, Value scaleD, Value offset,
2051 int size, ImplicitLocOpBuilder &
b) {
2060 Value val = arith::MulIOp::create(
b, in, scaleD);
2061 val = arith::AddIOp::create(
b, val, offset);
2062 index = arith::FloorDivSIOp::create(
b, val, scaleN);
2063 delta = arith::MulIOp::create(
b, index, scaleN);
2064 delta = arith::SubIOp::create(
b, val, delta);
2067 Value ix, iy, dx, dy;
2068 if (floatingPointMode) {
2069 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH,
b);
2070 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW,
b);
2072 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH,
b);
2073 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW,
b);
2076 if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) {
2077 auto one = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(1));
2079 auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
2080 Value
max,
int size,
2081 ImplicitLocOpBuilder &
b) -> Value {
2087 if (floatingPointMode) {
2089 arith::ConstantOp::create(
b,
b.getFloatAttr(floatTy, 0.5f));
2090 pred = arith::CmpFOp::create(
b, arith::CmpFPredicate::OGE, dval, h);
2092 Value dvalDouble = arith::ShLIOp::create(
b, dval, one);
2093 pred = arith::CmpIOp::create(
b, arith::CmpIPredicate::sge,
2097 auto offset = arith::SelectOp::create(
b, pred, one, zeroI32);
2098 val = arith::AddIOp::create(
b, val, offset);
2100 return arith::IndexCastOp::create(
b,
b.getIndexType(), val);
2103 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH,
b);
2104 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW,
b);
2106 Value
result = tensor::ExtractOp::create(
2109 linalg::YieldOp::create(
b,
result);
2112 assert(op.getMode() == ResizeMode::BILINEAR);
2114 auto oneVal = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(1));
2116 auto getClampedIdxs = [&](Value &val0, Value &val1,
int size, Value in,
2117 Value
max, ImplicitLocOpBuilder &
b) {
2119 val1 = arith::AddIOp::create(
b, val0, oneVal);
2124 val0 = arith::IndexCastOp::create(
b,
b.getIndexType(), val0);
2125 val1 = arith::IndexCastOp::create(
b,
b.getIndexType(), val1);
2133 Value x0, x1, y0, y1;
2134 getClampedIdxs(y0, y1, imageH, iy, hMax,
b);
2135 getClampedIdxs(x0, x1, imageW, ix, wMax,
b);
2137 Value y0x0 = tensor::ExtractOp::create(
2139 Value y0x1 = tensor::ExtractOp::create(
2141 Value y1x0 = tensor::ExtractOp::create(
2143 Value y1x1 = tensor::ExtractOp::create(
2146 if (floatingPointMode) {
2148 arith::ConstantOp::create(
b,
b.getFloatAttr(floatTy, 1.0f));
2149 auto interpolate = [&](Value val0, Value val1, Value delta,
2151 ImplicitLocOpBuilder &
b) -> Value {
2154 Value oneMinusDelta = arith::SubFOp::create(
b, oneVal, delta);
2155 Value mul0 = arith::MulFOp::create(
b, val0, oneMinusDelta);
2156 Value mul1 = arith::MulFOp::create(
b, val1, delta);
2157 return arith::AddFOp::create(
b, mul0, mul1);
2163 Value topAcc = interpolate(y0x0, y0x1, dx, imageW,
b);
2168 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW,
b);
2172 Value
result = interpolate(topAcc, bottomAcc, dy, imageH,
b);
2173 linalg::YieldOp::create(
b,
result);
2176 y0x0 = arith::ExtSIOp::create(
b, resultETy, y0x0);
2177 y0x1 = arith::ExtSIOp::create(
b, resultETy, y0x1);
2178 y1x0 = arith::ExtSIOp::create(
b, resultETy, y1x0);
2179 y1x1 = arith::ExtSIOp::create(
b, resultETy, y1x1);
2182 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
2183 dx = arith::ExtSIOp::create(
b, resultETy, dx);
2184 dy = arith::ExtSIOp::create(
b, resultETy, dy);
2187 Value yScaleNExt = yScaleN;
2188 Value xScaleNExt = xScaleN;
2190 const int64_t scaleBitwidth =
2192 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2193 yScaleNExt = arith::ExtSIOp::create(
b, resultETy, yScaleN);
2194 xScaleNExt = arith::ExtSIOp::create(
b, resultETy, xScaleN);
2197 auto interpolate = [](Value val0, Value val1, Value weight1,
2198 Value scale,
int inputSize,
2199 ImplicitLocOpBuilder &
b) -> Value {
2201 return arith::MulIOp::create(
b, val0, scale);
2202 Value weight0 = arith::SubIOp::create(
b, scale, weight1);
2203 Value mul0 = arith::MulIOp::create(
b, val0, weight0);
2204 Value mul1 = arith::MulIOp::create(
b, val1, weight1);
2205 return arith::AddIOp::create(
b, mul0, mul1);
2208 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW,
b);
2209 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW,
b);
2211 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH,
b);
2212 linalg::YieldOp::create(
b,
result);
2225template <
typename SrcOp>
2228 using OpRewritePattern<SrcOp>::OpRewritePattern;
2230 LogicalResult matchAndRewrite(SrcOp op,
2231 PatternRewriter &rewriter)
const final {
2232 rewriter.
replaceOp(op, op.getOperation()->getOperands());
2237template <
typename SrcOp>
2240 using OpRewritePattern<SrcOp>::OpRewritePattern;
2242 LogicalResult matchAndRewrite(SrcOp reduceOp,
2243 PatternRewriter &rewriter)
const final {
2250 using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
2252 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2253 PatternRewriter &rewriter)
const final {
2254 auto loc = op.getLoc();
2255 Value input = op.getInput1();
2256 auto inputTy = cast<ShapedType>(input.
getType());
2257 auto resultTy = cast<ShapedType>(op.getType());
2258 auto axis = op.getAxis();
2260 SmallVector<Value> dynDims;
2261 for (
int i = 0; i < inputTy.getRank(); i++) {
2262 if (inputTy.isDynamicDim(i)) {
2263 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2267 Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
2270 auto emptyTensor = tensor::EmptyOp::create(
2271 rewriter, loc, inputTy.getShape(),
2272 inputTy.getElementType(), ArrayRef<Value>({dynDims}))
2274 SmallVector<AffineMap, 2> affineMaps = {
2278 op, resultTy, ArrayRef<Value>({}),
ValueRange{emptyTensor}, affineMaps,
2280 [&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
2281 llvm::SmallVector<Value>
indices;
2282 for (
unsigned int i = 0; i < inputTy.getRank(); i++) {
2284 linalg::IndexOp::create(rewriter, nestedLoc, i).getResult();
2288 arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one);
2289 index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne,
2296 auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc,
2298 linalg::YieldOp::create(nestedBuilder, op.getLoc(),
2299 extract.getResult());
2309struct TileConverter :
public OpConversionPattern<tosa::TileOp> {
2310 using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
2313 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2314 ConversionPatternRewriter &rewriter)
const override {
2315 auto loc = op.getLoc();
2316 auto input = op.getInput1();
2317 auto inputTy = cast<ShapedType>(input.
getType());
2318 auto inputShape = inputTy.getShape();
2319 auto resultTy = cast<ShapedType>(op.getType());
2320 auto elementTy = inputTy.getElementType();
2321 int64_t rank = inputTy.getRank();
2323 SmallVector<int64_t> multiples;
2324 if (
failed(op.getConstantMultiples(multiples)))
2328 SmallVector<int64_t, 2> genericShape;
2329 for (
int i = 0; i < rank; i++) {
2330 int64_t dim = multiples[i];
2331 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2332 genericShape.push_back(inputShape[i]);
2335 SmallVector<Value> dynDims;
2336 for (
int i = 0; i < inputTy.getRank(); i++) {
2337 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2338 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2342 auto emptyTensor = tensor::EmptyOp::create(
2343 rewriter, op.getLoc(), genericShape, elementTy, dynDims);
2346 SmallVector<AffineExpr, 4> dimExprs;
2347 dimExprs.reserve(rank);
2348 for (
unsigned i = 0; i < rank; ++i)
2349 dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
2351 auto readAffineMap =
2353 rewriter.getContext());
2355 SmallVector<AffineMap, 2> affineMaps = {
2356 readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
2358 auto genericOp = linalg::GenericOp::create(
2359 rewriter, loc, RankedTensorType::get(genericShape, elementTy), input,
2362 [&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
2363 linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin());
2368 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2369 op, resultTy, genericOp.getResult(0), shapeValue);
2389 using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
2391 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2392 PatternRewriter &rewriter)
const final {
2393 auto loc = argmaxOp.getLoc();
2394 Value input = argmaxOp.getInput();
2395 auto inputTy = cast<ShapedType>(input.
getType());
2396 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2397 auto inElementTy = inputTy.getElementType();
2398 auto outElementTy = resultTy.getElementType();
2399 int axis = argmaxOp.getAxis();
2400 auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2402 if (!isa<IntegerType>(outElementTy))
2403 return rewriter.notifyMatchFailure(
2405 "tosa.arg_max to linalg.* requires integer-like result type");
2407 SmallVector<Value> dynDims;
2408 for (
int i = 0; i < inputTy.getRank(); i++) {
2409 if (inputTy.isDynamicDim(i) && i != axis) {
2410 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2415 auto emptyTensorIdx =
2416 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2417 outElementTy, dynDims)
2419 auto fillValueIdx = arith::ConstantOp::create(
2420 rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
2421 auto filledTensorIdx =
2422 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValueIdx},
2427 auto emptyTensorMax =
2428 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
2431 auto fillValueMaxAttr =
2434 if (!fillValueMaxAttr)
2435 return rewriter.notifyMatchFailure(
2436 argmaxOp,
"unsupported tosa.argmax element type");
2439 arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
2440 auto filledTensorMax =
2441 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValueMax},
2447 SmallVector<utils::IteratorType, 4> iteratorTypes;
2448 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2449 iteratorTypes[axis] = utils::IteratorType::reduction;
2451 SmallVector<AffineExpr, 2> srcExprs;
2452 SmallVector<AffineExpr, 2> dstExprs;
2453 for (
int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2459 bool didEncounterError =
false;
2461 rewriter.getContext());
2462 auto linalgOp = linalg::GenericOp::create(
2463 rewriter, loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2464 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2465 [&](OpBuilder &nestedBuilder, Location nestedLoc,
2467 auto newValue = blockArgs[0];
2468 auto oldIndex = blockArgs[1];
2469 auto oldValue = blockArgs[2];
2471 Value newIndex = arith::IndexCastOp::create(
2472 rewriter, nestedLoc, oldIndex.getType(),
2473 linalg::IndexOp::create(rewriter, loc, axis));
2476 if (isa<FloatType>(inElementTy)) {
2477 if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) {
2480 predicate = arith::CmpFOp::create(rewriter, nestedLoc,
2481 arith::CmpFPredicate::OGT,
2482 newValue, oldValue);
2487 Value gt = arith::CmpFOp::create(rewriter, nestedLoc,
2488 arith::CmpFPredicate::UGT,
2489 newValue, oldValue);
2490 Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc,
2491 arith::CmpFPredicate::ORD,
2492 oldValue, oldValue);
2493 predicate = arith::AndIOp::create(
2494 rewriter, nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
2496 }
else if (isa<IntegerType>(inElementTy)) {
2497 predicate = arith::CmpIOp::create(rewriter, nestedLoc,
2498 arith::CmpIPredicate::sgt,
2499 newValue, oldValue);
2501 didEncounterError =
true;
2505 auto resultMax = arith::SelectOp::create(
2506 rewriter, nestedLoc, predicate, newValue, oldValue);
2507 auto resultIndex = arith::SelectOp::create(
2508 rewriter, nestedLoc, predicate, newIndex, oldIndex);
2509 linalg::YieldOp::create(nestedBuilder, nestedLoc,
2513 if (didEncounterError)
2514 return rewriter.notifyMatchFailure(
2515 argmaxOp,
"unsupported tosa.argmax element type");
2517 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2522class GatherConverter :
public OpConversionPattern<tosa::GatherOp> {
2524 using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2526 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2527 ConversionPatternRewriter &rewriter)
const final {
2528 auto input = adaptor.getOperands()[0];
2529 auto indices = adaptor.getOperands()[1];
2531 auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2532 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2533 if (!valuesTy || !resultTy)
2534 return rewriter.notifyMatchFailure(op,
"unranked tensors not supported");
2536 auto dynamicDims = inferDynamicDimsForGather(
2537 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2539 auto resultElementTy = resultTy.getElementType();
2541 auto loc = op.getLoc();
2543 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2544 resultElementTy, dynamicDims)
2547 SmallVector<AffineMap, 2> affineMaps = {
2549 resultTy.getRank(), 0,
2550 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2551 rewriter.getContext()),
2552 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2554 auto genericOp = linalg::GenericOp::create(
2558 [&](OpBuilder &
b, Location loc,
ValueRange args) {
2559 auto indexValue = args[0];
2560 auto index0 = linalg::IndexOp::create(rewriter, loc, 0);
2561 Value index1 = arith::IndexCastOp::create(
2562 rewriter, loc, rewriter.getIndexType(), indexValue);
2563 auto index2 = linalg::IndexOp::create(rewriter, loc, 2);
2564 Value extract = tensor::ExtractOp::create(
2565 rewriter, loc, input,
ValueRange{index0, index1, index2});
2566 linalg::YieldOp::create(rewriter, loc, extract);
2568 rewriter.replaceOp(op, genericOp.getResult(0));
2572 static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2576 llvm::SmallVector<Value> results;
2578 auto addDynamicDimension = [&](Value source, int64_t dim) {
2580 if (
auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2581 results.push_back(dimValue);
2584 addDynamicDimension(values, 0);
2585 addDynamicDimension(
indices, 1);
2586 addDynamicDimension(values, 2);
2596 using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2598 LogicalResult matchAndRewrite(tosa::TableOp op,
2599 PatternRewriter &rewriter)
const final {
2600 auto loc = op.getLoc();
2601 Value input = op.getInput1();
2602 Value table = op.getTable();
2603 auto inputTy = cast<ShapedType>(input.
getType());
2604 auto tableTy = cast<ShapedType>(table.
getType());
2605 auto resultTy = cast<ShapedType>(op.getType());
2607 auto inputElementTy = inputTy.getElementType();
2608 auto tableElementTy = tableTy.getElementType();
2609 auto resultElementTy = resultTy.getElementType();
2611 SmallVector<Value> dynDims;
2612 for (
int i = 0; i < resultTy.getRank(); ++i) {
2613 if (inputTy.isDynamicDim(i)) {
2615 tensor::DimOp::create(rewriter, loc, op.getOperand(0), i));
2620 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2621 resultElementTy, dynDims)
2624 SmallVector<AffineMap, 2> affineMaps = {
2625 rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2626 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2628 auto genericOp = linalg::GenericOp::create(
2631 rewriter.replaceOp(op, genericOp.getResult(0));
2634 OpBuilder::InsertionGuard regionGuard(rewriter);
2635 Block *block = rewriter.createBlock(
2636 &genericOp.getRegion(), genericOp.getRegion().end(),
2637 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2640 rewriter.setInsertionPointToStart(block);
2641 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2642 resultElementTy.isInteger(8)) {
2643 Value index = arith::IndexCastOp::create(
2644 rewriter, loc, rewriter.getIndexType(), inputValue);
2646 index = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
2649 tensor::ExtractOp::create(rewriter, loc, table,
ValueRange{index});
2650 linalg::YieldOp::create(rewriter, loc, extract);
2654 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2655 resultElementTy.isInteger(32)) {
2656 Value extend = arith::ExtSIOp::create(
2657 rewriter, loc, rewriter.getI32Type(), inputValue);
2659 auto offset = arith::ConstantOp::create(
2660 rewriter, loc, rewriter.getI32IntegerAttr(32768));
2661 auto seven = arith::ConstantOp::create(rewriter, loc,
2662 rewriter.getI32IntegerAttr(7));
2663 auto one = arith::ConstantOp::create(rewriter, loc,
2664 rewriter.getI32IntegerAttr(1));
2665 auto b1111111 = arith::ConstantOp::create(
2666 rewriter, loc, rewriter.getI32IntegerAttr(127));
2672 auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset);
2673 Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven);
2675 arith::AndIOp::create(rewriter, loc, extendAdd, b1111111);
2680 Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one);
2682 index = arith::IndexCastOp::create(rewriter, loc,
2683 rewriter.getIndexType(), index);
2684 indexPlusOne = arith::IndexCastOp::create(
2685 rewriter, loc, rewriter.getIndexType(), indexPlusOne);
2688 tensor::ExtractOp::create(rewriter, loc, table,
ValueRange{index});
2689 Value next = tensor::ExtractOp::create(rewriter, loc, table,
2693 arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), base);
2695 arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), next);
2699 Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven);
2700 Value diff = arith::SubIOp::create(rewriter, loc, next, base);
2701 Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction);
2703 arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled);
2705 linalg::YieldOp::create(rewriter, loc,
result);
2711 return rewriter.notifyMatchFailure(
2712 op,
"unable to create body for tosa.table op");
2717 using OpRewritePattern<RFFT2dOp>::OpRewritePattern;
2719 static bool isRankedTensor(Type type) {
return isa<RankedTensorType>(type); }
2721 static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2727 auto divBy2 = builder.
createOrFold<arith::DivUIOp>(loc, value, two);
2728 auto plusOne = builder.
createOrFold<arith::AddIOp>(loc, divBy2, one);
2732 static RankedTensorType
2733 computeOutputShape(OpBuilder &builder, Location loc, Value input,
2734 llvm::SmallVectorImpl<Value> &dynamicSizes) {
2740 dims[2] = halfPlusOne(builder, loc, dims[2]);
2742 llvm::SmallVector<int64_t, 3> staticSizes;
2745 auto elementType = cast<RankedTensorType>(input.
getType()).getElementType();
2746 return RankedTensorType::get(staticSizes, elementType);
2749 static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2750 RankedTensorType type,
2751 llvm::ArrayRef<Value> dynamicSizes) {
2753 tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
2754 auto fillValueAttr = rewriter.
getZeroAttr(type.getElementType());
2755 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
2757 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValue},
2760 return filledTensor;
2763 static Value castIndexToFloat(OpBuilder &builder, Location loc,
2764 FloatType type, Value value) {
2765 auto integerVal = arith::IndexCastUIOp::create(
2767 type.getIntOrFloatBitWidth() > 32 ? builder.
getI64Type()
2771 return arith::UIToFPOp::create(builder, loc, type, integerVal);
2774 static Value createLinalgIndex(OpBuilder &builder, Location loc,
2775 FloatType type, int64_t index) {
2776 auto indexVal = linalg::IndexOp::create(builder, loc, index);
2777 return castIndexToFloat(builder, loc, type, indexVal);
2780 template <
typename... Args>
2781 static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2786 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2787 PatternRewriter &rewriter)
const override {
2788 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2789 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2791 "only supports ranked tensors");
2794 auto loc = rfft2d.getLoc();
2795 auto input = rfft2d.getInputReal();
2797 dyn_cast<FloatType>(cast<ShapedType>(input.
getType()).getElementType());
2800 "only supports float element types");
2803 llvm::SmallVector<Value> dynamicSizes;
2804 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2807 llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
2808 utils::IteratorType::parallel, utils::IteratorType::parallel,
2809 utils::IteratorType::parallel, utils::IteratorType::reduction,
2810 utils::IteratorType::reduction};
2813 llvm::SmallVector<Value> genericOpInputs = {input};
2814 llvm::SmallVector<Value> genericOpOutputs = {
2815 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2816 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2820 llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2821 affineDimsExpr(rewriter, 0, 1, 2),
2822 affineDimsExpr(rewriter, 0, 1, 2)},
2826 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input, 1);
2827 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input, 2);
2830 auto twoPiAttr = rewriter.
getFloatAttr(elementType, 6.283185307179586);
2831 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2832 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2833 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2835 auto buildBody = [&](OpBuilder &builder, Location loc,
ValueRange args) {
2836 Value valReal = args[0];
2837 Value sumReal = args[1];
2838 Value sumImag = args[2];
2841 Value oy = linalg::IndexOp::create(builder, loc, 1);
2842 Value ox = linalg::IndexOp::create(builder, loc, 2);
2843 Value iy = linalg::IndexOp::create(builder, loc, 3);
2844 Value ix = linalg::IndexOp::create(builder, loc, 4);
2849 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2850 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2852 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2853 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2855 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2856 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2858 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2859 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2860 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2861 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2865 auto cosAngle = math::CosOp::create(builder, loc, angle);
2866 auto sinAngle = math::SinOp::create(builder, loc, angle);
2867 auto realComponent =
2868 arith::MulFOp::create(builder, loc, valReal, cosAngle);
2869 auto imagComponent =
2870 arith::MulFOp::create(builder, loc, valReal, sinAngle);
2875 arith::AddFOp::create(builder, loc, sumReal, realComponent);
2877 arith::SubFOp::create(builder, loc, sumImag, imagComponent);
2879 linalg::YieldOp::create(builder, loc,
ValueRange{outReal, outImag});
2883 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2884 indexingMaps, iteratorTypes, buildBody);
2893 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2894 PatternRewriter &rewriter)
const override {
2895 if (!llvm::all_of(fft2d->getOperandTypes(),
2896 RFFT2dConverter::isRankedTensor) ||
2897 !llvm::all_of(fft2d->getResultTypes(),
2898 RFFT2dConverter::isRankedTensor)) {
2902 Location loc = fft2d.getLoc();
2903 Value input_real = fft2d.getInputReal();
2904 Value input_imag = fft2d.getInputImag();
2905 BoolAttr inverse = fft2d.getInverseAttr();
2907 auto real_el_ty = cast<FloatType>(
2908 cast<ShapedType>(input_real.
getType()).getElementType());
2909 [[maybe_unused]]
auto imag_el_ty = cast<FloatType>(
2910 cast<ShapedType>(input_imag.
getType()).getElementType());
2912 assert(real_el_ty == imag_el_ty);
2915 SmallVector<Value> dynamicSizes;
2920 SmallVector<int64_t, 3> staticSizes;
2923 auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2926 SmallVector<utils::IteratorType, 5> iteratorTypes = {
2927 utils::IteratorType::parallel, utils::IteratorType::parallel,
2928 utils::IteratorType::parallel, utils::IteratorType::reduction,
2929 utils::IteratorType::reduction};
2932 SmallVector<Value> genericOpInputs = {input_real, input_imag};
2933 SmallVector<Value> genericOpOutputs = {
2934 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2936 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2941 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2942 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2943 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2944 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2948 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 1);
2949 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 2);
2952 auto twoPiAttr = rewriter.
getFloatAttr(real_el_ty, 6.283185307179586);
2953 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2955 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2957 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2959 auto buildBody = [&](OpBuilder &builder, Location loc,
ValueRange args) {
2960 Value valReal = args[0];
2961 Value valImag = args[1];
2962 Value sumReal = args[2];
2963 Value sumImag = args[3];
2966 Value oy = linalg::IndexOp::create(builder, loc, 1);
2967 Value ox = linalg::IndexOp::create(builder, loc, 2);
2968 Value iy = linalg::IndexOp::create(builder, loc, 3);
2969 Value ix = linalg::IndexOp::create(builder, loc, 4);
2973 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2974 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2976 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2977 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2980 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2982 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2984 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2985 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2987 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2988 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2991 angle = arith::MulFOp::create(
2992 builder, loc, angle,
2993 arith::ConstantOp::create(rewriter, loc,
2999 auto cosAngle = math::CosOp::create(builder, loc, angle);
3000 auto sinAngle = math::SinOp::create(builder, loc, angle);
3002 auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle);
3003 auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle);
3004 auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin);
3006 auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle);
3007 auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle);
3009 auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin);
3014 arith::AddFOp::create(builder, loc, sumReal, realComponent);
3016 arith::AddFOp::create(builder, loc, sumImag, imagComponent);
3018 linalg::YieldOp::create(builder, loc,
ValueRange{outReal, outImag});
3022 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
3023 indexingMaps, iteratorTypes, buildBody);
3035 patterns->
add<GenericResizeConverter>(patterns->
getContext(),
3039 patterns->
add<MaterializeResizeBroadcast>(patterns->
getContext(),
3044 PointwiseConverter<tosa::AddOp>,
3045 PointwiseConverter<tosa::SubOp>,
3046 PointwiseConverter<tosa::MulOp>,
3047 PointwiseConverter<tosa::IntDivOp>,
3048 PointwiseConverter<tosa::NegateOp>,
3049 PointwiseConverter<tosa::PowOp>,
3050 PointwiseConverter<tosa::ReciprocalOp>,
3051 PointwiseConverter<tosa::RsqrtOp>,
3052 PointwiseConverter<tosa::LogOp>,
3053 PointwiseConverter<tosa::ExpOp>,
3054 PointwiseConverter<tosa::AbsOp>,
3055 PointwiseConverter<tosa::SinOp>,
3056 PointwiseConverter<tosa::CosOp>,
3057 PointwiseConverter<tosa::TanhOp>,
3058 PointwiseConverter<tosa::ErfOp>,
3059 PointwiseConverter<tosa::BitwiseAndOp>,
3060 PointwiseConverter<tosa::BitwiseOrOp>,
3061 PointwiseConverter<tosa::BitwiseNotOp>,
3062 PointwiseConverter<tosa::BitwiseXorOp>,
3063 PointwiseConverter<tosa::LogicalAndOp>,
3064 PointwiseConverter<tosa::LogicalNotOp>,
3065 PointwiseConverter<tosa::LogicalOrOp>,
3066 PointwiseConverter<tosa::LogicalXorOp>,
3067 PointwiseConverter<tosa::CastOp>,
3068 PointwiseConverter<tosa::LogicalLeftShiftOp>,
3069 PointwiseConverter<tosa::LogicalRightShiftOp>,
3070 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
3071 PointwiseConverter<tosa::ClzOp>,
3072 PointwiseConverter<tosa::SelectOp>,
3073 PointwiseConverter<tosa::GreaterOp>,
3074 PointwiseConverter<tosa::GreaterEqualOp>,
3075 PointwiseConverter<tosa::EqualOp>,
3076 PointwiseConverter<tosa::MaximumOp>,
3077 PointwiseConverter<tosa::MinimumOp>,
3078 PointwiseConverter<tosa::CeilOp>,
3079 PointwiseConverter<tosa::FloorOp>,
3080 PointwiseConverter<tosa::ClampOp>,
3081 PointwiseConverter<tosa::SigmoidOp>
3085 IdentityNConverter<tosa::IdentityOp>,
3086 ReduceConverter<tosa::ReduceAllOp>,
3087 ReduceConverter<tosa::ReduceAnyOp>,
3088 ReduceConverter<tosa::ReduceMinOp>,
3089 ReduceConverter<tosa::ReduceMaxOp>,
3090 ReduceConverter<tosa::ReduceSumOp>,
3091 ReduceConverter<tosa::ReduceProductOp>,
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
DenseMap< int64_t, Value > IndexPool
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static bool operandsAndResultsRanked(Operation *operation)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
BlockArgument getArgument(unsigned i)
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
BoolAttr getBoolAttr(bool value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
IntegerAttr getI8IntegerAttr(int8_t value)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...