29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/Sequence.h"
59template <
typename OpTy>
67 auto nanMode = op.getNanMode();
68 if (nanMode == NanPropagationMode::PROPAGATE)
72 Value lhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
73 arith::CmpFPredicate::UNO,
lhs,
lhs);
74 Value rhsIsNaN = arith::CmpFOp::create(rewriter, op.getLoc(),
75 arith::CmpFPredicate::UNO,
rhs,
rhs);
77 arith::SelectOp::create(rewriter, op.getLoc(), lhsIsNaN,
rhs,
result);
78 return arith::SelectOp::create(rewriter, op.getLoc(), rhsIsNaN,
lhs,
84 ConversionPatternRewriter &rewriter) {
90 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
91 return math::AbsFOp::create(rewriter, loc, resultTypes, args);
93 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
94 auto zero = arith::ConstantOp::create(rewriter, loc,
95 rewriter.getZeroAttr(elementTy));
96 auto neg = arith::SubIOp::create(rewriter, loc, zero, args[0]);
97 return arith::MaxSIOp::create(rewriter, loc, args[0], neg);
101 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
102 return arith::AddFOp::create(rewriter, loc, resultTypes, args);
104 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
105 return arith::AddIOp::create(rewriter, loc, resultTypes, args);
108 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
109 return arith::SubFOp::create(rewriter, loc, resultTypes, args);
111 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
112 return arith::SubIOp::create(rewriter, loc, resultTypes, args);
115 if (isa<tosa::IntDivOp>(op) && isa<IntegerType>(elementTy))
116 return arith::DivSIOp::create(rewriter, loc, resultTypes, args);
119 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
121 arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
122 return arith::DivFOp::create(rewriter, loc, resultTypes, one, args[0]);
126 if (isa<tosa::MulOp>(op)) {
127 auto shiftVal = cast<tosa::MulOp>(op).getShift();
129 bool shiftIsConstant =
true;
132 shift = shiftElem.
getValues<IntegerAttr>()[0].getInt();
134 shiftIsConstant =
false;
136 if (isa<FloatType>(elementTy)) {
138 (
void)rewriter.notifyMatchFailure(op,
139 "Cannot have shift value for float");
142 return arith::MulFOp::create(rewriter, loc, resultTypes, args[0],
146 if (isa<IntegerType>(elementTy)) {
150 if (shift > 0 || !shiftIsConstant) {
157 a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
159 if (!
b.getType().isInteger(32))
160 b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(),
b);
162 auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
163 auto roundingAttr = RoundingModeAttr::get(rewriter.getContext(),
164 RoundingMode::SINGLE_ROUND);
166 tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a,
167 b, shiftAmount, roundingAttr);
173 int bWidth =
b.getType().getIntOrFloatBitWidth();
174 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
177 a = arith::ExtSIOp::create(rewriter, loc, resultTypes[0], a);
179 b = arith::ExtSIOp::create(rewriter, loc, resultTypes[0],
b);
181 return arith::MulIOp::create(rewriter, loc, resultTypes, a,
b);
186 if (isa<tosa::NegateOp>(op)) {
187 auto negate = cast<tosa::NegateOp>(op);
190 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
191 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
192 bool hasInZp = !failed(maybeInZp);
193 bool hasOutZp = !failed(maybeOutZp);
199 if (isa<FloatType>(elementTy))
200 return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
202 if (isa<IntegerType>(elementTy)) {
204 Type intermediateType;
207 int intermediateBitWidth = 64;
209 if (hasInZp && hasOutZp) {
211 const int64_t zpAdd = inZp + outZp;
213 APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
218 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
219 intermediateBitWidth = 16;
220 }
else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
221 intermediateBitWidth = 32;
224 intermediateType = rewriter.getIntegerType(intermediateBitWidth);
225 zpAddValue = arith::ConstantOp::create(
226 rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
228 intermediateType = rewriter.getIntegerType(intermediateBitWidth);
230 arith::ExtSIOp::create(rewriter, loc, intermediateType, args[1]);
232 arith::ExtSIOp::create(rewriter, loc, intermediateType, args[2]);
234 arith::AddIOp::create(rewriter, loc, intermediateType, arg1, arg2);
240 arith::ExtSIOp::create(rewriter, loc, intermediateType, args[0]);
241 auto sub = arith::SubIOp::create(rewriter, loc, zpAddValue, ext);
245 rewriter, loc, intermediateType,
246 APInt::getSignedMinValue(inputBitWidth).getSExtValue());
248 rewriter, loc, intermediateType,
249 APInt::getSignedMaxValue(inputBitWidth).getSExtValue());
253 return arith::TruncIOp::create(rewriter, loc, elementTy,
clamp);
258 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
259 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
262 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
263 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
266 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
267 auto allOnesAttr = rewriter.getIntegerAttr(
268 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
269 auto allOnes = arith::ConstantOp::create(rewriter, loc, allOnesAttr);
270 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], allOnes);
274 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
275 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
278 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
279 return arith::ShLIOp::create(rewriter, loc, resultTypes, args);
282 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
283 return arith::ShRUIOp::create(rewriter, loc, resultTypes, args);
286 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
287 auto result = arith::ShRSIOp::create(rewriter, loc, resultTypes, args);
288 auto round = cast<BoolAttr>(op->
getAttr(
"round")).getValue();
293 Type i1Ty = IntegerType::get(rewriter.getContext(), 1);
294 auto one = arith::ConstantOp::create(rewriter, loc,
295 IntegerAttr::get(elementTy, 1));
296 auto zero = arith::ConstantOp::create(rewriter, loc,
297 IntegerAttr::get(elementTy, 0));
299 arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 0));
301 arith::ConstantOp::create(rewriter, loc, IntegerAttr::get(i1Ty, 1));
304 auto shiftValueGreaterThanZero = arith::CmpIOp::create(
305 rewriter, loc, arith::CmpIPredicate::sgt, args[1], zero);
309 arith::SubIOp::create(rewriter, loc, resultTypes, args[1], one);
311 arith::ShRSIOp::create(rewriter, loc, resultTypes, args[0], subtract)
313 auto truncated = arith::TruncIOp::create(rewriter, loc, i1Ty, shifted,
316 arith::AndIOp::create(rewriter, loc, i1Ty, truncated, i1one);
318 auto shouldRound = arith::SelectOp::create(
319 rewriter, loc, i1Ty, shiftValueGreaterThanZero, isInputOdd, i1zero);
321 arith::ExtUIOp::create(rewriter, loc, resultTypes, shouldRound);
322 return arith::AddIOp::create(rewriter, loc, resultTypes,
result, extended);
326 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
327 return math::CountLeadingZerosOp::create(rewriter, loc, elementTy, args[0]);
331 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
332 return arith::AndIOp::create(rewriter, loc, resultTypes, args);
335 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
336 auto one = arith::ConstantOp::create(rewriter, loc,
337 rewriter.getIntegerAttr(elementTy, 1));
338 return arith::XOrIOp::create(rewriter, loc, resultTypes, args[0], one);
342 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
343 return arith::OrIOp::create(rewriter, loc, resultTypes, args);
346 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
347 return arith::XOrIOp::create(rewriter, loc, resultTypes, args);
350 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
351 return mlir::math::PowFOp::create(rewriter, loc, resultTypes, args);
354 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
355 return mlir::math::RsqrtOp::create(rewriter, loc, resultTypes, args);
358 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
359 return mlir::math::LogOp::create(rewriter, loc, resultTypes, args);
362 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
363 return mlir::math::ExpOp::create(rewriter, loc, resultTypes, args);
366 if (isa<tosa::SinOp>(op) && isa<FloatType>(elementTy))
367 return mlir::math::SinOp::create(rewriter, loc, resultTypes, args);
370 if (isa<tosa::CosOp>(op) && isa<FloatType>(elementTy))
371 return mlir::math::CosOp::create(rewriter, loc, resultTypes, args);
374 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
375 return mlir::math::TanhOp::create(rewriter, loc, resultTypes, args);
378 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
379 return mlir::math::ErfOp::create(rewriter, loc, resultTypes, args);
382 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
383 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGT,
386 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
387 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt,
391 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
392 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OGE,
395 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
396 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
400 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
401 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::OEQ,
404 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
405 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
409 if (isa<tosa::SelectOp>(op)) {
411 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
412 return arith::SelectOp::create(rewriter, loc, args[0], args[1], args[2]);
416 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
417 auto max = arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
419 rewriter, args[0], args[1],
max);
422 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
423 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
427 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
428 auto min = arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
430 rewriter, args[0], args[1],
min);
433 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
434 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
438 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
439 return math::CeilOp::create(rewriter, loc, resultTypes, args);
442 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
443 return math::FloorOp::create(rewriter, loc, resultTypes, args);
446 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
447 bool losesInfo =
false;
448 APFloat minApf = cast<FloatAttr>(op->
getAttr(
"min_val")).getValue();
449 APFloat maxApf = cast<FloatAttr>(op->
getAttr(
"max_val")).getValue();
450 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
451 APFloat::rmNearestTiesToEven, &losesInfo);
452 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
453 APFloat::rmNearestTiesToEven, &losesInfo);
454 auto min = arith::ConstantOp::create(
455 rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, minApf));
456 auto max = arith::ConstantOp::create(
457 rewriter, loc, elementTy, rewriter.getFloatAttr(elementTy, maxApf));
460 auto clampOp = llvm::cast<tosa::ClampOp>(op);
461 const auto nanMode = clampOp.getNanMode();
464 if (!isa<FloatType>(elementTy))
469 if (nanMode == NanPropagationMode::PROPAGATE)
483 Value isNaN = arith::CmpFOp::create(
484 rewriter, op->
getLoc(), arith::CmpFPredicate::UNO, args[0], args[0]);
487 return arith::SelectOp::create(rewriter, op->
getLoc(), isNaN,
min,
result);
490 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
491 auto intTy = cast<IntegerType>(elementTy);
493 cast<IntegerAttr>(op->
getAttr(
"min_val")).getValue().getSExtValue();
495 cast<IntegerAttr>(op->
getAttr(
"max_val")).getValue().getSExtValue();
497 int64_t minRepresentable = std::numeric_limits<int64_t>::min();
498 int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
499 if (intTy.isUnsignedInteger()) {
500 minRepresentable = 0;
501 if (intTy.getIntOrFloatBitWidth() <= 63) {
503 (
int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
506 }
else if (intTy.getIntOrFloatBitWidth() <= 64) {
508 minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
510 maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
515 min = std::max(
min, minRepresentable);
516 max = std::max(
max, minRepresentable);
517 min = std::min(
min, maxRepresentable);
518 max = std::min(
max, maxRepresentable);
521 intTy.getIntOrFloatBitWidth());
523 intTy.getIntOrFloatBitWidth());
525 intTy.isUnsignedInteger());
529 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
531 arith::ConstantOp::create(rewriter, loc, FloatAttr::get(elementTy, 1));
532 auto negate = arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
533 auto exp = mlir::math::ExpOp::create(rewriter, loc, resultTypes, negate);
534 auto added = arith::AddFOp::create(rewriter, loc, resultTypes, exp, one);
535 return arith::DivFOp::create(rewriter, loc, resultTypes, one, added);
539 if (isa<tosa::CastOp>(op)) {
540 Type srcTy = elementTy;
541 Type dstTy = resultTypes.front();
543 (
void)rewriter.notifyMatchFailure(op,
"unsupported type");
553 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
554 return arith::ExtFOp::create(rewriter, loc, resultTypes, args,
557 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
558 return arith::TruncFOp::create(rewriter, loc, resultTypes, args,
562 if (srcTy.
isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
563 return arith::UIToFPOp::create(rewriter, loc, resultTypes, args,
566 if (srcTy.
isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
567 return arith::ExtUIOp::create(rewriter, loc, resultTypes, args,
573 auto unrealizedCast =
574 UnrealizedConversionCastOp::create(
578 return arith::UIToFPOp::create(rewriter, loc, resultTypes[0],
583 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
584 return arith::SIToFPOp::create(rewriter, loc, resultTypes, args,
588 if (isa<FloatType>(srcTy) && dstTy.
isInteger(1)) {
589 Value zero = arith::ConstantOp::create(rewriter, loc,
590 rewriter.getFloatAttr(srcTy, 0.0));
591 return arith::CmpFOp::create(rewriter, loc, arith::CmpFPredicate::UNE,
595 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
596 auto rounded = math::RoundEvenOp::create(rewriter, loc, args[0]);
598 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
602 APFloat::semanticsMaxExponent(fltSemantics)) {
605 auto conv = arith::FPToSIOp::create(rewriter, loc, dstTy, rounded);
606 auto posInf = arith::ConstantOp::create(
609 APFloat::getInf(fltSemantics)));
610 auto negInf = arith::ConstantOp::create(
612 rewriter.getFloatAttr(
614 APFloat::getInf(fltSemantics,
true)));
615 auto overflow = arith::CmpFOp::create(
616 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, posInf);
617 auto underflow = arith::CmpFOp::create(
618 rewriter, loc, arith::CmpFPredicate::UEQ, rounded, negInf);
619 auto intMin = arith::ConstantOp::create(
621 rewriter.getIntegerAttr(
624 auto intMax = arith::ConstantOp::create(
626 rewriter.getIntegerAttr(
630 arith::SelectOp::create(rewriter, loc, overflow, intMax, conv);
631 return arith::SelectOp::create(rewriter, loc, underflow, intMin,
635 auto intMinFP = arith::ConstantOp::create(
637 rewriter.getFloatAttr(
643 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
649 auto intMaxFP = arith::ConstantOp::create(
651 rewriter.getFloatAttr(
658 return arith::FPToSIOp::create(rewriter, loc, dstTy, clamped);
665 auto intMaxPlusOneFP = arith::ConstantOp::create(
667 rewriter.getFloatAttr(
674 auto intMax = arith::ConstantOp::create(
676 rewriter.getIntegerAttr(
680 arith::MaximumFOp::create(rewriter, loc, rounded, intMinFP);
682 arith::FPToSIOp::create(rewriter, loc, dstTy, minClampedFP);
683 auto overflow = arith::CmpFOp::create(
684 rewriter, loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
685 return arith::SelectOp::create(rewriter, loc, overflow, intMax,
691 if (isa<IntegerType>(srcTy) && dstTy.
isInteger(1)) {
694 return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne,
698 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
699 return arith::ExtSIOp::create(rewriter, loc, resultTypes, args,
702 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
703 return arith::TruncIOp::create(rewriter, loc, dstTy, args[0]);
707 (
void)rewriter.notifyMatchFailure(
708 op,
"unhandled op for linalg body calculation for elementwise op");
729 return tensor::DimOp::create(rewriter, loc,
tensor, indexValue).getResult();
735 auto shapedType = dyn_cast<ShapedType>(
tensor.getType());
736 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
737 assert(
index >= 0 &&
index < shapedType.getRank() &&
"index out of bounds");
738 if (shapedType.isDynamicDim(
index))
744 auto isRanked = [](
Value value) {
745 return isa<RankedTensorType>(value.getType());
747 return llvm::all_of(operation->
getOperands(), isRanked) &&
748 llvm::all_of(operation->
getResults(), isRanked);
761static std::pair<OpFoldResult, Value>
767 for (
auto operand : operands) {
768 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
769 if (ShapedType::isStatic(size) && size > 1)
774 auto operandsWithDynamicDim =
775 llvm::filter_to_vector(operands, [&](
Value operand) {
776 return cast<RankedTensorType>(operand.
getType()).isDynamicDim(dim);
780 if (operandsWithDynamicDim.empty())
787 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
788 if (operandsWithDynamicDim.size() == 1)
789 return {targetSize, operandsWithDynamicDim[0]};
792 for (
size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
794 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
795 targetSize = arith::MaxUIOp::create(rewriter, loc, targetSize, nextSize);
797 return {targetSize,
nullptr};
805 assert(!operands.empty());
806 auto rank = cast<RankedTensorType>(operands.front().
getType()).getRank();
809 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
810 auto [targetSize, masterOperand] =
812 targetShape.push_back(targetSize);
813 masterOperands.push_back(masterOperand);
815 return {targetShape, masterOperands};
821 Value masterOperand) {
823 auto rankedTensorType = cast<RankedTensorType>(operand.
getType());
824 if (!rankedTensorType.isDynamicDim(dim))
831 if (operand == masterOperand)
835 auto rank = rankedTensorType.getRank();
837 for (
auto index : llvm::seq<int64_t>(0, rank)) {
840 affineExprs.push_back(affineExpr);
842 auto broadcastAffineMap =
848 auto one =
createIndex(rewriter, loc, indexPool, 1);
849 auto runtimeSize =
getTensorDim(rewriter, loc, indexPool, operand, dim);
850 auto broadcastNecessary = arith::CmpIOp::create(
851 rewriter, loc, arith::CmpIPredicate::eq, runtimeSize, one);
861 for (
auto index : llvm::seq<int64_t>(0, rank)) {
862 auto size =
index == dim ? targetSize
865 outputTensorShape.push_back(size);
867 Value outputTensor = tensor::EmptyOp::create(
868 opBuilder, loc, outputTensorShape, rankedTensorType.getElementType());
872 linalg::GenericOp::create(
873 opBuilder, loc, outputTensor.
getType(), operand, outputTensor,
877 linalg::YieldOp::create(opBuilder, loc, blockArgs.front());
882 auto castResultTensor = rewriter.
createOrFold<tensor::CastOp>(
883 loc, operand.
getType(), resultTensor);
886 scf::YieldOp::create(opBuilder, loc, castResultTensor);
891 scf::YieldOp::create(opBuilder, loc, operand);
895 auto ifOp = scf::IfOp::create(rewriter, loc, broadcastNecessary,
896 emitThenRegion, emitElseRegion);
897 return ifOp.getResult(0);
904 int64_t rank = cast<RankedTensorType>(operand.
getType()).getRank();
905 assert((
int64_t)targetShape.size() == rank);
906 assert((
int64_t)masterOperands.size() == rank);
907 for (
auto index : llvm::seq<int64_t>(0, rank))
920 if (operands.size() == 1)
924 bool hasDynamic =
false;
925 for (
auto op : operands) {
926 const auto tType = dyn_cast<RankedTensorType>(op.getType());
927 if (tType && !tType.hasStaticShape()) {
936 return llvm::map_to_vector(operands, [&](
Value operand) {
938 targetShape, masterOperands);
948 auto resultType = cast_or_null<RankedTensorType>(
951 return rewriter.notifyMatchFailure(operation,
"failed to convert type");
953 Value outputTensor = tensor::EmptyOp::create(rewriter, loc, targetShape,
954 resultType.getElementType());
959 auto rank = resultType.getRank();
960 auto affineMaps = llvm::map_to_vector(operands, [&](
Value operand) {
961 auto shape = cast<ShapedType>(operand.
getType()).getShape();
963 for (
auto it : llvm::enumerate(
shape)) {
967 bool requiresBroadcast =
968 (it.value() == 1 && resultType.getDimSize(it.index()) != 1);
969 auto affineExpr = requiresBroadcast
970 ? rewriter.getAffineConstantExpr(0)
971 : rewriter.getAffineDimExpr(it.index());
972 affineExprs.push_back(affineExpr);
974 return AffineMap::get(rank, 0, affineExprs, rewriter.getContext());
976 affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
979 bool encounteredError =
false;
980 auto linalgOp = linalg::GenericOp::create(
981 rewriter, loc, outputTensor.
getType(), operands, outputTensor, affineMaps,
986 {resultType.getElementType()}, rewriter);
988 encounteredError =
true;
991 linalg::YieldOp::create(opBuilder, loc, opResult);
993 if (encounteredError)
994 return rewriter.notifyMatchFailure(
995 operation,
"unable to create linalg.generic body for elementwise op");
998 auto castResult = rewriter.createOrFold<tensor::CastOp>(
999 loc, resultType, linalgOp->getResult(0));
1000 rewriter.replaceOp(operation, castResult);
1007 if (isa<tosa::MulOp>(operation)) {
1011 return operands.take_front(2);
1013 return operands.take_front(3);
1015 if (
auto negate = dyn_cast<tosa::NegateOp>(operation)) {
1016 FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
1017 FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
1018 if (failed(maybeOutZp) && failed(maybeInZp))
1021 return operands.take_front(1);
1028 ConversionPatternRewriter &rewriter,
1032 assert(operation->
getNumResults() == 1 &&
"elementwise op expects 1 result");
1034 "elementwise op expects at least 1 operand");
1036 return rewriter.notifyMatchFailure(operation,
1037 "Unranked tensors not supported");
1041 auto loc = operation->
getLoc();
1043 auto [targetShape, masterOperands] =
1045 auto broadcastOperands =
1047 targetShape, masterOperands);
1049 targetShape, converter);
1056 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
1059 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
1062 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy))
1065 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy))
1068 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
1070 elementTy, APFloat::getLargest(
1071 cast<FloatType>(elementTy).getFloatSemantics(),
false));
1073 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
1077 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
1079 elementTy, APFloat::getLargest(
1080 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1082 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
1086 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1089 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1092 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
1094 elementTy, APFloat::getLargest(
1095 cast<FloatType>(elementTy).getFloatSemantics(),
true));
1097 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
1111 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
1112 return arith::AddFOp::create(rewriter, loc, args);
1115 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
1116 return arith::AddIOp::create(rewriter, loc, args);
1119 if (isa<tosa::ReduceProductOp>(op) && isa<FloatType>(elementTy)) {
1120 return arith::MulFOp::create(rewriter, loc, args);
1123 if (isa<tosa::ReduceProductOp>(op) && isa<IntegerType>(elementTy)) {
1124 return arith::MulIOp::create(rewriter, loc, args);
1127 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
1128 return arith::MinimumFOp::create(rewriter, loc, args[0], args[1]);
1131 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1132 return arith::MinSIOp::create(rewriter, loc, args[0], args[1]);
1135 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1136 return arith::MaximumFOp::create(rewriter, loc, args[0], args[1]);
1139 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1140 return arith::MaxSIOp::create(rewriter, loc, args[0], args[1]);
1143 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1144 return arith::AndIOp::create(rewriter, loc, args);
1146 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1147 return arith::OrIOp::create(rewriter, loc, args);
1155template <
typename OpTy>
1158 auto loc = op->getLoc();
1159 auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
1160 auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
1161 if (!inputTy || !resultTy)
1164 auto elementTy = resultTy.getElementType();
1165 Value input = op->getOperand(0);
1168 bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
1169 isa<FloatType>(elementTy) &&
1170 cast<FloatType>(elementTy).isBF16();
1175 for (
unsigned i = 0; i < inputTy.getRank(); i++) {
1177 reduceShape.push_back(inputTy.getDimSize(i));
1178 if (inputTy.isDynamicDim(i))
1179 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1184 inputs.push_back(input);
1188 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1194 op,
"No initial value found for reduction operation");
1196 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
1198 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValue},
1201 outputs.push_back(filledTensor);
1203 bool isNanIgnoreMode =
false;
1204 if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> ||
1205 std::is_same_v<OpTy, tosa::ReduceMaxOp>) {
1207 if (isa<FloatType>(elementTy) &&
1208 op.getNanMode() == NanPropagationMode::IGNORE) {
1209 isNanIgnoreMode =
true;
1215 auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr);
1216 auto emptyBoolTensor =
1217 tensor::EmptyOp::create(rewriter, loc, reduceShape,
1218 trueValue.getType(), dynDims)
1220 auto allResultsNaNTensor =
1221 linalg::FillOp::create(rewriter, loc,
ValueRange{trueValue},
1233 inputs.push_back(input);
1234 outputs.push_back(allResultsNaNTensor);
1238 bool didEncounterError =
false;
1239 linalg::LinalgOp linalgOp = linalg::ReduceOp::create(
1240 rewriter, loc, inputs, outputs, axis,
1242 std::array<Value, 2> binaryArgs{
1243 blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
1246 if (binaryArgs[0].
getType() != accTy)
1247 binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
1253 didEncounterError =
true;
1256 if (isNanIgnoreMode) {
1257 auto inputValue = blockArgs[0];
1258 auto initialValue = blockArgs[2];
1259 auto oldAllResultsNanFlagValue = blockArgs[3];
1262 Value isNaN = arith::CmpFOp::create(nestedBuilder, op->getLoc(),
1263 arith::CmpFPredicate::UNO,
1264 inputValue, inputValue);
1266 auto selectOp = arith::SelectOp::create(nestedBuilder, op->getLoc(),
1267 isNaN, initialValue,
result);
1270 auto newAllResultsNanFlagValue = arith::AndIOp::create(
1271 nestedBuilder, op->getLoc(), oldAllResultsNanFlagValue, isNaN);
1272 resultsToYield.push_back(selectOp);
1273 resultsToYield.push_back(newAllResultsNanFlagValue);
1275 resultsToYield.push_back(
result);
1277 linalg::YieldOp::create(nestedBuilder, loc, resultsToYield);
1280 if (!didEncounterError)
1282 op,
"unable to create linalg.generic body for reduce op");
1284 if (isNanIgnoreMode) {
1293 APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(),
false));
1294 auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
1295 auto emptyNanTensor =
1296 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1298 auto nanFilledTensor =
1299 linalg::FillOp::create(rewriter, loc,
ValueRange{nanValue},
1305 auto finalEmptyTensor =
1306 tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
1312 ins.push_back(linalgOp->getOpResult(1));
1313 ins.push_back(nanFilledTensor);
1314 ins.push_back(linalgOp->getResult(0));
1315 outs.push_back(finalEmptyTensor);
1317 linalg::SelectOp::create(rewriter, op->getLoc(), ins, outs);
1318 linalgOp = linalgSelect;
1322 Value reducedRes = linalgOp->getResult(0);
1325 tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
1328 const unsigned reducedRank =
1329 cast<ShapedType>(reducedRes.
getType()).getRank();
1332 linalg::GenericOp::create(
1338 Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc,
1339 elementTy, args[0]);
1340 linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf);
1346 uint64_t expandInputRank = cast<ShapedType>(reducedRes.
getType()).getRank();
1347 reassociationMap.resize(expandInputRank);
1349 for (uint64_t i = 0; i < expandInputRank; i++) {
1350 int32_t dimToPush = i > axis ? i + 1 : i;
1354 if (expandInputRank != 0) {
1355 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1356 reassociationMap[expandedDim].push_back(
1371template <
typename SrcOp>
1372class PointwiseConverter :
public OpConversionPattern<SrcOp> {
1374 using OpConversionPattern<SrcOp>::OpConversionPattern;
1375 using typename OpConversionPattern<SrcOp>::OpAdaptor;
1378 matchAndRewrite(SrcOp op, OpAdaptor operands,
1379 ConversionPatternRewriter &rewriter)
const final {
1381 op, operands.getOperands(), rewriter, *this->getTypeConverter());
1391 auto inputType = cast<RankedTensorType>(input.
getType());
1392 auto elemType = inputType.getElementType();
1393 auto collapsedType = RankedTensorType::get({}, elemType);
1395 return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input,
1402 output.reserve(input.size());
1404 for (
auto v : llvm::map_range(
1405 input, [](int32_t val) {
return static_cast<int8_t
>(val); })) {
1406 output.push_back(v);
1418static void setupLinalgGenericOpInputAndIndexingMap(
1421 bool isConstant, tosa::RescaleOp op,
Value &constant,
int64_t &arg,
1422 bool isShift =
false) {
1424 auto loc = op.getLoc();
1425 auto inputTy = cast<ShapedType>(op.getInput().getType());
1426 unsigned rank = inputTy.getRank();
1432 if (values.size() == 1) {
1433 IntegerAttr intAttr = isShift
1436 constant = arith::ConstantOp::create(rewriter, loc, intAttr);
1440 auto tensorType = RankedTensorType::get(
1441 {
static_cast<int64_t>(values.size())}, elementType);
1447 genericInputs.push_back(
1448 arith::ConstantOp::create(rewriter, loc, EltAttr));
1456 auto operand = isShift ? op.getShift() : op.getMultiplier();
1457 auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
1458 if (tensorType && tensorType.hasStaticShape() &&
1459 tensorType.getShape()[0] == 1) {
1464 genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
1465 indexingMaps.push_back(broadcastMap);
1467 genericInputs.push_back(operand);
1473 arg = indexingMaps.size() - 1;
1478 FailureOr<int64_t> maybeZp,
Location loc,
1480 bool isOutputZp =
false) {
1483 const uint32_t attrBitwidth =
1484 isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
1491 result = blockArgs[zpArg];
1492 auto zpTy =
result.getType();
1493 if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
1496 if (zpTy.isUnsignedInteger()) {
1498 UnrealizedConversionCastOp::create(
1503 if (zpTy.isUnsignedInteger()) {
1504 return arith::ExtUIOp::create(builder, loc, extendType,
result);
1506 return arith::ExtSIOp::create(builder, loc, extendType,
result);
1510 return arith::ConstantOp::create(builder, loc,
1511 IntegerAttr::get(extendType, *maybeZp));
1518 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
1520 LogicalResult matchAndRewrite(tosa::RescaleOp op,
1521 PatternRewriter &rewriter)
const final {
1522 auto loc = op.getLoc();
1523 auto input = op.getInput();
1524 auto inputTy = cast<ShapedType>(op.getInput().getType());
1525 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1526 unsigned rank = inputTy.getRank();
1529 if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND)
1531 op,
"tosa.rescale with rounding mode = 'INEXACT_ROUND' is not "
1532 "currently supported");
1533 if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32())
1535 op,
"tosa.rescale requires scale32 for double_round to be true");
1537 if (!isa<IntegerType>(inputTy.getElementType()))
1540 SmallVector<Value> dynDims;
1541 for (
int i = 0; i < outputTy.getRank(); i++) {
1542 if (outputTy.isDynamicDim(i)) {
1543 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
1547 DenseElementsAttr shiftElems;
1548 bool isShiftConstant =
false;
1550 isShiftConstant =
true;
1552 DenseElementsAttr multiplierElems;
1553 bool isMultiplierConstant =
false;
1555 isMultiplierConstant =
true;
1557 llvm::SmallVector<int32_t> shiftValues;
1558 llvm::SmallVector<int32_t> multiplierValues;
1561 if (isMultiplierConstant && isShiftConstant) {
1563 shiftValues = llvm::to_vector(llvm::map_range(
1564 shiftElems.
getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
1565 return static_cast<int32_t>(attr.getInt());
1567 multiplierValues = llvm::to_vector(
1568 llvm::map_range(multiplierElems.
getValues<IntegerAttr>(),
1569 [](IntegerAttr attr) -> int32_t {
1570 return static_cast<int32_t>(attr.getInt());
1574 for (
int i = 0, s = multiplierValues.size(); i < s; i++) {
1575 if (shiftValues[i] > 63) {
1577 multiplierValues[i] = 0;
1582 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1583 llvm::any_of(shiftValues, [](int32_t v) {
return v > 31; });
1585 doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
1587 RoundingMode roundingMode =
1588 doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
1590 SmallVector<AffineMap> indexingMaps = {
1592 SmallVector<Value, 4> genericInputs = {input};
1596 Value multiplierConstant;
1597 int64_t multiplierArg = 0;
1598 setupLinalgGenericOpInputAndIndexingMap(
1599 rewriter, multiplierValues, genericInputs, indexingMaps,
1600 isMultiplierConstant, op, multiplierConstant, multiplierArg);
1604 Value shiftConstant;
1605 int64_t shiftArg = 0;
1606 setupLinalgGenericOpInputAndIndexingMap(
1607 rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1608 shiftConstant, shiftArg,
true);
1613 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1614 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1624 genericInputs.push_back(
1625 collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
1626 indexingMaps.push_back(broadcastMap);
1627 iZpArg = indexingMaps.size() - 1;
1631 genericInputs.push_back(
1632 collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
1633 indexingMaps.push_back(broadcastMap);
1634 oZpArg = indexingMaps.size() - 1;
1641 Value emptyTensor = tensor::EmptyOp::create(
1642 rewriter, loc, outputTy.getShape(), outputTy.getElementType(),
1643 ArrayRef<Value>({dynDims}));
1645 auto linalgOp = linalg::GenericOp::create(
1646 rewriter, loc, outputTy, genericInputs,
ValueRange{emptyTensor},
1648 [&](OpBuilder &nestedBuilder, Location nestedLoc,
1650 Value value = blockArgs[0];
1651 Type valueTy = value.
getType();
1653 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1654 auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
1655 nestedLoc, blockArgs, iZpArg);
1657 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1658 auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
1659 nestedLoc, blockArgs, oZpArg,
true);
1661 IntegerType outIntType =
1662 cast<IntegerType>(blockArgs.back().
getType());
1663 unsigned outBitWidth = outIntType.getWidth();
1664 assert(outBitWidth <= 32 &&
"Unexpected output zeropoint bitwidth");
1666 Value multiplier = multiplierConstant ? multiplierConstant
1667 : blockArgs[multiplierArg];
1668 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1671 value = UnrealizedConversionCastOp::create(
1672 nestedBuilder, nestedLoc,
1673 nestedBuilder.getIntegerType(
1679 if (op.getInputUnsigned()) {
1680 value = arith::ExtUIOp::create(nestedBuilder, nestedLoc,
1681 nestedBuilder.getI32Type(), value);
1683 value = arith::ExtSIOp::create(nestedBuilder, nestedLoc,
1684 nestedBuilder.getI32Type(), value);
1689 arith::SubIOp::create(nestedBuilder, nestedLoc, value, inputZp);
1691 value = tosa::ApplyScaleOp::create(nestedBuilder, loc,
1692 nestedBuilder.getI32Type(), value,
1693 multiplier, shift, roundingMode);
1697 arith::AddIOp::create(nestedBuilder, nestedLoc, value, outputZp);
1700 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1701 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1704 if (op.getOutputUnsigned()) {
1706 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1709 auto intMinVal = arith::ConstantOp::create(
1710 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMin));
1711 auto intMaxVal = arith::ConstantOp::create(
1712 nestedBuilder, loc, nestedBuilder.getI32IntegerAttr(intMax));
1715 nestedBuilder,
false);
1717 if (outIntType.getWidth() < 32) {
1718 value = arith::TruncIOp::create(
1719 nestedBuilder, nestedLoc,
1723 if (outIntType.isUnsignedInteger()) {
1724 value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc,
1728 linalg::YieldOp::create(nestedBuilder, loc, value);
1731 rewriter.
replaceOp(op, linalgOp->getResults());
1741 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1743 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1744 PatternRewriter &rewriter)
const final {
1745 Location loc = op.getLoc();
1746 ImplicitLocOpBuilder builder(loc, rewriter);
1747 auto input = op.getInput();
1748 auto inputTy = cast<RankedTensorType>(input.getType());
1749 auto resultTy = cast<RankedTensorType>(op.getType());
1750 const bool isBilinear = op.getMode() == ResizeMode::BILINEAR;
1752 auto inputH = inputTy.getDimSize(1);
1753 auto inputW = inputTy.getDimSize(2);
1754 auto outputH = resultTy.getDimSize(1);
1755 auto outputW = resultTy.getDimSize(2);
1757 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1759 op,
"tosa.resize is not a pure 1x1->1x1 image operation");
1761 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1762 op.getMode() != ResizeMode::BILINEAR)
1764 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1766 if (inputTy == resultTy) {
1771 SmallVector<int64_t> scale;
1777 SmallVector<ReassociationExprs, 4> reassociationMap(2);
1784 RankedTensorType::get({inputTy.getDimSize(0), inputTy.getDimSize(3)},
1785 inputTy.getElementType());
1786 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy, input,
1790 llvm::SmallVector<Value> outputDynSize;
1791 if (inputTy.isDynamicDim(0))
1792 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1793 if (inputTy.isDynamicDim(3))
1794 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1797 auto genericTy = collapseTy.clone(resultTy.getElementType());
1799 tensor::EmptyOp::create(builder, genericTy.getShape(),
1800 resultTy.getElementType(), outputDynSize);
1802 SmallVector<utils::IteratorType> iterators(genericTy.getRank(),
1803 utils::IteratorType::parallel);
1805 auto generic = linalg::GenericOp::create(
1807 ArrayRef<AffineMap>{genericMap, genericMap}, iterators,
1808 [=](OpBuilder &
b, Location loc,
ValueRange args) {
1809 Value value = args[0];
1811 if (inputTy.getElementType() != resultTy.getElementType()) {
1812 value = arith::ExtSIOp::create(
b, loc, resultTy.getElementType(),
1815 if (isBilinear && scale[0] != 0) {
1816 Value scaleY = arith::ConstantOp::create(
1817 b, loc,
b.getI32IntegerAttr(scale[0]));
1818 value = arith::MulIOp::create(
b, loc, value, scaleY);
1821 if (isBilinear && scale[2] != 0) {
1822 Value scaleX = arith::ConstantOp::create(
1823 b, loc,
b.getI32IntegerAttr(scale[2]));
1824 value = arith::MulIOp::create(
b, loc, value, scaleX);
1828 linalg::YieldOp::create(
b, loc, value);
1832 op, resultTy,
generic.getResults()[0], reassociationMap);
1844 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1848 auto input = op.getInput();
1849 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1850 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1852 if (!inputTy || !resultTy)
1854 "requires ranked input/output types");
1856 auto batch = inputTy.getDimSize(0);
1857 auto channels = inputTy.getDimSize(3);
1858 auto inputH = inputTy.getDimSize(1);
1859 auto inputW = inputTy.getDimSize(2);
1860 auto outputH = resultTy.getDimSize(1);
1861 auto outputW = resultTy.getDimSize(2);
1863 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1865 op,
"tosa.resize has no broadcasting behavior");
1870 resizeShape.push_back(batch);
1871 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1872 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1873 resizeShape.push_back(channels);
1875 auto resizeTy = resultTy.clone(resizeShape);
1877 tosa::ResizeOp::create(builder, resizeTy, input, op.getScale(),
1878 op.getOffset(), op.getBorder(), op.getMode());
1885 reassociationMap.push_back({});
1888 reassociationMap.push_back({});
1893 collapseShape.push_back(outputH);
1895 collapseShape.push_back(outputW);
1896 collapseShape.push_back(channels);
1898 auto collapseTy = resultTy.clone(collapseShape);
1899 Value collapse = tensor::CollapseShapeOp::create(builder, collapseTy,
1900 resize, reassociationMap);
1904 if (inputTy.isDynamicDim(0))
1905 outputDynSize.push_back(tensor::DimOp::create(builder, input, 0));
1906 if (inputTy.isDynamicDim(3))
1907 outputDynSize.push_back(tensor::DimOp::create(builder, input, 3));
1910 utils::IteratorType::parallel);
1911 Value empty = tensor::EmptyOp::create(
1912 builder, resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1929 Value value = args[0];
1930 linalg::YieldOp::create(
b, loc, value);
1939 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1941 LogicalResult matchAndRewrite(tosa::ResizeOp op,
1942 PatternRewriter &rewriter)
const final {
1943 Location loc = op.getLoc();
1944 ImplicitLocOpBuilder
b(loc, rewriter);
1945 auto input = op.getInput();
1946 auto inputTy = cast<ShapedType>(input.getType());
1947 auto resultTy = cast<ShapedType>(op.getType());
1948 auto resultETy = resultTy.getElementType();
1950 bool floatingPointMode = isa<FloatType>(resultETy);
1951 auto floatTy = resultETy;
1953 auto imageH = inputTy.getShape()[1];
1954 auto imageW = inputTy.getShape()[2];
1956 auto dynamicDimsOr =
1958 if (!dynamicDimsOr.has_value())
1960 op,
"unable to get dynamic dimensions of tosa.resize");
1962 if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR &&
1963 op.getMode() != ResizeMode::BILINEAR)
1965 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1967 SmallVector<AffineMap, 2> affineMaps = {
1969 auto emptyTensor = tensor::EmptyOp::create(
b, resultTy.getShape(),
1970 resultETy, *dynamicDimsOr);
1971 auto genericOp = linalg::GenericOp::create(
1974 Value resize = genericOp.getResult(0);
1977 OpBuilder::InsertionGuard regionGuard(
b);
1978 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1980 Value batch = linalg::IndexOp::create(
b, 0);
1981 Value y = linalg::IndexOp::create(
b, 1);
1982 Value x = linalg::IndexOp::create(
b, 2);
1983 Value channel = linalg::IndexOp::create(
b, 3);
1986 arith::ConstantOp::create(
b,
b.getZeroAttr(
b.getI32Type()));
1987 Value zeroFp = arith::ConstantOp::create(
b,
b.getZeroAttr(floatTy));
1989 arith::ConstantOp::create(
b,
b.getI32IntegerAttr(imageH - 1));
1991 arith::ConstantOp::create(
b,
b.getI32IntegerAttr(imageW - 1));
1993 Value inY = arith::IndexCastOp::create(
b,
b.getI32Type(), y);
1994 Value inX = arith::IndexCastOp::create(
b,
b.getI32Type(), x);
1996 SmallVector<int64_t> scale, offset, border;
2001 op,
"tosa.resize scale/offset/border should have compile time "
2002 "constant values.");
2005 Value yScaleN, yScaleD, xScaleN, xScaleD;
2006 yScaleN = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[0]));
2007 yScaleD = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[1]));
2008 xScaleN = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[2]));
2009 xScaleD = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(scale[3]));
2011 Value yOffset, xOffset, yBorder, xBorder;
2012 yOffset = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(offset[0]));
2013 xOffset = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(offset[1]));
2014 yBorder = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(border[0]));
2015 xBorder = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(border[1]));
2018 auto getIndexAndDeltaFp = [&](Value &index, Value &delta, Value in,
2019 Value scaleN, Value scaleD, Value offset,
2020 int size, ImplicitLocOpBuilder &
b) {
2028 Value val = arith::MulIOp::create(
b, in, scaleD);
2029 val = arith::AddIOp::create(
b, val, offset);
2030 index = arith::FloorDivSIOp::create(
b, val, scaleN);
2034 Value r = arith::RemSIOp::create(
b, val, scaleN);
2035 Value rFp = arith::SIToFPOp::create(
b, floatTy, r);
2036 Value scaleNfp = arith::UIToFPOp::create(
b, floatTy, scaleN);
2037 delta = arith::DivFOp::create(
b, rFp, scaleNfp);
2041 auto getIndexAndDeltaInt = [&](Value &index, Value &delta, Value in,
2042 Value scaleN, Value scaleD, Value offset,
2043 int size, ImplicitLocOpBuilder &
b) {
2052 Value val = arith::MulIOp::create(
b, in, scaleD);
2053 val = arith::AddIOp::create(
b, val, offset);
2054 index = arith::DivSIOp::create(
b, val, scaleN);
2055 delta = arith::MulIOp::create(
b, index, scaleN);
2056 delta = arith::SubIOp::create(
b, val, delta);
2059 Value ix, iy, dx, dy;
2060 if (floatingPointMode) {
2061 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH,
b);
2062 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW,
b);
2064 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH,
b);
2065 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW,
b);
2068 if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) {
2069 auto one = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(1));
2071 auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale,
2072 Value
max,
int size,
2073 ImplicitLocOpBuilder &
b) -> Value {
2079 if (floatingPointMode) {
2081 arith::ConstantOp::create(
b,
b.getFloatAttr(floatTy, 0.5f));
2082 pred = arith::CmpFOp::create(
b, arith::CmpFPredicate::OGE, dval, h);
2084 Value dvalDouble = arith::ShLIOp::create(
b, dval, one);
2085 pred = arith::CmpIOp::create(
b, arith::CmpIPredicate::sge,
2089 auto offset = arith::SelectOp::create(
b, pred, one, zeroI32);
2090 val = arith::AddIOp::create(
b, val, offset);
2092 return arith::IndexCastOp::create(
b,
b.getIndexType(), val);
2095 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH,
b);
2096 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW,
b);
2098 Value
result = tensor::ExtractOp::create(
2101 linalg::YieldOp::create(
b,
result);
2104 assert(op.getMode() == ResizeMode::BILINEAR);
2106 auto oneVal = arith::ConstantOp::create(
b,
b.getI32IntegerAttr(1));
2108 auto getClampedIdxs = [&](Value &val0, Value &val1,
int size, Value in,
2109 Value
max, ImplicitLocOpBuilder &
b) {
2111 val1 = arith::AddIOp::create(
b, val0, oneVal);
2116 val0 = arith::IndexCastOp::create(
b,
b.getIndexType(), val0);
2117 val1 = arith::IndexCastOp::create(
b,
b.getIndexType(), val1);
2125 Value x0, x1, y0, y1;
2126 getClampedIdxs(y0, y1, imageH, iy, hMax,
b);
2127 getClampedIdxs(x0, x1, imageW, ix, wMax,
b);
2129 Value y0x0 = tensor::ExtractOp::create(
2131 Value y0x1 = tensor::ExtractOp::create(
2133 Value y1x0 = tensor::ExtractOp::create(
2135 Value y1x1 = tensor::ExtractOp::create(
2138 if (floatingPointMode) {
2140 arith::ConstantOp::create(
b,
b.getFloatAttr(floatTy, 1.0f));
2141 auto interpolate = [&](Value val0, Value val1, Value delta,
2143 ImplicitLocOpBuilder &
b) -> Value {
2146 Value oneMinusDelta = arith::SubFOp::create(
b, oneVal, delta);
2147 Value mul0 = arith::MulFOp::create(
b, val0, oneMinusDelta);
2148 Value mul1 = arith::MulFOp::create(
b, val1, delta);
2149 return arith::AddFOp::create(
b, mul0, mul1);
2155 Value topAcc = interpolate(y0x0, y0x1, dx, imageW,
b);
2160 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW,
b);
2164 Value
result = interpolate(topAcc, bottomAcc, dy, imageH,
b);
2165 linalg::YieldOp::create(
b,
result);
2168 y0x0 = arith::ExtSIOp::create(
b, resultETy, y0x0);
2169 y0x1 = arith::ExtSIOp::create(
b, resultETy, y0x1);
2170 y1x0 = arith::ExtSIOp::create(
b, resultETy, y1x0);
2171 y1x1 = arith::ExtSIOp::create(
b, resultETy, y1x1);
2174 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
2175 dx = arith::ExtSIOp::create(
b, resultETy, dx);
2176 dy = arith::ExtSIOp::create(
b, resultETy, dy);
2179 Value yScaleNExt = yScaleN;
2180 Value xScaleNExt = xScaleN;
2182 const int64_t scaleBitwidth =
2184 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
2185 yScaleNExt = arith::ExtSIOp::create(
b, resultETy, yScaleN);
2186 xScaleNExt = arith::ExtSIOp::create(
b, resultETy, xScaleN);
2189 auto interpolate = [](Value val0, Value val1, Value weight1,
2190 Value scale,
int inputSize,
2191 ImplicitLocOpBuilder &
b) -> Value {
2193 return arith::MulIOp::create(
b, val0, scale);
2194 Value weight0 = arith::SubIOp::create(
b, scale, weight1);
2195 Value mul0 = arith::MulIOp::create(
b, val0, weight0);
2196 Value mul1 = arith::MulIOp::create(
b, val1, weight1);
2197 return arith::AddIOp::create(
b, mul0, mul1);
2200 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW,
b);
2201 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW,
b);
2203 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH,
b);
2204 linalg::YieldOp::create(
b,
result);
2217template <
typename SrcOp>
2220 using OpRewritePattern<SrcOp>::OpRewritePattern;
2222 LogicalResult matchAndRewrite(SrcOp op,
2223 PatternRewriter &rewriter)
const final {
2224 rewriter.
replaceOp(op, op.getOperation()->getOperands());
2229template <
typename SrcOp>
2232 using OpRewritePattern<SrcOp>::OpRewritePattern;
2234 LogicalResult matchAndRewrite(SrcOp reduceOp,
2235 PatternRewriter &rewriter)
const final {
2242 using OpRewritePattern<tosa::ReverseOp>::OpRewritePattern;
2244 LogicalResult matchAndRewrite(tosa::ReverseOp op,
2245 PatternRewriter &rewriter)
const final {
2246 auto loc = op.getLoc();
2247 Value input = op.getInput1();
2248 auto inputTy = cast<ShapedType>(input.
getType());
2249 auto resultTy = cast<ShapedType>(op.getType());
2250 auto axis = op.getAxis();
2252 SmallVector<Value> dynDims;
2253 for (
int i = 0; i < inputTy.getRank(); i++) {
2254 if (inputTy.isDynamicDim(i)) {
2255 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2259 Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis);
2262 auto emptyTensor = tensor::EmptyOp::create(
2263 rewriter, loc, inputTy.getShape(),
2264 inputTy.getElementType(), ArrayRef<Value>({dynDims}))
2266 SmallVector<AffineMap, 2> affineMaps = {
2270 op, resultTy, ArrayRef<Value>({}),
ValueRange{emptyTensor}, affineMaps,
2272 [&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
2273 llvm::SmallVector<Value>
indices;
2274 for (
unsigned int i = 0; i < inputTy.getRank(); i++) {
2276 linalg::IndexOp::create(rewriter, nestedLoc, i).getResult();
2280 arith::SubIOp::create(rewriter, nestedLoc, axisDimSize, one);
2281 index = arith::SubIOp::create(rewriter, nestedLoc, sizeMinusOne,
2288 auto extract = tensor::ExtractOp::create(nestedBuilder, nestedLoc,
2290 linalg::YieldOp::create(nestedBuilder, op.getLoc(),
2291 extract.getResult());
2301struct TileConverter :
public OpConversionPattern<tosa::TileOp> {
2302 using OpConversionPattern<tosa::TileOp>::OpConversionPattern;
2305 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
2306 ConversionPatternRewriter &rewriter)
const override {
2307 auto loc = op.getLoc();
2308 auto input = op.getInput1();
2309 auto inputTy = cast<ShapedType>(input.
getType());
2310 auto inputShape = inputTy.getShape();
2311 auto resultTy = cast<ShapedType>(op.getType());
2312 auto elementTy = inputTy.getElementType();
2313 int64_t rank = inputTy.getRank();
2315 SmallVector<int64_t> multiples;
2316 if (
failed(op.getConstantMultiples(multiples)))
2320 SmallVector<int64_t, 2> genericShape;
2321 for (
int i = 0; i < rank; i++) {
2322 int64_t dim = multiples[i];
2323 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
2324 genericShape.push_back(inputShape[i]);
2327 SmallVector<Value> dynDims;
2328 for (
int i = 0; i < inputTy.getRank(); i++) {
2329 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
2330 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2334 auto emptyTensor = tensor::EmptyOp::create(
2335 rewriter, op.getLoc(), genericShape, elementTy, dynDims);
2338 SmallVector<AffineExpr, 4> dimExprs;
2339 dimExprs.reserve(rank);
2340 for (
unsigned i = 0; i < rank; ++i)
2341 dimExprs.push_back(rewriter.getAffineDimExpr(i * 2 + 1));
2343 auto readAffineMap =
2345 rewriter.getContext());
2347 SmallVector<AffineMap, 2> affineMaps = {
2348 readAffineMap, rewriter.getMultiDimIdentityMap(genericShape.size())};
2350 auto genericOp = linalg::GenericOp::create(
2351 rewriter, loc, RankedTensorType::get(genericShape, elementTy), input,
2354 [&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
2355 linalg::YieldOp::create(nestedBuilder, op.getLoc(), *args.begin());
2360 rewriter.replaceOpWithNewOp<tosa::ReshapeOp>(
2361 op, resultTy, genericOp.getResult(0), shapeValue);
2381 using OpRewritePattern<tosa::ArgMaxOp>::OpRewritePattern;
2383 LogicalResult matchAndRewrite(tosa::ArgMaxOp argmaxOp,
2384 PatternRewriter &rewriter)
const final {
2385 auto loc = argmaxOp.getLoc();
2386 Value input = argmaxOp.getInput();
2387 auto inputTy = cast<ShapedType>(input.
getType());
2388 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
2389 auto inElementTy = inputTy.getElementType();
2390 auto outElementTy = resultTy.getElementType();
2391 int axis = argmaxOp.getAxis();
2392 auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
2394 if (!isa<IntegerType>(outElementTy))
2395 return rewriter.notifyMatchFailure(
2397 "tosa.arg_max to linalg.* requires integer-like result type");
2399 SmallVector<Value> dynDims;
2400 for (
int i = 0; i < inputTy.getRank(); i++) {
2401 if (inputTy.isDynamicDim(i) && i != axis) {
2402 dynDims.push_back(tensor::DimOp::create(rewriter, loc, input, i));
2407 auto emptyTensorIdx =
2408 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2409 outElementTy, dynDims)
2411 auto fillValueIdx = arith::ConstantOp::create(
2412 rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0));
2413 auto filledTensorIdx =
2414 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValueIdx},
2419 auto emptyTensorMax =
2420 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy,
2423 auto fillValueMaxAttr =
2426 if (!fillValueMaxAttr)
2427 return rewriter.notifyMatchFailure(
2428 argmaxOp,
"unsupported tosa.argmax element type");
2431 arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr);
2432 auto filledTensorMax =
2433 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValueMax},
2439 SmallVector<utils::IteratorType, 4> iteratorTypes;
2440 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2441 iteratorTypes[axis] = utils::IteratorType::reduction;
2443 SmallVector<AffineExpr, 2> srcExprs;
2444 SmallVector<AffineExpr, 2> dstExprs;
2445 for (
int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2451 bool didEncounterError =
false;
2453 rewriter.getContext());
2454 auto linalgOp = linalg::GenericOp::create(
2455 rewriter, loc, ArrayRef<Type>({resultTy, resultMaxTy}), input,
2456 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2457 [&](OpBuilder &nestedBuilder, Location nestedLoc,
2459 auto newValue = blockArgs[0];
2460 auto oldIndex = blockArgs[1];
2461 auto oldValue = blockArgs[2];
2463 Value newIndex = arith::IndexCastOp::create(
2464 rewriter, nestedLoc, oldIndex.getType(),
2465 linalg::IndexOp::create(rewriter, loc, axis));
2468 if (isa<FloatType>(inElementTy)) {
2469 if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) {
2472 predicate = arith::CmpFOp::create(rewriter, nestedLoc,
2473 arith::CmpFPredicate::OGT,
2474 newValue, oldValue);
2479 Value gt = arith::CmpFOp::create(rewriter, nestedLoc,
2480 arith::CmpFPredicate::UGT,
2481 newValue, oldValue);
2482 Value oldNonNaN = arith::CmpFOp::create(rewriter, nestedLoc,
2483 arith::CmpFPredicate::ORD,
2484 oldValue, oldValue);
2485 predicate = arith::AndIOp::create(
2486 rewriter, nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
2488 }
else if (isa<IntegerType>(inElementTy)) {
2489 predicate = arith::CmpIOp::create(rewriter, nestedLoc,
2490 arith::CmpIPredicate::sgt,
2491 newValue, oldValue);
2493 didEncounterError =
true;
2497 auto resultMax = arith::SelectOp::create(
2498 rewriter, nestedLoc, predicate, newValue, oldValue);
2499 auto resultIndex = arith::SelectOp::create(
2500 rewriter, nestedLoc, predicate, newIndex, oldIndex);
2501 linalg::YieldOp::create(nestedBuilder, nestedLoc,
2505 if (didEncounterError)
2506 return rewriter.notifyMatchFailure(
2507 argmaxOp,
"unsupported tosa.argmax element type");
2509 rewriter.replaceOp(argmaxOp, linalgOp.getResult(0));
2514class GatherConverter :
public OpConversionPattern<tosa::GatherOp> {
2516 using OpConversionPattern<tosa::GatherOp>::OpConversionPattern;
2518 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2519 ConversionPatternRewriter &rewriter)
const final {
2520 auto input = adaptor.getOperands()[0];
2521 auto indices = adaptor.getOperands()[1];
2523 auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
2524 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
2525 if (!valuesTy || !resultTy)
2526 return rewriter.notifyMatchFailure(op,
"unranked tensors not supported");
2528 auto dynamicDims = inferDynamicDimsForGather(
2529 rewriter, op.getLoc(), adaptor.getValues(), adaptor.getIndices());
2531 auto resultElementTy = resultTy.getElementType();
2533 auto loc = op.getLoc();
2535 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2536 resultElementTy, dynamicDims)
2539 SmallVector<AffineMap, 2> affineMaps = {
2541 resultTy.getRank(), 0,
2542 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2543 rewriter.getContext()),
2544 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2546 auto genericOp = linalg::GenericOp::create(
2550 [&](OpBuilder &
b, Location loc,
ValueRange args) {
2551 auto indexValue = args[0];
2552 auto index0 = linalg::IndexOp::create(rewriter, loc, 0);
2553 Value index1 = arith::IndexCastOp::create(
2554 rewriter, loc, rewriter.getIndexType(), indexValue);
2555 auto index2 = linalg::IndexOp::create(rewriter, loc, 2);
2556 Value extract = tensor::ExtractOp::create(
2557 rewriter, loc, input,
ValueRange{index0, index1, index2});
2558 linalg::YieldOp::create(rewriter, loc, extract);
2560 rewriter.replaceOp(op, genericOp.getResult(0));
2564 static llvm::SmallVector<Value> inferDynamicDimsForGather(OpBuilder &builder,
2568 llvm::SmallVector<Value> results;
2570 auto addDynamicDimension = [&](Value source, int64_t dim) {
2572 if (
auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2573 results.push_back(dimValue);
2576 addDynamicDimension(values, 0);
2577 addDynamicDimension(
indices, 1);
2578 addDynamicDimension(values, 2);
2588 using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
2590 LogicalResult matchAndRewrite(tosa::TableOp op,
2591 PatternRewriter &rewriter)
const final {
2592 auto loc = op.getLoc();
2593 Value input = op.getInput1();
2594 Value table = op.getTable();
2595 auto inputTy = cast<ShapedType>(input.
getType());
2596 auto tableTy = cast<ShapedType>(table.
getType());
2597 auto resultTy = cast<ShapedType>(op.getType());
2599 auto inputElementTy = inputTy.getElementType();
2600 auto tableElementTy = tableTy.getElementType();
2601 auto resultElementTy = resultTy.getElementType();
2603 SmallVector<Value> dynDims;
2604 for (
int i = 0; i < resultTy.getRank(); ++i) {
2605 if (inputTy.isDynamicDim(i)) {
2607 tensor::DimOp::create(rewriter, loc, op.getOperand(0), i));
2612 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
2613 resultElementTy, dynDims)
2616 SmallVector<AffineMap, 2> affineMaps = {
2617 rewriter.getMultiDimIdentityMap(resultTy.getRank()),
2618 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
2620 auto genericOp = linalg::GenericOp::create(
2623 rewriter.replaceOp(op, genericOp.getResult(0));
2626 OpBuilder::InsertionGuard regionGuard(rewriter);
2627 Block *block = rewriter.createBlock(
2628 &genericOp.getRegion(), genericOp.getRegion().end(),
2629 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2632 rewriter.setInsertionPointToStart(block);
2633 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2634 resultElementTy.isInteger(8)) {
2635 Value index = arith::IndexCastOp::create(
2636 rewriter, loc, rewriter.getIndexType(), inputValue);
2638 index = arith::AddIOp::create(rewriter, loc, rewriter.getIndexType(),
2641 tensor::ExtractOp::create(rewriter, loc, table,
ValueRange{index});
2642 linalg::YieldOp::create(rewriter, loc, extract);
2646 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2647 resultElementTy.isInteger(32)) {
2648 Value extend = arith::ExtSIOp::create(
2649 rewriter, loc, rewriter.getI32Type(), inputValue);
2651 auto offset = arith::ConstantOp::create(
2652 rewriter, loc, rewriter.getI32IntegerAttr(32768));
2653 auto seven = arith::ConstantOp::create(rewriter, loc,
2654 rewriter.getI32IntegerAttr(7));
2655 auto one = arith::ConstantOp::create(rewriter, loc,
2656 rewriter.getI32IntegerAttr(1));
2657 auto b1111111 = arith::ConstantOp::create(
2658 rewriter, loc, rewriter.getI32IntegerAttr(127));
2664 auto extendAdd = arith::AddIOp::create(rewriter, loc, extend, offset);
2665 Value index = arith::ShRUIOp::create(rewriter, loc, extendAdd, seven);
2667 arith::AndIOp::create(rewriter, loc, extendAdd, b1111111);
2672 Value indexPlusOne = arith::AddIOp::create(rewriter, loc, index, one);
2674 index = arith::IndexCastOp::create(rewriter, loc,
2675 rewriter.getIndexType(), index);
2676 indexPlusOne = arith::IndexCastOp::create(
2677 rewriter, loc, rewriter.getIndexType(), indexPlusOne);
2680 tensor::ExtractOp::create(rewriter, loc, table,
ValueRange{index});
2681 Value next = tensor::ExtractOp::create(rewriter, loc, table,
2685 arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), base);
2687 arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), next);
2691 Value baseScaled = arith::ShLIOp::create(rewriter, loc, base, seven);
2692 Value diff = arith::SubIOp::create(rewriter, loc, next, base);
2693 Value diffScaled = arith::MulIOp::create(rewriter, loc, diff, fraction);
2695 arith::AddIOp::create(rewriter, loc, baseScaled, diffScaled);
2697 linalg::YieldOp::create(rewriter, loc,
result);
2703 return rewriter.notifyMatchFailure(
2704 op,
"unable to create body for tosa.table op");
2709 using OpRewritePattern<RFFT2dOp>::OpRewritePattern;
2711 static bool isRankedTensor(Type type) {
return isa<RankedTensorType>(type); }
2713 static OpFoldResult halfPlusOne(OpBuilder &builder, Location loc,
2719 auto divBy2 = builder.
createOrFold<arith::DivUIOp>(loc, value, two);
2720 auto plusOne = builder.
createOrFold<arith::AddIOp>(loc, divBy2, one);
2724 static RankedTensorType
2725 computeOutputShape(OpBuilder &builder, Location loc, Value input,
2726 llvm::SmallVectorImpl<Value> &dynamicSizes) {
2732 dims[2] = halfPlusOne(builder, loc, dims[2]);
2734 llvm::SmallVector<int64_t, 3> staticSizes;
2737 auto elementType = cast<RankedTensorType>(input.
getType()).getElementType();
2738 return RankedTensorType::get(staticSizes, elementType);
2741 static Value createZeroTensor(PatternRewriter &rewriter, Location loc,
2742 RankedTensorType type,
2743 llvm::ArrayRef<Value> dynamicSizes) {
2745 tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes);
2746 auto fillValueAttr = rewriter.
getZeroAttr(type.getElementType());
2747 auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr);
2749 linalg::FillOp::create(rewriter, loc,
ValueRange{fillValue},
2752 return filledTensor;
2755 static Value castIndexToFloat(OpBuilder &builder, Location loc,
2756 FloatType type, Value value) {
2757 auto integerVal = arith::IndexCastUIOp::create(
2759 type.getIntOrFloatBitWidth() > 32 ? builder.
getI64Type()
2763 return arith::UIToFPOp::create(builder, loc, type, integerVal);
2766 static Value createLinalgIndex(OpBuilder &builder, Location loc,
2767 FloatType type, int64_t index) {
2768 auto indexVal = linalg::IndexOp::create(builder, loc, index);
2769 return castIndexToFloat(builder, loc, type, indexVal);
2772 template <
typename... Args>
2773 static llvm::SmallVector<AffineExpr, 4> affineDimsExpr(OpBuilder &builder,
2778 LogicalResult matchAndRewrite(RFFT2dOp rfft2d,
2779 PatternRewriter &rewriter)
const override {
2780 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2781 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2783 "only supports ranked tensors");
2786 auto loc = rfft2d.getLoc();
2787 auto input = rfft2d.getInputReal();
2789 dyn_cast<FloatType>(cast<ShapedType>(input.
getType()).getElementType());
2792 "only supports float element types");
2795 llvm::SmallVector<Value> dynamicSizes;
2796 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2799 llvm::SmallVector<utils::IteratorType, 5> iteratorTypes = {
2800 utils::IteratorType::parallel, utils::IteratorType::parallel,
2801 utils::IteratorType::parallel, utils::IteratorType::reduction,
2802 utils::IteratorType::reduction};
2805 llvm::SmallVector<Value> genericOpInputs = {input};
2806 llvm::SmallVector<Value> genericOpOutputs = {
2807 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2808 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2812 llvm::ArrayRef{affineDimsExpr(rewriter, 0, 3, 4),
2813 affineDimsExpr(rewriter, 0, 1, 2),
2814 affineDimsExpr(rewriter, 0, 1, 2)},
2818 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input, 1);
2819 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input, 2);
2822 auto twoPiAttr = rewriter.
getFloatAttr(elementType, 6.283185307179586);
2823 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2824 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2825 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2827 auto buildBody = [&](OpBuilder &builder, Location loc,
ValueRange args) {
2828 Value valReal = args[0];
2829 Value sumReal = args[1];
2830 Value sumImag = args[2];
2833 Value oy = linalg::IndexOp::create(builder, loc, 1);
2834 Value ox = linalg::IndexOp::create(builder, loc, 2);
2835 Value iy = linalg::IndexOp::create(builder, loc, 3);
2836 Value ix = linalg::IndexOp::create(builder, loc, 4);
2841 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2842 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2844 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2845 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2847 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2848 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2850 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2851 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2852 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2853 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2857 auto cosAngle = math::CosOp::create(builder, loc, angle);
2858 auto sinAngle = math::SinOp::create(builder, loc, angle);
2859 auto realComponent =
2860 arith::MulFOp::create(builder, loc, valReal, cosAngle);
2861 auto imagComponent =
2862 arith::MulFOp::create(builder, loc, valReal, sinAngle);
2867 arith::AddFOp::create(builder, loc, sumReal, realComponent);
2869 arith::SubFOp::create(builder, loc, sumImag, imagComponent);
2871 linalg::YieldOp::create(builder, loc,
ValueRange{outReal, outImag});
2875 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2876 indexingMaps, iteratorTypes, buildBody);
2885 LogicalResult matchAndRewrite(FFT2dOp fft2d,
2886 PatternRewriter &rewriter)
const override {
2887 if (!llvm::all_of(fft2d->getOperandTypes(),
2888 RFFT2dConverter::isRankedTensor) ||
2889 !llvm::all_of(fft2d->getResultTypes(),
2890 RFFT2dConverter::isRankedTensor)) {
2894 Location loc = fft2d.getLoc();
2895 Value input_real = fft2d.getInputReal();
2896 Value input_imag = fft2d.getInputImag();
2897 BoolAttr inverse = fft2d.getInverseAttr();
2899 auto real_el_ty = cast<FloatType>(
2900 cast<ShapedType>(input_real.
getType()).getElementType());
2901 [[maybe_unused]]
auto imag_el_ty = cast<FloatType>(
2902 cast<ShapedType>(input_imag.
getType()).getElementType());
2904 assert(real_el_ty == imag_el_ty);
2907 SmallVector<Value> dynamicSizes;
2912 SmallVector<int64_t, 3> staticSizes;
2915 auto outputType = RankedTensorType::get(staticSizes, real_el_ty);
2918 SmallVector<utils::IteratorType, 5> iteratorTypes = {
2919 utils::IteratorType::parallel, utils::IteratorType::parallel,
2920 utils::IteratorType::parallel, utils::IteratorType::reduction,
2921 utils::IteratorType::reduction};
2924 SmallVector<Value> genericOpInputs = {input_real, input_imag};
2925 SmallVector<Value> genericOpOutputs = {
2926 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2928 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2933 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2934 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2935 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2936 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2940 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 1);
2941 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 2);
2944 auto twoPiAttr = rewriter.
getFloatAttr(real_el_ty, 6.283185307179586);
2945 auto twoPi = arith::ConstantOp::create(rewriter, loc, twoPiAttr);
2947 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2949 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2951 auto buildBody = [&](OpBuilder &builder, Location loc,
ValueRange args) {
2952 Value valReal = args[0];
2953 Value valImag = args[1];
2954 Value sumReal = args[2];
2955 Value sumImag = args[3];
2958 Value oy = linalg::IndexOp::create(builder, loc, 1);
2959 Value ox = linalg::IndexOp::create(builder, loc, 2);
2960 Value iy = linalg::IndexOp::create(builder, loc, 3);
2961 Value ix = linalg::IndexOp::create(builder, loc, 4);
2965 auto iyXoy = index::MulOp::create(builder, loc, iy, oy);
2966 auto ixXox = index::MulOp::create(builder, loc, ix, ox);
2968 auto iyRem = index::RemUOp::create(builder, loc, iyXoy, dimH);
2969 auto ixRem = index::RemUOp::create(builder, loc, ixXox, dimW);
2972 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2974 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2976 auto yComponent = arith::DivFOp::create(builder, loc, iyRemFloat, constH);
2977 auto xComponent = arith::DivFOp::create(builder, loc, ixRemFloat, constW);
2979 auto sumXY = arith::AddFOp::create(builder, loc, yComponent, xComponent);
2980 auto angle = arith::MulFOp::create(builder, loc, twoPi, sumXY);
2983 angle = arith::MulFOp::create(
2984 builder, loc, angle,
2985 arith::ConstantOp::create(rewriter, loc,
2991 auto cosAngle = math::CosOp::create(builder, loc, angle);
2992 auto sinAngle = math::SinOp::create(builder, loc, angle);
2994 auto rcos = arith::MulFOp::create(builder, loc, valReal, cosAngle);
2995 auto rsin = arith::MulFOp::create(builder, loc, valImag, sinAngle);
2996 auto realComponent = arith::AddFOp::create(builder, loc, rcos, rsin);
2998 auto icos = arith::MulFOp::create(builder, loc, valImag, cosAngle);
2999 auto isin = arith::MulFOp::create(builder, loc, valReal, sinAngle);
3001 auto imagComponent = arith::SubFOp::create(builder, loc, icos, isin);
3006 arith::AddFOp::create(builder, loc, sumReal, realComponent);
3008 arith::AddFOp::create(builder, loc, sumImag, imagComponent);
3010 linalg::YieldOp::create(builder, loc,
ValueRange{outReal, outImag});
3014 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
3015 indexingMaps, iteratorTypes, buildBody);
3036 PointwiseConverter<tosa::AddOp>,
3037 PointwiseConverter<tosa::SubOp>,
3038 PointwiseConverter<tosa::MulOp>,
3039 PointwiseConverter<tosa::IntDivOp>,
3040 PointwiseConverter<tosa::NegateOp>,
3041 PointwiseConverter<tosa::PowOp>,
3042 PointwiseConverter<tosa::ReciprocalOp>,
3043 PointwiseConverter<tosa::RsqrtOp>,
3044 PointwiseConverter<tosa::LogOp>,
3045 PointwiseConverter<tosa::ExpOp>,
3046 PointwiseConverter<tosa::AbsOp>,
3047 PointwiseConverter<tosa::SinOp>,
3048 PointwiseConverter<tosa::CosOp>,
3049 PointwiseConverter<tosa::TanhOp>,
3050 PointwiseConverter<tosa::ErfOp>,
3051 PointwiseConverter<tosa::BitwiseAndOp>,
3052 PointwiseConverter<tosa::BitwiseOrOp>,
3053 PointwiseConverter<tosa::BitwiseNotOp>,
3054 PointwiseConverter<tosa::BitwiseXorOp>,
3055 PointwiseConverter<tosa::LogicalAndOp>,
3056 PointwiseConverter<tosa::LogicalNotOp>,
3057 PointwiseConverter<tosa::LogicalOrOp>,
3058 PointwiseConverter<tosa::LogicalXorOp>,
3059 PointwiseConverter<tosa::CastOp>,
3060 PointwiseConverter<tosa::LogicalLeftShiftOp>,
3061 PointwiseConverter<tosa::LogicalRightShiftOp>,
3062 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
3063 PointwiseConverter<tosa::ClzOp>,
3064 PointwiseConverter<tosa::SelectOp>,
3065 PointwiseConverter<tosa::GreaterOp>,
3066 PointwiseConverter<tosa::GreaterEqualOp>,
3067 PointwiseConverter<tosa::EqualOp>,
3068 PointwiseConverter<tosa::MaximumOp>,
3069 PointwiseConverter<tosa::MinimumOp>,
3070 PointwiseConverter<tosa::CeilOp>,
3071 PointwiseConverter<tosa::FloorOp>,
3072 PointwiseConverter<tosa::ClampOp>,
3073 PointwiseConverter<tosa::SigmoidOp>
3074 >(converter,
patterns->getContext());
3077 IdentityNConverter<tosa::IdentityOp>,
3078 ReduceConverter<tosa::ReduceAllOp>,
3079 ReduceConverter<tosa::ReduceAnyOp>,
3080 ReduceConverter<tosa::ReduceMinOp>,
3081 ReduceConverter<tosa::ReduceMaxOp>,
3082 ReduceConverter<tosa::ReduceSumOp>,
3083 ReduceConverter<tosa::ReduceProductOp>,
3091 TileConverter>(
patterns->getContext());
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be inserted(the insertion happens right before the *insertion point). Since `begin` can itself be invalidated due to the memref *rewriting done from this method
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static LogicalResult emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape, const TypeConverter &converter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
DenseMap< int64_t, Value > IndexPool
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, PatternRewriter &rewriter)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, ConversionPatternRewriter &rewriter)
static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands)
static Value materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, Value lhs, Value rhs, Value result)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static bool operandsAndResultsRanked(Operation *operation)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
BlockArgument getArgument(unsigned i)
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
BoolAttr getBoolAttr(bool value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
IntegerAttr getI8IntegerAttr(int8_t value)
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
void populateTosaToLinalgConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...