24 #include <type_traits>
32 if (llvm::all_of(pad, [](int64_t p) {
return p == 0; }))
35 ShapedType inputTy = cast<ShapedType>(input.
getType());
36 Type inputETy = inputTy.getElementType();
37 auto inputShape = inputTy.getShape();
39 assert((inputShape.size() * 2) == pad.size());
44 for (
size_t i : llvm::seq(inputShape.size())) {
45 auto lowPad = pad[i * 2];
46 auto highPad = pad[i * 2 + 1];
47 if (ShapedType::isDynamic(inputShape[i]))
48 paddedShape.push_back(inputShape[i]);
50 paddedShape.push_back(inputShape[i] + highPad + lowPad);
55 Value padValue = arith::ConstantOp::create(rewriter, loc, padAttr);
57 return tensor::PadOp::create(rewriter, loc,
59 input, lowIndices, highIndices, padValue);
66 ShapedType resultTy = cast<ShapedType>(conv.
getType());
67 return linalg::GenericOp::create(
68 rewriter, loc, resultTy,
ValueRange({bias, conv}), result,
71 Value biasVal = args[0];
72 Type resType = args[1].getType();
73 if (resType != biasVal.getType()) {
75 arith::ExtSIOp::create(builder, loc, resType, biasVal);
78 arith::AddIOp::create(builder, loc, biasVal, args[1]);
79 linalg::YieldOp::create(builder, loc, added);
88 ShapedType resultTy = cast<ShapedType>(result.
getType());
89 ShapedType sourceTy = cast<ShapedType>(source.
getType());
90 const int64_t resultRank = resultTy.getRank();
91 const int64_t sourceRank = sourceTy.getRank();
99 assert(sourceTy.hasStaticShape() &&
100 "Dynamic broadcasting shapes not supported!");
101 if (sourceRank == 1 && sourceTy.getDimSize(0) == 1) {
104 for (
auto dim : llvm::seq<int64_t>(0, sourceRank)) {
106 sourceDims.push_back(expr);
120 ShapedType resultTy = cast<ShapedType>(result.
getType());
121 const int64_t resultRank = resultTy.getRank();
128 return linalg::GenericOp::create(
129 rewriter, loc, resultTy,
ValueRange({source}), result,
132 Value biasVal = args[0];
133 Type resType = args[1].getType();
134 if (resType != biasVal.getType()) {
136 resultTy.getElementType().isFloat()
137 ? arith::ExtFOp::create(builder, loc, resType, biasVal)
139 : arith::ExtSIOp::create(builder, loc, resType,
143 linalg::YieldOp::create(builder, loc, biasVal);
150 return arith::ConstantIndexOp::create(builder, attr);
158 int64_t padBeforeAttr,
159 int64_t padAfterAttr,
Value kernelDim,
161 int64_t dilationAttr,
164 auto one = arith::ConstantOp::create(rewriter, loc,
167 Value paddedBefore = arith::AddIOp::create(builder, inputDim, padBefore);
169 Value paddedAfter = arith::AddIOp::create(builder, paddedBefore, padAfter);
171 Value subOne = arith::SubIOp::create(builder, kernelDim, one);
173 Value dilated = arith::MulIOp::create(builder, dilation, subOne);
174 Value addOne = arith::AddIOp::create(builder, dilated, one);
176 Value subtract = arith::SubIOp::create(builder, paddedAfter, addOne);
178 Value divide = arith::DivUIOp::create(builder, subtract, stride);
179 return arith::AddIOp::create(builder, divide, one);
188 ShapedType inputTy = cast<ShapedType>(input.
getType());
189 int64_t inputRank = inputTy.getRank();
192 dynDims.resize(resultTy.getRank());
194 for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
195 int64_t inputDim = inputSizeDims[i];
196 int64_t kernelDim = kernelSizeDims[i];
197 if (resultTy.isDynamicDim(inputDim)) {
198 auto padTop = padAttr[i * 2];
199 auto padBottom = padAttr[i * 2 + 1];
200 auto stride = strideAttr[i];
201 auto dilation = dilationAttr[i];
202 Value initDynDim = tensor::DimOp::create(rewriter, loc, input, inputDim);
204 tensor::DimOp::create(rewriter, loc, weight, kernelDim);
208 kernelDynDim, stride, dilation, rewriter);
213 for (
int i = 0; i < inputRank; i++) {
214 if (resultTy.isDynamicDim(i) && !dynDims[i])
215 dynDims[i] = tensor::DimOp::create(rewriter, loc, input, i);
227 reassociationMap.resize(outputRank);
228 for (
int i = 0; i < outputRank; i++) {
231 reassociationMap[outputRank - 1].push_back(
237 template <
typename TosaConvOp,
typename LinalgConvOp,
typename LinalgConvQOp>
242 matchAndRewrite(TosaConvOp op,
typename TosaConvOp::Adaptor adaptor,
245 Value input = op->getOperand(0);
246 Value weight = op->getOperand(1);
247 Value bias = op->getOperand(2);
249 ShapedType inputTy = cast<ShapedType>(input.
getType());
250 ShapedType weightTy = cast<ShapedType>(weight.
getType());
251 ShapedType biasTy = cast<ShapedType>(bias.
getType());
252 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
254 Type inputETy = inputTy.getElementType();
260 Type accETy = op.getAccType();
264 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
266 return rewriter.notifyMatchFailure(
267 op,
"input zero point cannot be statically determined");
269 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
271 return rewriter.notifyMatchFailure(
272 op,
"weight zero point cannot be statically determined");
274 const int64_t inputZpVal = *maybeIZp;
275 const int64_t weightZpVal = *maybeWZp;
277 if (op.verifyInputZeroPoint(inputZpVal).failed())
278 return rewriter.notifyMatchFailure(
279 op,
"input zero point must be zero for non-int8 integer types");
281 if (op.verifyWeightZeroPoint(weightZpVal).failed())
282 return rewriter.notifyMatchFailure(
283 op,
"weight zero point must be zero for non-int8 integer types");
285 bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
287 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
288 return rewriter.notifyMatchFailure(
289 op,
"tosa.conv ops require static shapes for weight and bias");
292 return rewriter.notifyMatchFailure(
293 op,
"tosa.conv ops does not support unsigned integer input");
297 for (
int i = 1; i < resultTy.getRank() - 1; i++) {
298 inputSizeDims.push_back(i);
299 kernelSizeDims.push_back(i);
303 loc, input, weight, resultTy, padAttr.
asArrayRef(),
305 inputSizeDims, kernelSizeDims, rewriter);
307 auto weightShape = weightTy.getShape();
310 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
319 if (inputZpVal < intMin || inputZpVal > intMax)
320 return rewriter.notifyMatchFailure(
321 op,
"tosa.conv op quantization has zp outside of input range");
323 zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
328 llvm::append_range(pad, padAttr.
asArrayRef());
329 pad.resize(pad.size() + 2, 0);
330 input =
applyPad(loc, input, pad, zeroAttr, rewriter);
332 if (4 == inputTy.getRank()) {
336 hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
337 : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
344 for (
int i = 1; i < resultTy.getRank(); i++)
345 weightPerm.push_back(i);
346 weightPerm.push_back(0);
349 for (
auto dim : weightPerm)
350 newWeightShape.push_back(weightShape[dim]);
351 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
354 weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
362 if (5 == inputTy.getRank()) {
366 for (
int i = 1; i < resultTy.getRank(); i++)
367 weightPerm.push_back(i);
368 weightPerm.push_back(0);
371 for (
auto dim : weightPerm)
372 newWeightShape.push_back(weightShape[dim]);
373 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
376 weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
385 auto strideAttr = rewriter.getI64TensorAttr(stride);
386 auto dilationAttr = rewriter.getI64TensorAttr(dilation);
388 Value biasEmptyTensor = tensor::EmptyOp::create(
389 rewriter, loc, resultTy.getShape(), accETy, filteredDims);
391 Value broadcastBias =
395 auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
396 auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
398 auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
399 auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp);
401 Value conv = LinalgConvQOp::create(
402 rewriter, loc, resultTy,
404 ValueRange{broadcastBias}, strideAttr, dilationAttr)
407 rewriter.replaceOp(op, conv);
411 Value conv = LinalgConvOp::create(
412 rewriter, loc, accTy,
ValueRange{input, weight},
413 ValueRange{broadcastBias}, strideAttr, dilationAttr)
418 if (resultTy != accTy)
419 conv = tosa::CastOp::create(rewriter, loc, resultTy, conv);
421 rewriter.replaceOp(op, conv);
426 class DepthwiseConvConverter
431 matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
434 Value input = op->getOperand(0);
435 Value weight = op->getOperand(1);
436 Value bias = op->getOperand(2);
438 ShapedType inputTy = cast<ShapedType>(input.
getType());
439 ShapedType weightTy = cast<ShapedType>(weight.
getType());
440 ShapedType biasTy = cast<ShapedType>(bias.
getType());
441 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
442 int64_t resultRank = resultTy.getRank();
444 Type inputETy = inputTy.getElementType();
445 Type resultETy = resultTy.getElementType();
447 auto padAttr = cast<DenseI64ArrayAttr>(op->getAttr(
"pad"));
448 auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr(
"stride"));
449 auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr(
"dilation"));
451 Type accETy = op.getAccType();
453 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
454 return rewriter.notifyMatchFailure(
455 op,
"tosa.depthwise_conv ops require static shapes");
459 loc, input, weight, resultTy, padAttr.
asArrayRef(),
466 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
467 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
469 return rewriter.notifyMatchFailure(
470 op,
"input zero point cannot be statically determined");
472 return rewriter.notifyMatchFailure(
473 op,
"weight zero point cannot be statically determined");
475 const int64_t inputZpVal = *maybeIZp;
476 const int64_t weightZpVal = *maybeWZp;
478 if (op.verifyInputZeroPoint(inputZpVal).failed())
479 return rewriter.notifyMatchFailure(
480 op,
"input zero point must be zero for non-int8 integer types");
482 if (op.verifyWeightZeroPoint(weightZpVal).failed())
483 return rewriter.notifyMatchFailure(
484 op,
"weight zero point must be zero for non-int8 integer types");
486 bool hasNullZps = (inputZpVal == 0) && (weightZpVal == 0);
487 auto weightShape = weightTy.getShape();
488 auto resultShape = resultTy.getShape();
491 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
500 if (inputZpVal < intMin || inputZpVal > intMax)
501 return rewriter.notifyMatchFailure(
502 op,
"tosa.depthwise_conv op quantization has zp outside of input "
505 zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
510 llvm::append_range(pad, padAttr.
asArrayRef());
511 pad.resize(pad.size() + 2, 0);
513 input =
applyPad(loc, input, pad, zeroAttr, rewriter);
520 auto strideAttr = rewriter.getI64TensorAttr(stride);
521 auto dilationAttr = rewriter.getI64TensorAttr(dilation);
522 ShapedType linalgConvTy =
524 weightShape[2], weightShape[3]},
527 auto resultZeroAttr = rewriter.getZeroAttr(accETy);
528 Value emptyTensor = tensor::EmptyOp::create(
529 rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims);
530 Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr);
531 Value zeroTensor = linalg::FillOp::create(rewriter, loc,
ValueRange{zero},
535 Value biasEmptyTensor = tensor::EmptyOp::create(
536 rewriter, loc, resultTy.getShape(), resultETy, filteredDims);
541 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
542 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
545 Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create(
546 rewriter, loc, linalgConvTy,
ValueRange{input, weight},
547 ValueRange{zeroTensor}, strideAttr, dilationAttr)
552 if (accETy != resultETy)
553 conv = tosa::CastOp::create(
561 Value convReshape = tensor::CollapseShapeOp::create(
562 rewriter, loc, resultTy, conv, reassociationMap);
565 linalg::GenericOp::create(
566 rewriter, loc, resultTy,
ValueRange({bias, convReshape}),
571 if (llvm::isa<FloatType>(inputETy))
572 added = arith::AddFOp::create(nestedBuilder, loc, args[0],
575 added = arith::AddIOp::create(nestedBuilder, loc, args[0],
577 linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
580 rewriter.replaceOp(op, result);
582 IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
583 IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
584 auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
585 auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp);
586 Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create(
587 rewriter, loc, linalgConvTy,
589 ValueRange{zeroTensor}, strideAttr, dilationAttr)
593 Value convReshape = tensor::CollapseShapeOp::create(
594 rewriter, loc, resultTy, conv, reassociationMap);
596 rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps);
597 rewriter.replaceOp(op, result);
607 matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
611 auto outputTy = cast<ShapedType>(op.getType());
612 auto outputElementTy = outputTy.getElementType();
615 dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
617 if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
618 dynDims[0] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 0);
621 if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
622 dynDims[1] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 1);
625 if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
626 dynDims[2] = tensor::DimOp::create(rewriter, loc, op->getOperand(1), 2);
631 auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
632 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
634 tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(),
635 outputTy.getElementType(), filteredDims);
636 Value zeroTensor = linalg::FillOp::create(rewriter, loc,
ValueRange{zero},
640 FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
641 FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
643 return rewriter.notifyMatchFailure(
644 op,
"input a zero point cannot be statically determined");
646 return rewriter.notifyMatchFailure(
647 op,
"input b zero point cannot be statically determined");
649 const int64_t aZpVal = *maybeAZp;
650 const int64_t bZpVal = *maybeBZp;
652 if (op.verifyAZeroPoint(aZpVal).failed())
653 return rewriter.notifyMatchFailure(
654 op,
"input a zero point must be zero for non-int8 integer types");
656 if (op.verifyBZeroPoint(bZpVal).failed())
657 return rewriter.notifyMatchFailure(
658 op,
"input b zero point must be zero for non-int8 integer types");
660 if (aZpVal == 0 && bZpVal == 0) {
661 rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
667 auto aZp = arith::ConstantOp::create(rewriter, loc,
668 rewriter.getI32IntegerAttr(aZpVal));
669 auto bZp = arith::ConstantOp::create(rewriter, loc,
670 rewriter.getI32IntegerAttr(bZpVal));
671 rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
673 ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
685 computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor,
690 Value input = adaptor.getInput();
698 if (resultTy.isDynamicDim(0))
699 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
702 for (int64_t dim : {1, 2}) {
703 if (!resultTy.isDynamicDim(dim))
707 int64_t index = dim - 1;
710 Value ihw = tensor::DimOp::create(rewriter, loc, input, dim);
717 pad[index * 2 + 1], khw, stride[index],
719 dynamicDims.push_back(ohw);
723 if (resultTy.isDynamicDim(3))
724 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 3));
730 matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
733 Value input = adaptor.getInput();
734 ShapedType inputTy = cast<ShapedType>(input.
getType());
736 bool isUnsigned = op.getType().getElementType().isUnsignedInteger();
737 ShapedType resultTy =
741 Type resultETy = inputTy.getElementType();
744 computeDynamicOutputSizes(op, adaptor, rewriter);
747 TypedAttr initialAttr;
750 resultETy, APFloat::getLargest(
751 cast<FloatType>(resultETy).getFloatSemantics(),
true));
756 else if (isa<IntegerType>(resultETy))
763 op,
"Unsupported initial value for tosa.maxpool_2d op");
768 llvm::append_range(pad, op.getPad());
769 pad.resize(pad.size() + 2, 0);
771 Value paddedInput =
applyPad(loc, input, pad, initialAttr, rewriter);
773 Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
783 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
784 resultTy.getElementType(), dynamicDims);
786 Value filledEmptyTensor =
787 linalg::FillOp::create(rewriter, loc, initialValue, emptyTensor)
790 Value fakeWindowDims =
791 tensor::EmptyOp::create(rewriter, loc, kernel, resultETy);
796 filledEmptyTensor, strideAttr, dilationAttr);
797 return llvm::success();
800 auto resultOp = linalg::PoolingNhwcMaxOp::create(
802 ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
806 NanPropagationMode nanMode = op.getNanMode();
820 if (nanMode == NanPropagationMode::IGNORE) {
821 auto genericOp = linalg::GenericOp::create(
822 rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
823 resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
824 resultOp.getIteratorTypesArray(),
827 auto oldBlock = resultOp.getRegion().begin();
828 auto oldArgs = oldBlock->getArguments();
829 auto &oldMaxOp = *resultOp.getBlock()->begin();
830 map.map(oldArgs, blockArgs);
831 auto *newOp = opBuilder.clone(oldMaxOp, map);
833 arith::CmpFOp::create(opBuilder, loc, arith::CmpFPredicate::UNO,
834 blockArgs.front(), blockArgs.front());
835 auto selectOp = arith::SelectOp::create(
836 opBuilder, loc, isNaN, blockArgs.back(), newOp->getResult(0));
837 linalg::YieldOp::create(opBuilder, loc, selectOp.getResult());
853 Value input = op.getInput();
854 ShapedType inputTy = cast<ShapedType>(input.
getType());
855 Type inElementTy = inputTy.getElementType();
857 ShapedType resultTy = cast<ShapedType>(op.getType());
858 Type resultETy = cast<ShapedType>(op.getType()).getElementType();
860 Type accETy = op.getAccType();
861 ShapedType accTy = resultTy.clone(accETy);
865 if (!dynamicDimsOr.has_value())
869 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
870 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
873 op,
"input zero point could not be statically determined");
876 op,
"output zero point could not be statically determined");
878 const int64_t inputZpVal = *maybeIZp;
879 const int64_t outputZpVal = *maybeOZp;
884 llvm::append_range(pad, op.getPad());
885 pad.resize(pad.size() + 2, 0);
886 TypedAttr padAttr = rewriter.
getZeroAttr(inElementTy);
890 Value paddedInput =
applyPad(loc, input, pad, padAttr, rewriter);
893 Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
902 Value poolEmptyTensor = tensor::EmptyOp::create(
903 rewriter, loc, accTy.getShape(), accETy, dynamicDims);
905 Value filledEmptyTensor =
906 linalg::FillOp::create(rewriter, loc,
ValueRange{initialValue},
910 Value fakeWindowDims =
911 tensor::EmptyOp::create(rewriter, loc, kernel, accETy);
914 Value poolingOp = linalg::PoolingNhwcSumOp::create(
917 filledEmptyTensor, strideAttr, dilationAttr)
922 Value iH = tensor::DimOp::create(rewriter, loc, poolingOp, 1);
923 Value iW = tensor::DimOp::create(rewriter, loc, poolingOp, 2);
926 iH = arith::SubIOp::create(rewriter, loc, iH, one);
927 iW = arith::SubIOp::create(rewriter, loc, iW, one);
929 Value genericEmptyTensor = tensor::EmptyOp::create(
930 rewriter, loc, resultTy.getShape(), resultETy, dynamicDims);
933 auto genericOp = linalg::GenericOp::create(
948 Value dpos = arith::SubIOp::create(rewriter, loc, pos, padVal);
950 Value offset = arith::MinSIOp::create(rewriter, loc, dpos, zero);
951 return arith::AddIOp::create(rewriter, loc, valid, offset)
955 auto coverageFn = [&](int64_t i,
Value isize) ->
Value {
962 Value left = linalg::IndexOp::create(rewriter, loc, i);
963 Value right = arith::SubIOp::create(rewriter, loc, isize, left);
964 left = arith::MulIOp::create(rewriter, loc, left, strideVal);
965 right = arith::MulIOp::create(rewriter, loc, right, strideVal);
968 val = padFn(val, left, pad[i * 2]);
969 val = padFn(val, right, pad[i * 2 + 1]);
970 return arith::MaxSIOp::create(rewriter, loc, one, val);
974 Value kH3 = coverageFn(1, iH);
975 Value kW3 = coverageFn(2, iW);
978 auto count = arith::IndexCastOp::create(
980 arith::MulIOp::create(rewriter, loc, kH3, kW3));
985 Value poolVal = args[0];
986 if (isa<FloatType>(accETy)) {
987 auto countF = arith::SIToFPOp::create(rewriter, loc, accETy, count);
988 poolVal = arith::DivFOp::create(rewriter, loc, poolVal, countF)
993 arith::TruncFOp::create(rewriter, loc, resultETy, poolVal);
998 if (inputZpVal != 0) {
999 auto inputZp = arith::ConstantOp::create(
1002 arith::MulIOp::create(rewriter, loc, accETy, count, inputZp);
1004 arith::SubIOp::create(rewriter, loc, accETy, poolVal, offset);
1008 Value one32 = arith::ConstantOp::create(
1010 Value thirtyTwo32 = arith::ConstantOp::create(
1014 arith::SubIOp::create(rewriter, loc, count, one32);
1015 Value leadingZeros =
1016 math::CountLeadingZerosOp::create(rewriter, loc, countSubOne);
1018 arith::SubIOp::create(rewriter, loc, thirtyTwo32, leadingZeros);
1022 arith::ExtUIOp::create(rewriter, loc, rewriter.
getI64Type(), k);
1023 Value thirtyShiftPlusOne = arith::ConstantOp::create(
1026 arith::ShLIOp::create(rewriter, loc, thirtyShiftPlusOne, k64);
1029 Value count64 = arith::ExtUIOp::create(
1030 rewriter, loc, rewriter.
getI64Type(), count);
1032 arith::DivUIOp::create(rewriter, loc, numerator, count64);
1033 multiplier = arith::TruncIOp::create(
1034 rewriter, loc, rewriter.
getI32Type(), multiplier);
1038 arith::TruncIOp::create(rewriter, loc, rewriter.
getI8Type(), k);
1039 Value thirty8 = arith::ConstantOp::create(
1041 Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
1044 rewriter.
getContext(), RoundingMode::SINGLE_ROUND);
1046 auto scaled = tosa::ApplyScaleOp::create(
1047 rewriter, loc, rewriter.
getI32Type(), poolVal,
1048 multiplier, shift, roundingAttr)
1053 if (outputZpVal != 0) {
1054 auto outputZp = arith::ConstantOp::create(
1057 scaled = arith::AddIOp::create(rewriter, loc, scaled, outputZp)
1065 rewriter, loc, accETy,
1066 APInt::getSignedMinValue(outBitwidth).getSExtValue());
1068 rewriter, loc, accETy,
1069 APInt::getSignedMaxValue(outBitwidth).getSExtValue());
1077 arith::TruncIOp::create(rewriter, loc, resultETy, poolVal);
1081 linalg::YieldOp::create(rewriter, loc, poolVal);
1084 rewriter.
replaceOp(op, genericOp.getResult(0));
1103 auto permutedSizes =
1104 applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms);
1107 tensor::EmptyOp::create(rewriter, loc, permutedSizes,
1108 op.getInput1().getType().getElementType());
1110 op, op.getInput1(), permutedInit,
1111 llvm::to_vector(llvm::map_range(
1112 constantPerms, [](int32_t v) -> int64_t {
return v; })));
1120 const TosaToLinalgNamedOptions &
options) {
1121 if (
options.preferConv2DKernelLayoutHWCF) {
1122 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
1123 linalg::Conv2DNhwcHwcfQOp>>(
1126 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
1127 linalg::Conv2DNhwcFhwcQOp>>(
1132 ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
1133 DepthwiseConvConverter,
1141 >(converter,
patterns->getContext());
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static llvm::ManagedStatic< PassManagerOptions > options
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source, Value result)
static mlir::Value applyPad(Location loc, Value input, ArrayRef< int64_t > pad, TypedAttr padAttr, OpBuilder &rewriter)
static void createDepthwiseConvCollapseMap(int64_t outputRank, SmallVector< ReassociationExprs, 4 > &reassociationMap, OpBuilder &rewriter)
static mlir::Value linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef< AffineMap > indexingMaps)
static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim, int64_t padBeforeAttr, int64_t padAfterAttr, Value kernelDim, int64_t strideAttr, int64_t dilationAttr, OpBuilder &rewriter)
static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter, Location loc, Value source, Value result)
static mlir::Value reifyConstantDim(int64_t attr, ImplicitLocOpBuilder &builder)
static SmallVector< Value > inferDynamicDimsForConv(Location loc, Value input, Value weight, ShapedType resultTy, ArrayRef< int64_t > padAttr, ArrayRef< int64_t > strideAttr, ArrayRef< int64_t > dilationAttr, ArrayRef< int64_t > inputSizeDims, ArrayRef< int64_t > kernelSizeDims, OpBuilder &rewriter)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerAttr getI64IntegerAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
IntegerAttr getI8IntegerAttr(int8_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
void populateTosaToLinalgNamedConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options)
Populates conversion passes from TOSA dialect to Linalg named operations.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.