31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/Sequence.h"
40 static arith::ConstantOp
43 auto castedN =
static_cast<T
>(
44 cast<IntegerAttr>(op->
getAttr(attrName)).getValue().getSExtValue());
45 return rewriter.
create<arith::ConstantOp>(
58 if (isa<tosa::AbsOp>(op) && isa<FloatType>(elementTy))
59 return rewriter.
create<math::AbsFOp>(loc, resultTypes, args);
61 if (isa<tosa::AbsOp>(op) && isa<IntegerType>(elementTy)) {
62 auto zero = rewriter.
create<arith::ConstantOp>(
64 auto cmp = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
66 auto neg = rewriter.
create<arith::SubIOp>(loc, zero, args[0]);
67 return rewriter.
create<arith::SelectOp>(loc, cmp, args[0], neg);
71 if (isa<tosa::AddOp>(op) && isa<FloatType>(elementTy))
72 return rewriter.
create<arith::AddFOp>(loc, resultTypes, args);
74 if (isa<tosa::AddOp>(op) && isa<IntegerType>(elementTy))
75 return rewriter.
create<arith::AddIOp>(loc, resultTypes, args);
78 if (isa<tosa::SubOp>(op) && isa<FloatType>(elementTy))
79 return rewriter.
create<arith::SubFOp>(loc, resultTypes, args);
81 if (isa<tosa::SubOp>(op) && isa<IntegerType>(elementTy))
82 return rewriter.
create<arith::SubIOp>(loc, resultTypes, args);
85 if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy)) {
86 if (dyn_cast<tosa::MulOp>(op).getShift() != 0) {
88 "Cannot have shift value for float");
91 return rewriter.
create<arith::MulFOp>(loc, resultTypes, args);
95 if (isa<tosa::DivOp>(op) && isa<IntegerType>(elementTy))
96 return rewriter.
create<arith::DivSIOp>(loc, resultTypes, args);
99 if (isa<tosa::ReciprocalOp>(op) && isa<FloatType>(elementTy)) {
102 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, args[0]);
105 if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
109 cast<IntegerAttr>(op->
getAttr(
"shift")).getValue().getSExtValue();
112 rewriter.
create<arith::ConstantIntOp>(loc, shift, 8);
119 auto result = rewriter.
create<tosa::ApplyScaleOp>(
123 if (elementTy.isInteger(32))
126 return rewriter.
create<arith::TruncIOp>(loc, elementTy, result);
131 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
134 a = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], a);
136 b = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], b);
138 return rewriter.
create<arith::MulIOp>(loc, resultTypes, a, b);
142 if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy))
143 return rewriter.
create<arith::NegFOp>(loc, resultTypes, args);
145 if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
146 !cast<tosa::NegateOp>(op).getQuantizationInfo()) {
149 return rewriter.
create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
152 if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy) &&
153 cast<tosa::NegateOp>(op).getQuantizationInfo()) {
154 auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo();
155 int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
156 int64_t inZp = quantizationInfo.value().getInputZp();
157 int64_t outZp = quantizationInfo.value().getOutputZp();
160 int64_t zpAdd = inZp + outZp;
161 int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
166 int intermediateBitWidth = 64;
167 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
168 intermediateBitWidth = 16;
169 }
else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
170 intermediateBitWidth = 32;
171 }
else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
172 intermediateBitWidth = 48;
176 Value zpAddValue = rewriter.
create<arith::ConstantOp>(
181 auto ext = rewriter.
create<arith::ExtSIOp>(loc, intermediateType, args[0]);
182 auto sub = rewriter.
create<arith::SubIOp>(loc, zpAddValue, ext);
186 loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
189 loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
194 return rewriter.
create<arith::TruncIOp>(loc, elementTy,
clamp);
198 if (isa<tosa::BitwiseAndOp>(op) && isa<IntegerType>(elementTy))
199 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
202 if (isa<tosa::BitwiseOrOp>(op) && isa<IntegerType>(elementTy))
203 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
206 if (isa<tosa::BitwiseNotOp>(op) && isa<IntegerType>(elementTy)) {
208 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
209 auto allOnes = rewriter.
create<arith::ConstantOp>(loc, allOnesAttr);
210 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
214 if (isa<tosa::BitwiseXorOp>(op) && isa<IntegerType>(elementTy))
215 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
218 if (isa<tosa::LogicalLeftShiftOp>(op) && isa<IntegerType>(elementTy))
219 return rewriter.
create<arith::ShLIOp>(loc, resultTypes, args);
222 if (isa<tosa::LogicalRightShiftOp>(op) && isa<IntegerType>(elementTy))
223 return rewriter.
create<arith::ShRUIOp>(loc, resultTypes, args);
226 if (isa<tosa::ArithmeticRightShiftOp>(op) && isa<IntegerType>(elementTy)) {
227 auto result = rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args);
228 auto round = cast<BoolAttr>(op->
getAttr(
"round")).getValue();
242 auto shiftValueGreaterThanZero = rewriter.
create<arith::CmpIOp>(
243 loc, arith::CmpIPredicate::sgt, args[1], zero);
247 rewriter.
create<arith::SubIOp>(loc, resultTypes, args[1], one);
249 rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args[0], subtract)
252 rewriter.
create<arith::TruncIOp>(loc, i1Ty, shifted, std::nullopt);
254 rewriter.
create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
256 auto shouldRound = rewriter.
create<arith::AndIOp>(
257 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
259 rewriter.
create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
260 return rewriter.
create<arith::AddIOp>(loc, resultTypes, result, extended);
264 if (isa<tosa::ClzOp>(op) && isa<IntegerType>(elementTy)) {
265 return rewriter.
create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
269 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
270 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
273 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
274 auto one = rewriter.
create<arith::ConstantOp>(
276 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], one);
280 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
281 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
284 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
285 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
288 if (isa<tosa::PowOp>(op) && isa<FloatType>(elementTy))
289 return rewriter.
create<mlir::math::PowFOp>(loc, resultTypes, args);
292 if (isa<tosa::RsqrtOp>(op) && isa<FloatType>(elementTy))
293 return rewriter.
create<mlir::math::RsqrtOp>(loc, resultTypes, args);
296 if (isa<tosa::LogOp>(op) && isa<FloatType>(elementTy))
297 return rewriter.
create<mlir::math::LogOp>(loc, resultTypes, args);
300 if (isa<tosa::ExpOp>(op) && isa<FloatType>(elementTy))
301 return rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, args);
304 if (isa<tosa::TanhOp>(op) && isa<FloatType>(elementTy))
305 return rewriter.
create<mlir::math::TanhOp>(loc, resultTypes, args);
308 if (isa<tosa::ErfOp>(op) && llvm::isa<FloatType>(elementTy))
309 return rewriter.
create<mlir::math::ErfOp>(loc, resultTypes, args);
312 if (isa<tosa::GreaterOp>(op) && isa<FloatType>(elementTy))
313 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
316 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
317 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
321 if (isa<tosa::GreaterEqualOp>(op) && isa<FloatType>(elementTy))
322 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
325 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
326 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
330 if (isa<tosa::EqualOp>(op) && isa<FloatType>(elementTy))
331 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
334 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
335 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
339 if (isa<tosa::SelectOp>(op)) {
341 if (isa<FloatType>(elementTy) || isa<IntegerType>(elementTy))
342 return rewriter.
create<arith::SelectOp>(loc, args[0], args[1], args[2]);
346 if (isa<tosa::MaximumOp>(op) && isa<FloatType>(elementTy)) {
347 return rewriter.
create<arith::MaximumFOp>(loc, args[0], args[1]);
350 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
351 auto predicate = rewriter.
create<arith::CmpIOp>(
352 loc, arith::CmpIPredicate::sgt, args[0], args[1]);
353 return rewriter.
create<arith::SelectOp>(loc, predicate, args[0], args[1]);
357 if (isa<tosa::MinimumOp>(op) && isa<FloatType>(elementTy)) {
358 return rewriter.
create<arith::MinimumFOp>(loc, args[0], args[1]);
361 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
362 auto predicate = rewriter.
create<arith::CmpIOp>(
363 loc, arith::CmpIPredicate::slt, args[0], args[1]);
364 return rewriter.
create<arith::SelectOp>(loc, predicate, args[0], args[1]);
368 if (isa<tosa::CeilOp>(op) && isa<FloatType>(elementTy))
369 return rewriter.
create<math::CeilOp>(loc, resultTypes, args);
372 if (isa<tosa::FloorOp>(op) && isa<FloatType>(elementTy))
373 return rewriter.
create<math::FloorOp>(loc, resultTypes, args);
376 if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
377 bool losesInfo =
false;
378 APFloat minApf = cast<FloatAttr>(op->
getAttr(
"min_fp")).getValue();
379 APFloat maxApf = cast<FloatAttr>(op->
getAttr(
"max_fp")).getValue();
380 minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
381 APFloat::rmNearestTiesToEven, &losesInfo);
382 maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
383 APFloat::rmNearestTiesToEven, &losesInfo);
384 auto min = rewriter.
create<arith::ConstantOp>(
385 loc, elementTy, rewriter.
getFloatAttr(elementTy, minApf));
386 auto max = rewriter.
create<arith::ConstantOp>(
387 loc, elementTy, rewriter.
getFloatAttr(elementTy, maxApf));
391 if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
392 auto intTy = cast<IntegerType>(elementTy);
393 int32_t
min =
static_cast<int32_t
>(
394 cast<IntegerAttr>(op->
getAttr(
"min_int")).getValue().getSExtValue());
395 int32_t
max =
static_cast<int32_t
>(
396 cast<IntegerAttr>(op->
getAttr(
"max_int")).getValue().getSExtValue());
398 if (intTy.isUnsignedInteger()) {
399 min = std::max<int32_t>(
min, 0);
400 max = std::min<int32_t>(
402 APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
404 min = std::max<int32_t>(
405 min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
407 max = std::min<int32_t>(
408 max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
412 auto minVal = rewriter.
create<arith::ConstantIntOp>(
413 loc,
min, intTy.getIntOrFloatBitWidth());
414 auto maxVal = rewriter.
create<arith::ConstantIntOp>(
415 loc,
max, intTy.getIntOrFloatBitWidth());
420 if (isa<tosa::SigmoidOp>(op) && isa<FloatType>(elementTy)) {
423 auto negate = rewriter.
create<arith::NegFOp>(loc, resultTypes, args[0]);
424 auto exp = rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, negate);
425 auto added = rewriter.
create<arith::AddFOp>(loc, resultTypes, exp, one);
426 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, added);
430 if (isa<tosa::CastOp>(op)) {
431 Type srcTy = elementTy;
432 Type dstTy = resultTypes.front();
439 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && bitExtend)
440 return rewriter.
create<arith::ExtFOp>(loc, resultTypes, args,
443 if (isa<FloatType>(srcTy) && isa<FloatType>(dstTy) && !bitExtend)
444 return rewriter.
create<arith::TruncFOp>(loc, resultTypes, args,
448 if (srcTy.
isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
449 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes, args,
452 if (srcTy.
isInteger(1) && isa<IntegerType>(dstTy) && bitExtend)
453 return rewriter.
create<arith::ExtUIOp>(loc, resultTypes, args,
459 auto unrealizedCast =
461 .
create<UnrealizedConversionCastOp>(
465 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes[0],
470 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
471 return rewriter.
create<arith::SIToFPOp>(loc, resultTypes, args,
475 if (isa<FloatType>(srcTy) && dstTy.
isInteger(1)) {
478 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
482 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
483 auto intMin = rewriter.
create<arith::ConstantOp>(
489 auto intMax = rewriter.
create<arith::ConstantOp>(
495 auto rounded = rewriter.
create<math::RoundEvenOp>(loc, args[0]);
499 return rewriter.
create<arith::FPToSIOp>(loc, dstTy, clamped);
504 if (isa<IntegerType>(srcTy) && dstTy.
isInteger(1)) {
505 Value zero = rewriter.
create<arith::ConstantIntOp>(
507 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
511 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && bitExtend)
512 return rewriter.
create<arith::ExtSIOp>(loc, resultTypes, args,
515 if (isa<IntegerType>(srcTy) && isa<IntegerType>(dstTy) && !bitExtend) {
516 return rewriter.
create<arith::TruncIOp>(loc, dstTy, args[0]);
521 op,
"unhandled op for linalg body calculation for elementwise op");
528 auto shapedType = dyn_cast<ShapedType>(tensor.
getType());
529 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
530 int64_t numExtraDims = rank - shapedType.getRank();
531 assert(numExtraDims >= 0 &&
"cannot expand tensor to a lower rank");
537 shapedType.getRank());
539 for (index = 0; index <= numExtraDims; index++)
540 reassociationIndices[0].push_back(index);
541 for (
size_t position = 1; position < reassociationIndices.size(); position++)
542 reassociationIndices[position].push_back(index++);
546 for (index = 0; index < numExtraDims; index++)
547 resultShape.push_back(1);
548 for (
auto size : shapedType.getShape())
549 resultShape.push_back(size);
554 return rewriter.
create<tensor::ExpandShapeOp>(loc, resultType, tensor,
555 reassociationIndices);
563 return expandRank(rewriter, loc, operand, rank);
574 auto [it, inserted] = indexPool.try_emplace(index);
583 auto indexValue =
createIndex(rewriter, loc, indexPool, index);
584 return rewriter.
create<tensor::DimOp>(loc, tensor, indexValue).getResult();
590 auto shapedType = dyn_cast<ShapedType>(tensor.
getType());
591 assert(shapedType && shapedType.hasRank() &&
"expected a ranked shaped type");
592 assert(index >= 0 && index < shapedType.getRank() &&
"index out of bounds");
593 if (shapedType.isDynamicDim(index))
594 return getTensorDim(rewriter, loc, indexPool, tensor, index);
595 return rewriter.
getIndexAttr(shapedType.getDimSize(index));
599 auto isRanked = [](
Value value) {
600 return isa<RankedTensorType>(value.getType());
602 return llvm::all_of(operation->
getOperands(), isRanked) &&
603 llvm::all_of(operation->
getResults(), isRanked);
616 static std::pair<OpFoldResult, Value>
622 for (
auto operand : operands) {
623 auto size = operand.getType().cast<RankedTensorType>().getDimSize(dim);
624 if (!ShapedType::isDynamic(size) && size > 1)
629 auto operandsWithDynamicDim =
630 llvm::to_vector(llvm::make_filter_range(operands, [&](
Value operand) {
631 return operand.
getType().cast<RankedTensorType>().isDynamicDim(dim);
635 if (operandsWithDynamicDim.empty())
642 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[0], dim);
643 if (operandsWithDynamicDim.size() == 1)
644 return {targetSize, operandsWithDynamicDim[0]};
647 for (
size_t i = 1; i < operandsWithDynamicDim.size(); i++) {
649 getTensorDim(rewriter, loc, indexPool, operandsWithDynamicDim[i], dim);
650 targetSize = rewriter.
create<arith::MaxUIOp>(loc, targetSize, nextSize);
652 return {targetSize,
nullptr};
660 assert(!operands.empty());
661 auto rank = operands.front().
getType().cast<RankedTensorType>().getRank();
664 for (
auto dim : llvm::seq<int64_t>(0, rank)) {
665 auto [targetSize, masterOperand] =
667 targetShape.push_back(targetSize);
668 masterOperands.push_back(masterOperand);
670 return {targetShape, masterOperands};
676 Value masterOperand) {
678 auto rankedTensorType = operand.
getType().
cast<RankedTensorType>();
679 if (!rankedTensorType.isDynamicDim(dim))
686 if (operand == masterOperand)
690 auto rank = rankedTensorType.getRank();
692 for (
auto index : llvm::seq<int64_t>(0, rank)) {
695 affineExprs.push_back(affineExpr);
697 auto broadcastAffineMap =
703 auto one =
createIndex(rewriter, loc, indexPool, 1);
704 auto runtimeSize =
getTensorDim(rewriter, loc, indexPool, operand, dim);
705 auto broadcastNecessary = rewriter.
create<arith::CmpIOp>(
706 loc, arith::CmpIPredicate::eq, runtimeSize, one);
712 for (
auto index : llvm::seq<int64_t>(0, rank)) {
713 auto size = index == dim ? targetSize
716 outputTensorShape.push_back(size);
718 Value outputTensor = opBuilder.
create<tensor::EmptyOp>(
719 loc, outputTensorShape, rankedTensorType.getElementType());
724 .
create<linalg::GenericOp>(
725 loc, outputTensor.
getType(), operand, outputTensor, affineMaps,
729 opBuilder.create<linalg::YieldOp>(loc, blockArgs.front());
734 auto castResultTensor = rewriter.
createOrFold<tensor::CastOp>(
735 loc, operand.
getType(), resultTensor);
738 opBuilder.create<scf::YieldOp>(loc, castResultTensor);
743 opBuilder.
create<scf::YieldOp>(loc, operand);
747 auto ifOp = rewriter.
create<scf::IfOp>(loc, broadcastNecessary,
748 emitThenRegion, emitElseRegion);
756 size_t rank = operand.
getType().
cast<RankedTensorType>().getRank();
757 assert(targetShape.size() == rank);
758 assert(masterOperands.size() == rank);
759 for (
auto index : llvm::seq<int64_t>(0, rank))
762 targetShape[index], masterOperands[index]);
772 if (operands.size() == 1)
776 return llvm::map_to_vector(operands, [&](
Value operand) {
778 targetShape, masterOperands);
789 Value outputTensor = rewriter.
create<tensor::EmptyOp>(
790 loc, targetShape, resultType.getElementType());
795 auto rank = resultType.getRank();
796 auto affineMaps = llvm::map_to_vector(operands, [&](
Value operand) {
797 auto shape = cast<ShapedType>(operand.
getType()).getShape();
802 affineExprs.push_back(affineExpr);
809 bool encounteredError =
false;
810 auto linalgOp = rewriter.
create<linalg::GenericOp>(
811 loc, outputTensor.getType(), operands, outputTensor, affineMaps,
816 {resultType.getElementType()}, rewriter);
818 encounteredError =
true;
821 opBuilder.create<linalg::YieldOp>(loc, opResult);
823 if (encounteredError)
825 operation,
"unable to create linalg.generic body for elementwise op");
828 auto castResult = rewriter.
createOrFold<tensor::CastOp>(
829 loc, resultType, linalgOp->getResult(0));
830 rewriter.
replaceOp(operation, castResult);
839 assert(operation->
getNumResults() == 1 &&
"elementwise op expects 1 result");
841 "elementwise op expects at least 1 operand");
844 "Unranked tensors not supported");
848 auto loc = operation->
getLoc();
850 auto [targetShape, masterOperands] =
853 rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
862 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy))
865 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy))
868 if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy))
871 if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy))
874 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy))
876 elementTy, APFloat::getLargest(
877 cast<FloatType>(elementTy).getFloatSemantics(),
false));
879 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy))
883 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy))
885 elementTy, APFloat::getLargest(
886 cast<FloatType>(elementTy).getFloatSemantics(),
true));
888 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy))
892 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
895 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
898 if (isa<tosa::ArgMaxOp>(op) && isa<FloatType>(elementTy))
900 elementTy, APFloat::getLargest(
901 cast<FloatType>(elementTy).getFloatSemantics(),
true));
903 if (isa<tosa::ArgMaxOp>(op) && isa<IntegerType>(elementTy))
917 if (isa<tosa::ReduceSumOp>(op) && isa<FloatType>(elementTy)) {
918 return rewriter.
create<arith::AddFOp>(loc, args);
921 if (isa<tosa::ReduceSumOp>(op) && isa<IntegerType>(elementTy)) {
922 return rewriter.
create<arith::AddIOp>(loc, args);
925 if (isa<tosa::ReduceProdOp>(op) && isa<FloatType>(elementTy)) {
926 return rewriter.
create<arith::MulFOp>(loc, args);
929 if (isa<tosa::ReduceProdOp>(op) && isa<IntegerType>(elementTy)) {
930 return rewriter.
create<arith::MulIOp>(loc, args);
933 if (isa<tosa::ReduceMinOp>(op) && isa<FloatType>(elementTy)) {
934 return rewriter.
create<arith::MinimumFOp>(loc, args[0], args[1]);
937 if (isa<tosa::ReduceMinOp>(op) && isa<IntegerType>(elementTy)) {
938 auto predicate = rewriter.
create<arith::CmpIOp>(
939 loc, arith::CmpIPredicate::slt, args[0], args[1]);
940 return rewriter.
create<arith::SelectOp>(loc, predicate, args[0], args[1]);
943 if (isa<tosa::ReduceMaxOp>(op) && isa<FloatType>(elementTy)) {
944 return rewriter.
create<arith::MaximumFOp>(loc, args[0], args[1]);
947 if (isa<tosa::ReduceMaxOp>(op) && isa<IntegerType>(elementTy)) {
948 auto predicate = rewriter.
create<arith::CmpIOp>(
949 loc, arith::CmpIPredicate::sgt, args[0], args[1]);
950 return rewriter.
create<arith::SelectOp>(loc, predicate, args[0], args[1]);
953 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
954 return rewriter.
create<arith::AndIOp>(loc, args);
956 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
957 return rewriter.
create<arith::OrIOp>(loc, args);
970 auto elementTy = resultTy.getElementType();
975 for (
unsigned i = 0; i < inputTy.getRank(); i++) {
977 reduceShape.push_back(inputTy.getDimSize(i));
978 if (inputTy.isDynamicDim(i))
979 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
988 .
create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(),
995 op,
"No initial value found for reduction operation");
997 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
998 auto filledTensor = rewriter
1006 for (
unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
1009 iteratorTypes.push_back(axis == i ? utils::IteratorType::reduction
1010 : utils::IteratorType::parallel);
1015 bool didEncounterError =
false;
1017 auto linalgOp = rewriter.
create<linalg::GenericOp>(
1018 loc, reduceTy, input, filledTensor, maps, iteratorTypes,
1021 op, blockArgs, elementTy, rewriter);
1023 didEncounterError =
true;
1025 nestedBuilder.create<linalg::YieldOp>(loc, result);
1028 if (!didEncounterError)
1030 op,
"unable to create linalg.generic body for reduce op");
1033 uint64_t expandInputRank =
1034 cast<ShapedType>(linalgOp.getResults()[0].getType()).getRank();
1035 reassociationMap.resize(expandInputRank);
1037 for (uint64_t i = 0; i < expandInputRank; i++) {
1038 int32_t dimToPush = i > axis ? i + 1 : i;
1042 if (expandInputRank != 0) {
1043 int32_t expandedDim = axis < expandInputRank ? axis : expandInputRank - 1;
1044 reassociationMap[expandedDim].push_back(
1053 op, resultTy, linalgOp.
getResults()[0], reassociationMap);
1059 template <
typename SrcOp>
1078 return rewriter.notifyMatchFailure(op,
"unmatched permutation tensor");
1083 auto resultTy = cast<ShapedType>(op.getType());
1089 inputExprs.resize(resultTy.getRank());
1090 auto operandTy = cast<ShapedType>(input.getType());
1091 for (
const auto &permutation :
llvm::enumerate(perms.getValues<APInt>())) {
1092 auto index = permutation.index();
1093 auto value = permutation.value().getZExtValue();
1094 if (!operandTy.hasRank() || operandTy.isDynamicDim(index)) {
1095 dynDims[value] = rewriter.create<tensor::DimOp>(loc, input, index);
1097 inputExprs[value] = rewriter.getAffineDimExpr(index);
1102 auto emptyTensor = rewriter.create<tensor::EmptyOp>(
1103 loc, resultTy.getShape(), resultTy.getElementType(), filteredDims);
1107 rewriter.getContext()),
1110 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1111 op, resultTy, op.getInput1(),
ValueRange{emptyTensor}, affineMaps,
1114 nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
1127 auto input = op.getInput();
1128 auto inputTy = cast<ShapedType>(op.getInput().getType());
1129 auto outputTy = cast<ShapedType>(op.getOutput().getType());
1130 unsigned rank = inputTy.getRank();
1133 if (op.getDoubleRound() && !op.getScale32())
1134 return rewriter.notifyMatchFailure(
1135 op,
"tosa.rescale requires scale32 for double_round to be true");
1138 for (
int i = 0; i < outputTy.getRank(); i++) {
1139 if (outputTy.isDynamicDim(i)) {
1140 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1149 for (
int i = 0, s = multiplierValues.size(); i < s; i++) {
1150 if (shiftValues[i] > 63) {
1152 multiplierValues[i] = 0;
1159 op.getDoubleRound() &&
1160 llvm::any_of(shiftValues, [](int32_t v) {
return v > 31; });
1163 rewriter.getMultiDimIdentityMap(rank)};
1168 Value multiplierConstant;
1169 int64_t multiplierArg = 0;
1170 if (multiplierValues.size() == 1) {
1171 multiplierConstant = rewriter.create<arith::ConstantOp>(
1172 loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1175 rewriter.getAffineDimExpr(rank - 1)};
1176 auto multiplierType =
1178 rewriter.getI32Type());
1179 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1184 rewriter.getContext()));
1186 multiplierArg = indexingMaps.size() - 1;
1191 Value shiftConstant;
1192 int64_t shiftArg = 0;
1193 if (shiftValues.size() == 1) {
1194 shiftConstant = rewriter.create<arith::ConstantOp>(
1195 loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1198 rewriter.getAffineDimExpr(rank - 1)};
1201 rewriter.getIntegerType(8));
1202 genericInputs.push_back(rewriter.create<arith::ConstantOp>(
1206 rewriter.getContext()));
1207 shiftArg = indexingMaps.size() - 1;
1211 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank));
1214 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
1215 loc, outputTy.getShape(), outputTy.getElementType(),
1218 auto linalgOp = rewriter.create<linalg::GenericOp>(
1219 loc, outputTy, genericInputs,
ValueRange{emptyTensor}, indexingMaps,
1223 Value value = blockArgs[0];
1231 auto inputZp = createConstFromIntAttribute<int32_t>(
1232 op,
"input_zp", nestedBuilder.getIntegerType(inBitwidth),
1234 auto outputZp = createConstFromIntAttribute<int32_t>(
1235 op,
"output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1237 Value multiplier = multiplierConstant ? multiplierConstant
1238 : blockArgs[multiplierArg];
1239 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1243 value = nestedBuilder
1244 .create<UnrealizedConversionCastOp>(
1246 nestedBuilder.getIntegerType(
1250 value = nestedBuilder.create<arith::ExtUIOp>(
1251 nestedLoc, nestedBuilder.getI32Type(), value);
1253 value = nestedBuilder.create<arith::ExtSIOp>(
1254 nestedLoc, nestedBuilder.getI32Type(), value);
1259 nestedBuilder.create<arith::SubIOp>(nestedLoc, value, inputZp);
1261 value = nestedBuilder.create<tosa::ApplyScaleOp>(
1262 loc, nestedBuilder.getI32Type(), value, multiplier, shift,
1263 nestedBuilder.getBoolAttr(doubleRound));
1267 nestedBuilder.create<arith::AddIOp>(nestedLoc, value, outputZp);
1270 IntegerType outIntType =
1271 cast<IntegerType>(blockArgs.back().getType());
1272 unsigned outBitWidth = outIntType.getWidth();
1274 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1275 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1278 if (outIntType.isUnsignedInteger()) {
1280 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1283 auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1284 loc, nestedBuilder.getI32IntegerAttr(intMin));
1285 auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1286 loc, nestedBuilder.getI32IntegerAttr(intMax));
1291 if (outIntType.getWidth() < 32) {
1292 value = nestedBuilder.create<arith::TruncIOp>(
1293 nestedLoc, rewriter.getIntegerType(outIntType.getWidth()),
1296 if (outIntType.isUnsignedInteger()) {
1297 value = nestedBuilder
1298 .create<UnrealizedConversionCastOp>(nestedLoc,
1304 nestedBuilder.create<linalg::YieldOp>(loc, value);
1307 rewriter.replaceOp(op, linalgOp->getResults());
1323 auto input = op.getInput();
1324 auto inputTy = cast<RankedTensorType>(input.getType());
1325 auto resultTy = cast<RankedTensorType>(op.getType());
1326 const bool isBilinear = op.getMode() ==
"BILINEAR";
1328 auto inputH = inputTy.getDimSize(1);
1329 auto inputW = inputTy.getDimSize(2);
1330 auto outputH = resultTy.getDimSize(1);
1331 auto outputW = resultTy.getDimSize(2);
1333 if (inputH != 1 || inputW != 1 || outputH != 1 || outputW != 1)
1334 return rewriter.notifyMatchFailure(
1335 op,
"tosa.resize is not a pure 1x1->1x1 image operation");
1338 if (op.getMode() !=
"NEAREST_NEIGHBOR" && op.getMode() !=
"BILINEAR")
1339 return rewriter.notifyMatchFailure(
1340 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1342 if (inputTy == resultTy) {
1343 rewriter.replaceOp(op, input);
1351 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1352 reassociationMap[1].push_back(builder.getAffineDimExpr(1));
1353 reassociationMap[1].push_back(builder.getAffineDimExpr(2));
1354 reassociationMap[1].push_back(builder.getAffineDimExpr(3));
1358 inputTy.getElementType());
1359 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, input,
1364 if (inputTy.isDynamicDim(0))
1365 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1366 if (inputTy.isDynamicDim(3))
1367 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1370 auto genericTy = collapseTy.clone(resultTy.getElementType());
1371 Value empty = builder.create<tensor::EmptyOp>(
1372 genericTy.getShape(), resultTy.getElementType(), outputDynSize);
1373 auto genericMap = rewriter.getMultiDimIdentityMap(genericTy.getRank());
1375 utils::IteratorType::parallel);
1377 auto generic = builder.create<linalg::GenericOp>(
1381 Value value = args[0];
1383 if (inputTy.getElementType() != resultTy.getElementType()) {
1385 b.create<arith::ExtSIOp>(loc, resultTy.getElementType(), value);
1387 if (isBilinear && scale[0] != 0) {
1388 Value scaleY = b.create<arith::ConstantOp>(
1389 loc, b.getI32IntegerAttr(scale[0]));
1390 value = b.create<arith::MulIOp>(loc, value, scaleY);
1393 if (isBilinear && scale[2] != 0) {
1394 Value scaleX = b.create<arith::ConstantOp>(
1395 loc, b.getI32IntegerAttr(scale[2]));
1396 value = b.create<arith::MulIOp>(loc, value, scaleX);
1400 b.create<linalg::YieldOp>(loc, value);
1403 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
1404 op, resultTy,
generic.
getResults()[0], reassociationMap);
1420 auto input = op.getInput();
1421 auto inputTy = dyn_cast<RankedTensorType>(input.getType());
1422 auto resultTy = dyn_cast<RankedTensorType>(op.getType());
1424 if (!inputTy || !resultTy)
1425 return rewriter.notifyMatchFailure(op,
1426 "requires ranked input/output types");
1428 auto batch = inputTy.getDimSize(0);
1429 auto channels = inputTy.getDimSize(3);
1430 auto inputH = inputTy.getDimSize(1);
1431 auto inputW = inputTy.getDimSize(2);
1432 auto outputH = resultTy.getDimSize(1);
1433 auto outputW = resultTy.getDimSize(2);
1435 if ((inputH != 1 || outputH == 1) && (inputW != 1 || outputW == 1))
1436 return rewriter.notifyMatchFailure(
1437 op,
"tosa.resize has no broadcasting behavior");
1442 resizeShape.push_back(batch);
1443 resizeShape.push_back(inputH == 1 ? 1 : outputH);
1444 resizeShape.push_back(inputW == 1 ? 1 : outputW);
1445 resizeShape.push_back(channels);
1447 auto resizeTy = resultTy.clone(resizeShape);
1449 builder.create<tosa::ResizeOp>(resizeTy, input, op->
getAttrs());
1453 reassociationMap[0].push_back(builder.getAffineDimExpr(0));
1454 reassociationMap.back().push_back(builder.getAffineDimExpr(1));
1456 reassociationMap.push_back({});
1457 reassociationMap.back().push_back(builder.getAffineDimExpr(2));
1459 reassociationMap.push_back({});
1460 reassociationMap.back().push_back(builder.getAffineDimExpr(3));
1464 collapseShape.push_back(outputH);
1466 collapseShape.push_back(outputW);
1467 collapseShape.push_back(channels);
1469 auto collapseTy = resultTy.clone(collapseShape);
1470 Value collapse = builder.create<tensor::CollapseShapeOp>(collapseTy, resize,
1475 if (inputTy.isDynamicDim(0))
1476 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 0));
1477 if (inputTy.isDynamicDim(3))
1478 outputDynSize.push_back(builder.create<tensor::DimOp>(input, 3));
1481 utils::IteratorType::parallel);
1482 Value empty = builder.create<tensor::EmptyOp>(
1483 resultTy.getShape(), resultTy.getElementType(), outputDynSize);
1487 inputExprs.push_back(rewriter.getAffineDimExpr(1));
1489 inputExprs.push_back(rewriter.getAffineDimExpr(2));
1490 inputExprs.push_back(rewriter.getAffineDimExpr(3));
1493 inputExprs, rewriter.getContext());
1495 auto outputMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
1496 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1500 Value value = args[0];
1501 b.create<linalg::YieldOp>(loc, value);
1516 auto input = op.getInput();
1517 auto inputTy = cast<ShapedType>(input.getType());
1518 auto resultTy = cast<ShapedType>(op.getType());
1519 auto resultETy = resultTy.getElementType();
1521 auto imageH = inputTy.getShape()[1];
1522 auto imageW = inputTy.getShape()[2];
1524 auto dynamicDimsOr =
1526 if (!dynamicDimsOr.has_value())
1527 return rewriter.notifyMatchFailure(
1528 op,
"unable to get dynamic dimensions of tosa.resize");
1530 if (op.getMode() !=
"NEAREST_NEIGHBOR" && op.getMode() !=
"BILINEAR")
1531 return rewriter.notifyMatchFailure(
1532 op,
"tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR");
1535 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1536 auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy,
1538 auto genericOp = b.create<linalg::GenericOp>(
1541 Value resize = genericOp.getResult(0);
1545 b.createBlock(&genericOp.getRegion(), genericOp.getRegion().end(),
1547 Value batch = b.create<linalg::IndexOp>(0);
1548 Value y = b.create<linalg::IndexOp>(1);
1549 Value x = b.create<linalg::IndexOp>(2);
1550 Value channel = b.create<linalg::IndexOp>(3);
1553 b.create<arith::ConstantOp>(b.getZeroAttr(b.getI32Type()));
1555 b.create<arith::ConstantOp>(b.getZeroAttr(b.getF32Type()));
1556 Value hMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageH - 1));
1557 Value wMax = b.create<arith::ConstantOp>(b.getI32IntegerAttr(imageW - 1));
1559 Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
1560 Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
1562 bool floatingPointMode = resultETy.isF32();
1568 Value yScaleN, yScaleD, xScaleN, xScaleD;
1569 yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));
1570 yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1]));
1571 xScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[2]));
1572 xScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[3]));
1574 Value yOffset, xOffset, yBorder, xBorder;
1575 yOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[0]));
1576 xOffset = b.create<arith::ConstantOp>(b.getI32IntegerAttr(offset[1]));
1577 yBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[0]));
1578 xBorder = b.create<arith::ConstantOp>(b.getI32IntegerAttr(border[1]));
1592 Value val = b.create<arith::UIToFPOp>(b.getF32Type(), in);
1593 scaleN = b.create<arith::UIToFPOp>(b.getF32Type(), scaleN);
1594 scaleD = b.create<arith::UIToFPOp>(b.getF32Type(), scaleD);
1595 offset = b.create<arith::SIToFPOp>(b.getF32Type(), offset);
1596 val = b.create<arith::MulFOp>(val, scaleD);
1597 val = b.create<arith::AddFOp>(val, offset);
1598 val = b.create<arith::DivFOp>(val, scaleN);
1599 index = b.create<math::FloorOp>(val);
1600 delta = b.create<arith::SubFOp>(val, index);
1601 index = b.create<arith::FPToSIOp>(b.getI32Type(), index);
1616 Value val = b.create<arith::MulIOp>(in, scaleD);
1617 val = b.create<arith::AddIOp>(val, offset);
1618 index = b.create<arith::DivSIOp>(val, scaleN);
1619 delta = b.create<arith::MulIOp>(index, scaleN);
1620 delta = b.create<arith::SubIOp>(val, delta);
1623 Value ix, iy, dx, dy;
1624 if (floatingPointMode) {
1625 getIndexAndDeltaFp(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1626 getIndexAndDeltaFp(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1628 getIndexAndDeltaInt(iy, dy, inY, yScaleN, yScaleD, yOffset, imageH, b);
1629 getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b);
1632 if (op.getMode() ==
"NEAREST_NEIGHBOR") {
1633 auto one = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1635 auto getNearestIndexAndClamp = [&](
Value val,
Value dval,
Value scale,
1639 return b.create<arith::ConstantIndexOp>(0);
1643 if (floatingPointMode) {
1644 auto h = b.create<arith::ConstantOp>(b.getF32FloatAttr(0.5f));
1645 pred = b.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, dval, h);
1647 Value dvalDouble = b.create<arith::ShLIOp>(dval, one);
1648 pred = b.create<arith::CmpIOp>(arith::CmpIPredicate::sge,
1652 auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
1653 val = b.create<arith::AddIOp>(val, offset);
1655 return b.create<arith::IndexCastOp>(b.getIndexType(), val);
1658 iy = getNearestIndexAndClamp(iy, dy, yScaleN, hMax, imageH, b);
1659 ix = getNearestIndexAndClamp(ix, dx, xScaleN, wMax, imageW, b);
1661 Value result = b.create<tensor::ExtractOp>(
1664 b.create<linalg::YieldOp>(result);
1667 assert(op.getMode() ==
"BILINEAR");
1669 auto oneVal = b.create<arith::ConstantOp>(b.getI32IntegerAttr(1));
1671 auto getClampedIdxs = [&](
Value &val0,
Value &val1,
int size,
Value in,
1674 val1 = b.create<arith::AddIOp>(val0, oneVal);
1677 val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
1678 val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
1686 Value x0, x1, y0, y1;
1687 getClampedIdxs(y0, y1, imageH, iy, hMax, b);
1688 getClampedIdxs(x0, x1, imageW, ix, wMax, b);
1690 Value y0x0 = b.create<tensor::ExtractOp>(
1692 Value y0x1 = b.create<tensor::ExtractOp>(
1694 Value y1x0 = b.create<tensor::ExtractOp>(
1696 Value y1x1 = b.create<tensor::ExtractOp>(
1699 if (floatingPointMode) {
1700 auto oneVal = b.create<arith::ConstantOp>(b.getF32FloatAttr(1.0f));
1706 Value oneMinusDelta = b.create<arith::SubFOp>(oneVal, delta);
1707 Value mul0 = b.create<arith::MulFOp>(val0, oneMinusDelta);
1708 Value mul1 = b.create<arith::MulFOp>(val1, delta);
1709 return b.create<arith::AddFOp>(mul0, mul1);
1715 Value topAcc = interpolate(y0x0, y0x1, dx, imageW, b);
1720 Value bottomAcc = interpolate(y1x0, y1x1, dx, imageW, b);
1724 Value result = interpolate(topAcc, bottomAcc, dy, imageH, b);
1725 b.create<linalg::YieldOp>(result);
1728 y0x0 = b.create<arith::ExtSIOp>(resultETy, y0x0);
1729 y0x1 = b.create<arith::ExtSIOp>(resultETy, y0x1);
1730 y1x0 = b.create<arith::ExtSIOp>(resultETy, y1x0);
1731 y1x1 = b.create<arith::ExtSIOp>(resultETy, y1x1);
1734 if (resultETy.getIntOrFloatBitWidth() > deltaBitwidth) {
1735 dx = b.create<arith::ExtSIOp>(resultETy, dx);
1736 dy = b.create<arith::ExtSIOp>(resultETy, dy);
1739 Value yScaleNExt = yScaleN;
1740 Value xScaleNExt = xScaleN;
1742 const int64_t scaleBitwidth =
1744 if (resultETy.getIntOrFloatBitWidth() > scaleBitwidth) {
1745 yScaleNExt = b.create<arith::ExtSIOp>(resultETy, yScaleN);
1746 xScaleNExt = b.create<arith::ExtSIOp>(resultETy, xScaleN);
1750 Value scale,
int inputSize,
1753 return b.create<arith::MulIOp>(val0, scale);
1754 Value weight0 = b.create<arith::SubIOp>(scale, weight1);
1755 Value mul0 = b.create<arith::MulIOp>(val0, weight0);
1756 Value mul1 = b.create<arith::MulIOp>(val1, weight1);
1757 return b.create<arith::AddIOp>(mul0, mul1);
1760 Value topAcc = interpolate(y0x0, y0x1, dx, xScaleNExt, imageW, b);
1761 Value bottomAcc = interpolate(y1x0, y1x1, dx, xScaleNExt, imageW, b);
1763 interpolate(topAcc, bottomAcc, dy, yScaleNExt, imageH, b);
1764 b.create<linalg::YieldOp>(result);
1769 rewriter.replaceOp(op, resize);
1777 template <
typename SrcOp>
1784 rewriter.replaceOp(op, op.getOperation()->
getOperands());
1789 template <
typename SrcOp>
1807 Value input = op.getInput();
1808 auto inputTy = cast<ShapedType>(input.
getType());
1809 auto resultTy = cast<ShapedType>(op.getType());
1810 auto axis = op.getAxis();
1813 for (
int i = 0; i < inputTy.getRank(); i++) {
1814 if (inputTy.isDynamicDim(i)) {
1815 dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i));
1819 Value axisDimSize = rewriter.create<tensor::DimOp>(loc, input, axis);
1822 auto emptyTensor = rewriter
1823 .create<tensor::EmptyOp>(loc, inputTy.getShape(),
1824 inputTy.getElementType(),
1828 rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1830 rewriter.replaceOpWithNewOp<linalg::GenericOp>(
1835 for (
unsigned int i = 0; i < inputTy.getRank(); i++) {
1837 rewriter.create<linalg::IndexOp>(nestedLoc, i).getResult();
1839 auto one = rewriter.create<arith::ConstantIndexOp>(nestedLoc, 1);
1841 rewriter.create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1842 index = rewriter.create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1846 indices.push_back(index);
1849 auto extract = nestedBuilder.create<tensor::ExtractOp>(
1850 nestedLoc, input, indices);
1851 nestedBuilder.create<linalg::YieldOp>(op.
getLoc(),
1852 extract.getResult());
1866 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1869 auto input = op.getInput1();
1870 auto inputTy = cast<ShapedType>(input.getType());
1871 auto inputShape = inputTy.getShape();
1872 auto resultTy = cast<ShapedType>(op.getType());
1873 auto elementTy = inputTy.getElementType();
1874 int64_t rank = inputTy.getRank();
1880 for (
int i = 0; i < rank; i++) {
1881 int64_t dim = multiples[i];
1882 genericShape.push_back(dim == -1 ? ShapedType::kDynamic : dim);
1883 genericShape.push_back(inputShape[i]);
1887 for (
int i = 0; i < inputTy.getRank(); i++) {
1888 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1889 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1893 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
1894 op.
getLoc(), genericShape, elementTy, dynDims);
1898 dimExprs.reserve(rank);
1899 for (
unsigned i = 0; i < rank; ++i)
1902 auto readAffineMap =
1909 auto genericOp = rewriter.
create<linalg::GenericOp>(
1914 nestedBuilder.create<linalg::YieldOp>(op.
getLoc(), *args.begin());
1943 auto loc = argmaxOp.getLoc();
1944 Value input = argmaxOp.getInput();
1945 auto inputTy = cast<ShapedType>(input.
getType());
1946 auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType());
1947 auto inElementTy = inputTy.getElementType();
1948 auto outElementTy = resultTy.getElementType();
1949 int axis = argmaxOp.getAxis();
1952 if (!isa<IntegerType>(outElementTy))
1955 "tosa.arg_max to linalg.* requires integer-like result type");
1958 for (
int i = 0; i < inputTy.getRank(); i++) {
1959 if (inputTy.isDynamicDim(i) && i != axis) {
1960 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1965 auto emptyTensorIdx = rewriter
1966 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
1967 outElementTy, dynDims)
1969 auto fillValueIdx = rewriter.
create<arith::ConstantOp>(
1971 auto filledTensorIdx =
1978 auto emptyTensorMax = rewriter
1979 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
1980 inElementTy, dynDims)
1982 auto fillValueMaxAttr =
1985 if (!fillValueMaxAttr)
1987 argmaxOp,
"unsupported tosa.argmax element type");
1990 rewriter.
create<arith::ConstantOp>(loc, fillValueMaxAttr);
1991 auto filledTensorMax =
2000 iteratorTypes.resize(inputTy.getRank(), utils::IteratorType::parallel);
2001 iteratorTypes[axis] = utils::IteratorType::reduction;
2005 for (
int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
2011 bool didEncounterError =
false;
2013 auto linalgOp = rewriter.
create<linalg::GenericOp>(
2015 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
2018 auto newValue = blockArgs[0];
2019 auto oldIndex = blockArgs[1];
2020 auto oldValue = blockArgs[2];
2022 Value newIndex = rewriter.
create<arith::IndexCastOp>(
2023 nestedLoc, oldIndex.getType(),
2024 rewriter.
create<linalg::IndexOp>(loc, axis));
2027 if (isa<FloatType>(inElementTy)) {
2028 predicate = rewriter.
create<arith::CmpFOp>(
2029 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2030 }
else if (isa<IntegerType>(inElementTy)) {
2031 predicate = rewriter.
create<arith::CmpIOp>(
2032 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2034 didEncounterError =
true;
2038 auto resultMax = rewriter.
create<arith::SelectOp>(
2039 nestedLoc, predicate, newValue, oldValue);
2040 auto resultIndex = rewriter.
create<arith::SelectOp>(
2041 nestedLoc, predicate, newIndex, oldIndex);
2042 nestedBuilder.
create<linalg::YieldOp>(
2043 nestedLoc,
ValueRange({resultIndex, resultMax}));
2046 if (didEncounterError)
2048 argmaxOp,
"unsupported tosa.argmax element type");
2050 rewriter.
replaceOp(argmaxOp, linalgOp.getResult(0));
2059 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2061 auto input = adaptor.getOperands()[0];
2062 auto indices = adaptor.getOperands()[1];
2065 dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
2066 auto resultTy = cast<ShapedType>(op.getType());
2071 auto dynamicDims = inferDynamicDimsForGather(
2072 rewriter, op.
getLoc(), adaptor.getValues(), adaptor.getIndices());
2074 auto resultElementTy = resultTy.getElementType();
2079 .
create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy,
2085 resultTy.getRank(), 0,
2086 {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)},
2090 auto genericOp = rewriter.
create<linalg::GenericOp>(
2095 auto indexValue = args[0];
2096 auto index0 = rewriter.
create<linalg::IndexOp>(loc, 0);
2097 Value index1 = rewriter.
create<arith::IndexCastOp>(
2099 auto index2 = rewriter.
create<linalg::IndexOp>(loc, 2);
2100 Value extract = rewriter.
create<tensor::ExtractOp>(
2101 loc, input,
ValueRange{index0, index1, index2});
2102 rewriter.
create<linalg::YieldOp>(loc, extract);
2104 rewriter.
replaceOp(op, genericOp.getResult(0));
2114 auto addDynamicDimension = [&](
Value source, int64_t dim) {
2116 if (
auto dimValue = llvm::dyn_cast_if_present<Value>(sz))
2117 results.push_back(dimValue);
2120 addDynamicDimension(values, 0);
2121 addDynamicDimension(indices, 1);
2122 addDynamicDimension(values, 2);
2137 Value input = op.getInput();
2138 Value table = op.getTable();
2139 auto inputTy = cast<ShapedType>(input.
getType());
2140 auto tableTy = cast<ShapedType>(table.
getType());
2141 auto resultTy = cast<ShapedType>(op.getType());
2143 auto inputElementTy = inputTy.getElementType();
2144 auto tableElementTy = tableTy.getElementType();
2145 auto resultElementTy = resultTy.getElementType();
2148 for (
int i = 0; i < resultTy.getRank(); ++i) {
2149 if (inputTy.isDynamicDim(i)) {
2155 auto emptyTensor = rewriter
2156 .
create<tensor::EmptyOp>(loc, resultTy.getShape(),
2157 resultElementTy, dynDims)
2164 auto genericOp = rewriter.
create<linalg::GenericOp>(
2167 rewriter.
replaceOp(op, genericOp.getResult(0));
2172 &genericOp.getRegion(), genericOp.getRegion().end(),
2173 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2177 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2178 resultElementTy.isInteger(8)) {
2179 Value index = rewriter.
create<arith::IndexCastOp>(
2181 Value offset = rewriter.
create<arith::ConstantIndexOp>(loc, 128);
2186 rewriter.
create<linalg::YieldOp>(loc, extract);
2190 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2191 resultElementTy.isInteger(32)) {
2195 auto offset = rewriter.
create<arith::ConstantOp>(
2197 auto seven = rewriter.
create<arith::ConstantOp>(
2199 auto one = rewriter.
create<arith::ConstantOp>(
2201 auto b1111111 = rewriter.
create<arith::ConstantOp>(
2208 auto extendAdd = rewriter.
create<arith::AddIOp>(loc, extend, offset);
2209 Value index = rewriter.
create<arith::ShRUIOp>(loc, extendAdd, seven);
2211 rewriter.
create<arith::AndIOp>(loc, extendAdd, b1111111);
2216 Value indexPlusOne = rewriter.
create<arith::AddIOp>(loc, index, one);
2218 index = rewriter.
create<arith::IndexCastOp>(
2220 indexPlusOne = rewriter.
create<arith::IndexCastOp>(
2235 Value baseScaled = rewriter.
create<arith::ShLIOp>(loc, base, seven);
2236 Value diff = rewriter.
create<arith::SubIOp>(loc, next, base);
2237 Value diffScaled = rewriter.
create<arith::MulIOp>(loc, diff, fraction);
2239 rewriter.
create<arith::AddIOp>(loc, baseScaled, diffScaled);
2241 rewriter.
create<linalg::YieldOp>(loc, result);
2248 op,
"unable to create body for tosa.table op");
2255 static bool isRankedTensor(
Type type) {
return isa<RankedTensorType>(type); }
2259 auto one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2260 auto two = builder.
create<arith::ConstantIndexOp>(loc, 2);
2263 auto divBy2 = builder.
createOrFold<arith::DivUIOp>(loc, value, two);
2264 auto plusOne = builder.
createOrFold<arith::AddIOp>(loc, divBy2, one);
2268 static RankedTensorType
2276 dims[2] = halfPlusOne(builder, loc, dims[2]);
2287 RankedTensorType type,
2290 rewriter.
create<tensor::EmptyOp>(loc, type, dynamicSizes);
2291 auto fillValueAttr = rewriter.
getZeroAttr(type.getElementType());
2292 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
2293 auto filledTensor = rewriter
2297 return filledTensor;
2305 return builder.
create<arith::UIToFPOp>(loc, type, integerVal);
2310 auto indexVal = builder.
create<linalg::IndexOp>(loc, index);
2311 return castIndexToFloat(builder, loc, type, indexVal);
2314 template <
typename... Args>
2322 if (!llvm::all_of(rfft2d->getOperandTypes(), isRankedTensor) ||
2323 !llvm::all_of(rfft2d->getResultTypes(), isRankedTensor)) {
2325 "only supports ranked tensors");
2328 auto loc = rfft2d.getLoc();
2329 auto input = rfft2d.getInput();
2335 auto outputType = computeOutputShape(rewriter, loc, input, dynamicSizes);
2339 utils::IteratorType::parallel, utils::IteratorType::parallel,
2340 utils::IteratorType::parallel, utils::IteratorType::reduction,
2341 utils::IteratorType::reduction};
2346 createZeroTensor(rewriter, loc, outputType, dynamicSizes),
2347 createZeroTensor(rewriter, loc, outputType, dynamicSizes)};
2351 affineDimsExpr(rewriter, 0, 3, 4), affineDimsExpr(rewriter, 0, 1, 2),
2352 affineDimsExpr(rewriter, 0, 1, 2)});
2355 auto dimH = rewriter.
createOrFold<tensor::DimOp>(loc, input, 1);
2356 auto dimW = rewriter.
createOrFold<tensor::DimOp>(loc, input, 2);
2359 auto twoPiAttr = rewriter.
getFloatAttr(elementType, 6.283185307179586);
2360 auto twoPi = rewriter.
create<arith::ConstantOp>(loc, twoPiAttr);
2361 auto constH = castIndexToFloat(rewriter, loc, elementType, dimH);
2362 auto constW = castIndexToFloat(rewriter, loc, elementType, dimW);
2365 Value valReal = args[0];
2366 Value sumReal = args[1];
2367 Value sumImag = args[2];
2370 auto oy = createLinalgIndex(builder, loc, elementType, 1);
2371 auto ox = createLinalgIndex(builder, loc, elementType, 2);
2372 auto iy = createLinalgIndex(builder, loc, elementType, 3);
2373 auto ix = createLinalgIndex(builder, loc, elementType, 4);
2376 auto iyXoy = builder.
create<arith::MulFOp>(loc, iy, oy);
2377 auto ixXox = builder.
create<arith::MulFOp>(loc, ix, ox);
2378 auto yComponent = builder.
create<arith::DivFOp>(loc, iyXoy, constH);
2379 auto xComponent = builder.
create<arith::DivFOp>(loc, ixXox, constW);
2380 auto sumXY = builder.
create<arith::AddFOp>(loc, yComponent, xComponent);
2381 auto angle = builder.
create<arith::MulFOp>(loc, twoPi, sumXY);
2385 auto cosAngle = builder.
create<math::CosOp>(loc, angle);
2386 auto sinAngle = builder.
create<math::SinOp>(loc, angle);
2387 auto realComponent =
2388 builder.
create<arith::MulFOp>(loc, valReal, cosAngle);
2389 auto imagComponent =
2390 builder.
create<arith::MulFOp>(loc, valReal, sinAngle);
2394 auto outReal = builder.
create<arith::AddFOp>(loc, sumReal, realComponent);
2395 auto outImag = builder.
create<arith::SubFOp>(loc, sumImag, imagComponent);
2401 rfft2d, rfft2d.getResultTypes(), genericOpInputs, genericOpOutputs,
2402 indexingMaps, iteratorTypes, buildBody);
2414 patterns->
add<GenericResizeConverter>(patterns->
getContext(),
2418 patterns->
add<MaterializeResizeBroadcast>(patterns->
getContext(),
2423 PointwiseConverter<tosa::AddOp>,
2424 PointwiseConverter<tosa::SubOp>,
2425 PointwiseConverter<tosa::MulOp>,
2426 PointwiseConverter<tosa::DivOp>,
2427 PointwiseConverter<tosa::NegateOp>,
2428 PointwiseConverter<tosa::PowOp>,
2429 PointwiseConverter<tosa::ReciprocalOp>,
2430 PointwiseConverter<tosa::RsqrtOp>,
2431 PointwiseConverter<tosa::LogOp>,
2432 PointwiseConverter<tosa::ExpOp>,
2433 PointwiseConverter<tosa::AbsOp>,
2434 PointwiseConverter<tosa::TanhOp>,
2435 PointwiseConverter<tosa::ErfOp>,
2436 PointwiseConverter<tosa::BitwiseAndOp>,
2437 PointwiseConverter<tosa::BitwiseOrOp>,
2438 PointwiseConverter<tosa::BitwiseNotOp>,
2439 PointwiseConverter<tosa::BitwiseXorOp>,
2440 PointwiseConverter<tosa::LogicalAndOp>,
2441 PointwiseConverter<tosa::LogicalNotOp>,
2442 PointwiseConverter<tosa::LogicalOrOp>,
2443 PointwiseConverter<tosa::LogicalXorOp>,
2444 PointwiseConverter<tosa::CastOp>,
2445 PointwiseConverter<tosa::LogicalLeftShiftOp>,
2446 PointwiseConverter<tosa::LogicalRightShiftOp>,
2447 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2448 PointwiseConverter<tosa::ClzOp>,
2449 PointwiseConverter<tosa::SelectOp>,
2450 PointwiseConverter<tosa::GreaterOp>,
2451 PointwiseConverter<tosa::GreaterEqualOp>,
2452 PointwiseConverter<tosa::EqualOp>,
2453 PointwiseConverter<tosa::MaximumOp>,
2454 PointwiseConverter<tosa::MinimumOp>,
2455 PointwiseConverter<tosa::CeilOp>,
2456 PointwiseConverter<tosa::FloorOp>,
2457 PointwiseConverter<tosa::ClampOp>,
2458 PointwiseConverter<tosa::SigmoidOp>,
2459 IdentityNConverter<tosa::IdentityOp>,
2460 ReduceConverter<tosa::ReduceAllOp>,
2461 ReduceConverter<tosa::ReduceAnyOp>,
2462 ReduceConverter<tosa::ReduceMinOp>,
2463 ReduceConverter<tosa::ReduceMaxOp>,
2464 ReduceConverter<tosa::ReduceSumOp>,
2465 ReduceConverter<tosa::ReduceProdOp>,
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor, int64_t rank)
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, PatternRewriter &rewriter)
static OpFoldResult getOrFoldTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, PatternRewriter &rewriter)
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
static Value getTensorDim(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value tensor, int64_t index)
static Value createIndex(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, int64_t index)
static TypedAttr createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter)
static LogicalResult emitElementwiseComputation(PatternRewriter &rewriter, Location loc, Operation *operation, ValueRange operands, ArrayRef< OpFoldResult > targetShape)
static std::pair< OpFoldResult, Value > computeTargetSize(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands, int64_t dim)
static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef< OpFoldResult > targetShape, ArrayRef< Value > masterOperands)
static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, int64_t dim, OpFoldResult targetSize, Value masterOperand)
static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter)
static std::pair< SmallVector< OpFoldResult >, SmallVector< Value > > computeTargetShape(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, ValueRange operands)
static SmallVector< Value > expandInputRanks(PatternRewriter &rewriter, Location loc, Operation *operation)
static bool operandsAndResultsRanked(Operation *operation)
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerType getIntegerType(unsigned width)
BoolAttr getBoolAttr(bool value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
This class represents a single result from folding an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
LLVM_ATTRIBUTE_ALWAYS_INLINE MPInt abs(const MPInt &x)
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Value clampFloatHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter)
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...