32 if (llvm::all_of(pad, [](
int64_t p) {
return p == 0; }))
35 ShapedType inputTy = cast<ShapedType>(input.
getType());
36 Type inputETy = inputTy.getElementType();
37 auto inputShape = inputTy.getShape();
39 assert((inputShape.size() * 2) == pad.size());
44 for (
size_t i : llvm::seq(inputShape.size())) {
45 auto lowPad = pad[i * 2];
46 auto highPad = pad[i * 2 + 1];
47 if (ShapedType::isDynamic(inputShape[i]))
48 paddedShape.push_back(inputShape[i]);
50 paddedShape.push_back(inputShape[i] + highPad + lowPad);
55 Value padValue = arith::ConstantOp::create(rewriter, loc, padAttr);
57 return tensor::PadOp::create(rewriter, loc,
58 RankedTensorType::get(paddedShape, inputETy),
59 input, lowIndices, highIndices, padValue);
66 ShapedType resultTy = cast<ShapedType>(conv.
getType());
67 return linalg::GenericOp::create(
71 Value biasVal = args[0];
72 Type resType = args[1].getType();
73 if (resType != biasVal.
getType()) {
75 arith::ExtSIOp::create(builder, loc, resType, biasVal);
78 arith::AddIOp::create(builder, loc, biasVal, args[1]);
79 linalg::YieldOp::create(builder, loc, added);
88 ShapedType resultTy = cast<ShapedType>(
result.getType());
89 ShapedType sourceTy = cast<ShapedType>(source.
getType());
90 const int64_t resultRank = resultTy.getRank();
91 const int64_t sourceRank = sourceTy.getRank();
99 assert(sourceTy.hasStaticShape() &&
100 "Dynamic broadcasting shapes not supported!");
101 if (sourceRank == 1 && sourceTy.getDimSize(0) == 1) {
104 for (
auto dim : llvm::seq<int64_t>(0, sourceRank)) {
106 sourceDims.push_back(expr);
120 ShapedType resultTy = cast<ShapedType>(
result.getType());
121 const int64_t resultRank = resultTy.getRank();
128 return linalg::GenericOp::create(
132 Value biasVal = args[0];
133 Type resType = args[1].getType();
134 if (resType != biasVal.
getType()) {
136 resultTy.getElementType().isFloat()
137 ? arith::ExtFOp::create(builder, loc, resType, biasVal)
139 : arith::ExtSIOp::create(builder, loc, resType,
143 linalg::YieldOp::create(builder, loc, biasVal);
164 auto one = arith::ConstantOp::create(rewriter, loc,
165 IntegerAttr::get(inputDim.
getType(), 1));
167 Value paddedBefore = arith::AddIOp::create(builder, inputDim, padBefore);
169 Value paddedAfter = arith::AddIOp::create(builder, paddedBefore, padAfter);
171 Value subOne = arith::SubIOp::create(builder, kernelDim, one);
173 Value dilated = arith::MulIOp::create(builder, dilation, subOne);
174 Value addOne = arith::AddIOp::create(builder, dilated, one);
176 Value subtract = arith::SubIOp::create(builder, paddedAfter, addOne);
178 Value divide = arith::DivUIOp::create(builder, subtract, stride);
179 return arith::AddIOp::create(builder, divide, one);
188 ShapedType inputTy = cast<ShapedType>(input.
getType());
189 int64_t inputRank = inputTy.getRank();
192 dynDims.resize(resultTy.getRank());
194 for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
195 int64_t inputDim = inputSizeDims[i];
196 int64_t kernelDim = kernelSizeDims[i];
197 if (resultTy.isDynamicDim(inputDim)) {
198 auto padTop = padAttr[i * 2];
199 auto padBottom = padAttr[i * 2 + 1];
200 auto stride = strideAttr[i];
201 auto dilation = dilationAttr[i];
202 Value initDynDim = tensor::DimOp::create(rewriter, loc, input, inputDim);
204 tensor::DimOp::create(rewriter, loc, weight, kernelDim);
208 kernelDynDim, stride, dilation, rewriter);
213 for (
int i = 0; i < inputRank; i++) {
214 if (resultTy.isDynamicDim(i) && !dynDims[i])
215 dynDims[i] = tensor::DimOp::create(rewriter, loc, input, i);
227 reassociationMap.resize(outputRank);
228 for (
int i = 0; i < outputRank; i++) {
231 reassociationMap[outputRank - 1].push_back(
237template <
typename TosaConvOp,
typename LinalgConvOp,
typename LinalgConvQOp>
238class ConvConverter :
public OpConversionPattern<TosaConvOp> {
240 using OpConversionPattern<TosaConvOp>::OpConversionPattern;
242 matchAndRewrite(TosaConvOp op,
typename TosaConvOp::Adaptor adaptor,
243 ConversionPatternRewriter &rewriter)
const final {
244 Location loc = op->getLoc();
245 Value input = op->getOperand(0);
246 Value weight = op->getOperand(1);
247 Value bias = op->getOperand(2);
249 ShapedType inputTy = cast<ShapedType>(input.
getType());
250 ShapedType weightTy = cast<ShapedType>(weight.
getType());
251 ShapedType biasTy = cast<ShapedType>(bias.
getType());
252 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
254 Type inputETy = inputTy.getElementType();
260 Type accETy = op.getAccType();
261 Type accTy = RankedTensorType::get(resultTy.getShape(), accETy);
264 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
266 return rewriter.notifyMatchFailure(
267 op,
"input zero point cannot be statically determined");
269 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
271 return rewriter.notifyMatchFailure(
272 op,
"weight zero point cannot be statically determined");
274 const int64_t inputZpVal = *maybeIZp;
275 const int64_t weightZpVal = *maybeWZp;
277 if (op.verifyInputZeroPoint(inputZpVal).failed())
278 return rewriter.notifyMatchFailure(
279 op,
"input zero point must be zero for non-int8 integer types");
281 if (op.verifyWeightZeroPoint(weightZpVal).failed())
282 return rewriter.notifyMatchFailure(
283 op,
"weight zero point must be zero for non-int8 integer types");
285 bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
287 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
288 return rewriter.notifyMatchFailure(
289 op,
"tosa.conv ops require static shapes for weight and bias");
292 return rewriter.notifyMatchFailure(
293 op,
"tosa.conv ops does not support unsigned integer input");
295 llvm::SmallVector<int64_t> inputSizeDims;
296 llvm::SmallVector<int64_t> kernelSizeDims;
297 for (
int i = 1; i < resultTy.getRank() - 1; i++) {
298 inputSizeDims.push_back(i);
299 kernelSizeDims.push_back(i);
303 loc, input, weight, resultTy, padAttr.
asArrayRef(),
305 inputSizeDims, kernelSizeDims, rewriter);
307 auto weightShape = weightTy.getShape();
310 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
319 if (inputZpVal < intMin || inputZpVal > intMax)
320 return rewriter.notifyMatchFailure(
321 op,
"tosa.conv op quantization has zp outside of input range");
323 zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
326 llvm::SmallVector<int64_t> pad;
328 llvm::append_range(pad, padAttr.
asArrayRef());
329 pad.resize(pad.size() + 2, 0);
330 input =
applyPad(loc, input, pad, zeroAttr, rewriter);
332 if (4 == inputTy.getRank()) {
336 hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
337 : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
343 SmallVector<int32_t> weightPerm;
344 for (
int i = 1; i < resultTy.getRank(); i++)
345 weightPerm.push_back(i);
346 weightPerm.push_back(0);
348 SmallVector<int64_t> newWeightShape;
349 for (
auto dim : weightPerm)
350 newWeightShape.push_back(weightShape[dim]);
351 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
353 RankedTensorType::get(newWeightShape, weightTy.getElementType());
354 weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
362 if (5 == inputTy.getRank()) {
365 SmallVector<int32_t> weightPerm;
366 for (
int i = 1; i < resultTy.getRank(); i++)
367 weightPerm.push_back(i);
368 weightPerm.push_back(0);
370 SmallVector<int64_t> newWeightShape;
371 for (
auto dim : weightPerm)
372 newWeightShape.push_back(weightShape[dim]);
373 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
375 RankedTensorType::get(newWeightShape, weightTy.getElementType());
376 weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
381 ArrayRef<int64_t> stride = strideTosaAttr;
382 ArrayRef<int64_t> dilation = dilationTosaAttr;
385 auto strideAttr = rewriter.getI64TensorAttr(stride);
386 auto dilationAttr = rewriter.getI64TensorAttr(dilation);
388 Value biasEmptyTensor = tensor::EmptyOp::create(
389 rewriter, loc, resultTy.getShape(), accETy, filteredDims);
391 Value broadcastBias =
395 auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
396 auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
398 auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
399 auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp);
401 Value conv = LinalgConvQOp::create(
402 rewriter, loc, resultTy,
404 ValueRange{broadcastBias}, strideAttr, dilationAttr)
407 rewriter.replaceOp(op, conv);
411 Value conv = LinalgConvOp::create(
412 rewriter, loc, accTy,
ValueRange{input, weight},
413 ValueRange{broadcastBias}, strideAttr, dilationAttr)
418 if (resultTy != accTy)
419 conv = tosa::CastOp::create(rewriter, loc, resultTy, conv);
421 rewriter.replaceOp(op, conv);
426class DepthwiseConvConverter
427 :
public OpConversionPattern<tosa::DepthwiseConv2DOp> {
429 using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern;
431 matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
432 ConversionPatternRewriter &rewriter)
const final {
433 Location loc = op->getLoc();
434 Value input = op->getOperand(0);
435 Value weight = op->getOperand(1);
436 Value bias = op->getOperand(2);
438 ShapedType inputTy = cast<ShapedType>(input.
getType());
439 ShapedType weightTy = cast<ShapedType>(weight.
getType());
440 ShapedType biasTy = cast<ShapedType>(bias.
getType());
441 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
442 int64_t resultRank = resultTy.getRank();
444 Type inputETy = inputTy.getElementType();
445 Type resultETy = resultTy.getElementType();
447 auto padAttr = cast<DenseI64ArrayAttr>(op->getAttr(
"pad"));
448 auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr(
"stride"));
449 auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr(
"dilation"));
451 Type accETy = op.getAccType();
453 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
454 return rewriter.notifyMatchFailure(
455 op,
"tosa.depthwise_conv ops require static shapes");
459 loc, input, weight, resultTy, padAttr.
asArrayRef(),
466 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
467 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
469 return rewriter.notifyMatchFailure(
470 op,
"input zero point cannot be statically determined");
472 return rewriter.notifyMatchFailure(
473 op,
"weight zero point cannot be statically determined");
475 const int64_t inputZpVal = *maybeIZp;
476 const int64_t weightZpVal = *maybeWZp;
478 if (op.verifyInputZeroPoint(inputZpVal).failed())
479 return rewriter.notifyMatchFailure(
480 op,
"input zero point must be zero for non-int8 integer types");
482 if (op.verifyWeightZeroPoint(weightZpVal).failed())
483 return rewriter.notifyMatchFailure(
484 op,
"weight zero point must be zero for non-int8 integer types");
486 bool hasNullZps = (inputZpVal == 0) && (weightZpVal == 0);
487 auto weightShape = weightTy.getShape();
488 auto resultShape = resultTy.getShape();
491 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
500 if (inputZpVal < intMin || inputZpVal > intMax)
501 return rewriter.notifyMatchFailure(
502 op,
"tosa.depthwise_conv op quantization has zp outside of input "
505 zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
508 llvm::SmallVector<int64_t> pad;
510 llvm::append_range(pad, padAttr.
asArrayRef());
511 pad.resize(pad.size() + 2, 0);
513 input =
applyPad(loc, input, pad, zeroAttr, rewriter);
516 ArrayRef<int64_t> stride = strideTosaAttr;
517 ArrayRef<int64_t> dilation = dilationTosaAttr;
520 auto strideAttr = rewriter.getI64TensorAttr(stride);
521 auto dilationAttr = rewriter.getI64TensorAttr(dilation);
522 ShapedType linalgConvTy =
523 RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
524 weightShape[2], weightShape[3]},
527 auto resultZeroAttr = rewriter.getZeroAttr(accETy);
528 Value emptyTensor = tensor::EmptyOp::create(
529 rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims);
530 Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr);
531 Value zeroTensor = linalg::FillOp::create(rewriter, loc,
ValueRange{zero},
535 Value biasEmptyTensor = tensor::EmptyOp::create(
536 rewriter, loc, resultTy.getShape(), resultETy, filteredDims);
539 SmallVector<AffineMap, 4> indexingMaps;
541 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
542 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
545 Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create(
546 rewriter, loc, linalgConvTy,
ValueRange{input, weight},
547 ValueRange{zeroTensor}, strideAttr, dilationAttr)
552 if (accETy != resultETy)
553 conv = tosa::CastOp::create(
555 RankedTensorType::get(cast<ShapedType>(conv.
getType()).getShape(),
559 SmallVector<ReassociationExprs, 4> reassociationMap;
561 Value convReshape = tensor::CollapseShapeOp::create(
562 rewriter, loc, resultTy, conv, reassociationMap);
565 linalg::GenericOp::create(
566 rewriter, loc, resultTy,
ValueRange({bias, convReshape}),
568 [&](OpBuilder &nestedBuilder, Location nestedLoc,
571 if (llvm::isa<FloatType>(inputETy))
572 added = arith::AddFOp::create(nestedBuilder, loc, args[0],
575 added = arith::AddIOp::create(nestedBuilder, loc, args[0],
577 linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
580 rewriter.replaceOp(op,
result);
582 IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
583 IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
584 auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
585 auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp);
586 Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create(
587 rewriter, loc, linalgConvTy,
589 ValueRange{zeroTensor}, strideAttr, dilationAttr)
591 SmallVector<ReassociationExprs, 4> reassociationMap;
593 Value convReshape = tensor::CollapseShapeOp::create(
594 rewriter, loc, resultTy, conv, reassociationMap);
596 rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps);
597 rewriter.replaceOp(op,
result);
603class MatMulConverter :
public OpConversionPattern<tosa::MatMulOp> {
605 using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
607 matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
608 ConversionPatternRewriter &rewriter)
const final {
609 Location loc = op.getLoc();
611 auto outputTy = cast<ShapedType>(op.getType());
612 auto outputElementTy = outputTy.getElementType();
614 SmallVector<Value> dynDims;
615 dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
617 if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
618 dynDims[0] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 0);
621 if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
622 dynDims[1] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 1);
625 if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
626 dynDims[2] = tensor::DimOp::create(rewriter, loc, op->getOperand(1), 2);
631 auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
632 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
634 tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(),
635 outputTy.getElementType(), filteredDims);
636 Value zeroTensor = linalg::FillOp::create(rewriter, loc,
ValueRange{zero},
640 FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
641 FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
643 return rewriter.notifyMatchFailure(
644 op,
"input a zero point cannot be statically determined");
646 return rewriter.notifyMatchFailure(
647 op,
"input b zero point cannot be statically determined");
649 const int64_t aZpVal = *maybeAZp;
650 const int64_t bZpVal = *maybeBZp;
652 if (op.verifyAZeroPoint(aZpVal).failed())
653 return rewriter.notifyMatchFailure(
654 op,
"input a zero point must be zero for non-int8 integer types");
656 if (op.verifyBZeroPoint(bZpVal).failed())
657 return rewriter.notifyMatchFailure(
658 op,
"input b zero point must be zero for non-int8 integer types");
660 if (aZpVal == 0 && bZpVal == 0) {
661 rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
667 auto aZp = arith::ConstantOp::create(rewriter, loc,
668 rewriter.getI32IntegerAttr(aZpVal));
669 auto bZp = arith::ConstantOp::create(rewriter, loc,
670 rewriter.getI32IntegerAttr(bZpVal));
671 rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
673 ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
679class MaxPool2dConverter :
public OpConversionPattern<tosa::MaxPool2dOp> {
681 using OpConversionPattern::OpConversionPattern;
684 static SmallVector<Value>
685 computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor,
686 ConversionPatternRewriter &rewriter) {
687 TensorType resultTy = op.getType();
688 Location loc = op.getLoc();
690 Value input = adaptor.getInput();
691 ArrayRef<int64_t> kernel = op.getKernel();
692 ArrayRef<int64_t> pad = op.getPad();
693 ArrayRef<int64_t> stride = op.getStride();
695 SmallVector<Value> dynamicDims;
698 if (resultTy.isDynamicDim(0))
699 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
702 for (int64_t dim : {1, 2}) {
703 if (!resultTy.isDynamicDim(dim))
707 int64_t index = dim - 1;
710 Value ihw = tensor::DimOp::create(rewriter, loc, input, dim);
717 pad[index * 2 + 1], khw, stride[index],
719 dynamicDims.push_back(ohw);
723 if (resultTy.isDynamicDim(3))
724 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 3));
730 matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
731 ConversionPatternRewriter &rewriter)
const final {
732 Location loc = op.getLoc();
733 Value input = adaptor.getInput();
734 ShapedType inputTy = cast<ShapedType>(input.
getType());
736 bool isUnsigned = op.getType().getElementType().isUnsignedInteger();
737 ShapedType resultTy =
738 getTypeConverter()->convertType<ShapedType>(op.getType());
740 return rewriter.notifyMatchFailure(op,
"failed to convert type");
741 Type resultETy = inputTy.getElementType();
743 SmallVector<Value> dynamicDims =
744 computeDynamicOutputSizes(op, adaptor, rewriter);
747 TypedAttr initialAttr;
749 initialAttr = rewriter.getFloatAttr(
750 resultETy, APFloat::getLargest(
751 cast<FloatType>(resultETy).getFloatSemantics(),
true));
754 initialAttr = rewriter.getIntegerAttr(
756 else if (isa<IntegerType>(resultETy))
757 initialAttr = rewriter.getIntegerAttr(
762 return rewriter.notifyMatchFailure(
763 op,
"Unsupported initial value for tosa.maxpool_2d op");
766 llvm::SmallVector<int64_t> pad;
768 llvm::append_range(pad, op.getPad());
769 pad.resize(pad.size() + 2, 0);
771 Value paddedInput =
applyPad(loc, input, pad, initialAttr, rewriter);
773 Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
775 ArrayRef<int64_t> kernel = op.getKernel();
776 ArrayRef<int64_t> stride = op.getStride();
778 Attribute strideAttr = rewriter.getI64VectorAttr(stride);
779 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
783 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
784 resultTy.getElementType(), dynamicDims);
786 Value filledEmptyTensor =
787 linalg::FillOp::create(rewriter, loc, initialValue, emptyTensor)
790 Value fakeWindowDims =
791 tensor::EmptyOp::create(rewriter, loc, kernel, resultETy);
794 rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
795 op, ArrayRef<Type>{resultTy},
ValueRange{paddedInput, fakeWindowDims},
796 filledEmptyTensor, strideAttr, dilationAttr);
797 return llvm::success();
800 auto resultOp = linalg::PoolingNhwcMaxOp::create(
801 rewriter, op->getLoc(), ArrayRef<Type>{resultTy},
802 ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
805 NanPropagationMode nanMode = op.getNanMode();
806 rewriter.replaceOp(op, resultOp);
819 if (nanMode == NanPropagationMode::IGNORE) {
820 auto genericOp = linalg::GenericOp::create(
821 rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
822 resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
823 resultOp.getIteratorTypesArray(),
824 [&](OpBuilder &opBuilder, Location loc,
ValueRange blockArgs) {
826 auto oldBlock = resultOp.getRegion().begin();
827 auto oldArgs = oldBlock->getArguments();
828 auto &oldMaxOp = *resultOp.getBlock()->begin();
829 map.map(oldArgs, blockArgs);
830 auto *newOp = opBuilder.clone(oldMaxOp, map);
832 arith::CmpFOp::create(opBuilder, loc, arith::CmpFPredicate::UNO,
833 blockArgs.front(), blockArgs.front());
834 auto selectOp = arith::SelectOp::create(
835 opBuilder, loc, isNaN, blockArgs.back(), newOp->getResult(0));
836 linalg::YieldOp::create(opBuilder, loc, selectOp.getResult());
838 rewriter.replaceOp(resultOp, genericOp);
847 using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern;
849 LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
850 PatternRewriter &rewriter)
const final {
851 Location loc = op.getLoc();
852 Value input = op.getInput();
853 ShapedType inputTy = cast<ShapedType>(input.
getType());
854 Type inElementTy = inputTy.getElementType();
856 ShapedType resultTy = cast<ShapedType>(op.getType());
857 Type resultETy = cast<ShapedType>(op.getType()).getElementType();
859 Type accETy = op.getAccType();
860 ShapedType accTy = resultTy.clone(accETy);
864 if (!dynamicDimsOr.has_value())
866 SmallVector<Value> dynamicDims = *dynamicDimsOr;
868 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
869 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
871 return rewriter.notifyMatchFailure(
872 op,
"input zero point could not be statically determined");
874 return rewriter.notifyMatchFailure(
875 op,
"output zero point could not be statically determined");
877 const int64_t inputZpVal = *maybeIZp;
878 const int64_t outputZpVal = *maybeOZp;
881 llvm::SmallVector<int64_t> pad;
883 llvm::append_range(pad, op.getPad());
884 pad.resize(pad.size() + 2, 0);
885 TypedAttr padAttr = rewriter.getZeroAttr(inElementTy);
889 Value paddedInput =
applyPad(loc, input, pad, padAttr, rewriter);
891 auto initialAttr = rewriter.getZeroAttr(accETy);
892 Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
894 ArrayRef<int64_t> kernel = op.getKernel();
895 ArrayRef<int64_t> stride = op.getStride();
897 Attribute strideAttr = rewriter.getI64VectorAttr(stride);
898 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
901 Value poolEmptyTensor = tensor::EmptyOp::create(
902 rewriter, loc, accTy.getShape(), accETy, dynamicDims);
904 Value filledEmptyTensor =
905 linalg::FillOp::create(rewriter, loc,
ValueRange{initialValue},
909 Value fakeWindowDims =
910 tensor::EmptyOp::create(rewriter, loc, kernel, accETy);
913 Value poolingOp = linalg::PoolingNhwcSumOp::create(
914 rewriter, loc, ArrayRef<Type>{accTy},
916 filledEmptyTensor, strideAttr, dilationAttr)
921 Value iH = tensor::DimOp::create(rewriter, loc, poolingOp, 1);
922 Value iW = tensor::DimOp::create(rewriter, loc, poolingOp, 2);
925 iH = arith::SubIOp::create(rewriter, loc, iH, one);
926 iW = arith::SubIOp::create(rewriter, loc, iW, one);
928 Value genericEmptyTensor = tensor::EmptyOp::create(
929 rewriter, loc, resultTy.getShape(), resultETy, dynamicDims);
931 auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
932 auto genericOp = linalg::GenericOp::create(
933 rewriter, loc, ArrayRef<Type>({resultTy}),
ValueRange{poolingOp},
935 ArrayRef<AffineMap>({affineMap, affineMap}),
942 auto padFn = [&](Value valid, Value pos, int64_t pad) -> Value {
947 Value dpos = arith::SubIOp::create(rewriter, loc, pos, padVal);
949 Value offset = arith::MinSIOp::create(rewriter, loc, dpos, zero);
950 return arith::AddIOp::create(rewriter, loc, valid, offset)
954 auto coverageFn = [&](int64_t i, Value isize) -> Value {
961 Value left = linalg::IndexOp::create(rewriter, loc, i);
962 Value right = arith::SubIOp::create(rewriter, loc, isize, left);
963 left = arith::MulIOp::create(rewriter, loc, left, strideVal);
964 right = arith::MulIOp::create(rewriter, loc, right, strideVal);
967 val = padFn(val, left, pad[i * 2]);
968 val = padFn(val, right, pad[i * 2 + 1]);
969 return arith::MaxSIOp::create(rewriter, loc, one, val);
973 Value kH3 = coverageFn(1, iH);
974 Value kW3 = coverageFn(2, iW);
977 auto count = arith::IndexCastOp::create(
978 rewriter, loc, rewriter.getI32Type(),
979 arith::MulIOp::create(rewriter, loc, kH3, kW3));
984 Value poolVal = args[0];
985 if (isa<FloatType>(accETy)) {
986 auto countF = arith::SIToFPOp::create(rewriter, loc, accETy, count);
987 poolVal = arith::DivFOp::create(rewriter, loc, poolVal, countF)
992 arith::TruncFOp::create(rewriter, loc, resultETy, poolVal);
997 if (inputZpVal != 0) {
998 auto inputZp = arith::ConstantOp::create(
999 rewriter, loc,
b.getIntegerAttr(accETy, inputZpVal));
1001 arith::MulIOp::create(rewriter, loc, accETy, count, inputZp);
1003 arith::SubIOp::create(rewriter, loc, accETy, poolVal, offset);
1007 Value one32 = arith::ConstantOp::create(
1008 rewriter, loc, rewriter.getI32IntegerAttr(1));
1009 Value thirtyTwo32 = arith::ConstantOp::create(
1010 rewriter, loc, rewriter.getI32IntegerAttr(32));
1013 arith::SubIOp::create(rewriter, loc, count, one32);
1014 Value leadingZeros =
1015 math::CountLeadingZerosOp::create(rewriter, loc, countSubOne);
1017 arith::SubIOp::create(rewriter, loc, thirtyTwo32, leadingZeros);
1021 arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), k);
1022 Value thirtyShiftPlusOne = arith::ConstantOp::create(
1023 rewriter, loc, rewriter.getI64IntegerAttr((1 << 30) + 1));
1025 arith::ShLIOp::create(rewriter, loc, thirtyShiftPlusOne, k64);
1028 Value count64 = arith::ExtUIOp::create(
1029 rewriter, loc, rewriter.getI64Type(), count);
1031 arith::DivUIOp::create(rewriter, loc, numerator, count64);
1032 multiplier = arith::TruncIOp::create(
1033 rewriter, loc, rewriter.getI32Type(), multiplier);
1037 arith::TruncIOp::create(rewriter, loc, rewriter.getI8Type(), k);
1038 Value thirty8 = arith::ConstantOp::create(
1039 rewriter, loc, rewriter.getI8IntegerAttr(30));
1040 Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
1042 auto roundingAttr = RoundingModeAttr::get(
1043 rewriter.getContext(), RoundingMode::SINGLE_ROUND);
1045 auto scaled = tosa::ApplyScaleOp::create(
1046 rewriter, loc, rewriter.getI32Type(), poolVal,
1047 multiplier, shift, roundingAttr)
1052 if (outputZpVal != 0) {
1053 auto outputZp = arith::ConstantOp::create(
1055 b.getIntegerAttr(scaled.getType(), outputZpVal));
1056 scaled = arith::AddIOp::create(rewriter, loc, scaled, outputZp)
1064 rewriter, loc, accETy,
1065 APInt::getSignedMinValue(outBitwidth).getSExtValue());
1067 rewriter, loc, accETy,
1068 APInt::getSignedMaxValue(outBitwidth).getSExtValue());
1076 arith::TruncIOp::create(rewriter, loc, resultETy, poolVal);
1080 linalg::YieldOp::create(rewriter, loc, poolVal);
1083 rewriter.replaceOp(op, genericOp.getResult(0));
1090 using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
1092 LogicalResult matchAndRewrite(tosa::TransposeOp op,
1093 PatternRewriter &rewriter)
const final {
1094 const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
1096 Location loc = op.getLoc();
1100 SmallVector<OpFoldResult> inputSizes =
1102 auto permutedSizes =
1106 tensor::EmptyOp::create(rewriter, loc, permutedSizes,
1107 op.getInput1().getType().getElementType());
1108 rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
1109 op, op.getInput1(), permutedInit,
1110 llvm::to_vector(llvm::map_range(
1111 constantPerms, [](int32_t v) -> int64_t {
return v; })));
1120 if (
options.preferConv2DKernelLayoutHWCF) {
1121 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
1122 linalg::Conv2DNhwcHwcfQOp>>(
1125 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
1126 linalg::Conv2DNhwcFhwcQOp>>(
1131 ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
1132 DepthwiseConvConverter,
1140 >(converter,
patterns->getContext());
static llvm::ManagedStatic< PassManagerOptions > options
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source, Value result)
static mlir::Value applyPad(Location loc, Value input, ArrayRef< int64_t > pad, TypedAttr padAttr, OpBuilder &rewriter)
static void createDepthwiseConvCollapseMap(int64_t outputRank, SmallVector< ReassociationExprs, 4 > &reassociationMap, OpBuilder &rewriter)
static mlir::Value linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef< AffineMap > indexingMaps)
static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim, int64_t padBeforeAttr, int64_t padAfterAttr, Value kernelDim, int64_t strideAttr, int64_t dilationAttr, OpBuilder &rewriter)
static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter, Location loc, Value source, Value result)
static mlir::Value reifyConstantDim(int64_t attr, ImplicitLocOpBuilder &builder)
static SmallVector< Value > inferDynamicDimsForConv(Location loc, Value input, Value weight, ShapedType resultTy, ArrayRef< int64_t > padAttr, ArrayRef< int64_t > strideAttr, ArrayRef< int64_t > dilationAttr, ArrayRef< int64_t > inputSizeDims, ArrayRef< int64_t > kernelSizeDims, OpBuilder &rewriter)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineConstantExpr(int64_t constant)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
ArrayRef< T > asArrayRef() const
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
SmallVector< T > applyTOSAPermutation(ArrayRef< T > input, ArrayRef< int32_t > perms)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
void populateTosaToLinalgNamedConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options)
Populates conversion passes from TOSA dialect to Linalg named operations.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...