24 #include <type_traits>
32 if (llvm::all_of(pad, [](int64_t p) {
return p == 0; }))
35 ShapedType inputTy = cast<ShapedType>(input.
getType());
36 Type inputETy = inputTy.getElementType();
37 auto inputShape = inputTy.getShape();
39 assert((inputShape.size() * 2) == pad.size());
44 for (
size_t i : llvm::seq(inputShape.size())) {
45 auto lowPad = pad[i * 2];
46 auto highPad = pad[i * 2 + 1];
47 if (ShapedType::isDynamic(inputShape[i]))
48 paddedShape.push_back(inputShape[i]);
50 paddedShape.push_back(inputShape[i] + highPad + lowPad);
55 Value padValue = arith::ConstantOp::create(rewriter, loc, padAttr);
57 return tensor::PadOp::create(rewriter, loc,
59 input, lowIndices, highIndices, padValue);
66 ShapedType resultTy = cast<ShapedType>(conv.
getType());
67 return linalg::GenericOp::create(
68 rewriter, loc, resultTy,
ValueRange({bias, conv}), result,
71 Value biasVal = args[0];
72 Type resType = args[1].getType();
73 if (resType != biasVal.getType()) {
75 arith::ExtSIOp::create(builder, loc, resType, biasVal);
78 arith::AddIOp::create(builder, loc, biasVal, args[1]);
79 linalg::YieldOp::create(builder, loc, added);
88 ShapedType resultTy = cast<ShapedType>(result.
getType());
89 ShapedType sourceTy = cast<ShapedType>(source.
getType());
90 const int64_t resultRank = resultTy.getRank();
91 const int64_t sourceRank = sourceTy.getRank();
99 assert(sourceTy.hasStaticShape() &&
100 "Dynamic broadcasting shapes not supported!");
101 if (sourceRank == 1 && sourceTy.getDimSize(0) == 1) {
104 for (
auto dim : llvm::seq<int64_t>(0, sourceRank)) {
106 sourceDims.push_back(expr);
120 ShapedType resultTy = cast<ShapedType>(result.
getType());
121 const int64_t resultRank = resultTy.getRank();
128 return linalg::GenericOp::create(
129 rewriter, loc, resultTy,
ValueRange({source}), result,
132 Value biasVal = args[0];
133 Type resType = args[1].getType();
134 if (resType != biasVal.getType()) {
136 resultTy.getElementType().isFloat()
137 ? arith::ExtFOp::create(builder, loc, resType, biasVal)
139 : arith::ExtSIOp::create(builder, loc, resType,
143 linalg::YieldOp::create(builder, loc, biasVal);
150 return arith::ConstantIndexOp::create(builder, attr);
158 int64_t padBeforeAttr,
159 int64_t padAfterAttr,
Value kernelDim,
161 int64_t dilationAttr,
164 auto one = arith::ConstantOp::create(rewriter, loc,
167 Value paddedBefore = arith::AddIOp::create(builder, inputDim, padBefore);
169 Value paddedAfter = arith::AddIOp::create(builder, paddedBefore, padAfter);
171 Value subOne = arith::SubIOp::create(builder, kernelDim, one);
173 Value dilated = arith::MulIOp::create(builder, dilation, subOne);
174 Value addOne = arith::AddIOp::create(builder, dilated, one);
176 Value subtract = arith::SubIOp::create(builder, paddedAfter, addOne);
178 Value divide = arith::DivUIOp::create(builder, subtract, stride);
179 return arith::AddIOp::create(builder, divide, one);
188 ShapedType inputTy = cast<ShapedType>(input.
getType());
189 int64_t inputRank = inputTy.getRank();
192 dynDims.resize(resultTy.getRank());
194 for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
195 int64_t inputDim = inputSizeDims[i];
196 int64_t kernelDim = kernelSizeDims[i];
197 if (resultTy.isDynamicDim(inputDim)) {
198 auto padTop = padAttr[i * 2];
199 auto padBottom = padAttr[i * 2 + 1];
200 auto stride = strideAttr[i];
201 auto dilation = dilationAttr[i];
202 Value initDynDim = tensor::DimOp::create(rewriter, loc, input, inputDim);
204 tensor::DimOp::create(rewriter, loc, weight, kernelDim);
208 kernelDynDim, stride, dilation, rewriter);
213 for (
int i = 0; i < inputRank; i++) {
214 if (resultTy.isDynamicDim(i) && !dynDims[i])
215 dynDims[i] = tensor::DimOp::create(rewriter, loc, input, i);
227 reassociationMap.resize(outputRank);
228 for (
int i = 0; i < outputRank; i++) {
231 reassociationMap[outputRank - 1].push_back(
237 template <
typename TosaConvOp,
typename LinalgConvOp,
typename LinalgConvQOp>
242 matchAndRewrite(TosaConvOp op,
typename TosaConvOp::Adaptor adaptor,
245 Value input = op->getOperand(0);
246 Value weight = op->getOperand(1);
247 Value bias = op->getOperand(2);
249 ShapedType inputTy = cast<ShapedType>(input.
getType());
250 ShapedType weightTy = cast<ShapedType>(weight.
getType());
251 ShapedType biasTy = cast<ShapedType>(bias.
getType());
252 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
254 Type inputETy = inputTy.getElementType();
260 Type accETy = op.getAccType();
264 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
266 return rewriter.notifyMatchFailure(
267 op,
"input zero point cannot be statically determined");
269 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
271 return rewriter.notifyMatchFailure(
272 op,
"weight zero point cannot be statically determined");
274 const int64_t inputZpVal = *maybeIZp;
275 const int64_t weightZpVal = *maybeWZp;
277 if (op.verifyInputZeroPoint(inputZpVal).failed())
278 return rewriter.notifyMatchFailure(
279 op,
"input zero point must be zero for non-int8 integer types");
281 if (op.verifyWeightZeroPoint(weightZpVal).failed())
282 return rewriter.notifyMatchFailure(
283 op,
"weight zero point must be zero for non-int8 integer types");
285 bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
287 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
288 return rewriter.notifyMatchFailure(
289 op,
"tosa.conv ops require static shapes for weight and bias");
292 return rewriter.notifyMatchFailure(
293 op,
"tosa.conv ops does not support unsigned integer input");
297 for (
int i = 1; i < resultTy.getRank() - 1; i++) {
298 inputSizeDims.push_back(i);
299 kernelSizeDims.push_back(i);
303 loc, input, weight, resultTy, padAttr.
asArrayRef(),
305 inputSizeDims, kernelSizeDims, rewriter);
307 auto weightShape = weightTy.getShape();
310 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
319 if (inputZpVal < intMin || inputZpVal > intMax)
320 return rewriter.notifyMatchFailure(
321 op,
"tosa.conv op quantization has zp outside of input range");
323 zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
328 llvm::append_range(pad, padAttr.
asArrayRef());
329 pad.resize(pad.size() + 2, 0);
330 input =
applyPad(loc, input, pad, zeroAttr, rewriter);
332 if (4 == inputTy.getRank()) {
336 hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
337 : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
344 for (
int i = 1; i < resultTy.getRank(); i++)
345 weightPerm.push_back(i);
346 weightPerm.push_back(0);
349 for (
auto dim : weightPerm)
350 newWeightShape.push_back(weightShape[dim]);
351 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
354 weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
362 if (5 == inputTy.getRank()) {
366 for (
int i = 1; i < resultTy.getRank(); i++)
367 weightPerm.push_back(i);
368 weightPerm.push_back(0);
371 for (
auto dim : weightPerm)
372 newWeightShape.push_back(weightShape[dim]);
373 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
376 weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
385 auto strideAttr = rewriter.getI64TensorAttr(stride);
386 auto dilationAttr = rewriter.getI64TensorAttr(dilation);
388 Value biasEmptyTensor = tensor::EmptyOp::create(
389 rewriter, loc, resultTy.getShape(), accETy, filteredDims);
391 Value broadcastBias =
395 auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
396 auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
398 auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
399 auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp);
401 Value conv = LinalgConvQOp::create(
402 rewriter, loc, resultTy,
404 ValueRange{broadcastBias}, strideAttr, dilationAttr)
407 rewriter.replaceOp(op, conv);
411 Value conv = LinalgConvOp::create(
412 rewriter, loc, accTy,
ValueRange{input, weight},
413 ValueRange{broadcastBias}, strideAttr, dilationAttr)
418 if (resultTy != accTy)
419 conv = tosa::CastOp::create(rewriter, loc, resultTy, conv);
421 rewriter.replaceOp(op, conv);
426 class DepthwiseConvConverter
431 matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
434 Value input = op->getOperand(0);
435 Value weight = op->getOperand(1);
436 Value bias = op->getOperand(2);
438 ShapedType inputTy = cast<ShapedType>(input.
getType());
439 ShapedType weightTy = cast<ShapedType>(weight.
getType());
440 ShapedType biasTy = cast<ShapedType>(bias.
getType());
441 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
442 int64_t resultRank = resultTy.getRank();
444 Type inputETy = inputTy.getElementType();
445 Type resultETy = resultTy.getElementType();
447 auto padAttr = cast<DenseI64ArrayAttr>(op->getAttr(
"pad"));
448 auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr(
"stride"));
449 auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr(
"dilation"));
451 Type accETy = op.getAccType();
453 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
454 return rewriter.notifyMatchFailure(
455 op,
"tosa.depthwise_conv ops require static shapes");
459 loc, input, weight, resultTy, padAttr.
asArrayRef(),
466 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
467 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
469 return rewriter.notifyMatchFailure(
470 op,
"input zero point cannot be statically determined");
472 return rewriter.notifyMatchFailure(
473 op,
"weight zero point cannot be statically determined");
475 const int64_t inputZpVal = *maybeIZp;
476 const int64_t weightZpVal = *maybeWZp;
478 if (op.verifyInputZeroPoint(inputZpVal).failed())
479 return rewriter.notifyMatchFailure(
480 op,
"input zero point must be zero for non-int8 integer types");
482 if (op.verifyWeightZeroPoint(weightZpVal).failed())
483 return rewriter.notifyMatchFailure(
484 op,
"weight zero point must be zero for non-int8 integer types");
486 bool hasNullZps = (inputZpVal == 0) && (weightZpVal == 0);
487 auto weightShape = weightTy.getShape();
488 auto resultShape = resultTy.getShape();
491 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
500 if (inputZpVal < intMin || inputZpVal > intMax)
501 return rewriter.notifyMatchFailure(
502 op,
"tosa.depthwise_conv op quantization has zp outside of input "
505 zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
510 llvm::append_range(pad, padAttr.
asArrayRef());
511 pad.resize(pad.size() + 2, 0);
513 input =
applyPad(loc, input, pad, zeroAttr, rewriter);
520 auto strideAttr = rewriter.getI64TensorAttr(stride);
521 auto dilationAttr = rewriter.getI64TensorAttr(dilation);
522 ShapedType linalgConvTy =
524 weightShape[2], weightShape[3]},
527 auto resultZeroAttr = rewriter.getZeroAttr(accETy);
528 Value emptyTensor = tensor::EmptyOp::create(
529 rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims);
530 Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr);
531 Value zeroTensor = linalg::FillOp::create(rewriter, loc,
ValueRange{zero},
535 Value biasEmptyTensor = tensor::EmptyOp::create(
536 rewriter, loc, resultTy.getShape(), resultETy, filteredDims);
541 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
542 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
545 Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create(
546 rewriter, loc, linalgConvTy,
ValueRange{input, weight},
547 ValueRange{zeroTensor}, strideAttr, dilationAttr)
552 if (accETy != resultETy)
553 conv = tosa::CastOp::create(
561 Value convReshape = tensor::CollapseShapeOp::create(
562 rewriter, loc, resultTy, conv, reassociationMap);
565 linalg::GenericOp::create(
566 rewriter, loc, resultTy,
ValueRange({bias, convReshape}),
571 if (llvm::isa<FloatType>(inputETy))
572 added = arith::AddFOp::create(nestedBuilder, loc, args[0],
575 added = arith::AddIOp::create(nestedBuilder, loc, args[0],
577 linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
580 rewriter.replaceOp(op, result);
582 IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
583 IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
584 auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
585 auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp);
586 Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create(
587 rewriter, loc, linalgConvTy,
589 ValueRange{zeroTensor}, strideAttr, dilationAttr)
593 Value convReshape = tensor::CollapseShapeOp::create(
594 rewriter, loc, resultTy, conv, reassociationMap);
596 rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps);
597 rewriter.replaceOp(op, result);
607 matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
611 auto outputTy = cast<ShapedType>(op.getType());
612 auto outputElementTy = outputTy.getElementType();
615 dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
617 if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
618 dynDims[0] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 0);
621 if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
622 dynDims[1] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 1);
625 if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
626 dynDims[2] = tensor::DimOp::create(rewriter, loc, op->getOperand(1), 2);
631 auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
632 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
634 tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(),
635 outputTy.getElementType(), filteredDims);
636 Value zeroTensor = linalg::FillOp::create(rewriter, loc,
ValueRange{zero},
640 FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
641 FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
643 return rewriter.notifyMatchFailure(
644 op,
"input a zero point cannot be statically determined");
646 return rewriter.notifyMatchFailure(
647 op,
"input b zero point cannot be statically determined");
649 const int64_t aZpVal = *maybeAZp;
650 const int64_t bZpVal = *maybeBZp;
652 if (op.verifyAZeroPoint(aZpVal).failed())
653 return rewriter.notifyMatchFailure(
654 op,
"input a zero point must be zero for non-int8 integer types");
656 if (op.verifyBZeroPoint(bZpVal).failed())
657 return rewriter.notifyMatchFailure(
658 op,
"input b zero point must be zero for non-int8 integer types");
660 if (aZpVal == 0 && bZpVal == 0) {
661 rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
667 auto aZp = arith::ConstantOp::create(rewriter, loc,
668 rewriter.getI32IntegerAttr(aZpVal));
669 auto bZp = arith::ConstantOp::create(rewriter, loc,
670 rewriter.getI32IntegerAttr(bZpVal));
671 rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
673 ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
685 computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor,
690 Value input = adaptor.getInput();
698 if (resultTy.isDynamicDim(0))
699 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
702 for (int64_t dim : {1, 2}) {
703 if (!resultTy.isDynamicDim(dim))
707 int64_t index = dim - 1;
710 Value ihw = tensor::DimOp::create(rewriter, loc, input, dim);
717 pad[index * 2 + 1], khw, stride[index],
719 dynamicDims.push_back(ohw);
723 if (resultTy.isDynamicDim(3))
724 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 3));
730 matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
733 Value input = adaptor.getInput();
734 ShapedType inputTy = cast<ShapedType>(input.
getType());
736 bool isUnsigned = op.getType().getElementType().isUnsignedInteger();
737 ShapedType resultTy =
741 Type resultETy = inputTy.getElementType();
744 computeDynamicOutputSizes(op, adaptor, rewriter);
747 TypedAttr initialAttr;
750 resultETy, APFloat::getLargest(
751 cast<FloatType>(resultETy).getFloatSemantics(),
true));
756 else if (isa<IntegerType>(resultETy))
763 op,
"Unsupported initial value for tosa.maxpool_2d op");
768 llvm::append_range(pad, op.getPad());
769 pad.resize(pad.size() + 2, 0);
771 Value paddedInput =
applyPad(loc, input, pad, initialAttr, rewriter);
773 Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
783 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
784 resultTy.getElementType(), dynamicDims);
786 Value filledEmptyTensor =
787 linalg::FillOp::create(rewriter, loc, initialValue, emptyTensor)
790 Value fakeWindowDims =
791 tensor::EmptyOp::create(rewriter, loc, kernel, resultETy);
796 filledEmptyTensor, strideAttr, dilationAttr);
797 return llvm::success();
800 auto resultOp = linalg::PoolingNhwcMaxOp::create(
802 ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
805 NanPropagationMode nanMode = op.getNanMode();
819 if (nanMode == NanPropagationMode::IGNORE) {
820 auto genericOp = linalg::GenericOp::create(
821 rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
822 resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
823 resultOp.getIteratorTypesArray(),
826 auto oldBlock = resultOp.getRegion().begin();
827 auto oldArgs = oldBlock->getArguments();
828 auto &oldMaxOp = *resultOp.getBlock()->begin();
829 map.map(oldArgs, blockArgs);
830 auto *newOp = opBuilder.clone(oldMaxOp, map);
832 arith::CmpFOp::create(opBuilder, loc, arith::CmpFPredicate::UNO,
833 blockArgs.front(), blockArgs.front());
834 auto selectOp = arith::SelectOp::create(
835 opBuilder, loc, isNaN, blockArgs.back(), newOp->getResult(0));
836 linalg::YieldOp::create(opBuilder, loc, selectOp.getResult());
852 Value input = op.getInput();
853 ShapedType inputTy = cast<ShapedType>(input.
getType());
854 Type inElementTy = inputTy.getElementType();
856 ShapedType resultTy = cast<ShapedType>(op.getType());
857 Type resultETy = cast<ShapedType>(op.getType()).getElementType();
859 Type accETy = op.getAccType();
860 ShapedType accTy = resultTy.clone(accETy);
864 if (!dynamicDimsOr.has_value())
868 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
869 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
872 op,
"input zero point could not be statically determined");
875 op,
"output zero point could not be statically determined");
877 const int64_t inputZpVal = *maybeIZp;
878 const int64_t outputZpVal = *maybeOZp;
883 llvm::append_range(pad, op.getPad());
884 pad.resize(pad.size() + 2, 0);
885 TypedAttr padAttr = rewriter.
getZeroAttr(inElementTy);
889 Value paddedInput =
applyPad(loc, input, pad, padAttr, rewriter);
892 Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
901 Value poolEmptyTensor = tensor::EmptyOp::create(
902 rewriter, loc, accTy.getShape(), accETy, dynamicDims);
904 Value filledEmptyTensor =
905 linalg::FillOp::create(rewriter, loc,
ValueRange{initialValue},
909 Value fakeWindowDims =
910 tensor::EmptyOp::create(rewriter, loc, kernel, accETy);
913 Value poolingOp = linalg::PoolingNhwcSumOp::create(
916 filledEmptyTensor, strideAttr, dilationAttr)
921 Value iH = tensor::DimOp::create(rewriter, loc, poolingOp, 1);
922 Value iW = tensor::DimOp::create(rewriter, loc, poolingOp, 2);
925 iH = arith::SubIOp::create(rewriter, loc, iH, one);
926 iW = arith::SubIOp::create(rewriter, loc, iW, one);
928 Value genericEmptyTensor = tensor::EmptyOp::create(
929 rewriter, loc, resultTy.getShape(), resultETy, dynamicDims);
932 auto genericOp = linalg::GenericOp::create(
947 Value dpos = arith::SubIOp::create(rewriter, loc, pos, padVal);
949 Value offset = arith::MinSIOp::create(rewriter, loc, dpos, zero);
950 return arith::AddIOp::create(rewriter, loc, valid, offset)
954 auto coverageFn = [&](int64_t i,
Value isize) ->
Value {
961 Value left = linalg::IndexOp::create(rewriter, loc, i);
962 Value right = arith::SubIOp::create(rewriter, loc, isize, left);
963 left = arith::MulIOp::create(rewriter, loc, left, strideVal);
964 right = arith::MulIOp::create(rewriter, loc, right, strideVal);
967 val = padFn(val, left, pad[i * 2]);
968 val = padFn(val, right, pad[i * 2 + 1]);
969 return arith::MaxSIOp::create(rewriter, loc, one, val);
973 Value kH3 = coverageFn(1, iH);
974 Value kW3 = coverageFn(2, iW);
977 auto count = arith::IndexCastOp::create(
979 arith::MulIOp::create(rewriter, loc, kH3, kW3));
984 Value poolVal = args[0];
985 if (isa<FloatType>(accETy)) {
986 auto countF = arith::SIToFPOp::create(rewriter, loc, accETy, count);
987 poolVal = arith::DivFOp::create(rewriter, loc, poolVal, countF)
992 arith::TruncFOp::create(rewriter, loc, resultETy, poolVal);
997 if (inputZpVal != 0) {
998 auto inputZp = arith::ConstantOp::create(
1001 arith::MulIOp::create(rewriter, loc, accETy, count, inputZp);
1003 arith::SubIOp::create(rewriter, loc, accETy, poolVal, offset);
1007 Value one32 = arith::ConstantOp::create(
1009 Value thirtyTwo32 = arith::ConstantOp::create(
1013 arith::SubIOp::create(rewriter, loc, count, one32);
1014 Value leadingZeros =
1015 math::CountLeadingZerosOp::create(rewriter, loc, countSubOne);
1017 arith::SubIOp::create(rewriter, loc, thirtyTwo32, leadingZeros);
1021 arith::ExtUIOp::create(rewriter, loc, rewriter.
getI64Type(), k);
1022 Value thirtyShiftPlusOne = arith::ConstantOp::create(
1025 arith::ShLIOp::create(rewriter, loc, thirtyShiftPlusOne, k64);
1028 Value count64 = arith::ExtUIOp::create(
1029 rewriter, loc, rewriter.
getI64Type(), count);
1031 arith::DivUIOp::create(rewriter, loc, numerator, count64);
1032 multiplier = arith::TruncIOp::create(
1033 rewriter, loc, rewriter.
getI32Type(), multiplier);
1037 arith::TruncIOp::create(rewriter, loc, rewriter.
getI8Type(), k);
1038 Value thirty8 = arith::ConstantOp::create(
1040 Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
1043 rewriter.
getContext(), RoundingMode::SINGLE_ROUND);
1045 auto scaled = tosa::ApplyScaleOp::create(
1046 rewriter, loc, rewriter.
getI32Type(), poolVal,
1047 multiplier, shift, roundingAttr)
1052 if (outputZpVal != 0) {
1053 auto outputZp = arith::ConstantOp::create(
1056 scaled = arith::AddIOp::create(rewriter, loc, scaled, outputZp)
1064 rewriter, loc, accETy,
1065 APInt::getSignedMinValue(outBitwidth).getSExtValue());
1067 rewriter, loc, accETy,
1068 APInt::getSignedMaxValue(outBitwidth).getSExtValue());
1076 arith::TruncIOp::create(rewriter, loc, resultETy, poolVal);
1080 linalg::YieldOp::create(rewriter, loc, poolVal);
1083 rewriter.
replaceOp(op, genericOp.getResult(0));
1102 auto permutedSizes =
1103 applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms);
1106 tensor::EmptyOp::create(rewriter, loc, permutedSizes,
1107 op.getInput1().getType().getElementType());
1109 op, op.getInput1(), permutedInit,
1110 llvm::to_vector(llvm::map_range(
1111 constantPerms, [](int32_t v) -> int64_t {
return v; })));
1119 const TosaToLinalgNamedOptions &
options) {
1120 if (
options.preferConv2DKernelLayoutHWCF) {
1121 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
1122 linalg::Conv2DNhwcHwcfQOp>>(
1125 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
1126 linalg::Conv2DNhwcFhwcQOp>>(
1131 ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
1132 DepthwiseConvConverter,
1140 >(converter,
patterns->getContext());
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static llvm::ManagedStatic< PassManagerOptions > options
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source, Value result)
static mlir::Value applyPad(Location loc, Value input, ArrayRef< int64_t > pad, TypedAttr padAttr, OpBuilder &rewriter)
static void createDepthwiseConvCollapseMap(int64_t outputRank, SmallVector< ReassociationExprs, 4 > &reassociationMap, OpBuilder &rewriter)
static mlir::Value linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef< AffineMap > indexingMaps)
static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim, int64_t padBeforeAttr, int64_t padAfterAttr, Value kernelDim, int64_t strideAttr, int64_t dilationAttr, OpBuilder &rewriter)
static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter, Location loc, Value source, Value result)
static mlir::Value reifyConstantDim(int64_t attr, ImplicitLocOpBuilder &builder)
static SmallVector< Value > inferDynamicDimsForConv(Location loc, Value input, Value weight, ShapedType resultTy, ArrayRef< int64_t > padAttr, ArrayRef< int64_t > strideAttr, ArrayRef< int64_t > dilationAttr, ArrayRef< int64_t > inputSizeDims, ArrayRef< int64_t > kernelSizeDims, OpBuilder &rewriter)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerAttr getI64IntegerAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
IntegerAttr getI8IntegerAttr(int8_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
void populateTosaToLinalgNamedConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options)
Populates conversion passes from TOSA dialect to Linalg named operations.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final
Wrapper around the RewritePattern method that passes the derived op type.