32 #include <type_traits>
40 if (llvm::all_of(pad, [](int64_t p) {
return p == 0; }))
43 ShapedType inputTy = cast<ShapedType>(input.
getType());
44 Type inputETy = inputTy.getElementType();
45 auto inputShape = inputTy.getShape();
47 assert((inputShape.size() * 2) == pad.size());
52 for (
size_t i : llvm::seq(inputShape.size())) {
53 auto lowPad = pad[i * 2];
54 auto highPad = pad[i * 2 + 1];
55 if (ShapedType::isDynamic(inputShape[i]))
56 paddedShape.push_back(inputShape[i]);
58 paddedShape.push_back(inputShape[i] + highPad + lowPad);
63 Value padValue = rewriter.
create<arith::ConstantOp>(loc, padAttr);
65 return rewriter.
create<tensor::PadOp>(
67 highIndices, padValue);
74 ShapedType resultTy = cast<ShapedType>(conv.
getType());
76 .
create<linalg::GenericOp>(
77 loc, resultTy,
ValueRange({bias, conv}), result, indexingMaps,
80 Value biasVal = args[0];
81 Type resType = args[1].getType();
82 if (resType != biasVal.
getType()) {
83 biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
85 Value added = builder.create<arith::AddIOp>(loc, biasVal, args[1]);
86 builder.create<linalg::YieldOp>(loc, added);
95 ShapedType resultTy = cast<ShapedType>(result.
getType());
96 ShapedType sourceTy = cast<ShapedType>(source.
getType());
97 const int64_t resultRank = resultTy.getRank();
98 const int64_t sourceRank = sourceTy.getRank();
106 assert(sourceTy.hasStaticShape() &&
107 "Dynamic broadcasting shapes not supported!");
108 if (sourceRank == 1 && sourceTy.getDimSize(0) == 1) {
111 for (
auto dim : llvm::seq<int64_t>(0, sourceRank)) {
113 sourceDims.push_back(expr);
127 ShapedType resultTy = cast<ShapedType>(result.
getType());
128 const int64_t resultRank = resultTy.getRank();
136 .
create<linalg::GenericOp>(
137 loc, resultTy,
ValueRange({source}), result, indexingMaps,
140 Value biasVal = args[0];
141 Type resType = args[1].getType();
142 if (resType != biasVal.
getType()) {
144 resultTy.getElementType().isFloat()
145 ? builder.create<arith::ExtFOp>(loc, resType, biasVal)
147 : builder.create<arith::ExtSIOp>(loc, resType, biasVal)
150 builder.create<linalg::YieldOp>(loc, biasVal);
157 return builder.
create<arith::ConstantIndexOp>(attr);
165 int64_t padBeforeAttr,
166 int64_t padAfterAttr,
Value kernelDim,
168 int64_t dilationAttr,
171 auto one = rewriter.
create<arith::ConstantOp>(
174 Value paddedBefore = builder.
create<arith::AddIOp>(inputDim, padBefore);
176 Value paddedAfter = builder.
create<arith::AddIOp>(paddedBefore, padAfter);
178 Value subOne = builder.
create<arith::SubIOp>(kernelDim, one);
180 Value dilated = builder.
create<arith::MulIOp>(dilation, subOne);
181 Value addOne = builder.
create<arith::AddIOp>(dilated, one);
183 Value subtract = builder.
create<arith::SubIOp>(paddedAfter, addOne);
185 Value divide = builder.
create<arith::DivUIOp>(subtract, stride);
186 return builder.
create<arith::AddIOp>(divide, one);
195 ShapedType inputTy = cast<ShapedType>(input.
getType());
196 int64_t inputRank = inputTy.getRank();
199 dynDims.resize(resultTy.getRank());
201 for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
202 int64_t inputDim = inputSizeDims[i];
203 int64_t kernelDim = kernelSizeDims[i];
204 if (resultTy.isDynamicDim(inputDim)) {
205 auto padTop = padAttr[i * 2];
206 auto padBottom = padAttr[i * 2 + 1];
207 auto stride = strideAttr[i];
208 auto dilation = dilationAttr[i];
209 Value initDynDim = rewriter.
create<tensor::DimOp>(loc, input, inputDim);
211 rewriter.
create<tensor::DimOp>(loc, weight, kernelDim);
215 kernelDynDim, stride, dilation, rewriter);
220 for (
int i = 0; i < inputRank; i++) {
221 if (resultTy.isDynamicDim(i) && !dynDims[i])
222 dynDims[i] = rewriter.
create<tensor::DimOp>(loc, input, i);
234 reassociationMap.resize(outputRank);
235 for (
int i = 0; i < outputRank; i++) {
238 reassociationMap[outputRank - 1].push_back(
244 template <
typename TosaConvOp,
typename LinalgConvOp,
typename LinalgConvQOp>
249 matchAndRewrite(TosaConvOp op,
typename TosaConvOp::Adaptor adaptor,
252 Value input = op->getOperand(0);
253 Value weight = op->getOperand(1);
254 Value bias = op->getOperand(2);
256 ShapedType inputTy = cast<ShapedType>(input.
getType());
257 ShapedType weightTy = cast<ShapedType>(weight.
getType());
258 ShapedType biasTy = cast<ShapedType>(bias.
getType());
259 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
261 Type inputETy = inputTy.getElementType();
267 Type accETy = op.getAccType();
271 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
272 if (failed(maybeIZp))
273 return rewriter.notifyMatchFailure(
274 op,
"input zero point cannot be statically determined");
276 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
277 if (failed(maybeWZp))
278 return rewriter.notifyMatchFailure(
279 op,
"weight zero point cannot be statically determined");
281 const int64_t inputZpVal = *maybeIZp;
282 const int64_t weightZpVal = *maybeWZp;
284 if (op.verifyInputZeroPoint(inputZpVal).failed())
285 return rewriter.notifyMatchFailure(
286 op,
"input zero point must be zero for non-int8 integer types");
288 if (op.verifyWeightZeroPoint(weightZpVal).failed())
289 return rewriter.notifyMatchFailure(
290 op,
"weight zero point must be zero for non-int8 integer types");
292 bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
294 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
295 return rewriter.notifyMatchFailure(
296 op,
"tosa.conv ops require static shapes for weight and bias");
299 return rewriter.notifyMatchFailure(
300 op,
"tosa.conv ops does not support unsigned integer input");
304 for (
int i = 1; i < resultTy.getRank() - 1; i++) {
305 inputSizeDims.push_back(i);
306 kernelSizeDims.push_back(i);
310 loc, input, weight, resultTy, padAttr.
asArrayRef(),
312 inputSizeDims, kernelSizeDims, rewriter);
314 auto weightShape = weightTy.getShape();
317 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
326 if (inputZpVal < intMin || inputZpVal > intMax)
327 return rewriter.notifyMatchFailure(
328 op,
"tosa.conv op quantization has zp outside of input range");
330 zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
335 llvm::append_range(pad, padAttr.
asArrayRef());
336 pad.resize(pad.size() + 2, 0);
337 input =
applyPad(loc, input, pad, zeroAttr, rewriter);
339 if (4 == inputTy.getRank()) {
343 hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
344 : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
351 for (
int i = 1; i < resultTy.getRank(); i++)
352 weightPerm.push_back(i);
353 weightPerm.push_back(0);
356 for (
auto dim : weightPerm)
357 newWeightShape.push_back(weightShape[dim]);
358 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
361 weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
369 if (5 == inputTy.getRank()) {
373 for (
int i = 1; i < resultTy.getRank(); i++)
374 weightPerm.push_back(i);
375 weightPerm.push_back(0);
378 for (
auto dim : weightPerm)
379 newWeightShape.push_back(weightShape[dim]);
380 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
383 weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
392 auto strideAttr = rewriter.getI64TensorAttr(stride);
393 auto dilationAttr = rewriter.getI64TensorAttr(dilation);
395 Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
396 loc, resultTy.getShape(), accETy, filteredDims);
398 Value broadcastBias =
402 auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
403 auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
405 auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
406 auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
410 .create<LinalgConvQOp>(
411 loc, resultTy,
ValueRange{input, weight, iZpVal, kZpVal},
412 ValueRange{broadcastBias}, strideAttr, dilationAttr)
415 rewriter.replaceOp(op, conv);
419 Value conv = rewriter
420 .create<LinalgConvOp>(
422 ValueRange{broadcastBias}, strideAttr, dilationAttr)
427 if (resultTy != accTy)
428 conv = rewriter.create<tosa::CastOp>(loc, resultTy, conv);
430 rewriter.replaceOp(op, conv);
435 class DepthwiseConvConverter
440 matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
443 Value input = op->getOperand(0);
444 Value weight = op->getOperand(1);
445 Value bias = op->getOperand(2);
447 ShapedType inputTy = cast<ShapedType>(input.
getType());
448 ShapedType weightTy = cast<ShapedType>(weight.
getType());
449 ShapedType biasTy = cast<ShapedType>(bias.
getType());
450 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
451 int64_t resultRank = resultTy.getRank();
453 Type inputETy = inputTy.getElementType();
454 Type resultETy = resultTy.getElementType();
456 auto padAttr = cast<DenseI64ArrayAttr>(op->getAttr(
"pad"));
457 auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr(
"stride"));
458 auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr(
"dilation"));
460 Type accETy = op.getAccType();
462 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
463 return rewriter.notifyMatchFailure(
464 op,
"tosa.depthwise_conv ops require static shapes");
468 loc, input, weight, resultTy, padAttr.
asArrayRef(),
475 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
476 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
477 if (failed(maybeIZp))
478 return rewriter.notifyMatchFailure(
479 op,
"input zero point cannot be statically determined");
480 if (failed(maybeWZp))
481 return rewriter.notifyMatchFailure(
482 op,
"weight zero point cannot be statically determined");
484 const int64_t inputZpVal = *maybeIZp;
485 const int64_t weightZpVal = *maybeWZp;
487 if (op.verifyInputZeroPoint(inputZpVal).failed())
488 return rewriter.notifyMatchFailure(
489 op,
"input zero point must be zero for non-int8 integer types");
491 if (op.verifyWeightZeroPoint(weightZpVal).failed())
492 return rewriter.notifyMatchFailure(
493 op,
"weight zero point must be zero for non-int8 integer types");
495 bool hasNullZps = (inputZpVal == 0) && (weightZpVal == 0);
496 auto weightShape = weightTy.getShape();
497 auto resultShape = resultTy.getShape();
500 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
509 if (inputZpVal < intMin || inputZpVal > intMax)
510 return rewriter.notifyMatchFailure(
511 op,
"tosa.depthwise_conv op quantization has zp outside of input "
514 zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
519 llvm::append_range(pad, padAttr.
asArrayRef());
520 pad.resize(pad.size() + 2, 0);
522 input =
applyPad(loc, input, pad, zeroAttr, rewriter);
529 auto strideAttr = rewriter.getI64TensorAttr(stride);
530 auto dilationAttr = rewriter.getI64TensorAttr(dilation);
531 ShapedType linalgConvTy =
533 weightShape[2], weightShape[3]},
536 auto resultZeroAttr = rewriter.getZeroAttr(accETy);
537 Value emptyTensor = rewriter.create<tensor::EmptyOp>(
538 loc, linalgConvTy.getShape(), accETy, filteredDims);
539 Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
540 Value zeroTensor = rewriter
541 .create<linalg::FillOp>(loc,
ValueRange{zero},
545 Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
546 loc, resultTy.getShape(), resultETy, filteredDims);
551 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
552 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
555 Value conv = rewriter
556 .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
558 ValueRange{zeroTensor}, strideAttr, dilationAttr)
563 if (accETy != resultETy)
564 conv = rewriter.create<tosa::CastOp>(
572 Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
573 loc, resultTy, conv, reassociationMap);
577 .create<linalg::GenericOp>(
578 loc, resultTy,
ValueRange({bias, convReshape}),
579 biasEmptyTensor, indexingMaps,
584 if (llvm::isa<FloatType>(inputETy))
585 added = nestedBuilder.create<arith::AddFOp>(loc, args[0],
588 added = nestedBuilder.create<arith::AddIOp>(loc, args[0],
590 nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
593 rewriter.replaceOp(op, result);
595 IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
596 IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
597 auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
598 auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp);
601 .create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
602 loc, linalgConvTy,
ValueRange{input, weight, iZpVal, kZpVal},
603 ValueRange{zeroTensor}, strideAttr, dilationAttr)
607 Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
608 loc, resultTy, conv, reassociationMap);
610 rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps);
611 rewriter.replaceOp(op, result);
621 matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
625 auto outputTy = cast<ShapedType>(op.getType());
626 auto outputElementTy = outputTy.getElementType();
629 dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
631 if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
632 dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
635 if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
636 dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
639 if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
640 dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
645 auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
646 Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
647 auto emptyTensor = rewriter.create<tensor::EmptyOp>(
648 loc, outputTy.getShape(), outputTy.getElementType(), filteredDims);
649 Value zeroTensor = rewriter
650 .create<linalg::FillOp>(loc,
ValueRange{zero},
654 FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
655 FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
656 if (failed(maybeAZp))
657 return rewriter.notifyMatchFailure(
658 op,
"input a zero point cannot be statically determined");
659 if (failed(maybeBZp))
660 return rewriter.notifyMatchFailure(
661 op,
"input b zero point cannot be statically determined");
663 const int64_t aZpVal = *maybeAZp;
664 const int64_t bZpVal = *maybeBZp;
666 if (op.verifyAZeroPoint(aZpVal).failed())
667 return rewriter.notifyMatchFailure(
668 op,
"input a zero point must be zero for non-int8 integer types");
670 if (op.verifyBZeroPoint(bZpVal).failed())
671 return rewriter.notifyMatchFailure(
672 op,
"input b zero point must be zero for non-int8 integer types");
674 if (aZpVal == 0 && bZpVal == 0) {
675 rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
681 auto aZp = rewriter.create<arith::ConstantOp>(
682 loc, rewriter.getI32IntegerAttr(aZpVal));
683 auto bZp = rewriter.create<arith::ConstantOp>(
684 loc, rewriter.getI32IntegerAttr(bZpVal));
685 rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
687 ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
699 computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor,
704 Value input = adaptor.getInput();
712 if (resultTy.isDynamicDim(0))
713 dynamicDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, 0));
716 for (int64_t dim : {1, 2}) {
717 if (!resultTy.isDynamicDim(dim))
721 int64_t index = dim - 1;
724 Value ihw = rewriter.
create<tensor::DimOp>(loc, input, dim);
727 Value khw = rewriter.
create<arith::ConstantIndexOp>(loc, kernel[index]);
731 pad[index * 2 + 1], khw, stride[index],
733 dynamicDims.push_back(ohw);
737 if (resultTy.isDynamicDim(3))
738 dynamicDims.push_back(rewriter.
create<tensor::DimOp>(loc, input, 3));
744 matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
747 Value input = adaptor.getInput();
748 ShapedType inputTy = cast<ShapedType>(input.
getType());
750 bool isUnsigned = op.getType().getElementType().isUnsignedInteger();
751 ShapedType resultTy =
752 cast<ShapedType>(getTypeConverter()->convertType(op.getType()));
755 Type resultETy = inputTy.getElementType();
758 computeDynamicOutputSizes(op, adaptor, rewriter);
761 TypedAttr initialAttr;
764 resultETy, APFloat::getLargest(
765 cast<FloatType>(resultETy).getFloatSemantics(),
true));
770 else if (isa<IntegerType>(resultETy))
777 op,
"Unsupported initial value for tosa.maxpool_2d op");
782 llvm::append_range(pad, op.getPad());
783 pad.resize(pad.size() + 2, 0);
785 Value paddedInput =
applyPad(loc, input, pad, initialAttr, rewriter);
787 Value initialValue = rewriter.
create<arith::ConstantOp>(loc, initialAttr);
796 Value emptyTensor = rewriter.
create<tensor::EmptyOp>(
797 loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims);
799 Value filledEmptyTensor =
800 rewriter.
create<linalg::FillOp>(loc, initialValue, emptyTensor)
803 Value fakeWindowDims =
804 rewriter.
create<tensor::EmptyOp>(loc, kernel, resultETy);
809 filledEmptyTensor, strideAttr, dilationAttr);
810 return llvm::success();
813 auto resultOp = rewriter.
create<linalg::PoolingNhwcMaxOp>(
815 ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
831 if (
const auto nanMode = op.getNanMode(); nanMode ==
"IGNORE") {
832 auto genericOp = rewriter.
create<linalg::GenericOp>(
833 op->getLoc(), resultOp.getType(0), resultOp.getInputs(),
834 resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
835 resultOp.getIteratorTypesArray(),
838 auto oldBlock = resultOp.getRegion().begin();
839 auto oldArgs = oldBlock->getArguments();
840 auto &oldMaxOp = *resultOp.getBlock()->begin();
841 map.
map(oldArgs, blockArgs);
842 auto *newOp = opBuilder.
clone(oldMaxOp, map);
844 op->getLoc(), arith::CmpFPredicate::UNO, blockArgs.front(),
846 auto selectOp = opBuilder.
create<arith::SelectOp>(
847 op->getLoc(), isNaN, blockArgs.back(), newOp->getResult(0));
848 opBuilder.
create<linalg::YieldOp>(loc, selectOp.getResult());
861 LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
864 Value input = op.getInput();
865 ShapedType inputTy = cast<ShapedType>(input.
getType());
866 Type inElementTy = inputTy.getElementType();
868 ShapedType resultTy = cast<ShapedType>(op.getType());
869 Type resultETy = cast<ShapedType>(op.getType()).getElementType();
871 Type accETy = op.getAccType();
872 ShapedType accTy = resultTy.clone(accETy);
876 if (!dynamicDimsOr.has_value())
880 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
881 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
882 if (failed(maybeIZp))
884 op,
"input zero point could not be statically determined");
885 if (failed(maybeOZp))
887 op,
"output zero point could not be statically determined");
889 const int64_t inputZpVal = *maybeIZp;
890 const int64_t outputZpVal = *maybeOZp;
895 llvm::append_range(pad, op.getPad());
896 pad.resize(pad.size() + 2, 0);
897 TypedAttr padAttr = rewriter.
getZeroAttr(inElementTy);
901 Value paddedInput =
applyPad(loc, input, pad, padAttr, rewriter);
904 Value initialValue = rewriter.
create<arith::ConstantOp>(loc, initialAttr);
913 Value poolEmptyTensor = rewriter.
create<tensor::EmptyOp>(
914 loc, accTy.getShape(), accETy, dynamicDims);
916 Value filledEmptyTensor =
922 Value fakeWindowDims =
923 rewriter.
create<tensor::EmptyOp>(loc, kernel, accETy);
926 Value poolingOp = rewriter
927 .
create<linalg::PoolingNhwcSumOp>(
930 filledEmptyTensor, strideAttr, dilationAttr)
935 Value iH = rewriter.
create<tensor::DimOp>(loc, poolingOp, 1);
936 Value iW = rewriter.
create<tensor::DimOp>(loc, poolingOp, 2);
938 auto one = rewriter.
create<arith::ConstantIndexOp>(loc, 1);
939 iH = rewriter.
create<arith::SubIOp>(loc, iH, one);
940 iW = rewriter.
create<arith::SubIOp>(loc, iW, one);
942 Value genericEmptyTensor = rewriter.
create<tensor::EmptyOp>(
943 loc, resultTy.getShape(), resultETy, dynamicDims);
946 auto genericOp = rewriter.
create<linalg::GenericOp>(
952 auto zero = rewriter.
create<arith::ConstantIndexOp>(loc, 0);
960 auto padVal = rewriter.
create<arith::ConstantIndexOp>(loc, pad);
961 Value dpos = rewriter.
create<arith::SubIOp>(loc, pos, padVal);
963 Value offset = rewriter.
create<arith::MinSIOp>(loc, dpos, zero);
964 return rewriter.
create<arith::AddIOp>(loc, valid, offset)
968 auto coverageFn = [&](int64_t i,
Value isize) ->
Value {
970 rewriter.
create<arith::ConstantIndexOp>(loc, stride[i - 1]);
972 rewriter.
create<arith::ConstantIndexOp>(loc, kernel[i - 1]);
975 Value left = rewriter.
create<linalg::IndexOp>(loc, i);
976 Value right = rewriter.
create<arith::SubIOp>(loc, isize, left);
977 left = rewriter.
create<arith::MulIOp>(loc, left, strideVal);
978 right = rewriter.
create<arith::MulIOp>(loc, right, strideVal);
981 val = padFn(val, left, pad[i * 2]);
982 val = padFn(val, right, pad[i * 2 + 1]);
983 return rewriter.
create<arith::MaxSIOp>(loc, one, val);
987 Value kH3 = coverageFn(1, iH);
988 Value kW3 = coverageFn(2, iW);
991 auto count = rewriter.
create<arith::IndexCastOp>(
993 rewriter.
create<arith::MulIOp>(loc, kH3, kW3));
998 Value poolVal = args[0];
999 if (isa<FloatType>(accETy)) {
1000 auto countF = rewriter.
create<arith::SIToFPOp>(loc, accETy, count);
1001 poolVal = rewriter.
create<arith::DivFOp>(loc, poolVal, countF)
1003 if (accETy.getIntOrFloatBitWidth() >
1006 rewriter.
create<arith::TruncFOp>(loc, resultETy, poolVal);
1011 if (inputZpVal != 0) {
1012 auto inputZp = rewriter.
create<arith::ConstantOp>(
1015 rewriter.
create<arith::MulIOp>(loc, accETy, count, inputZp);
1017 rewriter.
create<arith::SubIOp>(loc, accETy, poolVal, offset);
1023 Value thirtyTwo32 = rewriter.
create<arith::ConstantOp>(
1027 rewriter.
create<arith::SubIOp>(loc, count, one32);
1028 Value leadingZeros =
1029 rewriter.
create<math::CountLeadingZerosOp>(loc, countSubOne);
1031 rewriter.
create<arith::SubIOp>(loc, thirtyTwo32, leadingZeros);
1036 Value thirtyShiftPlusOne = rewriter.
create<arith::ConstantOp>(
1039 rewriter.
create<arith::ShLIOp>(loc, thirtyShiftPlusOne, k64);
1045 rewriter.
create<arith::DivUIOp>(loc, numerator, count64);
1046 multiplier = rewriter.
create<arith::TruncIOp>(
1052 Value thirty8 = rewriter.
create<arith::ConstantOp>(
1054 Value shift = rewriter.
create<arith::AddIOp>(loc, k8, thirty8);
1058 .
create<tosa::ApplyScaleOp>(
1059 loc, rewriter.
getI32Type(), poolVal, multiplier, shift,
1065 if (outputZpVal != 0) {
1066 auto outputZp = rewriter.
create<arith::ConstantOp>(
1068 scaled = rewriter.
create<arith::AddIOp>(loc, scaled, outputZp)
1075 auto min = rewriter.
create<arith::ConstantIntOp>(
1076 loc, APInt::getSignedMinValue(outBitwidth).getSExtValue(),
1078 auto max = rewriter.
create<arith::ConstantIntOp>(
1079 loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
1088 rewriter.
create<arith::TruncIOp>(loc, resultETy, poolVal);
1092 rewriter.
create<linalg::YieldOp>(loc, poolVal);
1095 rewriter.
replaceOp(op, genericOp.getResult(0));
1104 LogicalResult matchAndRewrite(tosa::TransposeOp op,
1114 auto permutedSizes =
1115 applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms);
1117 auto permutedInit = rewriter.
create<tensor::EmptyOp>(
1118 loc, permutedSizes, op.getInput1().getType().getElementType());
1120 op, op.getInput1(), permutedInit,
1121 llvm::to_vector(llvm::map_range(
1122 constantPerms, [](int32_t v) -> int64_t {
return v; })));
1130 const TosaToLinalgNamedOptions &
options) {
1131 if (
options.preferConv2DKernelLayoutHWCF) {
1132 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
1133 linalg::Conv2DNhwcHwcfQOp>>(
1136 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
1137 linalg::Conv2DNhwcFhwcQOp>>(
1142 ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
1143 DepthwiseConvConverter,
1151 >(converter,
patterns->getContext());
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static llvm::ManagedStatic< PassManagerOptions > options
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source, Value result)
static mlir::Value applyPad(Location loc, Value input, ArrayRef< int64_t > pad, TypedAttr padAttr, OpBuilder &rewriter)
static void createDepthwiseConvCollapseMap(int64_t outputRank, SmallVector< ReassociationExprs, 4 > &reassociationMap, OpBuilder &rewriter)
static mlir::Value linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef< AffineMap > indexingMaps)
static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim, int64_t padBeforeAttr, int64_t padAfterAttr, Value kernelDim, int64_t strideAttr, int64_t dilationAttr, OpBuilder &rewriter)
static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter, Location loc, Value source, Value result)
static mlir::Value reifyConstantDim(int64_t attr, ImplicitLocOpBuilder &builder)
static SmallVector< Value > inferDynamicDimsForConv(Location loc, Value input, Value weight, ShapedType resultTy, ArrayRef< int64_t > padAttr, ArrayRef< int64_t > strideAttr, ArrayRef< int64_t > dilationAttr, ArrayRef< int64_t > inputSizeDims, ArrayRef< int64_t > kernelSizeDims, OpBuilder &rewriter)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getI32IntegerAttr(int32_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
FloatAttr getFloatAttr(Type type, double value)
AffineExpr getAffineConstantExpr(int64_t constant)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
TypedAttr getZeroAttr(Type type)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
DenseIntElementsAttr getI64VectorAttr(ArrayRef< int64_t > values)
IntegerAttr getI8IntegerAttr(int8_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
void populateTosaToLinalgNamedConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options)
Populates conversion passes from TOSA dialect to Linalg named operations.
Include the generated interface declarations.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...