23#include "llvm/ADT/SmallVectorExtras.h"
33 if (llvm::all_of(pad, [](
int64_t p) {
return p == 0; }))
36 ShapedType inputTy = cast<ShapedType>(input.
getType());
37 Type inputETy = inputTy.getElementType();
38 auto inputShape = inputTy.getShape();
40 assert((inputShape.size() * 2) == pad.size());
45 for (
size_t i : llvm::seq(inputShape.size())) {
46 auto lowPad = pad[i * 2];
47 auto highPad = pad[i * 2 + 1];
48 if (ShapedType::isDynamic(inputShape[i]))
49 paddedShape.push_back(inputShape[i]);
51 paddedShape.push_back(inputShape[i] + highPad + lowPad);
56 Value padValue = arith::ConstantOp::create(rewriter, loc, padAttr);
58 return tensor::PadOp::create(rewriter, loc,
59 RankedTensorType::get(paddedShape, inputETy),
60 input, lowIndices, highIndices, padValue);
67 ShapedType resultTy = cast<ShapedType>(conv.
getType());
68 return linalg::GenericOp::create(
72 Value biasVal = args[0];
73 Type resType = args[1].getType();
74 if (resType != biasVal.
getType()) {
76 arith::ExtSIOp::create(builder, loc, resType, biasVal);
79 arith::AddIOp::create(builder, loc, biasVal, args[1]);
80 linalg::YieldOp::create(builder, loc, added);
89 ShapedType resultTy = cast<ShapedType>(
result.getType());
90 ShapedType sourceTy = cast<ShapedType>(source.
getType());
91 const int64_t resultRank = resultTy.getRank();
92 const int64_t sourceRank = sourceTy.getRank();
100 assert(sourceTy.hasStaticShape() &&
101 "Dynamic broadcasting shapes not supported!");
102 if (sourceRank == 1 && sourceTy.getDimSize(0) == 1) {
105 for (
auto dim : llvm::seq<int64_t>(0, sourceRank)) {
107 sourceDims.push_back(expr);
121 ShapedType resultTy = cast<ShapedType>(
result.getType());
122 const int64_t resultRank = resultTy.getRank();
129 return linalg::GenericOp::create(
133 Value biasVal = args[0];
134 Type resType = args[1].getType();
135 if (resType != biasVal.
getType()) {
137 resultTy.getElementType().isFloat()
138 ? arith::ExtFOp::create(builder, loc, resType, biasVal)
140 : arith::ExtSIOp::create(builder, loc, resType,
144 linalg::YieldOp::create(builder, loc, biasVal);
165 auto one = arith::ConstantOp::create(rewriter, loc,
166 IntegerAttr::get(inputDim.
getType(), 1));
168 Value paddedBefore = arith::AddIOp::create(builder, inputDim, padBefore);
170 Value paddedAfter = arith::AddIOp::create(builder, paddedBefore, padAfter);
172 Value subOne = arith::SubIOp::create(builder, kernelDim, one);
174 Value dilated = arith::MulIOp::create(builder, dilation, subOne);
175 Value addOne = arith::AddIOp::create(builder, dilated, one);
177 Value subtract = arith::SubIOp::create(builder, paddedAfter, addOne);
179 Value divide = arith::DivUIOp::create(builder, subtract, stride);
180 return arith::AddIOp::create(builder, divide, one);
189 ShapedType inputTy = cast<ShapedType>(input.
getType());
190 int64_t inputRank = inputTy.getRank();
193 dynDims.resize(resultTy.getRank());
195 for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
196 int64_t inputDim = inputSizeDims[i];
197 int64_t kernelDim = kernelSizeDims[i];
198 if (resultTy.isDynamicDim(inputDim)) {
199 auto padTop = padAttr[i * 2];
200 auto padBottom = padAttr[i * 2 + 1];
201 auto stride = strideAttr[i];
202 auto dilation = dilationAttr[i];
203 Value initDynDim = tensor::DimOp::create(rewriter, loc, input, inputDim);
205 tensor::DimOp::create(rewriter, loc, weight, kernelDim);
209 kernelDynDim, stride, dilation, rewriter);
214 for (
int i = 0; i < inputRank; i++) {
215 if (resultTy.isDynamicDim(i) && !dynDims[i])
216 dynDims[i] = tensor::DimOp::create(rewriter, loc, input, i);
228 reassociationMap.resize(outputRank);
229 for (
int i = 0; i < outputRank; i++) {
232 reassociationMap[outputRank - 1].push_back(
238template <
typename TosaConvOp,
typename LinalgConvOp,
typename LinalgConvQOp>
239class ConvConverter :
public OpConversionPattern<TosaConvOp> {
241 using OpConversionPattern<TosaConvOp>::OpConversionPattern;
243 matchAndRewrite(TosaConvOp op,
typename TosaConvOp::Adaptor adaptor,
244 ConversionPatternRewriter &rewriter)
const final {
245 Location loc = op->getLoc();
246 Value input = op->getOperand(0);
247 Value weight = op->getOperand(1);
248 Value bias = op->getOperand(2);
250 ShapedType inputTy = cast<ShapedType>(input.
getType());
251 ShapedType weightTy = cast<ShapedType>(weight.
getType());
252 ShapedType biasTy = cast<ShapedType>(bias.
getType());
253 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
255 Type inputETy = inputTy.getElementType();
261 Type accETy = op.getAccType();
262 Type accTy = RankedTensorType::get(resultTy.getShape(), accETy);
265 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
267 return rewriter.notifyMatchFailure(
268 op,
"input zero point cannot be statically determined");
270 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
272 return rewriter.notifyMatchFailure(
273 op,
"weight zero point cannot be statically determined");
275 const int64_t inputZpVal = *maybeIZp;
276 const int64_t weightZpVal = *maybeWZp;
278 if (op.verifyInputZeroPoint(inputZpVal).failed())
279 return rewriter.notifyMatchFailure(
280 op,
"input zero point must be zero for non-int8 integer types");
282 if (op.verifyWeightZeroPoint(weightZpVal).failed())
283 return rewriter.notifyMatchFailure(
284 op,
"weight zero point must be zero for non-int8 integer types");
286 bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
288 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
289 return rewriter.notifyMatchFailure(
290 op,
"tosa.conv ops require static shapes for weight and bias");
293 return rewriter.notifyMatchFailure(
294 op,
"tosa.conv ops does not support unsigned integer input");
296 llvm::SmallVector<int64_t> inputSizeDims;
297 llvm::SmallVector<int64_t> kernelSizeDims;
298 for (
int i = 1; i < resultTy.getRank() - 1; i++) {
299 inputSizeDims.push_back(i);
300 kernelSizeDims.push_back(i);
304 loc, input, weight, resultTy, padAttr.
asArrayRef(),
306 inputSizeDims, kernelSizeDims, rewriter);
308 auto weightShape = weightTy.getShape();
311 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
320 if (inputZpVal < intMin || inputZpVal > intMax)
321 return rewriter.notifyMatchFailure(
322 op,
"tosa.conv op quantization has zp outside of input range");
324 zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
327 llvm::SmallVector<int64_t> pad;
329 llvm::append_range(pad, padAttr.
asArrayRef());
330 pad.resize(pad.size() + 2, 0);
331 input =
applyPad(loc, input, pad, zeroAttr, rewriter);
333 if (4 == inputTy.getRank()) {
337 hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
338 : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
344 SmallVector<int32_t> weightPerm;
345 for (
int i = 1; i < resultTy.getRank(); i++)
346 weightPerm.push_back(i);
347 weightPerm.push_back(0);
349 SmallVector<int64_t> newWeightShape;
350 for (
auto dim : weightPerm)
351 newWeightShape.push_back(weightShape[dim]);
352 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
354 RankedTensorType::get(newWeightShape, weightTy.getElementType());
355 weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
363 if (5 == inputTy.getRank()) {
366 SmallVector<int32_t> weightPerm;
367 for (
int i = 1; i < resultTy.getRank(); i++)
368 weightPerm.push_back(i);
369 weightPerm.push_back(0);
371 SmallVector<int64_t> newWeightShape;
372 for (
auto dim : weightPerm)
373 newWeightShape.push_back(weightShape[dim]);
374 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
376 RankedTensorType::get(newWeightShape, weightTy.getElementType());
377 weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
382 ArrayRef<int64_t> stride = strideTosaAttr;
383 ArrayRef<int64_t> dilation = dilationTosaAttr;
386 auto strideAttr = rewriter.getI64TensorAttr(stride);
387 auto dilationAttr = rewriter.getI64TensorAttr(dilation);
389 Value biasEmptyTensor = tensor::EmptyOp::create(
390 rewriter, loc, resultTy.getShape(), accETy, filteredDims);
392 Value broadcastBias =
396 auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
397 auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
399 auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
400 auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp);
402 Value conv = LinalgConvQOp::create(
403 rewriter, loc, resultTy,
405 ValueRange{broadcastBias}, strideAttr, dilationAttr)
408 rewriter.replaceOp(op, conv);
412 Value conv = LinalgConvOp::create(
413 rewriter, loc, accTy,
ValueRange{input, weight},
414 ValueRange{broadcastBias}, strideAttr, dilationAttr)
419 if (resultTy != accTy)
420 conv = tosa::CastOp::create(rewriter, loc, resultTy, conv);
422 rewriter.replaceOp(op, conv);
427class DepthwiseConvConverter
428 :
public OpConversionPattern<tosa::DepthwiseConv2DOp> {
430 using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern;
432 matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
433 ConversionPatternRewriter &rewriter)
const final {
434 Location loc = op->getLoc();
435 Value input = op->getOperand(0);
436 Value weight = op->getOperand(1);
437 Value bias = op->getOperand(2);
439 ShapedType inputTy = cast<ShapedType>(input.
getType());
440 ShapedType weightTy = cast<ShapedType>(weight.
getType());
441 ShapedType biasTy = cast<ShapedType>(bias.
getType());
442 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
443 int64_t resultRank = resultTy.getRank();
445 Type inputETy = inputTy.getElementType();
446 Type resultETy = resultTy.getElementType();
448 auto padAttr = cast<DenseI64ArrayAttr>(op->getAttr(
"pad"));
449 auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr(
"stride"));
450 auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr(
"dilation"));
452 Type accETy = op.getAccType();
454 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
455 return rewriter.notifyMatchFailure(
456 op,
"tosa.depthwise_conv ops require static shapes");
460 loc, input, weight, resultTy, padAttr.
asArrayRef(),
467 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
468 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
470 return rewriter.notifyMatchFailure(
471 op,
"input zero point cannot be statically determined");
473 return rewriter.notifyMatchFailure(
474 op,
"weight zero point cannot be statically determined");
476 const int64_t inputZpVal = *maybeIZp;
477 const int64_t weightZpVal = *maybeWZp;
479 if (op.verifyInputZeroPoint(inputZpVal).failed())
480 return rewriter.notifyMatchFailure(
481 op,
"input zero point must be zero for non-int8 integer types");
483 if (op.verifyWeightZeroPoint(weightZpVal).failed())
484 return rewriter.notifyMatchFailure(
485 op,
"weight zero point must be zero for non-int8 integer types");
487 bool hasNullZps = (inputZpVal == 0) && (weightZpVal == 0);
488 auto weightShape = weightTy.getShape();
489 auto resultShape = resultTy.getShape();
492 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
501 if (inputZpVal < intMin || inputZpVal > intMax)
502 return rewriter.notifyMatchFailure(
503 op,
"tosa.depthwise_conv op quantization has zp outside of input "
506 zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
509 llvm::SmallVector<int64_t> pad;
511 llvm::append_range(pad, padAttr.
asArrayRef());
512 pad.resize(pad.size() + 2, 0);
514 input =
applyPad(loc, input, pad, zeroAttr, rewriter);
517 ArrayRef<int64_t> stride = strideTosaAttr;
518 ArrayRef<int64_t> dilation = dilationTosaAttr;
521 auto strideAttr = rewriter.getI64TensorAttr(stride);
522 auto dilationAttr = rewriter.getI64TensorAttr(dilation);
523 ShapedType linalgConvTy =
524 RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
525 weightShape[2], weightShape[3]},
528 auto resultZeroAttr = rewriter.getZeroAttr(accETy);
529 Value emptyTensor = tensor::EmptyOp::create(
530 rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims);
531 Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr);
532 Value zeroTensor = linalg::FillOp::create(rewriter, loc,
ValueRange{zero},
536 Value biasEmptyTensor = tensor::EmptyOp::create(
537 rewriter, loc, resultTy.getShape(), resultETy, filteredDims);
540 SmallVector<AffineMap, 4> indexingMaps;
542 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
543 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
546 Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create(
547 rewriter, loc, linalgConvTy,
ValueRange{input, weight},
548 ValueRange{zeroTensor}, strideAttr, dilationAttr)
553 if (accETy != resultETy)
554 conv = tosa::CastOp::create(
556 RankedTensorType::get(cast<ShapedType>(conv.
getType()).getShape(),
560 SmallVector<ReassociationExprs, 4> reassociationMap;
562 Value convReshape = tensor::CollapseShapeOp::create(
563 rewriter, loc, resultTy, conv, reassociationMap);
566 linalg::GenericOp::create(
567 rewriter, loc, resultTy,
ValueRange({bias, convReshape}),
569 [&](OpBuilder &nestedBuilder, Location nestedLoc,
572 if (llvm::isa<FloatType>(inputETy))
573 added = arith::AddFOp::create(nestedBuilder, loc, args[0],
576 added = arith::AddIOp::create(nestedBuilder, loc, args[0],
578 linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
581 rewriter.replaceOp(op,
result);
583 IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
584 IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
585 auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
586 auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp);
587 Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create(
588 rewriter, loc, linalgConvTy,
590 ValueRange{zeroTensor}, strideAttr, dilationAttr)
592 SmallVector<ReassociationExprs, 4> reassociationMap;
594 Value convReshape = tensor::CollapseShapeOp::create(
595 rewriter, loc, resultTy, conv, reassociationMap);
597 rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps);
598 rewriter.replaceOp(op,
result);
604class MatMulConverter :
public OpConversionPattern<tosa::MatMulOp> {
606 using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
608 matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
609 ConversionPatternRewriter &rewriter)
const final {
610 Location loc = op.getLoc();
612 auto outputTy = cast<ShapedType>(op.getType());
613 auto outputElementTy = outputTy.getElementType();
615 SmallVector<Value> dynDims;
616 dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
618 if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
619 dynDims[0] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 0);
622 if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
623 dynDims[1] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 1);
626 if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
627 dynDims[2] = tensor::DimOp::create(rewriter, loc, op->getOperand(1), 2);
632 auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
633 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
635 tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(),
636 outputTy.getElementType(), filteredDims);
637 Value zeroTensor = linalg::FillOp::create(rewriter, loc,
ValueRange{zero},
641 FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
642 FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
644 return rewriter.notifyMatchFailure(
645 op,
"input a zero point cannot be statically determined");
647 return rewriter.notifyMatchFailure(
648 op,
"input b zero point cannot be statically determined");
650 const int64_t aZpVal = *maybeAZp;
651 const int64_t bZpVal = *maybeBZp;
653 if (op.verifyAZeroPoint(aZpVal).failed())
654 return rewriter.notifyMatchFailure(
655 op,
"input a zero point must be zero for non-int8 integer types");
657 if (op.verifyBZeroPoint(bZpVal).failed())
658 return rewriter.notifyMatchFailure(
659 op,
"input b zero point must be zero for non-int8 integer types");
661 if (aZpVal == 0 && bZpVal == 0) {
662 rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
668 auto aZp = arith::ConstantOp::create(rewriter, loc,
669 rewriter.getI32IntegerAttr(aZpVal));
670 auto bZp = arith::ConstantOp::create(rewriter, loc,
671 rewriter.getI32IntegerAttr(bZpVal));
672 rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
674 ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
680class MaxPool2dConverter :
public OpConversionPattern<tosa::MaxPool2dOp> {
682 using OpConversionPattern::OpConversionPattern;
685 static SmallVector<Value>
686 computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor,
687 ConversionPatternRewriter &rewriter) {
688 TensorType resultTy = op.getType();
689 Location loc = op.getLoc();
691 Value input = adaptor.getInput();
692 ArrayRef<int64_t> kernel = op.getKernel();
693 ArrayRef<int64_t> pad = op.getPad();
694 ArrayRef<int64_t> stride = op.getStride();
696 SmallVector<Value> dynamicDims;
699 if (resultTy.isDynamicDim(0))
700 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
703 for (int64_t dim : {1, 2}) {
704 if (!resultTy.isDynamicDim(dim))
708 int64_t index = dim - 1;
711 Value ihw = tensor::DimOp::create(rewriter, loc, input, dim);
718 pad[index * 2 + 1], khw, stride[index],
720 dynamicDims.push_back(ohw);
724 if (resultTy.isDynamicDim(3))
725 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 3));
731 matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
732 ConversionPatternRewriter &rewriter)
const final {
733 Location loc = op.getLoc();
734 Value input = adaptor.getInput();
735 ShapedType inputTy = cast<ShapedType>(input.
getType());
737 bool isUnsigned = op.getType().getElementType().isUnsignedInteger();
738 ShapedType resultTy =
739 getTypeConverter()->convertType<ShapedType>(op.getType());
741 return rewriter.notifyMatchFailure(op,
"failed to convert type");
742 Type resultETy = inputTy.getElementType();
744 SmallVector<Value> dynamicDims =
745 computeDynamicOutputSizes(op, adaptor, rewriter);
748 TypedAttr initialAttr;
750 initialAttr = rewriter.getFloatAttr(
751 resultETy, APFloat::getLargest(
752 cast<FloatType>(resultETy).getFloatSemantics(),
true));
755 initialAttr = rewriter.getIntegerAttr(
757 else if (isa<IntegerType>(resultETy))
758 initialAttr = rewriter.getIntegerAttr(
763 return rewriter.notifyMatchFailure(
764 op,
"Unsupported initial value for tosa.maxpool_2d op");
767 llvm::SmallVector<int64_t> pad;
769 llvm::append_range(pad, op.getPad());
770 pad.resize(pad.size() + 2, 0);
772 Value paddedInput =
applyPad(loc, input, pad, initialAttr, rewriter);
774 Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
776 ArrayRef<int64_t> kernel = op.getKernel();
777 ArrayRef<int64_t> stride = op.getStride();
779 Attribute strideAttr = rewriter.getI64VectorAttr(stride);
780 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
784 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
785 resultTy.getElementType(), dynamicDims);
787 Value filledEmptyTensor =
788 linalg::FillOp::create(rewriter, loc, initialValue, emptyTensor)
791 Value fakeWindowDims =
792 tensor::EmptyOp::create(rewriter, loc, kernel, resultETy);
795 rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
796 op, ArrayRef<Type>{resultTy},
ValueRange{paddedInput, fakeWindowDims},
797 filledEmptyTensor, strideAttr, dilationAttr);
798 return llvm::success();
801 auto resultOp = linalg::PoolingNhwcMaxOp::create(
802 rewriter, op->getLoc(), ArrayRef<Type>{resultTy},
803 ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
806 NanPropagationMode nanMode = op.getNanMode();
807 rewriter.replaceOp(op, resultOp);
820 if (nanMode == NanPropagationMode::IGNORE) {
821 auto genericOp = linalg::GenericOp::create(
822 rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
823 resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
824 resultOp.getIteratorTypesArray(),
825 [&](OpBuilder &opBuilder, Location loc,
ValueRange blockArgs) {
827 auto oldBlock = resultOp.getRegion().begin();
828 auto oldArgs = oldBlock->getArguments();
829 auto &oldMaxOp = *resultOp.getBlock()->begin();
830 map.map(oldArgs, blockArgs);
831 auto *newOp = opBuilder.clone(oldMaxOp, map);
833 arith::CmpFOp::create(opBuilder, loc, arith::CmpFPredicate::UNO,
834 blockArgs.front(), blockArgs.front());
835 auto selectOp = arith::SelectOp::create(
836 opBuilder, loc, isNaN, blockArgs.back(), newOp->getResult(0));
837 linalg::YieldOp::create(opBuilder, loc, selectOp.getResult());
839 rewriter.replaceOp(resultOp, genericOp);
848 using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern;
850 LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
851 PatternRewriter &rewriter)
const final {
852 Location loc = op.getLoc();
853 Value input = op.getInput();
854 ShapedType inputTy = cast<ShapedType>(input.
getType());
855 Type inElementTy = inputTy.getElementType();
857 ShapedType resultTy = cast<ShapedType>(op.getType());
858 Type resultETy = cast<ShapedType>(op.getType()).getElementType();
860 Type accETy = op.getAccType();
861 ShapedType accTy = resultTy.clone(accETy);
865 if (!dynamicDimsOr.has_value())
867 SmallVector<Value> dynamicDims = *dynamicDimsOr;
869 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
870 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
872 return rewriter.notifyMatchFailure(
873 op,
"input zero point could not be statically determined");
875 return rewriter.notifyMatchFailure(
876 op,
"output zero point could not be statically determined");
878 const int64_t inputZpVal = *maybeIZp;
879 const int64_t outputZpVal = *maybeOZp;
882 llvm::SmallVector<int64_t> pad;
884 llvm::append_range(pad, op.getPad());
885 pad.resize(pad.size() + 2, 0);
886 TypedAttr padAttr = rewriter.getZeroAttr(inElementTy);
890 Value paddedInput =
applyPad(loc, input, pad, padAttr, rewriter);
892 auto initialAttr = rewriter.getZeroAttr(accETy);
893 Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
895 ArrayRef<int64_t> kernel = op.getKernel();
896 ArrayRef<int64_t> stride = op.getStride();
898 Attribute strideAttr = rewriter.getI64VectorAttr(stride);
899 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
902 Value poolEmptyTensor = tensor::EmptyOp::create(
903 rewriter, loc, accTy.getShape(), accETy, dynamicDims);
905 Value filledEmptyTensor =
906 linalg::FillOp::create(rewriter, loc,
ValueRange{initialValue},
910 Value fakeWindowDims =
911 tensor::EmptyOp::create(rewriter, loc, kernel, accETy);
914 Value poolingOp = linalg::PoolingNhwcSumOp::create(
915 rewriter, loc, ArrayRef<Type>{accTy},
917 filledEmptyTensor, strideAttr, dilationAttr)
922 Value iH = tensor::DimOp::create(rewriter, loc, poolingOp, 1);
923 Value iW = tensor::DimOp::create(rewriter, loc, poolingOp, 2);
926 iH = arith::SubIOp::create(rewriter, loc, iH, one);
927 iW = arith::SubIOp::create(rewriter, loc, iW, one);
929 Value genericEmptyTensor = tensor::EmptyOp::create(
930 rewriter, loc, resultTy.getShape(), resultETy, dynamicDims);
932 auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
933 auto genericOp = linalg::GenericOp::create(
934 rewriter, loc, ArrayRef<Type>({resultTy}),
ValueRange{poolingOp},
936 ArrayRef<AffineMap>({affineMap, affineMap}),
943 auto padFn = [&](Value valid, Value pos, int64_t pad) -> Value {
948 Value dpos = arith::SubIOp::create(rewriter, loc, pos, padVal);
950 Value offset = arith::MinSIOp::create(rewriter, loc, dpos, zero);
951 return arith::AddIOp::create(rewriter, loc, valid, offset)
955 auto coverageFn = [&](int64_t i, Value isize) -> Value {
962 Value left = linalg::IndexOp::create(rewriter, loc, i);
963 Value right = arith::SubIOp::create(rewriter, loc, isize, left);
964 left = arith::MulIOp::create(rewriter, loc, left, strideVal);
965 right = arith::MulIOp::create(rewriter, loc, right, strideVal);
968 val = padFn(val, left, pad[i * 2]);
969 val = padFn(val, right, pad[i * 2 + 1]);
970 return arith::MaxSIOp::create(rewriter, loc, one, val);
974 Value kH3 = coverageFn(1, iH);
975 Value kW3 = coverageFn(2, iW);
978 auto count = arith::IndexCastOp::create(
979 rewriter, loc, rewriter.getI32Type(),
980 arith::MulIOp::create(rewriter, loc, kH3, kW3));
985 Value poolVal = args[0];
986 if (isa<FloatType>(accETy)) {
987 auto countF = arith::SIToFPOp::create(rewriter, loc, accETy, count);
988 poolVal = arith::DivFOp::create(rewriter, loc, poolVal, countF)
993 arith::TruncFOp::create(rewriter, loc, resultETy, poolVal);
998 if (inputZpVal != 0) {
999 auto inputZp = arith::ConstantOp::create(
1000 rewriter, loc,
b.getIntegerAttr(accETy, inputZpVal));
1002 arith::MulIOp::create(rewriter, loc, accETy, count, inputZp);
1004 arith::SubIOp::create(rewriter, loc, accETy, poolVal, offset);
1008 Value one32 = arith::ConstantOp::create(
1009 rewriter, loc, rewriter.getI32IntegerAttr(1));
1010 Value thirtyTwo32 = arith::ConstantOp::create(
1011 rewriter, loc, rewriter.getI32IntegerAttr(32));
1014 arith::SubIOp::create(rewriter, loc, count, one32);
1015 Value leadingZeros =
1016 math::CountLeadingZerosOp::create(rewriter, loc, countSubOne);
1018 arith::SubIOp::create(rewriter, loc, thirtyTwo32, leadingZeros);
1022 arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), k);
1023 Value thirtyShiftPlusOne = arith::ConstantOp::create(
1024 rewriter, loc, rewriter.getI64IntegerAttr((1 << 30) + 1));
1026 arith::ShLIOp::create(rewriter, loc, thirtyShiftPlusOne, k64);
1029 Value count64 = arith::ExtUIOp::create(
1030 rewriter, loc, rewriter.getI64Type(), count);
1032 arith::DivUIOp::create(rewriter, loc, numerator, count64);
1033 multiplier = arith::TruncIOp::create(
1034 rewriter, loc, rewriter.getI32Type(), multiplier);
1038 arith::TruncIOp::create(rewriter, loc, rewriter.getI8Type(), k);
1039 Value thirty8 = arith::ConstantOp::create(
1040 rewriter, loc, rewriter.getI8IntegerAttr(30));
1041 Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
1043 auto roundingAttr = RoundingModeAttr::get(
1044 rewriter.getContext(), RoundingMode::SINGLE_ROUND);
1046 auto scaled = tosa::ApplyScaleOp::create(
1047 rewriter, loc, rewriter.getI32Type(), poolVal,
1048 multiplier, shift, roundingAttr)
1053 if (outputZpVal != 0) {
1054 auto outputZp = arith::ConstantOp::create(
1056 b.getIntegerAttr(scaled.getType(), outputZpVal));
1057 scaled = arith::AddIOp::create(rewriter, loc, scaled, outputZp)
1065 rewriter, loc, accETy,
1066 APInt::getSignedMinValue(outBitwidth).getSExtValue());
1068 rewriter, loc, accETy,
1069 APInt::getSignedMaxValue(outBitwidth).getSExtValue());
1077 arith::TruncIOp::create(rewriter, loc, resultETy, poolVal);
1081 linalg::YieldOp::create(rewriter, loc, poolVal);
1084 rewriter.replaceOp(op, genericOp.getResult(0));
1091 using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
1093 LogicalResult matchAndRewrite(tosa::TransposeOp op,
1094 PatternRewriter &rewriter)
const final {
1095 const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
1097 Location loc = op.getLoc();
1101 SmallVector<OpFoldResult> inputSizes =
1103 auto permutedSizes =
1107 tensor::EmptyOp::create(rewriter, loc, permutedSizes,
1108 op.getInput1().getType().getElementType());
1109 rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
1110 op, op.getInput1(), permutedInit,
1111 llvm::map_to_vector(constantPerms,
1112 [](int32_t v) -> int64_t {
return v; }));
1120 const TosaToLinalgNamedOptions &
options) {
1121 if (
options.preferConv2DKernelLayoutHWCF) {
1122 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
1123 linalg::Conv2DNhwcHwcfQOp>>(
1126 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
1127 linalg::Conv2DNhwcFhwcQOp>>(
1132 ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
1133 DepthwiseConvConverter,
1141 >(converter,
patterns->getContext());
static llvm::ManagedStatic< PassManagerOptions > options
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source, Value result)
static mlir::Value applyPad(Location loc, Value input, ArrayRef< int64_t > pad, TypedAttr padAttr, OpBuilder &rewriter)
static void createDepthwiseConvCollapseMap(int64_t outputRank, SmallVector< ReassociationExprs, 4 > &reassociationMap, OpBuilder &rewriter)
static mlir::Value linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef< AffineMap > indexingMaps)
static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim, int64_t padBeforeAttr, int64_t padAfterAttr, Value kernelDim, int64_t strideAttr, int64_t dilationAttr, OpBuilder &rewriter)
static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter, Location loc, Value source, Value result)
static mlir::Value reifyConstantDim(int64_t attr, ImplicitLocOpBuilder &builder)
static SmallVector< Value > inferDynamicDimsForConv(Location loc, Value input, Value weight, ShapedType resultTy, ArrayRef< int64_t > padAttr, ArrayRef< int64_t > strideAttr, ArrayRef< int64_t > dilationAttr, ArrayRef< int64_t > inputSizeDims, ArrayRef< int64_t > kernelSizeDims, OpBuilder &rewriter)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
AffineExpr getAffineConstantExpr(int64_t constant)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
ArrayRef< T > asArrayRef() const
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
SmallVector< T > applyTOSAPermutation(ArrayRef< T > input, ArrayRef< int32_t > perms)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
void populateTosaToLinalgNamedConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options)
Populates conversion passes from TOSA dialect to Linalg named operations.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...