34 static arith::ConstantOp
37 auto castedN =
static_cast<T
>(
38 op->
getAttr(attrName).
cast<IntegerAttr>().getValue().getSExtValue());
39 return rewriter.
create<arith::ConstantOp>(
40 op->
getLoc(), IntegerAttr::get(requiredAttrType, castedN));
52 if (isa<tosa::AbsOp>(op) && elementTy.isa<
FloatType>())
53 return rewriter.
create<math::AbsOp>(loc, resultTypes, args);
55 if (isa<tosa::AbsOp>(op) && elementTy.isa<IntegerType>()) {
56 auto zero = rewriter.
create<arith::ConstantOp>(
58 auto cmp = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
60 auto neg = rewriter.
create<arith::SubIOp>(loc, zero, args[0]);
61 return rewriter.
create<arith::SelectOp>(loc, cmp, args[0], neg);
65 if (isa<tosa::AddOp>(op) && elementTy.isa<
FloatType>())
66 return rewriter.
create<arith::AddFOp>(loc, resultTypes, args);
68 if (isa<tosa::AddOp>(op) && elementTy.isa<IntegerType>())
69 return rewriter.
create<arith::AddIOp>(loc, resultTypes, args);
72 if (isa<tosa::SubOp>(op) && elementTy.isa<
FloatType>())
73 return rewriter.
create<arith::SubFOp>(loc, resultTypes, args);
75 if (isa<tosa::SubOp>(op) && elementTy.isa<IntegerType>())
76 return rewriter.
create<arith::SubIOp>(loc, resultTypes, args);
79 if (isa<tosa::MulOp>(op) && elementTy.isa<
FloatType>()) {
80 if (dyn_cast<tosa::MulOp>(op).shift() != 0) {
82 "Cannot have shift value for float");
85 return rewriter.
create<arith::MulFOp>(loc, resultTypes, args);
89 if (isa<tosa::DivOp>(op) && elementTy.isa<IntegerType>())
90 return rewriter.
create<arith::DivSIOp>(loc, resultTypes, args);
93 if (isa<tosa::ReciprocalOp>(op) && elementTy.isa<
FloatType>()) {
95 rewriter.
create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
96 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, args[0]);
99 if (isa<tosa::MulOp>(op) && elementTy.isa<IntegerType>()) {
103 op->
getAttr(
"shift").
cast<IntegerAttr>().getValue().getSExtValue();
113 auto result = rewriter.
create<tosa::ApplyScaleOp>(
117 if (elementTy.isInteger(32))
120 return rewriter.
create<arith::TruncIOp>(loc, elementTy, result);
125 int cWidth = resultTypes[0].getIntOrFloatBitWidth();
128 a = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], a);
130 b = rewriter.
create<arith::ExtSIOp>(loc, resultTypes[0], b);
132 return rewriter.
create<arith::MulIOp>(loc, resultTypes, a, b);
136 if (isa<tosa::NegateOp>(op) && elementTy.isa<
FloatType>())
137 return rewriter.
create<arith::NegFOp>(loc, resultTypes, args);
139 if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
140 !cast<tosa::NegateOp>(op).quantization_info()) {
142 rewriter.
create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
143 return rewriter.
create<arith::SubIOp>(loc, resultTypes, constant, args[0]);
146 if (isa<tosa::NegateOp>(op) && elementTy.isa<IntegerType>() &&
147 cast<tosa::NegateOp>(op).quantization_info()) {
148 auto quantizationInfo = cast<tosa::NegateOp>(op).quantization_info();
149 int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
150 int64_t inZp = quantizationInfo.getValue().getInputZp();
151 int64_t outZp = quantizationInfo.getValue().getOutputZp();
154 int64_t zpAdd = inZp + outZp;
155 int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
160 int intermediateBitWidth = 64;
161 if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
162 intermediateBitWidth = 16;
163 }
else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
164 intermediateBitWidth = 32;
165 }
else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
166 intermediateBitWidth = 48;
170 Value zpAddValue = rewriter.
create<arith::ConstantOp>(
175 auto ext = rewriter.
create<arith::ExtSIOp>(loc, intermediateType, args[0]);
176 auto sub = rewriter.
create<arith::SubIOp>(loc, zpAddValue, ext);
180 loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(),
183 loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
185 auto clamp = clampHelper<arith::CmpIOp>(
186 loc, sub,
min,
max, arith::CmpIPredicate::slt, rewriter);
189 return rewriter.
create<arith::TruncIOp>(loc, elementTy,
clamp);
193 if (isa<tosa::BitwiseAndOp>(op) && elementTy.isa<IntegerType>())
194 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
197 if (isa<tosa::BitwiseOrOp>(op) && elementTy.isa<IntegerType>())
198 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
201 if (isa<tosa::BitwiseNotOp>(op) && elementTy.isa<IntegerType>()) {
203 elementTy, APInt::getAllOnes(elementTy.getIntOrFloatBitWidth()));
204 auto allOnes = rewriter.
create<arith::ConstantOp>(loc, allOnesAttr);
205 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], allOnes);
209 if (isa<tosa::BitwiseXorOp>(op) && elementTy.isa<IntegerType>())
210 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
213 if (isa<tosa::LogicalLeftShiftOp>(op) && elementTy.isa<IntegerType>())
214 return rewriter.
create<arith::ShLIOp>(loc, resultTypes, args);
217 if (isa<tosa::LogicalRightShiftOp>(op) && elementTy.isa<IntegerType>())
218 return rewriter.
create<arith::ShRUIOp>(loc, resultTypes, args);
221 if (isa<tosa::ArithmeticRightShiftOp>(op) && elementTy.isa<IntegerType>()) {
222 auto result = rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args);
230 rewriter.
create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
232 rewriter.
create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
234 rewriter.
create<arith::ConstantOp>(loc, IntegerAttr::get(i1Ty, 1));
237 auto shiftValueGreaterThanZero = rewriter.
create<arith::CmpIOp>(
238 loc, arith::CmpIPredicate::sgt, args[1], zero);
242 rewriter.
create<arith::SubIOp>(loc, resultTypes, args[1], one);
244 rewriter.
create<arith::ShRSIOp>(loc, resultTypes, args[0],
subtract)
247 rewriter.
create<arith::TruncIOp>(loc, i1Ty, shifted, mlir::None);
249 rewriter.
create<arith::AndIOp>(loc, i1Ty, truncated, i1one);
251 auto shouldRound = rewriter.
create<arith::AndIOp>(
252 loc, i1Ty, shiftValueGreaterThanZero, isInputOdd);
254 rewriter.
create<arith::ExtUIOp>(loc, resultTypes, shouldRound);
255 return rewriter.
create<arith::AddIOp>(loc, resultTypes, result, extended);
259 if (isa<tosa::ClzOp>(op) && elementTy.isa<IntegerType>()) {
260 return rewriter.
create<math::CountLeadingZerosOp>(loc, elementTy, args[0]);
264 if (isa<tosa::LogicalAndOp>(op) && elementTy.isInteger(1))
265 return rewriter.
create<arith::AndIOp>(loc, resultTypes, args);
268 if (isa<tosa::LogicalNotOp>(op) && elementTy.isInteger(1)) {
269 auto one = rewriter.
create<arith::ConstantOp>(
271 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args[0], one);
275 if (isa<tosa::LogicalOrOp>(op) && elementTy.isInteger(1))
276 return rewriter.
create<arith::OrIOp>(loc, resultTypes, args);
279 if (isa<tosa::LogicalXorOp>(op) && elementTy.isInteger(1))
280 return rewriter.
create<arith::XOrIOp>(loc, resultTypes, args);
283 if (isa<tosa::PowOp>(op) && elementTy.isa<
FloatType>())
284 return rewriter.
create<mlir::math::PowFOp>(loc, resultTypes, args);
287 if (isa<tosa::RsqrtOp>(op) && elementTy.isa<
FloatType>())
288 return rewriter.
create<mlir::math::RsqrtOp>(loc, resultTypes, args);
291 if (isa<tosa::LogOp>(op) && elementTy.isa<
FloatType>())
292 return rewriter.
create<mlir::math::LogOp>(loc, resultTypes, args);
295 if (isa<tosa::ExpOp>(op) && elementTy.isa<
FloatType>())
296 return rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, args);
299 if (isa<tosa::TanhOp>(op) && elementTy.isa<
FloatType>())
300 return rewriter.
create<mlir::math::TanhOp>(loc, resultTypes, args);
303 if (isa<tosa::GreaterOp>(op) && elementTy.isa<
FloatType>())
304 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT,
307 if (isa<tosa::GreaterOp>(op) && elementTy.isSignlessInteger())
308 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
312 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isa<
FloatType>())
313 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
316 if (isa<tosa::GreaterEqualOp>(op) && elementTy.isSignlessInteger())
317 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
321 if (isa<tosa::EqualOp>(op) && elementTy.isa<
FloatType>())
322 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
325 if (isa<tosa::EqualOp>(op) && elementTy.isSignlessInteger())
326 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
330 if (isa<tosa::SelectOp>(op)) {
332 if (elementTy.isa<
FloatType>() || elementTy.
isa<IntegerType>())
333 return rewriter.
create<arith::SelectOp>(loc, args[0], args[1], args[2]);
337 if (isa<tosa::MaximumOp>(op) && elementTy.isa<
FloatType>()) {
338 auto predicate = rewriter.
create<arith::CmpFOp>(
339 loc, arith::CmpFPredicate::OGT, args[0], args[1]);
340 return rewriter.
create<arith::SelectOp>(loc, predicate, args[0], args[1]);
343 if (isa<tosa::MaximumOp>(op) && elementTy.isSignlessInteger()) {
344 auto predicate = rewriter.
create<arith::CmpIOp>(
345 loc, arith::CmpIPredicate::sgt, args[0], args[1]);
346 return rewriter.
create<arith::SelectOp>(loc, predicate, args[0], args[1]);
350 if (isa<tosa::MinimumOp>(op) && elementTy.isa<
FloatType>()) {
351 auto predicate = rewriter.
create<arith::CmpFOp>(
352 loc, arith::CmpFPredicate::OLT, args[0], args[1]);
353 return rewriter.
create<arith::SelectOp>(loc, predicate, args[0], args[1]);
356 if (isa<tosa::MinimumOp>(op) && elementTy.isSignlessInteger()) {
357 auto predicate = rewriter.
create<arith::CmpIOp>(
358 loc, arith::CmpIPredicate::slt, args[0], args[1]);
359 return rewriter.
create<arith::SelectOp>(loc, predicate, args[0], args[1]);
363 if (isa<tosa::CeilOp>(op) && elementTy.isa<
FloatType>())
364 return rewriter.
create<math::CeilOp>(loc, resultTypes, args);
367 if (isa<tosa::FloorOp>(op) && elementTy.isa<
FloatType>())
368 return rewriter.
create<math::FloorOp>(loc, resultTypes, args);
371 if (isa<tosa::ClampOp>(op) && elementTy.isa<
FloatType>()) {
372 auto min = rewriter.
create<arith::ConstantOp>(loc, elementTy,
374 auto max = rewriter.
create<arith::ConstantOp>(loc, elementTy,
376 return clampHelper<arith::CmpFOp>(loc, args[0],
min,
max,
377 arith::CmpFPredicate::OLT, rewriter);
380 if (isa<tosa::ClampOp>(op) && elementTy.isa<IntegerType>()) {
381 auto intTy = elementTy.cast<IntegerType>();
382 int32_t
min =
static_cast<int32_t
>(
383 op->
getAttr(
"min_int").
cast<IntegerAttr>().getValue().getSExtValue());
384 int32_t
max =
static_cast<int32_t
>(
385 op->
getAttr(
"max_int").
cast<IntegerAttr>().getValue().getSExtValue());
387 if (intTy.isUnsignedInteger()) {
388 min = std::max<int32_t>(
min, 0);
389 max = std::min<int32_t>(
391 APInt::getMaxValue(intTy.getIntOrFloatBitWidth()).getSExtValue());
393 min = std::max<int32_t>(
394 min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
396 max = std::min<int32_t>(
397 max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
402 loc,
min, intTy.getIntOrFloatBitWidth());
404 loc,
max, intTy.getIntOrFloatBitWidth());
405 return clampHelper<arith::CmpIOp>(loc, args[0], minVal, maxVal,
406 arith::CmpIPredicate::slt, rewriter);
410 if (isa<tosa::ReluNOp>(op) && elementTy.isa<
FloatType>()) {
412 rewriter.
create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 0));
413 auto n = rewriter.
create<arith::ConstantOp>(loc, elementTy,
415 return clampHelper<arith::CmpFOp>(loc, args[0], zero, n,
416 arith::CmpFPredicate::OLT, rewriter);
419 if (isa<tosa::ReluNOp>(op) && elementTy.isa<IntegerType>()) {
421 rewriter.
create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
422 auto n = createConstFromIntAttribute<int32_t>(op,
"max_int", elementTy,
424 return clampHelper<arith::CmpIOp>(loc, args[0], zero, n,
425 arith::CmpIPredicate::slt, rewriter);
429 if (isa<tosa::SigmoidOp>(op) && elementTy.isa<
FloatType>()) {
431 rewriter.
create<arith::ConstantOp>(loc, FloatAttr::get(elementTy, 1));
432 auto negate = rewriter.
create<arith::NegFOp>(loc, resultTypes, args[0]);
433 auto exp = rewriter.
create<mlir::math::ExpOp>(loc, resultTypes, negate);
434 auto added = rewriter.
create<arith::AddFOp>(loc, resultTypes, exp, one);
435 return rewriter.
create<arith::DivFOp>(loc, resultTypes, one, added);
439 if (isa<tosa::CastOp>(op)) {
440 Type srcTy = elementTy;
441 Type dstTy = resultTypes.front();
449 return rewriter.
create<arith::ExtFOp>(loc, resultTypes, args, mlir::None);
452 return rewriter.
create<arith::TruncFOp>(loc, resultTypes, args,
456 if (srcTy.
isInteger(1) && arith::UIToFPOp::areCastCompatible(srcTy, dstTy))
457 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes, args,
460 if (srcTy.
isInteger(1) && dstTy.
isa<IntegerType>() && bitExtend)
461 return rewriter.
create<arith::ExtUIOp>(loc, resultTypes, args,
467 auto unrealizedCast =
469 .
create<UnrealizedConversionCastOp>(
473 return rewriter.
create<arith::UIToFPOp>(loc, resultTypes[0],
478 if (arith::SIToFPOp::areCastCompatible(srcTy, dstTy))
479 return rewriter.
create<arith::SIToFPOp>(loc, resultTypes, args,
486 return rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
490 if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
491 auto zero = rewriter.
create<arith::ConstantOp>(
493 auto half = rewriter.
create<arith::ConstantOp>(
496 auto intMin = rewriter.
create<arith::ConstantOp>(
501 auto intMax = rewriter.
create<arith::ConstantOp>(
506 auto added = rewriter.
create<arith::AddFOp>(loc, args[0], half);
507 auto subbed = rewriter.
create<arith::SubFOp>(loc, args[0], half);
508 auto negative = rewriter.
create<arith::CmpFOp>(
509 loc, arith::CmpFPredicate::OLT, args[0], zero);
511 rewriter.
create<arith::SelectOp>(loc, negative, subbed, added);
513 auto clamped = clampHelper<arith::CmpFOp>(
514 loc, rounded, intMin, intMax, arith::CmpFPredicate::OLT, rewriter);
516 return rewriter.
create<arith::FPToSIOp>(loc, dstTy, clamped);
524 return rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
528 if (srcTy.
isa<IntegerType>() && dstTy.
isa<IntegerType>() && bitExtend)
529 return rewriter.
create<arith::ExtSIOp>(loc, resultTypes, args,
532 if (srcTy.
isa<IntegerType>() && dstTy.
isa<IntegerType>() && !bitExtend) {
545 auto clamped = clampHelper<arith::CmpIOp>(
546 loc, args[0], intMin, intMax, arith::CmpIPredicate::slt, rewriter);
547 return rewriter.
create<arith::TruncIOp>(loc, dstTy, clamped);
552 op,
"unhandled op for linalg body calculation for elementwise op");
559 auto loc = operation->
getLoc();
562 "All TOSA elementwise ops should only return a single result.");
569 "All results must be a shaped type");
571 unsigned rank = resultTy.getRank();
583 dynDims.resize(results.front().getType().cast<ShapedType>().getRank());
586 auto operandTy = arg.getType().cast<ShapedType>();
587 for (
int i = 0; i < operandTy.getRank(); i++) {
588 if (operandTy.isDynamicDim(i) && !dynDims[i])
589 dynDims[i] = rewriter.
create<tensor::DimOp>(loc, arg, i);
595 for (
auto result : results) {
596 auto resultTy = result.getType().template cast<ShapedType>();
597 initTensors.push_back(rewriter.
create<linalg::InitTensorOp>(
598 loc, filteredDims, resultTy.getShape(), resultTy.getElementType()));
599 opResultTypes.push_back(result.getType());
602 auto bodyResultTypes = llvm::to_vector<4>(llvm::map_range(
607 indexingMaps.reserve(operation->
getNumOperands() + bodyResultTypes.size());
611 ShapedType type = operand.getType().cast<ShapedType>();
613 if (type.getShape() == resultTy.getShape()) {
614 operands.push_back(operand);
621 newShape.reserve(type.getRank());
623 if (it.value() == resultTy.getDimSize(it.index())) {
624 newShape.push_back(it.value());
625 affineExprs.push_back(
630 if (newShape.size() != rank) {
631 operand = rewriter.
create<tosa::ReshapeOp>(
632 loc, RankedTensorType::get(newShape, type.getElementType()), operand,
636 operands.push_back(operand);
638 type.getRank(), 0, affineExprs,
645 bool didEncounterError =
false;
646 auto linalgOp = rewriter.
create<linalg::GenericOp>(
647 loc, opResultTypes, operands, initTensors, indexingMaps,
652 bodyResultTypes, rewriter);
654 didEncounterError =
true;
657 nestedBuilder.create<linalg::YieldOp>(loc, opResult);
660 if (didEncounterError)
671 if (isa<tosa::ReduceSumOp>(op) && elementTy.
isa<
FloatType>())
674 if (isa<tosa::ReduceSumOp>(op) && elementTy.
isa<IntegerType>())
677 if (isa<tosa::ReduceProdOp>(op) && elementTy.
isa<
FloatType>())
680 if (isa<tosa::ReduceProdOp>(op) && elementTy.
isa<IntegerType>())
683 if (isa<tosa::ReduceMinOp>(op) && elementTy.
isa<
FloatType>())
685 elementTy, APFloat::getLargest(
688 if (isa<tosa::ReduceMinOp>(op) && elementTy.
isa<IntegerType>())
692 if (isa<tosa::ReduceMaxOp>(op) && elementTy.
isa<
FloatType>())
694 elementTy, APFloat::getLargest(
697 if (isa<tosa::ReduceMaxOp>(op) && elementTy.
isa<IntegerType>())
701 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
704 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
707 if (isa<tosa::ArgMaxOp>(op) && elementTy.
isa<
FloatType>())
709 elementTy, APFloat::getLargest(
712 if (isa<tosa::ArgMaxOp>(op) && elementTy.
isa<IntegerType>())
726 if (isa<tosa::ReduceSumOp>(op) && elementTy.
isa<
FloatType>()) {
727 return rewriter.
create<arith::AddFOp>(loc, args);
730 if (isa<tosa::ReduceSumOp>(op) && elementTy.
isa<IntegerType>()) {
731 return rewriter.
create<arith::AddIOp>(loc, args);
734 if (isa<tosa::ReduceProdOp>(op) && elementTy.
isa<
FloatType>()) {
735 return rewriter.
create<arith::MulFOp>(loc, args);
738 if (isa<tosa::ReduceProdOp>(op) && elementTy.
isa<IntegerType>()) {
739 return rewriter.
create<arith::MulIOp>(loc, args);
742 if (isa<tosa::ReduceMinOp>(op) && elementTy.
isa<
FloatType>()) {
743 auto predicate = rewriter.
create<arith::CmpFOp>(
744 loc, arith::CmpFPredicate::OLT, args[0], args[1]);
745 return rewriter.
create<arith::SelectOp>(loc, predicate, args[0], args[1]);
748 if (isa<tosa::ReduceMinOp>(op) && elementTy.
isa<IntegerType>()) {
749 auto predicate = rewriter.
create<arith::CmpIOp>(
750 loc, arith::CmpIPredicate::slt, args[0], args[1]);
751 return rewriter.
create<arith::SelectOp>(loc, predicate, args[0], args[1]);
754 if (isa<tosa::ReduceMaxOp>(op) && elementTy.
isa<
FloatType>()) {
755 auto predicate = rewriter.
create<arith::CmpFOp>(
756 loc, arith::CmpFPredicate::OGT, args[0], args[1]);
757 return rewriter.
create<arith::SelectOp>(loc, predicate, args[0], args[1]);
760 if (isa<tosa::ReduceMaxOp>(op) && elementTy.
isa<IntegerType>()) {
761 auto predicate = rewriter.
create<arith::CmpIOp>(
762 loc, arith::CmpIPredicate::sgt, args[0], args[1]);
763 return rewriter.
create<arith::SelectOp>(loc, predicate, args[0], args[1]);
766 if (isa<tosa::ReduceAllOp>(op) && elementTy.
isInteger(1))
767 return rewriter.
create<arith::AndIOp>(loc, args);
769 if (isa<tosa::ReduceAnyOp>(op) && elementTy.
isInteger(1))
770 return rewriter.
create<arith::OrIOp>(loc, args);
783 auto elementTy = resultTy.getElementType();
788 for (
unsigned i = 0; i < inputTy.getRank(); i++) {
790 reduceShape.push_back(inputTy.getDimSize(i));
791 if (inputTy.isDynamicDim(i))
792 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
796 Type reduceTy = RankedTensorType::get(reduceShape, resultTy.getElementType());
799 auto initTensor = rewriter
800 .
create<linalg::InitTensorOp>(loc, dynDims, reduceShape,
801 resultTy.getElementType())
807 op,
"No initial value found for reduction operation");
809 auto fillValue = rewriter.
create<arith::ConstantOp>(loc, fillValueAttr);
810 auto filledTensor = rewriter
818 for (
unsigned int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
827 bool didEncounterError =
false;
829 auto linalgOp = rewriter.
create<linalg::GenericOp>(
830 loc, reduceTy, input, filledTensor, maps, iteratorTypes,
833 op, blockArgs, elementTy, rewriter);
835 didEncounterError =
true;
837 nestedBuilder.create<linalg::YieldOp>(loc, result);
840 if (!didEncounterError)
854 intermediateShape = {-1};
858 if (lhsShape.empty() || rhsShape.empty()) {
859 intermediateShape = {};
863 unsigned currLhsDim = 0, currRhsDim = 0;
864 while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
865 int64_t rhsSize = rhsShape[currRhsDim];
866 int64_t lhsSize = lhsShape[currLhsDim];
867 while (lhsSize != rhsSize && currLhsDim < lhsShape.size() &&
868 currRhsDim < rhsShape.size()) {
869 if (lhsSize < rhsSize) {
871 lhsSize *= lhsShape[currLhsDim];
874 rhsSize *= rhsShape[currRhsDim];
877 if (lhsSize == rhsSize) {
878 intermediateShape.push_back(lhsSize);
886 while (currLhsDim < lhsShape.size()) {
887 if (lhsShape[currLhsDim++] != 1) {
892 while (currRhsDim < rhsShape.size()) {
893 if (rhsShape[currRhsDim++] != 1) {
909 for (
int i = 0, s = srcShape.size(); i < s; ++i)
911 reassociationMap = {exprs};
915 if (dstShape.empty()) {
916 reassociationMap = {};
920 reassociationMap.resize(dstShape.size());
921 unsigned currSrcDim = 0, currDstDim = 0;
922 while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
923 int64_t dstSize = dstShape[currDstDim];
924 int64_t srcSize = srcShape[currSrcDim];
925 while (srcSize < dstSize && currSrcDim < srcShape.size()) {
926 reassociationMap[currDstDim].push_back(
928 srcSize *= srcShape[currSrcDim];
930 if (srcSize == dstSize) {
931 reassociationMap[currDstDim].push_back(
935 if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
936 while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
937 reassociationMap[currDstDim].push_back(
947 return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
952 template <
typename SrcOp>
968 matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
970 ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
971 ShapedType resultTy = reshape.getType().template cast<ShapedType>();
972 bool isDynamic = !operandTy.hasStaticShape();
974 if (isDynamic && resultTy.getRank() != 1) {
976 reshape,
"Cannot collapse dynamic dims to more than one dimension");
979 if (operandTy == resultTy) {
980 rewriter.
replaceOp(reshape, adaptor.getOperands()[0]);
987 reassociationMap, isDynamic)) {
990 "tosa.reshape Attempting to collapse into an incompatible shape");
995 intermediateShape, isDynamic)) {
997 reshape,
"tosa.reshape Cannot collapse into given shape");
1001 reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
1011 matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
1013 ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
1014 ShapedType resultTy = reshape.getType().template cast<ShapedType>();
1015 bool isDynamic = !operandTy.hasStaticShape();
1017 if (operandTy == resultTy) {
1018 rewriter.
replaceOp(reshape, adaptor.getOperands()[0]);
1022 if (isDynamic && operandTy.getRank() != 1) {
1024 reshape,
"Cannot expand dynamic dims from more than one dimension");
1029 operandTy.getShape(),
1030 reassociationMap, isDynamic)) {
1033 "tosa.reshape Attempting to expand into an incompatible shape");
1038 intermediateShape, isDynamic) ||
1039 intermediateShape != operandTy.getShape()) {
1041 reshape,
"tosa.reshape Cannot expand into given shape");
1044 reshape, resultTy, adaptor.getOperands()[0], reassociationMap);
1049 class ReshapeConverterCollapseExpand
1055 matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
1057 ShapedType operandTy = adaptor.input1().getType().cast<ShapedType>();
1058 ShapedType resultTy = reshape.getType().template cast<ShapedType>();
1059 bool isDynamic = !operandTy.hasStaticShape();
1061 if (operandTy == resultTy) {
1062 rewriter.
replaceOp(reshape, adaptor.getOperands()[0]);
1068 intermediateShape, isDynamic)) {
1070 reshape,
"tosa.reshape Cannot identify an intermediate shape between " 1071 "the given two shapes");
1074 Value collapse = rewriter.
create<tosa::ReshapeOp>(
1076 RankedTensorType::get(intermediateShape,
1077 reshape.getType().getElementType()),
1080 rewriter.
create<tosa::ReshapeOp>(reshape.getLoc(), resultTy, collapse);
1098 auto loc = op.getLoc();
1099 auto input = op->getOperand(0);
1100 auto resultTy = op.getType().cast<ShapedType>();
1103 dynDims.resize(op->getResult(0).getType().cast<ShapedType>().getRank());
1106 inputExprs.resize(resultTy.getRank());
1107 auto operandTy = input.getType().cast<ShapedType>();
1108 for (
const auto &permutation :
llvm::enumerate(perms.getValues<APInt>())) {
1109 auto index = permutation.index();
1110 auto value = permutation.value().getZExtValue();
1111 if (!operandTy.hasRank() || operandTy.isDynamicDim(index)) {
1112 dynDims[
value] = rewriter.
create<tensor::DimOp>(loc, input, index);
1119 auto initTensor = rewriter.
create<linalg::InitTensorOp>(
1120 loc, filteredDims, resultTy.getShape(), resultTy.getElementType());
1128 op, resultTy, op.input1(),
ValueRange{initTensor}, affineMaps,
1131 nestedBuilder.
create<linalg::YieldOp>(loc, *args.begin());
1143 auto loc = op.getLoc();
1144 auto input = op.input();
1145 auto inputTy = op.input().getType().cast<ShapedType>();
1146 auto outputTy = op.output().getType().cast<ShapedType>();
1147 unsigned rank = inputTy.getRank();
1150 if (op.double_round() && !op.scale32())
1152 op,
"tosa.rescale requires scale32 for double_round to be true");
1154 auto dynamicDimsOr =
1156 if (!dynamicDimsOr.hasValue())
1168 for (
int i = 0, s = multiplierValues.size(); i < s; i++) {
1169 if (shiftValues[i] > 63) {
1171 multiplierValues[i] = 0;
1178 op.double_round() &&
1179 llvm::any_of(shiftValues, [](int32_t v) {
return v > 31; });
1187 Value multiplierConstant;
1188 int64_t multiplierArg = 0;
1189 if (multiplierValues.size() == 1) {
1190 multiplierConstant = rewriter.
create<arith::ConstantOp>(
1195 auto multiplierType =
1196 RankedTensorType::get({
static_cast<int64_t
>(multiplierValues.size())},
1198 genericInputs.push_back(rewriter.
create<arith::ConstantOp>(
1205 multiplierArg = indexingMaps.size() - 1;
1210 Value shiftConstant;
1211 int64_t shiftArg = 0;
1212 if (shiftValues.size() == 1) {
1213 shiftConstant = rewriter.
create<arith::ConstantOp>(
1219 RankedTensorType::get({
static_cast<int64_t
>(shiftValues.size())},
1221 genericInputs.push_back(rewriter.
create<arith::ConstantOp>(
1226 shiftArg = indexingMaps.size() - 1;
1233 Value initTensor = rewriter.
create<linalg::InitTensorOp>(
1234 loc, dynamicDims, outputTy.getShape(), outputTy.getElementType());
1236 auto linalgOp = rewriter.
create<linalg::GenericOp>(
1237 loc, outputTy, genericInputs,
ValueRange{initTensor}, indexingMaps,
1249 auto inputZp = createConstFromIntAttribute<int32_t>(
1250 op,
"input_zp", nestedBuilder.getIntegerType(inBitwidth),
1252 auto outputZp = createConstFromIntAttribute<int32_t>(
1253 op,
"output_zp", nestedBuilder.getI32Type(), nestedBuilder);
1255 Value multiplier = multiplierConstant ? multiplierConstant
1256 : blockArgs[multiplierArg];
1257 Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1261 value = nestedBuilder
1262 .create<UnrealizedConversionCastOp>(
1264 nestedBuilder.getIntegerType(
1268 value = nestedBuilder.create<arith::ExtUIOp>(
1269 nestedLoc, nestedBuilder.getI32Type(),
value);
1271 value = nestedBuilder.create<arith::ExtSIOp>(
1272 nestedLoc, nestedBuilder.getI32Type(),
value);
1277 nestedBuilder.create<arith::SubIOp>(nestedLoc,
value, inputZp);
1279 value = nestedBuilder.create<tosa::ApplyScaleOp>(
1280 loc, nestedBuilder.getI32Type(),
value, multiplier, shift,
1281 nestedBuilder.getBoolAttr(doubleRound));
1285 nestedBuilder.create<arith::AddIOp>(nestedLoc,
value, outputZp);
1288 IntegerType outIntType =
1290 unsigned outBitWidth = outIntType.getWidth();
1292 int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue();
1293 int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue();
1296 if (outIntType.isUnsignedInteger()) {
1298 intMax = APInt::getMaxValue(outBitWidth).getZExtValue();
1301 auto intMinVal = nestedBuilder.create<arith::ConstantOp>(
1302 loc, nestedBuilder.getI32IntegerAttr(intMin));
1303 auto intMaxVal = nestedBuilder.create<arith::ConstantOp>(
1304 loc, nestedBuilder.getI32IntegerAttr(intMax));
1306 value = clampHelper<arith::CmpIOp>(
1307 nestedLoc,
value, intMinVal, intMaxVal, arith::CmpIPredicate::slt,
1310 if (outIntType.getWidth() < 32) {
1311 value = nestedBuilder.create<arith::TruncIOp>(
1315 if (outIntType.isUnsignedInteger()) {
1316 value = nestedBuilder
1317 .create<UnrealizedConversionCastOp>(nestedLoc,
1323 nestedBuilder.create<linalg::YieldOp>(loc,
value);
1326 rewriter.
replaceOp(op, linalgOp->getResults());
1338 auto input = op.input();
1339 auto inputTy = input.getType().
cast<ShapedType>();
1340 auto resultTy = op.getType().cast<ShapedType>();
1341 auto resultElementTy = resultTy.getElementType();
1343 auto imageH = inputTy.getShape()[1];
1344 auto imageW = inputTy.getShape()[2];
1346 auto dynamicDimsOr =
1348 if (!dynamicDimsOr.hasValue())
1352 if (op.mode() !=
"NEAREST_NEIGHBOR" && op.mode() !=
"BILINEAR")
1355 auto initTensor = rewriter.
create<linalg::InitTensorOp>(
1356 loc, dynamicDims, resultTy.getShape(), resultElementTy);
1361 auto genericOp = rewriter.
create<linalg::GenericOp>(
1364 rewriter.
replaceOp(op, genericOp.getResult(0));
1367 rewriter.
createBlock(&genericOp.region(), genericOp.region().end(),
1369 Value batch = rewriter.
create<linalg::IndexOp>(loc, 0);
1370 Value y = rewriter.
create<linalg::IndexOp>(loc, 1);
1371 Value x = rewriter.
create<linalg::IndexOp>(loc, 2);
1372 Value channel = rewriter.
create<linalg::IndexOp>(loc, 3);
1376 auto hMax = rewriter.
create<arith::ConstantOp>(
1378 auto wMax = rewriter.
create<arith::ConstantOp>(
1386 int32_t shift = op.shift();
1387 bool floatingPointMode = shift == 0;
1389 Value yStride, xStride, yOffset, xOffset;
1390 if (floatingPointMode) {
1391 yStride = rewriter.
create<arith::ConstantOp>(loc, op.stride_fp()[0]);
1392 xStride = rewriter.
create<arith::ConstantOp>(loc, op.stride_fp()[1]);
1393 yOffset = rewriter.
create<arith::ConstantOp>(loc, op.offset_fp()[0]);
1394 xOffset = rewriter.
create<arith::ConstantOp>(loc, op.offset_fp()[1]);
1400 yStride = rewriter.
create<arith::ConstantOp>(
1402 xStride = rewriter.
create<arith::ConstantOp>(
1404 yOffset = rewriter.
create<arith::ConstantOp>(
1406 xOffset = rewriter.
create<arith::ConstantOp>(
1414 Value ix, iy, dx, dy;
1415 if (floatingPointMode) {
1421 y = rewriter.
create<arith::MulFOp>(loc, y, yStride);
1422 x = rewriter.
create<arith::MulFOp>(loc, x, xStride);
1424 y = rewriter.
create<arith::AddFOp>(loc, y, yOffset);
1425 x = rewriter.
create<arith::AddFOp>(loc, x, xOffset);
1427 iy = rewriter.
create<math::FloorOp>(loc, y);
1428 ix = rewriter.
create<math::FloorOp>(loc, x);
1430 dy = rewriter.
create<arith::SubFOp>(loc, y, iy);
1431 dx = rewriter.
create<arith::SubFOp>(loc, x, ix);
1436 Value shiftVal = rewriter.
create<arith::ConstantOp>(
1439 Value y = rewriter.
create<arith::MulIOp>(loc, inY, yStride);
1440 Value x = rewriter.
create<arith::MulIOp>(loc, inX, xStride);
1442 y = rewriter.
create<arith::AddIOp>(loc, y, yOffset);
1443 x = rewriter.
create<arith::AddIOp>(loc, x, xOffset);
1445 iy = rewriter.
create<arith::ShRSIOp>(loc, y, shiftVal);
1446 ix = rewriter.
create<arith::ShRSIOp>(loc, x, shiftVal);
1448 Value yTrunc = rewriter.
create<arith::ShLIOp>(loc, iy, shiftVal);
1449 Value xTrunc = rewriter.
create<arith::ShLIOp>(loc, ix, shiftVal);
1451 dy = rewriter.
create<arith::SubIOp>(loc, y, yTrunc);
1452 dx = rewriter.
create<arith::SubIOp>(loc, x, xTrunc);
1455 if (op.mode() ==
"NEAREST_NEIGHBOR") {
1458 if (floatingPointMode) {
1459 auto halfVal = rewriter.
create<arith::ConstantOp>(
1461 yPred = rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
1463 xPred = rewriter.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
1466 auto halfVal = rewriter.
create<arith::ConstantOp>(
1468 yPred = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
1470 xPred = rewriter.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
1474 auto zeroVal = rewriter.
create<arith::ConstantOp>(
1476 auto oneVal = rewriter.
create<arith::ConstantOp>(
1480 rewriter.
create<arith::SelectOp>(loc, yPred, oneVal, zeroVal);
1482 rewriter.
create<arith::SelectOp>(loc, xPred, oneVal, zeroVal);
1484 iy = rewriter.
create<arith::AddIOp>(loc, iy, yOffset);
1485 ix = rewriter.
create<arith::AddIOp>(loc, ix, xOffset);
1489 iy = clampHelper<arith::CmpIOp>(loc, iy, hwMin, hMax,
1490 arith::CmpIPredicate::slt, rewriter);
1491 ix = clampHelper<arith::CmpIOp>(loc, ix, hwMin, wMax,
1492 arith::CmpIPredicate::slt, rewriter);
1500 Value result = rewriter.
create<tensor::ExtractOp>(
1501 loc, input,
ValueRange{batch, iy, ix, channel});
1503 rewriter.
create<linalg::YieldOp>(loc, result);
1508 if (op.mode() ==
"BILINEAR") {
1512 auto oneVal = rewriter.
create<arith::ConstantOp>(
1514 Value y1 = rewriter.
create<arith::AddIOp>(loc, y0, oneVal);
1515 Value x1 = rewriter.
create<arith::AddIOp>(loc, x0, oneVal);
1517 y0 = clampHelper<arith::CmpIOp>(loc, y0, hwMin, hMax,
1518 arith::CmpIPredicate::slt, rewriter);
1519 y1 = clampHelper<arith::CmpIOp>(loc, y1, hwMin, hMax,
1520 arith::CmpIPredicate::slt, rewriter);
1522 x0 = clampHelper<arith::CmpIOp>(loc, x0, hwMin, wMax,
1523 arith::CmpIPredicate::slt, rewriter);
1524 x1 = clampHelper<arith::CmpIOp>(loc, x1, hwMin, wMax,
1525 arith::CmpIPredicate::slt, rewriter);
1537 loc, input,
ValueRange{batch, y0, x0, channel});
1539 loc, input,
ValueRange{batch, y0, x1, channel});
1541 loc, input,
ValueRange{batch, y1, x0, channel});
1543 loc, input,
ValueRange{batch, y1, x1, channel});
1545 if (floatingPointMode) {
1546 auto oneVal = rewriter.
create<arith::ConstantOp>(
1548 Value rightPart = dx;
1549 Value leftPart = rewriter.
create<arith::SubFOp>(loc, oneVal, dx);
1551 y0x0 = rewriter.
create<arith::MulFOp>(loc, y0x0, leftPart);
1552 y0x1 = rewriter.
create<arith::MulFOp>(loc, y0x1, rightPart);
1553 Value topAcc = rewriter.
create<arith::AddFOp>(loc, y0x0, y0x1);
1555 y1x0 = rewriter.
create<arith::MulFOp>(loc, y1x0, leftPart);
1556 y1x1 = rewriter.
create<arith::MulFOp>(loc, y1x1, rightPart);
1557 Value bottomAcc = rewriter.
create<arith::AddFOp>(loc, y1x0, y1x1);
1559 Value bottomPart = dy;
1560 Value topPart = rewriter.
create<arith::SubFOp>(loc, oneVal, dy);
1561 topAcc = rewriter.
create<arith::MulFOp>(loc, topAcc, topPart);
1562 bottomAcc = rewriter.
create<arith::MulFOp>(loc, bottomAcc, bottomPart);
1563 Value result = rewriter.
create<arith::AddFOp>(loc, topAcc, bottomAcc);
1565 rewriter.
create<linalg::YieldOp>(loc, result);
1568 y0x0 = rewriter.
create<arith::ExtSIOp>(loc, resultElementTy, y0x0);
1569 y0x1 = rewriter.
create<arith::ExtSIOp>(loc, resultElementTy, y0x1);
1570 y1x0 = rewriter.
create<arith::ExtSIOp>(loc, resultElementTy, y1x0);
1571 y1x1 = rewriter.
create<arith::ExtSIOp>(loc, resultElementTy, y1x1);
1573 if (resultElementTy.getIntOrFloatBitWidth() > 32) {
1574 dx = rewriter.
create<arith::ExtSIOp>(loc, resultElementTy, dx);
1575 dy = rewriter.
create<arith::ExtSIOp>(loc, resultElementTy, dy);
1578 auto unitVal = rewriter.
create<arith::ConstantOp>(
1580 Value rightPart = dx;
1581 Value leftPart = rewriter.
create<arith::SubIOp>(loc, unitVal, dx);
1583 y0x0 = rewriter.
create<arith::MulIOp>(loc, y0x0, leftPart);
1584 y0x1 = rewriter.
create<arith::MulIOp>(loc, y0x1, rightPart);
1585 Value topAcc = rewriter.
create<arith::AddIOp>(loc, y0x0, y0x1);
1587 y1x0 = rewriter.
create<arith::MulIOp>(loc, y1x0, leftPart);
1588 y1x1 = rewriter.
create<arith::MulIOp>(loc, y1x1, rightPart);
1589 Value bottomAcc = rewriter.
create<arith::AddIOp>(loc, y1x0, y1x1);
1591 Value bottomPart = dy;
1592 Value topPart = rewriter.
create<arith::SubIOp>(loc, unitVal, dy);
1593 topAcc = rewriter.
create<arith::MulIOp>(loc, topAcc, topPart);
1594 bottomAcc = rewriter.
create<arith::MulIOp>(loc, bottomAcc, bottomPart);
1595 Value result = rewriter.
create<arith::AddIOp>(loc, topAcc, bottomAcc);
1597 rewriter.
create<linalg::YieldOp>(loc, result);
1607 template <
typename SrcOp>
1614 rewriter.
replaceOp(op, op.getOperation()->getOperands());
1619 template <
typename SrcOp>
1634 matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor,
1636 auto inputType = op.getOperand(0).getType().template cast<ShapedType>();
1637 auto resultType = op.getType().dyn_cast<RankedTensorType>();
1640 int axis = op.axis();
1643 int rank = resultType.getRank();
1645 sizes.reserve(rank);
1650 for (
int i = 0; i < rank; ++i) {
1652 loc, adaptor.getOperands()[0], i));
1653 if (inputType.isDynamicDim(i)) {
1655 rewriter.
create<tensor::DimOp>(loc, op.getOperand(0), i));
1659 Value resultDimSize = sizes[axis];
1660 for (
auto arg : adaptor.getOperands().drop_front()) {
1661 auto size = rewriter.
createOrFold<tensor::DimOp>(loc, arg, axisValue);
1663 rewriter.
createOrFold<arith::AddIOp>(loc, resultDimSize, size);
1665 sizes[axis] = resultDimSize;
1667 Value init = rewriter.
create<linalg::InitTensorOp>(
1668 loc, dynDims, resultType.getShape(), resultType.getElementType());
1671 loc, rewriter.
getZeroAttr(resultType.getElementType()));
1681 return op.getValue();
1683 for (
auto arg : adaptor.getOperands()) {
1684 sizes[axis] = rewriter.
createOrFold<tensor::DimOp>(loc, arg, axisValue);
1687 llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)),
1688 llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)),
1689 llvm::to_vector(llvm::map_range(strides, toOpFoldResult)));
1691 rewriter.
createOrFold<arith::AddIOp>(loc, offsets[axis], sizes[axis]);
1704 auto loc = op.getLoc();
1705 Value input = op.input();
1706 auto inputTy = input.
getType().template cast<ShapedType>();
1707 auto resultTy = op.getType().template cast<ShapedType>();
1708 auto axis = op.axis();
1711 for (
int i = 0; i < inputTy.getRank(); i++) {
1712 if (inputTy.isDynamicDim(i)) {
1713 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1717 Value axisDimSize = rewriter.
create<tensor::DimOp>(loc, input, axis);
1720 auto initTensor = rewriter
1721 .
create<linalg::InitTensorOp>(
1723 inputTy.getShape(), inputTy.getElementType())
1733 for (
unsigned int i = 0; i < inputTy.getRank(); i++) {
1735 rewriter.
create<linalg::IndexOp>(nestedLoc, i).getResult();
1739 rewriter.
create<arith::SubIOp>(nestedLoc, axisDimSize, one);
1740 index = rewriter.
create<arith::SubIOp>(nestedLoc, sizeMinusOne,
1744 indices.push_back(index);
1747 auto extract = nestedBuilder.
create<tensor::ExtractOp>(
1748 nestedLoc, input, indices);
1749 nestedBuilder.
create<linalg::YieldOp>(op.getLoc(),
1750 extract.getResult());
1764 matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor,
1766 auto loc = op.getLoc();
1767 auto input = op.input1();
1768 auto inputTy = input.getType().cast<ShapedType>();
1769 auto inputShape = inputTy.getShape();
1770 auto resultTy = op.getType().cast<ShapedType>();
1771 auto elementTy = inputTy.getElementType();
1772 int64_t rank = inputTy.getRank();
1779 for (
int i = 0; i < rank; i++) {
1780 genericShape.push_back(multiples[i]);
1781 genericShape.push_back(inputShape[i]);
1785 for (
int i = 0; i < inputTy.getRank(); i++) {
1786 if (inputTy.isDynamicDim(i) || multiples[i] == -1) {
1787 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1791 auto initTensor = rewriter.
create<linalg::InitTensorOp>(
1792 op.getLoc(), dynDims, genericShape, elementTy);
1796 dimExprs.reserve(rank);
1797 for (
unsigned i = 0; i < rank; ++i)
1800 auto readAffineMap =
1807 auto genericOp = rewriter.
create<linalg::GenericOp>(
1808 loc, RankedTensorType::get(genericShape, elementTy), input,
1812 nestedBuilder.
create<linalg::YieldOp>(op.getLoc(), *args.begin());
1816 op, resultTy, genericOp.getResult(0),
1828 auto loc = padOp.getLoc();
1829 auto input = padOp.input1();
1830 auto padding = padOp.padding();
1832 ShapedType inputTy = input.getType().cast<ShapedType>();
1833 Type elementTy = inputTy.getElementType();
1834 int64_t rank = inputTy.getRank();
1840 if (padOp.pad_const()) {
1841 padConstant = rewriter.
createOrFold<tensor::ExtractOp>(
1847 }
else if (elementTy.isa<IntegerType>() && !padOp.quantization_info()) {
1849 }
else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
1850 int64_t
value = padOp.quantization_info()->getInputZp();
1854 padConstant = rewriter.
create<arith::ConstantOp>(loc, constantAttr);
1859 padOp,
"tosa.pad was unable to determine the pad constant value.");
1870 lowValues.reserve(rank);
1871 highValues.reserve(rank);
1873 for (
int i = 0; i < rank; i++) {
1876 loc, padding,
ValueRange({inputIndex, lowIndex}));
1878 loc, padding,
ValueRange({inputIndex, highIndex}));
1885 lowValues.push_back(lowVal);
1886 highValues.push_back(highVal);
1890 padOp.getType(), input, padConstant, lowValues, highValues,
1891 false, loc, rewriter);
1893 rewriter.
replaceOp(padOp, newPadOp.getResult());
1917 auto loc = argmaxOp.getLoc();
1918 Value input = argmaxOp.input();
1919 auto inputTy = input.
getType().
cast<ShapedType>();
1920 auto resultTy = argmaxOp.output().getType().cast<ShapedType>();
1921 auto inElementTy = inputTy.getElementType();
1922 auto outElementTy = resultTy.getElementType();
1923 int axis = argmaxOp.axis();
1924 auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy);
1926 if (!outElementTy.isa<IntegerType>())
1929 "tosa.arg_max to linalg.* requires integer-like result type");
1932 for (
int i = 0; i < inputTy.getRank(); i++) {
1933 if (inputTy.isDynamicDim(i) && i != axis) {
1934 dynDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, i));
1939 auto initTensorIdx =
1941 .
create<linalg::InitTensorOp>(loc, dynDims, resultTy.getShape(),
1944 auto fillValueIdx = rewriter.
create<arith::ConstantOp>(
1946 auto filledTensorIdx =
1953 auto initTensorMax = rewriter
1954 .
create<linalg::InitTensorOp>(
1955 loc, dynDims, resultTy.getShape(), inElementTy)
1957 auto fillValueMaxAttr =
1960 if (!fillValueMaxAttr)
1962 argmaxOp,
"unsupported tosa.argmax element type");
1965 rewriter.
create<arith::ConstantOp>(loc, fillValueMaxAttr);
1966 auto filledTensorMax =
1980 for (
int i = 0, rank = inputTy.getRank(); i != rank; ++i) {
1986 bool didEncounterError =
false;
1988 auto linalgOp = rewriter.
create<linalg::GenericOp>(
1990 ValueRange({filledTensorIdx, filledTensorMax}), maps, iteratorTypes,
1993 auto newValue = blockArgs[0];
1994 auto oldIndex = blockArgs[1];
1995 auto oldValue = blockArgs[2];
1997 Value newIndex = rewriter.
create<arith::IndexCastOp>(
1998 nestedLoc, oldIndex.getType(),
1999 rewriter.
create<linalg::IndexOp>(loc, axis));
2003 predicate = rewriter.
create<arith::CmpFOp>(
2004 nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
2005 }
else if (inElementTy.isa<IntegerType>()) {
2006 predicate = rewriter.
create<arith::CmpIOp>(
2007 nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
2009 didEncounterError =
true;
2013 auto resultMax = rewriter.
create<arith::SelectOp>(
2014 nestedLoc, predicate, newValue, oldValue);
2016 nestedLoc, predicate, newIndex, oldIndex);
2017 nestedBuilder.
create<linalg::YieldOp>(
2021 if (didEncounterError)
2023 argmaxOp,
"unsupported tosa.argmax element type");
2025 rewriter.
replaceOp(argmaxOp, linalgOp.getResult(0));
2034 matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor,
2036 auto input = adaptor.getOperands()[0];
2037 auto indices = adaptor.getOperands()[1];
2039 auto resultTy = op.getType().cast<ShapedType>();
2041 auto dynamicDimsOr =
2043 if (!dynamicDimsOr.hasValue())
2047 auto resultElementTy = resultTy.getElementType();
2049 auto loc = op.getLoc();
2053 .
create<linalg::InitTensorOp>(loc, dynamicDims, resultTy.getShape(),
2059 resultTy.getRank(), 0,
2064 auto genericOp = rewriter.
create<linalg::GenericOp>(
2069 auto indexValue = args[0];
2070 auto index0 = rewriter.
create<linalg::IndexOp>(loc, 0);
2071 Value index1 = rewriter.
create<arith::IndexCastOp>(
2073 auto index2 = rewriter.
create<linalg::IndexOp>(loc, 2);
2074 Value extract = rewriter.
create<tensor::ExtractOp>(
2075 loc, input,
ValueRange{index0, index1, index2});
2076 rewriter.
create<linalg::YieldOp>(loc, extract);
2078 rewriter.
replaceOp(op, genericOp.getResult(0));
2092 auto loc = op.getLoc();
2093 Value input = op.input();
2094 Value table = op.table();
2095 auto inputTy = input.
getType().
cast<ShapedType>();
2096 auto tableTy = table.
getType().
cast<ShapedType>();
2097 auto resultTy = op.getType().cast<ShapedType>();
2099 auto inputElementTy = inputTy.getElementType();
2100 auto tableElementTy = tableTy.getElementType();
2101 auto resultElementTy = resultTy.getElementType();
2104 for (
int i = 0; i < resultTy.getRank(); ++i) {
2105 if (inputTy.isDynamicDim(i)) {
2107 rewriter.
create<tensor::DimOp>(loc, op.getOperand(0), i));
2113 .
create<linalg::InitTensorOp>(loc, dynDims, resultTy.getShape(),
2121 auto genericOp = rewriter.
create<linalg::GenericOp>(
2124 rewriter.
replaceOp(op, genericOp.getResult(0));
2129 &genericOp.region(), genericOp.region().end(),
2130 TypeRange({inputElementTy, resultElementTy}), {loc, loc});
2134 if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
2135 resultElementTy.isInteger(8)) {
2136 Value index = rewriter.
create<arith::IndexCastOp>(
2143 rewriter.
create<linalg::YieldOp>(loc, extract);
2147 if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
2148 resultElementTy.isInteger(32)) {
2152 auto offset = rewriter.
create<arith::ConstantOp>(
2154 auto seven = rewriter.
create<arith::ConstantOp>(
2156 auto one = rewriter.
create<arith::ConstantOp>(
2158 auto b1111111 = rewriter.
create<arith::ConstantOp>(
2165 auto extendAdd = rewriter.
create<arith::AddIOp>(loc, extend, offset);
2166 Value index = rewriter.
create<arith::ShRUIOp>(loc, extendAdd, seven);
2168 rewriter.
create<arith::AndIOp>(loc, extendAdd, b1111111);
2173 Value indexPlusOne = rewriter.
create<arith::AddIOp>(loc, index, one);
2175 index = rewriter.
create<arith::IndexCastOp>(
2177 indexPlusOne = rewriter.
create<arith::IndexCastOp>(
2192 Value baseScaled = rewriter.
create<arith::ShLIOp>(loc, base, seven);
2193 Value diff = rewriter.
create<arith::SubIOp>(loc, next, base);
2194 Value diffScaled = rewriter.
create<arith::MulIOp>(loc, diff, fraction);
2196 rewriter.
create<arith::AddIOp>(loc, baseScaled, diffScaled);
2198 rewriter.
create<linalg::YieldOp>(loc, result);
2205 op,
"unable to create body for tosa.table op");
2215 PointwiseConverter<tosa::AddOp>,
2216 PointwiseConverter<tosa::SubOp>,
2217 PointwiseConverter<tosa::MulOp>,
2218 PointwiseConverter<tosa::DivOp>,
2219 PointwiseConverter<tosa::NegateOp>,
2220 PointwiseConverter<tosa::PowOp>,
2221 PointwiseConverter<tosa::ReciprocalOp>,
2222 PointwiseConverter<tosa::RsqrtOp>,
2223 PointwiseConverter<tosa::LogOp>,
2224 PointwiseConverter<tosa::ExpOp>,
2225 PointwiseConverter<tosa::AbsOp>,
2226 PointwiseConverter<tosa::TanhOp>,
2227 PointwiseConverter<tosa::BitwiseAndOp>,
2228 PointwiseConverter<tosa::BitwiseOrOp>,
2229 PointwiseConverter<tosa::BitwiseNotOp>,
2230 PointwiseConverter<tosa::BitwiseXorOp>,
2231 PointwiseConverter<tosa::LogicalAndOp>,
2232 PointwiseConverter<tosa::LogicalNotOp>,
2233 PointwiseConverter<tosa::LogicalOrOp>,
2234 PointwiseConverter<tosa::LogicalXorOp>,
2235 PointwiseConverter<tosa::CastOp>,
2236 PointwiseConverter<tosa::LogicalLeftShiftOp>,
2237 PointwiseConverter<tosa::LogicalRightShiftOp>,
2238 PointwiseConverter<tosa::ArithmeticRightShiftOp>,
2239 PointwiseConverter<tosa::ClzOp>,
2240 PointwiseConverter<tosa::SelectOp>,
2241 PointwiseConverter<tosa::GreaterOp>,
2242 PointwiseConverter<tosa::GreaterEqualOp>,
2243 PointwiseConverter<tosa::EqualOp>,
2244 PointwiseConverter<tosa::MaximumOp>,
2245 PointwiseConverter<tosa::MinimumOp>,
2246 PointwiseConverter<tosa::CeilOp>,
2247 PointwiseConverter<tosa::FloorOp>,
2248 PointwiseConverter<tosa::ClampOp>,
2249 PointwiseConverter<tosa::ReluNOp>,
2250 PointwiseConverter<tosa::SigmoidOp>,
2251 IdentityNConverter<tosa::IdentityOp>,
2252 ReduceConverter<tosa::ReduceAllOp>,
2253 ReduceConverter<tosa::ReduceAnyOp>,
2254 ReduceConverter<tosa::ReduceMinOp>,
2255 ReduceConverter<tosa::ReduceMaxOp>,
2256 ReduceConverter<tosa::ReduceSumOp>,
2257 ReduceConverter<tosa::ReduceProdOp>,
2262 ReshapeConverterCollapse,
2263 ReshapeConverterExpand,
2264 ReshapeConverterCollapseExpand,
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
MLIRContext * getContext() const
constexpr StringRef getParallelIteratorTypeName()
Use to encode that a particular iterator type has parallel semantics.
static Value createLinalgBodyCalculationForReduceOp(Operation *op, ValueRange args, Type elementTy, PatternRewriter &rewriter)
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
AffineMap getMultiDimIdentityMap(unsigned rank)
Operation is a basic unit of execution within MLIR.
Attribute getZeroAttr(Type type)
operand_range getOperands()
Returns an iterator on the underlying Value's.
Specialization of arith.constant op that returns an integer value.
Block represents an ordered list of Operations.
static Attribute createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter)
This class represents a single result from folding an operation.
Value getOperand(unsigned idx)
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
unsigned getNumOperands()
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
void getValuesFromIntArrayAttribute(ArrayAttr attr, SmallVector< T > &arrayValues)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
SmallVector< StringRef > getNParallelLoopsAttrs(unsigned nParallelLoops)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
BlockArgument getArgument(unsigned i)
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
FloatAttr getFloatAttr(Type type, double value)
IntegerAttr getI32IntegerAttr(int32_t value)
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerAttr getI8IntegerAttr(int8_t value)
Attributes are known-constant values of operations.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
PadOp createPadScalarOp(Type type, Value source, Value pad, ArrayRef< OpFoldResult > low, ArrayRef< OpFoldResult > high, bool nofold, Location loc, OpBuilder &builder)
IntegerType getIntegerType(unsigned width)
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
This class provides an abstraction over the various different ranges of value types.
Location getLoc()
The source location the operation was defined or derived from.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
static bool findIntermediateShape(ArrayRef< int64_t > lhsShape, ArrayRef< int64_t > rhsShape, SmallVector< int64_t > &intermediateShape, bool isDynamic)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, ArrayRef< Type > resultTypes, PatternRewriter &rewriter)
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static SmallVector< int64_t, 8 > subtract(ArrayRef< int64_t > vecA, ArrayRef< int64_t > vecB)
static int resultIndex(int i)
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
RAII guard to reset the insertion point of the builder when destroyed.
static LogicalResult reduceMatchAndRewriteHelper(Operation *op, uint64_t axis, PatternRewriter &rewriter)
Type getType() const
Return the type of this value.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
static bool createReassociationMapsForCollapse(PatternRewriter &rewriter, ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape, SmallVector< ReassociationExprs, 4 > &reassociationMap, bool isDynamic)
Specialization of arith.constant op that returns an integer of index type.
BoolAttr getBoolAttr(bool value)
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class implements a pattern rewriter for use with ConversionPatterns.
AffineExpr getAffineDimExpr(unsigned position)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure...
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns)
Populates conversion passes from TOSA dialect to Linalg dialect.
unsigned getNumResults()
Return the number of results held by this operation.
SlowMPInt abs(const SlowMPInt &x)
Redeclarations of friend declarations above to make it discoverable by lookups.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=llvm::None, ArrayRef< Location > locs=llvm::None)
Add new block with 'argTypes' arguments and set the insertion point to the end of it...
static LogicalResult elementwiseMatchAndRewriteHelper(Operation *operation, PatternRewriter &rewriter)
constexpr StringRef getReductionIteratorTypeName()
Use to encode that a particular iterator type has reduction semantics.
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
FloatAttr getF32FloatAttr(float value)
result_range getResults()
This class helps build Operations.
This class provides an abstraction over the different types of ranges over Values.
IntegerAttr getIndexAttr(int64_t value)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
static arith::ConstantOp createConstFromIntAttribute(Operation *op, const std::string &attrName, Type requiredAttrType, OpBuilder &rewriter)
MLIRContext * getContext() const
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
An attribute that represents a reference to a dense integer vector or tensor object.