32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/Sequence.h"
41 static arith::ConstantOp
44 auto castedN =
static_cast<T
>(
45 cast<IntegerAttr>(op->
getAttr(attrName)).getValue().getSExtValue());
46 return rewriter.
create<arith::ConstantOp>(
59 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
60 return rewriter.
create<math::AbsFOp>(loc, resultTypes, args);
62 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
63 auto zero = rewriter.
create<arith::ConstantOp>(
65 auto neg = rewriter.
create<arith::SubIOp>(loc, zero, args[0]);
66 return rewriter.
create<arith::MaxSIOp>(loc, args[0], neg);
70 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
71 return rewriter.
create<arith::AddFOp>(loc, resultTypes, args);
73 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
74 return rewriter.
create<arith::AddIOp>(loc, resultTypes, args);
77 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
78 return rewriter.
create<arith::SubFOp>(loc, resultTypes, args);
80 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
81 return rewriter.
create<arith::SubIOp>(loc, resultTypes, args);
84 if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) {
85 if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
87 "Cannot have shift value for float");
90 return rewriter.
create<arith::MulFOp>(loc, resultTypes, args);
94 if (isa<tosa::DivOp>(op) && isa<IntegerType>(elementTy))
95 return rewriter.
create<arith::DivSIOp>(loc, resultTypes, args);
98 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
101 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, args[0]);
104 if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
108 cast<IntegerAttr>(op->
getAttr(
"shift")).getValue().getSExtValue();
111 rewriter.
create<arith::ConstantIntOp>(loc, shift, 8);
118 auto result = rewriter.
create<tosa::ApplyScaleOp>(
122 if (elementTy.isInteger(32))
125 return rewriter.
create<arith::TruncIOp>(loc, elementTy, result);
130 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
133 a = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], a);
135 b = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], b);
137 return rewriter.
create<arith::MulIOp>(loc, resultTypes, a, b);
141 if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
142 return rewriter.
create<arith::NegFOp>(loc, resultTypes, args);
144 if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
145 !cast<tosa::NegateOp>(op).getQuantizationInfo()) {
148 return rewriter.
create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
151 if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
152 cast<tosa::NegateOp>(op).getQuantizationInfo()) {
153 auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
154 int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
155 int64_t inZp = quantizationInfo.value().getInputZp();
156 int64_t outZp = quantizationInfo.value().getOutputZp();
159 int64_t zpAdd = inZp + outZp;
160 int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
165 int intermediateBitWidth = 64;
166 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
167 intermediateBitWidth = 16;
168 }
else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
169 intermediateBitWidth = 32;
170 }
else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
171 intermediateBitWidth = 48;
175 Value zpAddValue = rewriter.
create<arith::ConstantOp>(
180 auto ext = rewriter.
create<arith::ExtSIOp>(loc, intermediateType, args[0]);
181 auto sub = rewriter.
create<arith::SubIOp>(loc, zpAddValue, ext);
185 loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
188 loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
193 return rewriter.
create<arith::TruncIOp>(loc, elementTy,
clamp);
197 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
198 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
201 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
202 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
205 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
207 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
208 auto allOnes = rewriter.
create<arith::ConstantOp>(loc, allOnesAttr);
209 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
213 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
214 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
217 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
218 return rewriter.
create<arith::ShLIOp>(loc, resultTypes, args);
221 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
222 return rewriter.
create<arith::ShRUIOp>(loc, resultTypes, args);
225 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
226 auto result = rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args);
227 auto round = cast<BoolAttr>(op->
getAttr(
"round")).getValue();
241 auto shiftValueGreaterThanZero = rewriter.
create<arith::CmpIOp>(
242 loc, arith::CmpIPredicate::sgt, args[1], zero);
246 rewriter.
create<arith::SubIOp>(loc, resultTypes, args[1], one);
248 rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
251 rewriter.
create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
253 rewriter.
create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
255 auto shouldRound = rewriter.
create<arith::AndIOp>(
256 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
258 rewriter.
create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
259 return rewriter.
create<arith::AddIOp>(loc, resultTypes, result, extended);
263 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
264 return rewriter.
create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
268 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
269 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
272 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
273 auto one = rewriter.
create<arith::ConstantOp>(
275 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], one);
279 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
280 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
283 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
284 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
287 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
288 return rewriter.
create<mlir::math::PowFOp>(loc, resultTypes, args);
291 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
292 return rewriter.
create<mlir::math::RsqrtOp>(loc, resultTypes, args);
295 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
296 return rewriter.
create<mlir::math::LogOp>(loc, resultTypes, args);
299 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
300 return rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, args);
303 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
304 return rewriter.
create<mlir::math::TanhOp>(loc, resultTypes, args);
307 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
308 return rewriter.
create<mlir::math::ErfOp>(loc, resultTypes, args);
311 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
312 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
315 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
316 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
320 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
321 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
324 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
325 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
329 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
330 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
333 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
334 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
338 if (isa<tosa::SelectOp>(op)) {
340 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
341 return rewriter.
create<arith::SelectOp>(loc, args[0], args[1], args[2]);
345 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
346 return rewriter.
create<arith::MaximumFOp>(loc, args[0], args[1]);
349 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
350 return rewriter.
create<arith::MaxSIOp>(loc, args[0], args[1]);
354 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
355 return rewriter.
create<arith::MinimumFOp>(loc, args[0], args[1]);
358 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
359 return rewriter.
create<arith::MinSIOp>(loc, args[0], args[1]);
363 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
364 return rewriter.
create<math::CeilOp>(loc, resultTypes, args);
367 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
368 return rewriter.
create<math::FloorOp>(loc, resultTypes, args);
371 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
372 bool losesInfo =
false;
373 APFloat minApf = cast<FloatAttr>(op->
getAttr(
"min_fp")).getValue();
374 APFloat maxApf = cast<FloatAttr>(op->
getAttr(
"max_fp")).getValue();
375 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
376 APFloat::rmNearestTiesToEven, &losesInfo);
377 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
378 APFloat::rmNearestTiesToEven, &losesInfo);
379 auto min = rewriter.
create<arith::ConstantOp>(
380 loc, elementTy, rewriter.
getFloatAttr(elementTy, minApf));
381 auto max = rewriter.
create<arith::ConstantOp>(
382 loc, elementTy, rewriter.
getFloatAttr(elementTy, maxApf));
386 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
387 auto intTy = cast<IntegerType>(elementTy);
389 cast<IntegerAttr>(op->
getAttr(
"min_int")).getValue().getSExtValue();
391 cast<IntegerAttr>(op->
getAttr(
"max_int")).getValue().getSExtValue();
393 if (intTy.isUnsignedInteger()) {
397 APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
400 std::max(
min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
403 std::min(
max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
407 auto minVal = rewriter.
create<arith::ConstantIntOp>(
408 loc,
min, intTy.getIntOrFloatBitWidth());
409 auto maxVal = rewriter.
create<arith::ConstantIntOp>(
410 loc,
max, intTy.getIntOrFloatBitWidth());
415 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
418 auto negate = rewriter.
create<arith::NegFOp>(loc, resultTypes, args[0]);
419 auto exp = rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, negate);
420 auto added = rewriter.
create<arith::AddFOp>(loc, resultTypes, exp, one);
421 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, added);
425 if (isa<tosa::CastOp>(op)) {
426 Type srcTy = elementTy;
427 Type dstTy = resultTypes.front();
434 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
435 return rewriter.
create<arith::ExtFOp>(loc, resultTypes, args,
438 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
439 return rewriter.
create<arith::TruncFOp>(loc, resultTypes, args,
443 if (srcTy.
isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
444 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes, args,
447 if (srcTy.
isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
448 return rewriter.
create<arith::ExtUIOp>(loc, resultTypes, args,
454 auto unrealizedCast =
456 .
create<UnrealizedConversionCastOp>(
460 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes[0],
465 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
466 return rewriter.
create<arith::SIToFPOp>(loc, resultTypes, args,
470 if (isa<FloatType>(srcTy) && dstTy.
isInteger(1)) {
473 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
477 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
478 auto rounded = rewriter.
create<math::RoundEvenOp>(loc, args[0]);
480 const auto &fltSemantics = cast<FloatType>(srcTy).getFloatSemantics();
484 APFloat::semanticsMaxExponent(fltSemantics)) {
487 auto conv = rewriter.
create<arith::FPToSIOp>(loc, dstTy, rounded);
488 auto posInf = rewriter.
create<arith::ConstantOp>(
490 APFloat::getInf(fltSemantics)));
491 auto negInf = rewriter.
create<arith::ConstantOp>(
494 APFloat::getInf(fltSemantics,
true)));
495 auto overflow = rewriter.
create<arith::CmpFOp>(
496 loc, arith::CmpFPredicate::UEQ, rounded, posInf);
497 auto underflow = rewriter.
create<arith::CmpFOp>(
498 loc, arith::CmpFPredicate::UEQ, rounded, negInf);
499 auto intMin = rewriter.
create<arith::ConstantOp>(
502 APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())));
503 auto intMax = rewriter.
create<arith::ConstantOp>(
506 APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth())));
508 rewriter.
create<arith::SelectOp>(loc, overflow, intMax, conv);
509 return rewriter.
create<arith::SelectOp>(loc, underflow, intMin,
513 auto intMinFP = rewriter.
create<arith::ConstantOp>(
520 if (cast<FloatType>(srcTy).getFPMantissaWidth() >=
526 auto intMaxFP = rewriter.
create<arith::ConstantOp>(
534 return rewriter.
create<arith::FPToSIOp>(loc, dstTy, clamped);
541 auto intMaxPlusOneFP = rewriter.
create<arith::ConstantOp>(
548 auto intMax = rewriter.
create<arith::ConstantOp>(
553 rewriter.
create<arith::MaximumFOp>(loc, rounded, intMinFP);
555 rewriter.
create<arith::FPToSIOp>(loc, dstTy, minClampedFP);
556 auto overflow = rewriter.
create<arith::CmpFOp>(
557 loc, arith::CmpFPredicate::UGE, rounded, intMaxPlusOneFP);
558 return rewriter.
create<arith::SelectOp>(loc, overflow, intMax,
564 if (isa<IntegerType>(srcTy) && dstTy.
isInteger(1)) {
565 Value zero = rewriter.
create<arith::ConstantIntOp>(
567 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
571 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
572 return rewriter.
create<arith::ExtSIOp>(loc, resultTypes, args,
575 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
576 return rewriter.
create<arith::TruncIOp>(loc, dstTy, args[0]);
581 op,
"unhandled op for linalg body calculation for elementwise op");
588 auto shapedType = dyn_cast<ShapedType>(tensor.
getType());
589 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
590 int64_t numExtraDims = rank - shapedType.getRank();
591 assert(numExtraDims >= 0 &&
"cannot expand tensor to a lower rank");
597 shapedType.getRank());
599 for (index = 0; index <= numExtraDims; index++)
600 reassociationIndices[0].push_back(index);
601 for (
size_t position = 1; position < reassociationIndices.size(); position++)
602 reassociationIndices[position].push_back(index++);
606 for (index = 0; index < numExtraDims; index++)
607 resultShape.push_back(1);
608 for (
auto size : shapedType.getShape())
609 resultShape.push_back(size);
614 return rewriter.
create<tensor::ExpandShapeOp>(loc, resultType, tensor,
615 reassociationIndices);
623 return expandRank(rewriter, loc, operand, rank);
634 auto [it, inserted] = indexPool.try_emplace(index);
643 auto indexValue =
createIndex(rewriter, loc, indexPool, index);
644 return rewriter.
create<tensor::DimOp>(loc, tensor, indexValue).getResult();
650 auto shapedType = dyn_cast<ShapedType>(tensor.
getType());
651 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
652 assert(index >= 0 && index < shapedType.getRank() &&
"index out of bounds");
653 if (shapedType.isDynamicDim(index))
654 return getTensorDim(rewriter, loc, indexPool, tensor, index);
655 return rewriter.
getIndexAttr(shapedType.getDimSize(index));
659 auto isRanked = [](
Value value) {
660 return isa<RankedTensorType>(value.getType());
662 return llvm::all_of(operation->
getOperands(), isRanked) &&
663 llvm::all_of(operation->
getResults(), isRanked);
676 static std::pair<OpFoldResult, Value>
682 for (
auto operand : operands) {
683 auto size = cast<RankedTensorType>(operand.getType()).getDimSize(dim);
684 if (!ShapedType::isDynamic(size) && size > 1)
689 auto operandsWithDynamicDim =
690 llvm::to_vector(llvm::make_filter_range(operands, [&](
Value operand) {
691 return cast<RankedTensorType>(operand.
getType()).isDynamicDim(dim);
695 if (operandsWithDynamicDim.empty())
702 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
703 if (operandsWithDynamicDim.size() == 1)
704 return {targetSize, operandsWithDynamicDim[0]};
707 for (
size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
709 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
710 targetSize = rewriter.
create<arith::MaxUIOp>(loc, targetSize, nextSize);
712 return {targetSize,
nullptr};
720 assert(!operands.empty());
721 auto rank = cast<RankedTensorType>(operands.front().
getType()).getRank();
724 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
725 auto [targetSize, masterOperand] =
727 targetShape.push_back(targetSize);
728 masterOperands.push_back(masterOperand);
730 return {targetShape, masterOperands};
736 Value masterOperand) {
738 auto rankedTensorType = cast<RankedTensorType>(operand.
getType());
739 if (!rankedTensorType.isDynamicDim(dim))
746 if (operand == masterOperand)
750 auto rank = rankedTensorType.getRank();
752 for (
auto index : llvm::seq<int64_t>(0, rank)) {
755 affineExprs.push_back(affineExpr);
757 auto broadcastAffineMap =
763 auto one =
createIndex(rewriter, loc, indexPool, 1);
764 auto runtimeSize =
getTensorDim(rewriter, loc, indexPool, operand, dim);
765 auto broadcastNecessary = rewriter.
create<arith::CmpIOp>(
766 loc, arith::CmpIPredicate::eq, runtimeSize, one);
776 for (
auto index : llvm::seq<int64_t>(0, rank)) {
777 auto size = index == dim ? targetSize
780 outputTensorShape.push_back(size);
782 Value outputTensor = opBuilder.
create<tensor::EmptyOp>(
783 loc, outputTensorShape, rankedTensorType.getElementType());
788 .
create<linalg::GenericOp>(
789 loc, outputTensor.
getType(), operand, outputTensor, affineMaps,
793 opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
798 auto castResultTensor = rewriter.
createOrFold<tensor::CastOp>(
799 loc, operand.
getType(), resultTensor);
802 opBuilder.create<scf::YieldOp>(loc, castResultTensor);
807 opBuilder.
create<scf::YieldOp>(loc, operand);
811 auto ifOp = rewriter.
create<scf::IfOp>(loc, broadcastNecessary,
812 emitThenRegion, emitElseRegion);
820 int64_t rank = cast<RankedTensorType>(operand.
getType()).getRank();
821 assert((int64_t)targetShape.size() == rank);
822 assert((int64_t)masterOperands.size() == rank);
823 for (
auto index : llvm::seq<int64_t>(0, rank))
826 targetShape[index], masterOperands[index]);
836 if (operands.size() == 1)
840 return llvm::map_to_vector(operands, [&](
Value operand) {
842 targetShape, masterOperands);
852 Value outputTensor = rewriter.
create<tensor::EmptyOp>(
853 loc, targetShape, resultType.getElementType());
858 auto rank = resultType.getRank();
859 auto affineMaps = llvm::map_to_vector(operands, [&](
Value operand) {
860 auto shape = cast<ShapedType>(operand.
getType()).getShape();
865 affineExprs.push_back(affineExpr);
872 bool encounteredError =
false;
873 auto linalgOp = rewriter.
create<linalg::GenericOp>(
874 loc, outputTensor.
getType(), operands, outputTensor, affineMaps,
879 {resultType.getElementType()}, rewriter);
881 encounteredError =
true;
884 opBuilder.create<linalg::YieldOp>(loc, opResult);
886 if (encounteredError)
888 operation,
"unable to create linalg.generic body for elementwise op");
891 auto castResult = rewriter.
createOrFold<tensor::CastOp>(
892 loc, resultType, linalgOp->getResult(0));
893 rewriter.
replaceOp(operation, castResult);
902 assert(operation->
getNumResults() == 1 &&
"elementwise op expects 1 result");
904 "elementwise op expects at least 1 operand");
907 "Unranked tensors not supported");
911 auto loc = operation->
getLoc();
913 auto [targetShape, masterOperands] =
916 rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
925 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
928 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
931 if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy))
934 if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy))
937 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
939 elementTy, APFloat::getLargest(
940 cast<FloatType>(elementTy).getFloatSemantics(),
false));
942 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
946 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
948 elementTy, APFloat::getLargest(
949 cast<FloatType>(elementTy).getFloatSemantics(),
true));
951 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
955 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
958 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
961 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
963 elementTy, APFloat::getLargest(
964 cast<FloatType>(elementTy).getFloatSemantics(),
true));
966 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
980 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
981 return rewriter.
create<arith::AddFOp>(loc, args);
984 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
985 return rewriter.
create<arith::AddIOp>(loc, args);
988 if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) {
989 return rewriter.
create<arith::MulFOp>(loc, args);
992 if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) {
993 return rewriter.
create<arith::MulIOp>(loc, args);
996 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
997 return rewriter.
create<arith::MinimumFOp>(loc, args[0], args[1]);
1000 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
1001 return rewriter.
create<arith::MinSIOp>(loc, args[0], args[1]);
1004 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
1005 return rewriter.
create<arith::MaximumFOp>(loc, args[0], args[1]);
1008 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
1009 return rewriter.
create<arith::MaxSIOp>(loc, args[0], args[1]);
1012 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
1013 return rewriter.
create<arith::AndIOp>(loc, args);
1015 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
1016 return rewriter.
create<arith::OrIOp>(loc, args);
1029 auto elementTy = resultTy.getElementType();
1034 for (
unsigned i = 0; i < inputTy.getRank(); i++) {
1036 reduceShape.push_back(inputTy.getDimSize(i));
1037 if (inputTy.isDynamicDim(i))
1038 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1045 .
create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
1052 op,
"No initial value found for reduction operation");
1054 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
1055 auto filledTensor = rewriter
1060 bool didEncounterError =
false;
1061 auto linalgOp = rewriter.
create<linalg::ReduceOp>(
1062 loc, input, filledTensor, axis,
1065 op, blockArgs, elementTy, rewriter);
1067 didEncounterError =
true;
1069 nestedBuilder.create<linalg::YieldOp>(loc, result);
1072 if (!didEncounterError)
1074 op,
"unable to create linalg.generic body for reduce op");
1077 uint64_t expandInputRank =
1078 cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
1079 reassociationMap.resize(expandInputRank);
1081 for (uint64_t i = 0; i < expandInputRank; i++) {
1082 int32_t dimToPush = i > axis ? i + 1 : i;
1086 if (expandInputRank != 0) {
1087 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1088 reassociationMap[expandedDim].push_back(
1097 op, resultTy, linalgOp.
getResults()[0], reassociationMap);
1103 template <
typename SrcOp>
1121 auto input = op.getInput();
1122 auto inputTy = cast<ShapedType>(op.getInput().getType());
1123 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1124 unsigned rank = inputTy.getRank();
1127 if (op.getDoubleRound() && !op.getScale32())
1128 return rewriter.notifyMatchFailure(
1129 op,
"tosa.rescale requires scale32 for double_round to be true");
1132 for (
int i = 0; i < outputTy.getRank(); i++) {
1133 if (outputTy.isDynamicDim(i)) {
1134 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1143 for (
int i = 0, s = multiplierValues.size(); i < s; i++) {
1144 if (shiftValues[i] > 63) {
1146 multiplierValues[i] = 0;
1153 op.getDoubleRound() &&
1154 llvm::any_of(shiftValues, [](int32_t v) {
return v > 31; });
1157 rewriter.getMultiDimIdentityMap(rank)};
1162 Value multiplierConstant;
1163 int64_t multiplierArg = 0;
1164 if (multiplierValues.size() == 1) {
1165 multiplierConstant = rewriter.create<arith::ConstantOp>(
1166 loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1169 rewriter.getAffineDimExpr(rank - 1)};
1170 auto multiplierType =
1172 rewriter.getI32Type());
1173 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1178 rewriter.getContext()));
1180 multiplierArg = indexingMaps.size() - 1;
1185 Value shiftConstant;
1186 int64_t shiftArg = 0;
1187 if (shiftValues.size() == 1) {
1188 shiftConstant = rewriter.create<arith::ConstantOp>(
1189 loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1192 rewriter.getAffineDimExpr(rank - 1)};
1195 rewriter.getIntegerType(8));
1196 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1200 rewriter.getContext()));
1201 shiftArg = indexingMaps.size() - 1;
1205 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1208 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1209 loc, outputTy.getShape(), outputTy.getElementType(),
1212 auto linalgOp = rewriter.create<linalg::GenericOp>(
1213 loc, outputTy, genericInputs,
ValueRange{emptyTensor}, indexingMaps,
1217 Value value = blockArgs[0];
1225 auto inputZp = createConstFromIntAttribute<int32_t>(
1226 op,
"input_zp", nestedBuilder.getIntegerType(inBitwidth),
1228 auto outputZp = createConstFromIntAttribute<int32_t>(
1229 op,
"output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1231 Value multiplier = multiplierConstant ? multiplierConstant
1232 : blockArgs[multiplierArg];
1233 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1237 value = nestedBuilder
1238 .create<UnrealizedConversionCastOp>(
1240 nestedBuilder.getIntegerType(
1244 value = nestedBuilder.create<arith::ExtUIOp>(
1245 nestedLoc, nestedBuilder.getI32Type(), value);
1247 value = nestedBuilder.create<arith::ExtSIOp>(
1248 nestedLoc, nestedBuilder.getI32Type(), value);
1253 nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1255 value = nestedBuilder.create<tosa::ApplyScaleOp>(
1256 loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1257 nestedBuilder.getBoolAttr(doubleRound));
1261 nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1264 IntegerType outIntType =
1265 cast<IntegerType>(blockArgs.back().getType());
1266 unsigned outBitWidth = outIntType.getWidth();
1268 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1269 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1272 if (outIntType.isUnsignedInteger()) {
1274 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1277 auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1278 loc, nestedBuilder.getI32IntegerAttr(intMin));
1279 auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1280 loc, nestedBuilder.getI32IntegerAttr(intMax));
1285 if (outIntType.getWidth() < 32) {
1286 value = nestedBuilder.create<arith::TruncIOp>(
1287 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1290 if (outIntType.isUnsignedInteger()) {
1291 value = nestedBuilder
1292 .create<UnrealizedConversionCastOp>(nestedLoc,
1298 nestedBuilder.create<linalg::YieldOp>(loc, value);
1301 rewriter.replaceOp(op, linalgOp->getResults());
1317 auto input = op.getInput();
1318 auto inputTy = cast<RankedTensorType>(input.getType());
1319 auto resultTy = cast<RankedTensorType>(op.getType());
1320 const bool isBilinear = op.getMode() ==
"BILINEAR";
1322 auto inputH = inputTy.getDimSize(1);
1323 auto inputW = inputTy.getDimSize(2);
1324 auto outputH = resultTy.getDimSize(1);
1325 auto outputW = resultTy.getDimSize(2);
1327 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1328 return rewriter.notifyMatchFailure(
1329 op,
"tosa.resize is not a pure 1x1->1x1 image operation");
1332 if (op.getMode() !=
"NEAREST_NEIGHBOR" && op.getMode() !=
"BILINEAR")
1333 return rewriter.notifyMatchFailure(
1334 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1336 if (inputTy == resultTy) {
1337 rewriter.replaceOp(op, input);
1345 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1346 reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1347 reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1348 reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1352 inputTy.getElementType());
1353 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1358 if (inputTy.isDynamicDim(0))
1359 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1360 if (inputTy.isDynamicDim(3))
1361 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1364 auto genericTy = collapseTy.clone(resultTy.getElementType());
1365 Value empty = builder.create<tensor::EmptyOp>(
1366 genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1367 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1369 utils::IteratorType::parallel);
1371 auto generic = builder.create<linalg::GenericOp>(
1375 Value value = args[0];
1377 if (inputTy.getElementType() != resultTy.getElementType()) {
1379 b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1381 if (isBilinear && scale[0] != 0) {
1382 Value scaleY = b.create<arith::ConstantOp>(
1383 loc, b.getI32IntegerAttr(scale[0]));
1384 value = b.create<arith::MulIOp>(loc, value, scaleY);
1387 if (isBilinear && scale[2] != 0) {
1388 Value scaleX = b.create<arith::ConstantOp>(
1389 loc, b.getI32IntegerAttr(scale[2]));
1390 value = b.create<arith::MulIOp>(loc, value, scaleX);
1394 b.create<linalg::YieldOp>(loc, value);
1397 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1398 op, resultTy,
generic.
getResults()[0], reassociationMap);
1414 auto input = op.getInput();
1415 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1416 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1418 if (!inputTy || !resultTy)
1419 return rewriter.notifyMatchFailure(op,
1420 "requires ranked input/output types");
1422 auto batch = inputTy.getDimSize(0);
1423 auto channels = inputTy.getDimSize(3);
1424 auto inputH = inputTy.getDimSize(1);
1425 auto inputW = inputTy.getDimSize(2);
1426 auto outputH = resultTy.getDimSize(1);
1427 auto outputW = resultTy.getDimSize(2);
1429 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1430 return rewriter.notifyMatchFailure(
1431 op,
"tosa.resize has no broadcasting behavior");
1436 resizeShape.push_back(batch);
1437 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1438 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1439 resizeShape.push_back(channels);
1441 auto resizeTy = resultTy.clone(resizeShape);
1443 builder.create<tosa::ResizeOp>(resizeTy, input, op->
getAttrs());
1447 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1448 reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1450 reassociationMap.push_back({});
1451 reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1453 reassociationMap.push_back({});
1454 reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1458 collapseShape.push_back(outputH);
1460 collapseShape.push_back(outputW);
1461 collapseShape.push_back(channels);
1463 auto collapseTy = resultTy.clone(collapseShape);
1464 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1469 if (inputTy.isDynamicDim(0))
1470 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1471 if (inputTy.isDynamicDim(3))
1472 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1475 utils::IteratorType::parallel);
1476 Value empty = builder.create<tensor::EmptyOp>(
1477 resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1481 inputExprs.push_back(rewriter.getAffineDimExpr(1));
1483 inputExprs.push_back(rewriter.getAffineDimExpr(2));
1484 inputExprs.push_back(rewriter.getAffineDimExpr(3));
1487 inputExprs, rewriter.getContext());
1489 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1490 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1494 Value value = args[0];
1495 b.create<linalg::YieldOp>(loc, value);
1510 auto input = op.getInput();
1511 auto inputTy = cast<ShapedType>(input.getType());
1512 auto resultTy = cast<ShapedType>(op.getType());
1513 auto resultETy = resultTy.getElementType();
1515 bool floatingPointMode = resultETy.isF16() || resultETy.isF32();
1516 auto floatTy = resultETy.isF16() ? b.getF16Type() : b.getF32Type();
1518 auto imageH = inputTy.getShape()[1];
1519 auto imageW = inputTy.getShape()[2];
1521 auto dynamicDimsOr =
1523 if (!dynamicDimsOr.has_value())
1524 return rewriter.notifyMatchFailure(
1525 op,
"unable to get dynamic dimensions of tosa.resize");
1527 if (op.getMode() !=
"NEAREST_NEIGHBOR" && op.getMode() !=
"BILINEAR")
1528 return rewriter.notifyMatchFailure(
1529 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1532 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1533 auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1535 auto genericOp = b.create<linalg::GenericOp>(
1538 Value resize = genericOp.getResult(0);
1542 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1544 Value batch = b.create<linalg::IndexOp>(0);
1545 Value y = b.create<linalg::IndexOp>(1);
1546 Value x = b.create<linalg::IndexOp>(2);
1547 Value channel = b.create<linalg::IndexOp>(3);
1550 b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1551 Value zeroFp = b.create<arith::ConstantOp>(b.getZeroAttr(floatTy));
1552 Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1553 Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1555 Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1556 Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1562 Value yScaleN, yScaleD, xScaleN, xScaleD;
1563 yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1564 yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1565 xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1566 xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1568 Value yOffset, xOffset, yBorder, xBorder;
1569 yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1570 xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1571 yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1572 xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1585 Value val = b.create<arith::MulIOp>(in, scaleD);
1586 val = b.create<arith::AddIOp>(val, offset);
1587 index = b.create<arith::FloorDivSIOp>(val, scaleN);
1591 Value r = b.create<arith::RemSIOp>(val, scaleN);
1592 Value rFp = b.create<arith::SIToFPOp>(floatTy, r);
1593 Value scaleNfp = b.create<arith::UIToFPOp>(floatTy, scaleN);
1594 delta = b.create<arith::DivFOp>(rFp, scaleNfp);
1609 Value val = b.create<arith::MulIOp>(in, scaleD);
1610 val = b.create<arith::AddIOp>(val, offset);
1611 index = b.create<arith::DivSIOp>(val, scaleN);
1612 delta = b.create<arith::MulIOp>(index, scaleN);
1613 delta = b.create<arith::SubIOp>(val, delta);
1616 Value ix, iy, dx, dy;
1617 if (floatingPointMode) {
1618 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1619 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1621 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1622 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1625 if (op.getMode() ==
"NEAREST_NEIGHBOR") {
1626 auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1628 auto getNearestIndexAndClamp = [&](
Value val,
Value dval,
Value scale,
1632 return b.create<arith::ConstantIndexOp>(0);
1636 if (floatingPointMode) {
1637 auto h = b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 0.5f));
1638 pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1640 Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1641 pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1645 auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1646 val = b.create<arith::AddIOp>(val, offset);
1648 return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1651 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1652 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1654 Value result = b.create<tensor::ExtractOp>(
1657 b.create<linalg::YieldOp>(result);
1660 assert(op.getMode() ==
"BILINEAR");
1662 auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1664 auto getClampedIdxs = [&](
Value &val0,
Value &val1,
int size,
Value in,
1667 val1 = b.create<arith::AddIOp>(val0, oneVal);
1670 val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1671 val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1679 Value x0, x1, y0, y1;
1680 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1681 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1683 Value y0x0 = b.create<tensor::ExtractOp>(
1685 Value y0x1 = b.create<tensor::ExtractOp>(
1687 Value y1x0 = b.create<tensor::ExtractOp>(
1689 Value y1x1 = b.create<tensor::ExtractOp>(
1692 if (floatingPointMode) {
1694 b.create<arith::ConstantOp>(b.getFloatAttr(floatTy, 1.0f));
1700 Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1701 Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1702 Value mul1 = b.create<arith::MulFOp>(val1, delta);
1703 return b.create<arith::AddFOp>(mul0, mul1);
1709 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1714 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1718 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1719 b.create<linalg::YieldOp>(result);
1722 y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1723 y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1724 y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1725 y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1728 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1729 dx = b.create<arith::ExtSIOp>(resultETy, dx);
1730 dy = b.create<arith::ExtSIOp>(resultETy, dy);
1733 Value yScaleNExt = yScaleN;
1734 Value xScaleNExt = xScaleN;
1736 const int64_t scaleBitwidth =
1738 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
1739 yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
1740 xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
1744 Value scale,
int inputSize,
1747 return b.create<arith::MulIOp>(val0, scale);
1748 Value weight0 = b.create<arith::SubIOp>(scale, weight1);
1749 Value mul0 = b.create<arith::MulIOp>(val0, weight0);
1750 Value mul1 = b.create<arith::MulIOp>(val1, weight1);
1751 return b.create<arith::AddIOp>(mul0, mul1);
1754 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
1755 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
1757 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
1758 b.create<linalg::YieldOp>(result);
1763 rewriter.replaceOp(op, resize);
1771 template <
typename SrcOp>
1778 rewriter.replaceOp(op, op.getOperation()->
getOperands());
1783 template <
typename SrcOp>
1801 Value input = op.getInput();
1802 auto inputTy = cast<ShapedType>(input.
getType());
1803 auto resultTy = cast<ShapedType>(op.getType());
1804 auto axis = op.getAxis();
1807 for (
int i = 0; i < inputTy.getRank(); i++) {
1808 if (inputTy.isDynamicDim(i)) {
1809 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1813 Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1816 auto emptyTensor = rewriter
1817 .create<tensor::EmptyOp>(loc, inputTy.getShape(),
1818 inputTy.getElementType(),
1822 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1824 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1829 for (
unsigned int i = 0; i < inputTy.getRank(); i++) {
1831 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1833 auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1835 rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1836 index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1840 indices.push_back(index);
1843 auto extract = nestedBuilder.create<tensor::ExtractOp>(
1844 nestedLoc, input, indices);
1845 nestedBuilder.create<linalg::YieldOp>(op.
getLoc(),
1846 extract.getResult());
1860 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1863 auto input = op.getInput1();
1864 auto inputTy = cast<ShapedType>(input.getType());
1865 auto inputShape = inputTy.getShape();
1866 auto resultTy = cast<ShapedType>(op.getType());
1867 auto elementTy = inputTy.getElementType();
1868 int64_t rank = inputTy.getRank();
1874 for (
int i = 0; i < rank; i++) {
1875 int64_t dim = multiples[i];
1876 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
1877 genericShape.push_back(inputShape[i]);
1881 for (
int i = 0; i < inputTy.getRank(); i++) {
1882 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1883 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1887 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
1888 op.
getLoc(), genericShape, elementTy, dynDims);
1892 dimExprs.reserve(rank);
1893 for (
unsigned i = 0; i < rank; ++i)
1896 auto readAffineMap =
1903 auto genericOp = rewriter.
create<linalg::GenericOp>(
1908 nestedBuilder.create<linalg::YieldOp>(op.
getLoc(), *args.begin());
1937 auto loc = argmaxOp.getLoc();
1938 Value input = argmaxOp.getInput();
1939 auto inputTy = cast<ShapedType>(input.
getType());
1940 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
1941 auto inElementTy = inputTy.getElementType();
1942 auto outElementTy = resultTy.getElementType();
1943 int axis = argmaxOp.getAxis();
1946 if (!isa<IntegerType>(outElementTy))
1949 "tosa.arg_max to linalg.* requires integer-like result type");
1952 for (
int i = 0; i < inputTy.getRank(); i++) {
1953 if (inputTy.isDynamicDim(i) && i != axis) {
1954 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1959 auto emptyTensorIdx = rewriter
1960 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
1961 outElementTy, dynDims)
1963 auto fillValueIdx = rewriter.
create<arith::ConstantOp>(
1965 auto filledTensorIdx =
1972 auto emptyTensorMax = rewriter
1973 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
1974 inElementTy, dynDims)
1976 auto fillValueMaxAttr =
1979 if (!fillValueMaxAttr)
1981 argmaxOp,
"unsupported tosa.argmax element type");
1984 rewriter.
create<arith::ConstantOp>(loc, fillValueMaxAttr);
1985 auto filledTensorMax =
1994 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
1995 iteratorTypes[axis] = utils::IteratorType::reduction;
1999 for (
int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2005 bool didEncounterError =
false;
2008 auto linalgOp = rewriter.
create<linalg::GenericOp>(
2010 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2013 auto newValue = blockArgs[0];
2014 auto oldIndex = blockArgs[1];
2015 auto oldValue = blockArgs[2];
2017 Value newIndex = rewriter.
create<arith::IndexCastOp>(
2018 nestedLoc, oldIndex.getType(),
2019 rewriter.
create<linalg::IndexOp>(loc, axis));
2022 if (isa<FloatType>(inElementTy)) {
2023 predicate = rewriter.
create<arith::CmpFOp>(
2024 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2025 }
else if (isa<IntegerType>(inElementTy)) {
2026 predicate = rewriter.
create<arith::CmpIOp>(
2027 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2029 didEncounterError =
true;
2033 auto resultMax = rewriter.
create<arith::SelectOp>(
2034 nestedLoc, predicate, newValue, oldValue);
2035 auto resultIndex = rewriter.
create<arith::SelectOp>(
2036 nestedLoc, predicate, newIndex, oldIndex);
2037 nestedBuilder.
create<linalg::YieldOp>(
2038 nestedLoc,
ValueRange({resultIndex, resultMax}));
2041 if (didEncounterError)
2043 argmaxOp,
"unsupported tosa.argmax element type");
2045 rewriter.
replaceOp(argmaxOp, linalgOp.getResult(0));
2054 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2056 auto input = adaptor.getOperands()[0];
2057 auto indices = adaptor.getOperands()[1];
2060 dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
2061 auto resultTy = cast<ShapedType>(op.getType());
2066 auto dynamicDims = inferDynamicDimsForGather(
2067 rewriter, op.
getLoc(), adaptor.getValues(), adaptor.getIndices());
2069 auto resultElementTy = resultTy.getElementType();
2074 .
create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2080 resultTy.getRank(), 0,
2081 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2085 auto genericOp = rewriter.
create<linalg::GenericOp>(
2090 auto indexValue = args[0];
2091 auto index0 = rewriter.
create<linalg::IndexOp>(loc, 0);
2092 Value index1 = rewriter.
create<arith::IndexCastOp>(
2094 auto index2 = rewriter.
create<linalg::IndexOp>(loc, 2);
2095 Value extract = rewriter.
create<tensor::ExtractOp>(
2096 loc, input,
ValueRange{index0, index1, index2});
2097 rewriter.
create<linalg::YieldOp>(loc, extract);
2099 rewriter.
replaceOp(op, genericOp.getResult(0));
2109 auto addDynamicDimension = [&](
Value source, int64_t dim) {
2111 if (
auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2112 results.push_back(dimValue);
2115 addDynamicDimension(values, 0);
2116 addDynamicDimension(indices, 1);
2117 addDynamicDimension(values, 2);
2132 Value input = op.getInput();
2133 Value table = op.getTable();
2134 auto inputTy = cast<ShapedType>(input.
getType());
2135 auto tableTy = cast<ShapedType>(table.
getType());
2136 auto resultTy = cast<ShapedType>(op.getType());
2138 auto inputElementTy = inputTy.getElementType();
2139 auto tableElementTy = tableTy.getElementType();
2140 auto resultElementTy = resultTy.getElementType();
2143 for (
int i = 0; i < resultTy.getRank(); ++i) {
2144 if (inputTy.isDynamicDim(i)) {
2150 auto emptyTensor = rewriter
2151 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2152 resultElementTy, dynDims)
2159 auto genericOp = rewriter.
create<linalg::GenericOp>(
2162 rewriter.
replaceOp(op, genericOp.getResult(0));
2167 &genericOp.getRegion(), genericOp.getRegion().end(),
2168 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2172 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2173 resultElementTy.isInteger(8)) {
2174 Value index = rewriter.
create<arith::IndexCastOp>(
2176 Value offset = rewriter.
create<arith::ConstantIndexOp>(loc, 128);
2181 rewriter.
create<linalg::YieldOp>(loc, extract);
2185 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2186 resultElementTy.isInteger(32)) {
2190 auto offset = rewriter.
create<arith::ConstantOp>(
2192 auto seven = rewriter.
create<arith::ConstantOp>(
2194 auto one = rewriter.
create<arith::ConstantOp>(
2196 auto b1111111 = rewriter.
create<arith::ConstantOp>(
2203 auto extendAdd = rewriter.
create<arith::AddIOp>(loc, extend, offset);
2204 Value index = rewriter.
create<arith::ShRUIOp>(loc, extendAdd, seven);
2206 rewriter.
create<arith::AndIOp>(loc, extendAdd, b1111111);
2211 Value indexPlusOne = rewriter.
create<arith::AddIOp>(loc, index, one);
2213 index = rewriter.
create<arith::IndexCastOp>(
2215 indexPlusOne = rewriter.
create<arith::IndexCastOp>(
2230 Value baseScaled = rewriter.
create<arith::ShLIOp>(loc, base, seven);
2231 Value diff = rewriter.
create<arith::SubIOp>(loc, next, base);
2232 Value diffScaled = rewriter.
create<arith::MulIOp>(loc, diff, fraction);
2234 rewriter.
create<arith::AddIOp>(loc, baseScaled, diffScaled);
2236 rewriter.
create<linalg::YieldOp>(loc, result);
2243 op,
"unable to create body for tosa.table op");
2250 static bool isRankedTensor(
Type type) {
return isa<RankedTensorType>(type); }
2254 auto one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2255 auto two = builder.
create<arith::ConstantIndexOp>(loc, 2);
2258 auto divBy2 = builder.
createOrFold<arith::DivUIOp>(loc, value, two);
2259 auto plusOne = builder.
createOrFold<arith::AddIOp>(loc, divBy2, one);
2263 static RankedTensorType
2271 dims[2] = halfPlusOne(builder, loc, dims[2]);
2276 auto elementType = cast<RankedTensorType>(input.
getType()).getElementType();
2281 RankedTensorType type,
2284 rewriter.
create<tensor::EmptyOp>(loc, type, dynamicSizes);
2285 auto fillValueAttr = rewriter.
getZeroAttr(type.getElementType());
2286 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
2287 auto filledTensor = rewriter
2291 return filledTensor;
2296 auto integerVal = builder.
create<arith::IndexCastUIOp>(
2302 return builder.
create<arith::UIToFPOp>(loc, type, integerVal);
2307 auto indexVal = builder.
create<linalg::IndexOp>(loc, index);
2308 return castIndexToFloat(builder, loc, type, indexVal);
2311 template <
typename... Args>
2319 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2320 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2322 "only supports ranked tensors");
2325 auto loc = rfft2d.getLoc();
2326 auto input = rfft2d.getInput();
2328 cast<FloatType>(cast<ShapedType>(input.
getType()).getElementType());
2332 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2336 utils::IteratorType::parallel, utils::IteratorType::parallel,
2337 utils::IteratorType::parallel, utils::IteratorType::reduction,
2338 utils::IteratorType::reduction};
2343 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2344 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2349 affineDimsExpr(rewriter, 0, 1, 2),
2350 affineDimsExpr(rewriter, 0, 1, 2)},
2354 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input, 1);
2355 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input, 2);
2358 auto twoPiAttr = rewriter.
getFloatAttr(elementType, 6.283185307179586);
2359 auto twoPi = rewriter.
create<arith::ConstantOp>(loc, twoPiAttr);
2360 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2361 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2364 Value valReal = args[0];
2365 Value sumReal = args[1];
2366 Value sumImag = args[2];
2369 Value oy = builder.
create<linalg::IndexOp>(loc, 1);
2370 Value ox = builder.
create<linalg::IndexOp>(loc, 2);
2371 Value iy = builder.
create<linalg::IndexOp>(loc, 3);
2372 Value ix = builder.
create<linalg::IndexOp>(loc, 4);
2377 auto iyXoy = builder.
create<index::MulOp>(loc, iy, oy);
2378 auto ixXox = builder.
create<index::MulOp>(loc, ix, ox);
2380 auto iyRem = builder.
create<index::RemUOp>(loc, iyXoy, dimH);
2381 auto ixRem = builder.
create<index::RemUOp>(loc, ixXox, dimW);
2383 auto iyRemFloat = castIndexToFloat(builder, loc, elementType, iyRem);
2384 auto ixRemFloat = castIndexToFloat(builder, loc, elementType, ixRem);
2386 auto yComponent = builder.
create<arith::DivFOp>(loc, iyRemFloat, constH);
2387 auto xComponent = builder.
create<arith::DivFOp>(loc, ixRemFloat, constW);
2388 auto sumXY = builder.
create<arith::AddFOp>(loc, yComponent, xComponent);
2389 auto angle = builder.
create<arith::MulFOp>(loc, twoPi, sumXY);
2393 auto cosAngle = builder.
create<math::CosOp>(loc, angle);
2394 auto sinAngle = builder.
create<math::SinOp>(loc, angle);
2395 auto realComponent =
2396 builder.
create<arith::MulFOp>(loc, valReal, cosAngle);
2397 auto imagComponent =
2398 builder.
create<arith::MulFOp>(loc, valReal, sinAngle);
2402 auto outReal = builder.
create<arith::AddFOp>(loc, sumReal, realComponent);
2403 auto outImag = builder.
create<arith::SubFOp>(loc, sumImag, imagComponent);
2409 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2410 indexingMaps, iteratorTypes, buildBody);
2421 if (!llvm::all_of(fft2d->getOperandTypes(),
2422 RFFT2dConverter::isRankedTensor) ||
2423 !llvm::all_of(fft2d->getResultTypes(),
2424 RFFT2dConverter::isRankedTensor)) {
2429 Value input_real = fft2d.getInputReal();
2430 Value input_imag = fft2d.getInputImag();
2431 BoolAttr inverse = fft2d.getInverseAttr();
2433 auto real_el_ty = cast<FloatType>(
2434 cast<ShapedType>(input_real.
getType()).getElementType());
2435 [[maybe_unused]]
auto imag_el_ty = cast<FloatType>(
2436 cast<ShapedType>(input_imag.
getType()).getElementType());
2438 assert(real_el_ty == imag_el_ty);
2453 utils::IteratorType::parallel, utils::IteratorType::parallel,
2454 utils::IteratorType::parallel, utils::IteratorType::reduction,
2455 utils::IteratorType::reduction};
2460 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2462 RFFT2dConverter::createZeroTensor(rewriter, loc, outputType,
2467 ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2468 RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4),
2469 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2),
2470 RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)},
2474 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 1);
2475 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input_real, 2);
2478 auto twoPiAttr = rewriter.
getFloatAttr(real_el_ty, 6.283185307179586);
2479 auto twoPi = rewriter.
create<arith::ConstantOp>(loc, twoPiAttr);
2481 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH);
2483 RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW);
2486 Value valReal = args[0];
2487 Value valImag = args[1];
2488 Value sumReal = args[2];
2489 Value sumImag = args[3];
2492 Value oy = builder.
create<linalg::IndexOp>(loc, 1);
2493 Value ox = builder.
create<linalg::IndexOp>(loc, 2);
2494 Value iy = builder.
create<linalg::IndexOp>(loc, 3);
2495 Value ix = builder.
create<linalg::IndexOp>(loc, 4);
2499 auto iyXoy = builder.
create<index::MulOp>(loc, iy, oy);
2500 auto ixXox = builder.
create<index::MulOp>(loc, ix, ox);
2502 auto iyRem = builder.
create<index::RemUOp>(loc, iyXoy, dimH);
2503 auto ixRem = builder.
create<index::RemUOp>(loc, ixXox, dimW);
2506 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, iyRem);
2508 RFFT2dConverter::castIndexToFloat(builder, loc, real_el_ty, ixRem);
2510 auto yComponent = builder.
create<arith::DivFOp>(loc, iyRemFloat, constH);
2511 auto xComponent = builder.
create<arith::DivFOp>(loc, ixRemFloat, constW);
2513 auto sumXY = builder.
create<arith::AddFOp>(loc, yComponent, xComponent);
2514 auto angle = builder.
create<arith::MulFOp>(loc, twoPi, sumXY);
2517 angle = builder.
create<arith::MulFOp>(
2519 rewriter.
create<arith::ConstantOp>(
2525 auto cosAngle = builder.
create<math::CosOp>(loc, angle);
2526 auto sinAngle = builder.
create<math::SinOp>(loc, angle);
2528 auto rcos = builder.
create<arith::MulFOp>(loc, valReal, cosAngle);
2529 auto rsin = builder.
create<arith::MulFOp>(loc, valImag, sinAngle);
2530 auto realComponent = builder.
create<arith::AddFOp>(loc, rcos, rsin);
2532 auto icos = builder.
create<arith::MulFOp>(loc, valImag, cosAngle);
2533 auto isin = builder.
create<arith::MulFOp>(loc, valReal, sinAngle);
2535 auto imagComponent = builder.
create<arith::SubFOp>(loc, icos, isin);
2539 auto outReal = builder.
create<arith::AddFOp>(loc, sumReal, realComponent);
2540 auto outImag = builder.
create<arith::AddFOp>(loc, sumImag, imagComponent);
2546 fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2547 indexingMaps, iteratorTypes, buildBody);
2559 patterns->
add<GenericResizeConverter>(patterns->
getContext(),
2563 patterns->
add<MaterializeResizeBroadcast>(patterns->
getContext(),
2568 PointwiseConverter<tosa::AddOp>,
2569 PointwiseConverter<tosa::SubOp>,
2570 PointwiseConverter<tosa::MulOp>,
2571 PointwiseConverter<tosa::DivOp>,
2572 PointwiseConverter<tosa::NegateOp>,
2573 PointwiseConverter<tosa::PowOp>,
2574 PointwiseConverter<tosa::ReciprocalOp>,
2575 PointwiseConverter<tosa::RsqrtOp>,
2576 PointwiseConverter<tosa::LogOp>,
2577 PointwiseConverter<tosa::ExpOp>,
2578 PointwiseConverter<tosa::AbsOp>,
2579 PointwiseConverter<tosa::TanhOp>,
2580 PointwiseConverter<tosa::ErfOp>,
2581 PointwiseConverter<tosa::BitwiseAndOp>,
2582 PointwiseConverter<tosa::BitwiseOrOp>,
2583 PointwiseConverter<tosa::BitwiseNotOp>,
2584 PointwiseConverter<tosa::BitwiseXorOp>,
2585 PointwiseConverter<tosa::LogicalAndOp>,
2586 PointwiseConverter<tosa::LogicalNotOp>,
2587 PointwiseConverter<tosa::LogicalOrOp>,
2588 PointwiseConverter<tosa::LogicalXorOp>,
2589 PointwiseConverter<tosa::CastOp>,
2590 PointwiseConverter<tosa::LogicalLeftShiftOp>,
2591 PointwiseConverter<tosa::LogicalRightShiftOp>,
2592 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2593 PointwiseConverter<tosa::ClzOp>,
2594 PointwiseConverter<tosa::SelectOp>,
2595 PointwiseConverter<tosa::GreaterOp>,
2596 PointwiseConverter<tosa::GreaterEqualOp>,
2597 PointwiseConverter<tosa::EqualOp>,
2598 PointwiseConverter<tosa::MaximumOp>,
2599 PointwiseConverter<tosa::MinimumOp>,
2600 PointwiseConverter<tosa::CeilOp>,
2601 PointwiseConverter<tosa::FloorOp>,
2602 PointwiseConverter<tosa::ClampOp>,
2603 PointwiseConverter<tosa::SigmoidOp>,
2604 IdentityNConverter<tosa::IdentityOp>,
2605 ReduceConverter<tosa::ReduceAllOp>,
2606 ReduceConverter<tosa::ReduceAnyOp>,
2607 ReduceConverter<tosa::ReduceMinOp>,
2608 ReduceConverter<tosa::ReduceMaxOp>,
2609 ReduceConverter<tosa::ReduceSumOp>,
2610 ReduceConverter<tosa::ReduceProdOp>,
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, int64_t rank)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, PatternRewriter &rewriter)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, PatternRewriter &rewriter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter)
static LogicalResult emitElementwiseComputation(PatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static SmallVector< Value > expandInputRanks(PatternRewriter &rewriter, Location loc, Operation *operation)
static bool operandsAndResultsRanked(Operation *operation)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
MPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...