MLIR 22.0.0git
TosaToLinalgNamed.cpp
Go to the documentation of this file.
1//===- TosaToLinalgNamed.cpp - Lowering Tosa to Linalg Named Ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Linalg named ops.
10//
11//===----------------------------------------------------------------------===//
12
23
24#include <type_traits>
25
26using namespace mlir;
27using namespace mlir::tosa;
28
30 TypedAttr padAttr, OpBuilder &rewriter) {
31 // Input should be padded only if necessary.
32 if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
33 return input;
34
35 ShapedType inputTy = cast<ShapedType>(input.getType());
36 Type inputETy = inputTy.getElementType();
37 auto inputShape = inputTy.getShape();
38
39 assert((inputShape.size() * 2) == pad.size());
40
41 SmallVector<int64_t, 4> paddedShape;
44 for (size_t i : llvm::seq(inputShape.size())) {
45 auto lowPad = pad[i * 2];
46 auto highPad = pad[i * 2 + 1];
47 if (ShapedType::isDynamic(inputShape[i]))
48 paddedShape.push_back(inputShape[i]);
49 else
50 paddedShape.push_back(inputShape[i] + highPad + lowPad);
51 lowIndices.push_back(rewriter.getIndexAttr(lowPad));
52 highIndices.push_back(rewriter.getIndexAttr(highPad));
53 }
54
55 Value padValue = arith::ConstantOp::create(rewriter, loc, padAttr);
56
57 return tensor::PadOp::create(rewriter, loc,
58 RankedTensorType::get(paddedShape, inputETy),
59 input, lowIndices, highIndices, padValue);
60}
61
62static mlir::Value
64 Value conv, Value result,
65 ArrayRef<AffineMap> indexingMaps) {
66 ShapedType resultTy = cast<ShapedType>(conv.getType());
67 return linalg::GenericOp::create(
68 rewriter, loc, resultTy, ValueRange({bias, conv}), result,
69 indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
70 [](OpBuilder &builder, Location loc, ValueRange args) {
71 Value biasVal = args[0];
72 Type resType = args[1].getType();
73 if (resType != biasVal.getType()) {
74 biasVal =
75 arith::ExtSIOp::create(builder, loc, resType, biasVal);
76 }
77 Value added =
78 arith::AddIOp::create(builder, loc, biasVal, args[1]);
79 linalg::YieldOp::create(builder, loc, added);
80 })
81 .getResult(0);
82}
83
84// Construct the affine map that a linalg generic would use to broadcast the
85// source tensor into the shape of the result tensor.
87 Value result) {
88 ShapedType resultTy = cast<ShapedType>(result.getType());
89 ShapedType sourceTy = cast<ShapedType>(source.getType());
90 const int64_t resultRank = resultTy.getRank();
91 const int64_t sourceRank = sourceTy.getRank();
92
93 // The source tensor is broadcast to all the outer dimensions of the
94 // result tensor.
95 SmallVector<AffineExpr> sourceDims;
96 // In the case of a rank one source tensor with a single element TOSA
97 // specifies that the value be broadcast meaning we need an edge case for a
98 // constant map.
99 assert(sourceTy.hasStaticShape() &&
100 "Dynamic broadcasting shapes not supported!");
101 if (sourceRank == 1 && sourceTy.getDimSize(0) == 1) {
102 sourceDims.push_back(rewriter.getAffineConstantExpr(0));
103 } else {
104 for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
105 auto expr = rewriter.getAffineDimExpr(dim + resultRank - sourceRank);
106 sourceDims.push_back(expr);
107 }
108 }
109
110 return AffineMap::get(/*dimCount=*/resultRank,
111 /*symbolCount=*/0, sourceDims, rewriter.getContext());
112}
113
114// Broadcast the source value to all the outer dimensions of the result value.
115// If required, the element type is expanded using an arith.extsi or arith.extf
116// operation as appropriate.
118 Location loc, Value source,
119 Value result) {
120 ShapedType resultTy = cast<ShapedType>(result.getType());
121 const int64_t resultRank = resultTy.getRank();
122 // Creating maps for the input and output of the broacast-like generic op.
123 SmallVector<AffineMap, 2> indexingMaps;
124 indexingMaps.push_back(getBroadcastingMap(rewriter, source, result));
125 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
126
127 // Build the broadcast-like operation as a linalg.generic.
128 return linalg::GenericOp::create(
129 rewriter, loc, resultTy, ValueRange({source}), result,
130 indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
131 [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
132 Value biasVal = args[0];
133 Type resType = args[1].getType();
134 if (resType != biasVal.getType()) {
135 biasVal =
136 resultTy.getElementType().isFloat()
137 ? arith::ExtFOp::create(builder, loc, resType, biasVal)
138 .getResult()
139 : arith::ExtSIOp::create(builder, loc, resType,
140 biasVal)
141 .getResult();
142 }
143 linalg::YieldOp::create(builder, loc, biasVal);
144 })
145 .getResult(0);
146}
147
149 ImplicitLocOpBuilder &builder) {
150 return arith::ConstantIndexOp::create(builder, attr);
151}
152
153// Calculating the output width/height using the formula:
154// H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
155// W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
156
158 int64_t padBeforeAttr,
159 int64_t padAfterAttr, Value kernelDim,
160 int64_t strideAttr,
161 int64_t dilationAttr,
162 OpBuilder &rewriter) {
163 ImplicitLocOpBuilder builder(loc, rewriter);
164 auto one = arith::ConstantOp::create(rewriter, loc,
165 IntegerAttr::get(inputDim.getType(), 1));
166 Value padBefore = reifyConstantDim(padBeforeAttr, builder);
167 Value paddedBefore = arith::AddIOp::create(builder, inputDim, padBefore);
168 Value padAfter = reifyConstantDim(padAfterAttr, builder);
169 Value paddedAfter = arith::AddIOp::create(builder, paddedBefore, padAfter);
170
171 Value subOne = arith::SubIOp::create(builder, kernelDim, one);
172 Value dilation = reifyConstantDim(dilationAttr, builder);
173 Value dilated = arith::MulIOp::create(builder, dilation, subOne);
174 Value addOne = arith::AddIOp::create(builder, dilated, one);
175
176 Value subtract = arith::SubIOp::create(builder, paddedAfter, addOne);
177 Value stride = reifyConstantDim(strideAttr, builder);
178 Value divide = arith::DivUIOp::create(builder, subtract, stride);
179 return arith::AddIOp::create(builder, divide, one);
180}
181
182// Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
184 Location loc, Value input, Value weight, ShapedType resultTy,
185 ArrayRef<int64_t> padAttr, ArrayRef<int64_t> strideAttr,
186 ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
187 ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
188 ShapedType inputTy = cast<ShapedType>(input.getType());
189 int64_t inputRank = inputTy.getRank();
190
191 SmallVector<Value> dynDims;
192 dynDims.resize(resultTy.getRank());
193
194 for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
195 int64_t inputDim = inputSizeDims[i];
196 int64_t kernelDim = kernelSizeDims[i];
197 if (resultTy.isDynamicDim(inputDim)) {
198 auto padTop = padAttr[i * 2];
199 auto padBottom = padAttr[i * 2 + 1];
200 auto stride = strideAttr[i];
201 auto dilation = dilationAttr[i];
202 Value initDynDim = tensor::DimOp::create(rewriter, loc, input, inputDim);
203 Value kernelDynDim =
204 tensor::DimOp::create(rewriter, loc, weight, kernelDim);
205 // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
206 dynDims[inputDim] =
207 getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom,
208 kernelDynDim, stride, dilation, rewriter);
209 }
210 }
211
212 // Get the batch/channels dimensions.
213 for (int i = 0; i < inputRank; i++) {
214 if (resultTy.isDynamicDim(i) && !dynDims[i])
215 dynDims[i] = tensor::DimOp::create(rewriter, loc, input, i);
216 }
217
218 SmallVector<Value> filteredDims = condenseValues(dynDims);
219 return filteredDims;
220}
221
222// Creates a map to collapse the last dimension of the Depthwise convolution op
223// due to a shape mismatch
225 int64_t outputRank, SmallVector<ReassociationExprs, 4> &reassociationMap,
226 OpBuilder &rewriter) {
227 reassociationMap.resize(outputRank);
228 for (int i = 0; i < outputRank; i++) {
229 reassociationMap[i].push_back(rewriter.getAffineDimExpr(i));
230 }
231 reassociationMap[outputRank - 1].push_back(
232 rewriter.getAffineDimExpr(outputRank));
233}
234
235namespace {
236
237template <typename TosaConvOp, typename LinalgConvOp, typename LinalgConvQOp>
238class ConvConverter : public OpConversionPattern<TosaConvOp> {
239public:
240 using OpConversionPattern<TosaConvOp>::OpConversionPattern;
241 LogicalResult
242 matchAndRewrite(TosaConvOp op, typename TosaConvOp::Adaptor adaptor,
243 ConversionPatternRewriter &rewriter) const final {
244 Location loc = op->getLoc();
245 Value input = op->getOperand(0);
246 Value weight = op->getOperand(1);
247 Value bias = op->getOperand(2);
248
249 ShapedType inputTy = cast<ShapedType>(input.getType());
250 ShapedType weightTy = cast<ShapedType>(weight.getType());
251 ShapedType biasTy = cast<ShapedType>(bias.getType());
252 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
253
254 Type inputETy = inputTy.getElementType();
255
256 DenseI64ArrayAttr padAttr = op.getPadAttr();
257 DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
258 DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
259
260 Type accETy = op.getAccType();
261 Type accTy = RankedTensorType::get(resultTy.getShape(), accETy);
262
263 // Get and verify zero points.
264 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
265 if (failed(maybeIZp))
266 return rewriter.notifyMatchFailure(
267 op, "input zero point cannot be statically determined");
268
269 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
270 if (failed(maybeWZp))
271 return rewriter.notifyMatchFailure(
272 op, "weight zero point cannot be statically determined");
273
274 const int64_t inputZpVal = *maybeIZp;
275 const int64_t weightZpVal = *maybeWZp;
276
277 if (op.verifyInputZeroPoint(inputZpVal).failed())
278 return rewriter.notifyMatchFailure(
279 op, "input zero point must be zero for non-int8 integer types");
280
281 if (op.verifyWeightZeroPoint(weightZpVal).failed())
282 return rewriter.notifyMatchFailure(
283 op, "weight zero point must be zero for non-int8 integer types");
284
285 bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
286
287 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
288 return rewriter.notifyMatchFailure(
289 op, "tosa.conv ops require static shapes for weight and bias");
290
291 if (inputETy.isUnsignedInteger())
292 return rewriter.notifyMatchFailure(
293 op, "tosa.conv ops does not support unsigned integer input");
294
295 llvm::SmallVector<int64_t> inputSizeDims;
296 llvm::SmallVector<int64_t> kernelSizeDims;
297 for (int i = 1; i < resultTy.getRank() - 1; i++) {
298 inputSizeDims.push_back(i);
299 kernelSizeDims.push_back(i);
300 }
301
302 SmallVector<Value> filteredDims = inferDynamicDimsForConv(
303 loc, input, weight, resultTy, padAttr.asArrayRef(),
304 strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
305 inputSizeDims, kernelSizeDims, rewriter);
306
307 auto weightShape = weightTy.getShape();
308
309 // Apply padding as necessary.
310 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
311 if (hasZp) {
312 int64_t intMin =
313 APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
314 .getSExtValue();
315 int64_t intMax =
316 APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
317 .getSExtValue();
318
319 if (inputZpVal < intMin || inputZpVal > intMax)
320 return rewriter.notifyMatchFailure(
321 op, "tosa.conv op quantization has zp outside of input range");
322
323 zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
324 }
325
326 llvm::SmallVector<int64_t> pad;
327 pad.resize(2, 0);
328 llvm::append_range(pad, padAttr.asArrayRef());
329 pad.resize(pad.size() + 2, 0);
330 input = applyPad(loc, input, pad, zeroAttr, rewriter);
331
332 if (4 == inputTy.getRank()) {
333 // For 2D convolutions, we need to check if the target convolution op
334 // wants a HWCF kernel layout.
335 bool wantHwcf =
336 hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
337 : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
338 if (wantHwcf) {
339 // Transpose the kernel to match dimension ordering of the linalg
340 // convolution operation.
341 // TODO(suderman): See if this can be efficiently folded - check whether
342 // the input is used anywhere else, if not fold the constant.
343 SmallVector<int32_t> weightPerm;
344 for (int i = 1; i < resultTy.getRank(); i++)
345 weightPerm.push_back(i);
346 weightPerm.push_back(0);
347
348 SmallVector<int64_t> newWeightShape;
349 for (auto dim : weightPerm)
350 newWeightShape.push_back(weightShape[dim]);
351 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
352 Type newWeightTy =
353 RankedTensorType::get(newWeightShape, weightTy.getElementType());
354 weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
355 weightPermAttr);
356 }
357 }
358
359 // For Conv3D transpose the kernel to match dimension ordering of the linalg
360 // convolution operation. Conv2D has a 1-1 mapping in linalg so better to
361 // map directly and then transpose later if desired.
362 if (5 == inputTy.getRank()) {
363 // TODO(suderman): See if this can be efficiently folded - check whether
364 // the input is used anywhere else, if not fold the constant.
365 SmallVector<int32_t> weightPerm;
366 for (int i = 1; i < resultTy.getRank(); i++)
367 weightPerm.push_back(i);
368 weightPerm.push_back(0);
369
370 SmallVector<int64_t> newWeightShape;
371 for (auto dim : weightPerm)
372 newWeightShape.push_back(weightShape[dim]);
373 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
374 Type newWeightTy =
375 RankedTensorType::get(newWeightShape, weightTy.getElementType());
376 weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
377 weightPermAttr);
378 }
379
380 // Extract the attributes for convolution.
381 ArrayRef<int64_t> stride = strideTosaAttr;
382 ArrayRef<int64_t> dilation = dilationTosaAttr;
383
384 // Create the convolution op.
385 auto strideAttr = rewriter.getI64TensorAttr(stride);
386 auto dilationAttr = rewriter.getI64TensorAttr(dilation);
387
388 Value biasEmptyTensor = tensor::EmptyOp::create(
389 rewriter, loc, resultTy.getShape(), accETy, filteredDims);
390
391 Value broadcastBias =
392 linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor);
393
394 if (hasZp) {
395 auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
396 auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
397
398 auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
399 auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp);
400
401 Value conv = LinalgConvQOp::create(
402 rewriter, loc, resultTy,
403 ValueRange{input, weight, iZpVal, kZpVal},
404 ValueRange{broadcastBias}, strideAttr, dilationAttr)
405 ->getResult(0);
406
407 rewriter.replaceOp(op, conv);
408 return success();
409 }
410
411 Value conv = LinalgConvOp::create(
412 rewriter, loc, accTy, ValueRange{input, weight},
413 ValueRange{broadcastBias}, strideAttr, dilationAttr)
414 ->getResult(0);
415
416 // We may need to truncate back to the result type if the accumulator was
417 // wider than the result.
418 if (resultTy != accTy)
419 conv = tosa::CastOp::create(rewriter, loc, resultTy, conv);
420
421 rewriter.replaceOp(op, conv);
422 return success();
423 }
424};
425
426class DepthwiseConvConverter
427 : public OpConversionPattern<tosa::DepthwiseConv2DOp> {
428public:
429 using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern;
430 LogicalResult
431 matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
432 ConversionPatternRewriter &rewriter) const final {
433 Location loc = op->getLoc();
434 Value input = op->getOperand(0);
435 Value weight = op->getOperand(1);
436 Value bias = op->getOperand(2);
437
438 ShapedType inputTy = cast<ShapedType>(input.getType());
439 ShapedType weightTy = cast<ShapedType>(weight.getType());
440 ShapedType biasTy = cast<ShapedType>(bias.getType());
441 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
442 int64_t resultRank = resultTy.getRank();
443
444 Type inputETy = inputTy.getElementType();
445 Type resultETy = resultTy.getElementType();
446
447 auto padAttr = cast<DenseI64ArrayAttr>(op->getAttr("pad"));
448 auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride"));
449 auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation"));
450
451 Type accETy = op.getAccType();
452
453 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
454 return rewriter.notifyMatchFailure(
455 op, "tosa.depthwise_conv ops require static shapes");
456
457 // Compute output dynamic dims
458 SmallVector<Value> filteredDims = inferDynamicDimsForConv(
459 loc, input, weight, resultTy, padAttr.asArrayRef(),
460 strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
461 /*inputSizeDims=*/{1, 2},
462 /*kernelSizeDims=*/{0, 1}, rewriter);
463
464 // Get and verify zero points.
465
466 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
467 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
468 if (failed(maybeIZp))
469 return rewriter.notifyMatchFailure(
470 op, "input zero point cannot be statically determined");
471 if (failed(maybeWZp))
472 return rewriter.notifyMatchFailure(
473 op, "weight zero point cannot be statically determined");
474
475 const int64_t inputZpVal = *maybeIZp;
476 const int64_t weightZpVal = *maybeWZp;
477
478 if (op.verifyInputZeroPoint(inputZpVal).failed())
479 return rewriter.notifyMatchFailure(
480 op, "input zero point must be zero for non-int8 integer types");
481
482 if (op.verifyWeightZeroPoint(weightZpVal).failed())
483 return rewriter.notifyMatchFailure(
484 op, "weight zero point must be zero for non-int8 integer types");
485
486 bool hasNullZps = (inputZpVal == 0) && (weightZpVal == 0);
487 auto weightShape = weightTy.getShape();
488 auto resultShape = resultTy.getShape();
489
490 // Apply padding as necessary.
491 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
492 if (!hasNullZps) {
493 int64_t intMin =
494 APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
495 .getSExtValue();
496 int64_t intMax =
497 APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
498 .getSExtValue();
499
500 if (inputZpVal < intMin || inputZpVal > intMax)
501 return rewriter.notifyMatchFailure(
502 op, "tosa.depthwise_conv op quantization has zp outside of input "
503 "range");
504
505 zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
506 }
507
508 llvm::SmallVector<int64_t> pad;
509 pad.resize(2, 0);
510 llvm::append_range(pad, padAttr.asArrayRef());
511 pad.resize(pad.size() + 2, 0);
512
513 input = applyPad(loc, input, pad, zeroAttr, rewriter);
514
515 // Extract the attributes for convolution.
516 ArrayRef<int64_t> stride = strideTosaAttr;
517 ArrayRef<int64_t> dilation = dilationTosaAttr;
518
519 // Create the convolution op.
520 auto strideAttr = rewriter.getI64TensorAttr(stride);
521 auto dilationAttr = rewriter.getI64TensorAttr(dilation);
522 ShapedType linalgConvTy =
523 RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
524 weightShape[2], weightShape[3]},
525 accETy);
526
527 auto resultZeroAttr = rewriter.getZeroAttr(accETy);
528 Value emptyTensor = tensor::EmptyOp::create(
529 rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims);
530 Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr);
531 Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
532 ValueRange{emptyTensor})
533 .result();
534
535 Value biasEmptyTensor = tensor::EmptyOp::create(
536 rewriter, loc, resultTy.getShape(), resultETy, filteredDims);
537
538 // Broadcast the initial value to the output tensor before convolving.
539 SmallVector<AffineMap, 4> indexingMaps;
540 indexingMaps.push_back(getBroadcastingMap(rewriter, bias, biasEmptyTensor));
541 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
542 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
543
544 if (hasNullZps) {
545 Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create(
546 rewriter, loc, linalgConvTy, ValueRange{input, weight},
547 ValueRange{zeroTensor}, strideAttr, dilationAttr)
548 .getResult(0);
549
550 // We may need to truncate back to the result type if the accumulator was
551 // wider than the result.
552 if (accETy != resultETy)
553 conv = tosa::CastOp::create(
554 rewriter, loc,
555 RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(),
556 resultETy),
557 conv);
558
559 SmallVector<ReassociationExprs, 4> reassociationMap;
560 createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
561 Value convReshape = tensor::CollapseShapeOp::create(
562 rewriter, loc, resultTy, conv, reassociationMap);
563
564 Value result =
565 linalg::GenericOp::create(
566 rewriter, loc, resultTy, ValueRange({bias, convReshape}),
567 biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultRank),
568 [&](OpBuilder &nestedBuilder, Location nestedLoc,
569 ValueRange args) {
570 Value added;
571 if (llvm::isa<FloatType>(inputETy))
572 added = arith::AddFOp::create(nestedBuilder, loc, args[0],
573 args[1]);
574 else
575 added = arith::AddIOp::create(nestedBuilder, loc, args[0],
576 args[1]);
577 linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
578 })
579 .getResult(0);
580 rewriter.replaceOp(op, result);
581 } else {
582 IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
583 IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
584 auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
585 auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp);
586 Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create(
587 rewriter, loc, linalgConvTy,
588 ValueRange{input, weight, iZpVal, kZpVal},
589 ValueRange{zeroTensor}, strideAttr, dilationAttr)
590 .getResult(0);
591 SmallVector<ReassociationExprs, 4> reassociationMap;
592 createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
593 Value convReshape = tensor::CollapseShapeOp::create(
594 rewriter, loc, resultTy, conv, reassociationMap);
596 rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps);
597 rewriter.replaceOp(op, result);
598 }
599 return success();
600 }
601};
602
603class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
604public:
605 using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
606 LogicalResult
607 matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
608 ConversionPatternRewriter &rewriter) const final {
609 Location loc = op.getLoc();
610
611 auto outputTy = cast<ShapedType>(op.getType());
612 auto outputElementTy = outputTy.getElementType();
613
614 SmallVector<Value> dynDims;
615 dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
616
617 if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
618 dynDims[0] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 0);
619 }
620
621 if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
622 dynDims[1] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 1);
623 }
624
625 if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
626 dynDims[2] = tensor::DimOp::create(rewriter, loc, op->getOperand(1), 2);
627 }
628
629 SmallVector<Value> filteredDims = condenseValues(dynDims);
630
631 auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
632 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
633 auto emptyTensor =
634 tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(),
635 outputTy.getElementType(), filteredDims);
636 Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
637 ValueRange{emptyTensor})
638 .result();
639
640 FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
641 FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
642 if (failed(maybeAZp))
643 return rewriter.notifyMatchFailure(
644 op, "input a zero point cannot be statically determined");
645 if (failed(maybeBZp))
646 return rewriter.notifyMatchFailure(
647 op, "input b zero point cannot be statically determined");
648
649 const int64_t aZpVal = *maybeAZp;
650 const int64_t bZpVal = *maybeBZp;
651
652 if (op.verifyAZeroPoint(aZpVal).failed())
653 return rewriter.notifyMatchFailure(
654 op, "input a zero point must be zero for non-int8 integer types");
655
656 if (op.verifyBZeroPoint(bZpVal).failed())
657 return rewriter.notifyMatchFailure(
658 op, "input b zero point must be zero for non-int8 integer types");
659
660 if (aZpVal == 0 && bZpVal == 0) {
661 rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
662 op, TypeRange{op.getType()},
663 ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
664 return success();
665 }
666
667 auto aZp = arith::ConstantOp::create(rewriter, loc,
668 rewriter.getI32IntegerAttr(aZpVal));
669 auto bZp = arith::ConstantOp::create(rewriter, loc,
670 rewriter.getI32IntegerAttr(bZpVal));
671 rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
672 op, TypeRange{op.getType()},
673 ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
674
675 return success();
676 }
677};
678
679class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
680public:
681 using OpConversionPattern::OpConversionPattern;
682
683 // Compute the dynamic output sizes of the maxpool operation.
684 static SmallVector<Value>
685 computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor,
686 ConversionPatternRewriter &rewriter) {
687 TensorType resultTy = op.getType();
688 Location loc = op.getLoc();
689
690 Value input = adaptor.getInput();
691 ArrayRef<int64_t> kernel = op.getKernel();
692 ArrayRef<int64_t> pad = op.getPad();
693 ArrayRef<int64_t> stride = op.getStride();
694
695 SmallVector<Value> dynamicDims;
696
697 // Batch dimension
698 if (resultTy.isDynamicDim(0))
699 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
700
701 // Height/width dimensions
702 for (int64_t dim : {1, 2}) {
703 if (!resultTy.isDynamicDim(dim))
704 continue;
705
706 // Index into the attribute arrays
707 int64_t index = dim - 1;
708
709 // Input height/width
710 Value ihw = tensor::DimOp::create(rewriter, loc, input, dim);
711
712 // Kernel height/width
713 Value khw = arith::ConstantIndexOp::create(rewriter, loc, kernel[index]);
714
715 // Output height/width
716 Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2],
717 pad[index * 2 + 1], khw, stride[index],
718 /*dilationAttr=*/1, rewriter);
719 dynamicDims.push_back(ohw);
720 }
721
722 // Channel dimension
723 if (resultTy.isDynamicDim(3))
724 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 3));
725
726 return dynamicDims;
727 }
728
729 LogicalResult
730 matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
731 ConversionPatternRewriter &rewriter) const final {
732 Location loc = op.getLoc();
733 Value input = adaptor.getInput();
734 ShapedType inputTy = cast<ShapedType>(input.getType());
735
736 bool isUnsigned = op.getType().getElementType().isUnsignedInteger();
737 ShapedType resultTy =
738 getTypeConverter()->convertType<ShapedType>(op.getType());
739 if (!resultTy)
740 return rewriter.notifyMatchFailure(op, "failed to convert type");
741 Type resultETy = inputTy.getElementType();
742
743 SmallVector<Value> dynamicDims =
744 computeDynamicOutputSizes(op, adaptor, rewriter);
745
746 // Determine what the initial value needs to be for the max pool op.
747 TypedAttr initialAttr;
748 if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16())
749 initialAttr = rewriter.getFloatAttr(
750 resultETy, APFloat::getLargest(
751 cast<FloatType>(resultETy).getFloatSemantics(), true));
752
753 else if (isUnsigned)
754 initialAttr = rewriter.getIntegerAttr(
755 resultETy, APInt::getZero(resultETy.getIntOrFloatBitWidth()));
756 else if (isa<IntegerType>(resultETy))
757 initialAttr = rewriter.getIntegerAttr(
758 resultETy,
759 APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
760
761 if (!initialAttr)
762 return rewriter.notifyMatchFailure(
763 op, "Unsupported initial value for tosa.maxpool_2d op");
764
765 // Apply padding as necessary.
766 llvm::SmallVector<int64_t> pad;
767 pad.resize(2, 0);
768 llvm::append_range(pad, op.getPad());
769 pad.resize(pad.size() + 2, 0);
770
771 Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
772
773 Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
774
775 ArrayRef<int64_t> kernel = op.getKernel();
776 ArrayRef<int64_t> stride = op.getStride();
777
778 Attribute strideAttr = rewriter.getI64VectorAttr(stride);
779 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
780
781 // Create the linalg op that performs pooling.
782 Value emptyTensor =
783 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
784 resultTy.getElementType(), dynamicDims);
785
786 Value filledEmptyTensor =
787 linalg::FillOp::create(rewriter, loc, initialValue, emptyTensor)
788 .result();
789
790 Value fakeWindowDims =
791 tensor::EmptyOp::create(rewriter, loc, kernel, resultETy);
792
793 if (isUnsigned) {
794 rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
795 op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
796 filledEmptyTensor, strideAttr, dilationAttr);
797 return llvm::success();
798 }
799
800 auto resultOp = linalg::PoolingNhwcMaxOp::create(
801 rewriter, op->getLoc(), ArrayRef<Type>{resultTy},
802 ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
803 dilationAttr);
804
805 NanPropagationMode nanMode = op.getNanMode();
806 rewriter.replaceOp(op, resultOp);
807
808 // NaN propagation has no meaning for non floating point types.
809 if (!isa<FloatType>(getElementTypeOrSelf(inputTy)))
810 return success();
811
812 // "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
813 // compare and select materialization is required.
814 //
815 // In the case of "IGNORE" we need to insert a compare and select. Since
816 // we've already produced a named op we will just take its body and modify
817 // it to include the appropriate checks. If the current value is NaN the
818 // old value of pool will be taken otherwise we use the result.
819 if (nanMode == NanPropagationMode::IGNORE) {
820 auto genericOp = linalg::GenericOp::create(
821 rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
822 resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
823 resultOp.getIteratorTypesArray(),
824 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
825 IRMapping map;
826 auto oldBlock = resultOp.getRegion().begin();
827 auto oldArgs = oldBlock->getArguments();
828 auto &oldMaxOp = *resultOp.getBlock()->begin();
829 map.map(oldArgs, blockArgs);
830 auto *newOp = opBuilder.clone(oldMaxOp, map);
831 Value isNaN =
832 arith::CmpFOp::create(opBuilder, loc, arith::CmpFPredicate::UNO,
833 blockArgs.front(), blockArgs.front());
834 auto selectOp = arith::SelectOp::create(
835 opBuilder, loc, isNaN, blockArgs.back(), newOp->getResult(0));
836 linalg::YieldOp::create(opBuilder, loc, selectOp.getResult());
837 });
838 rewriter.replaceOp(resultOp, genericOp);
839 }
840
841 return success();
842 }
843};
844
845class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
846public:
847 using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern;
848
849 LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
850 PatternRewriter &rewriter) const final {
851 Location loc = op.getLoc();
852 Value input = op.getInput();
853 ShapedType inputTy = cast<ShapedType>(input.getType());
854 Type inElementTy = inputTy.getElementType();
855
856 ShapedType resultTy = cast<ShapedType>(op.getType());
857 Type resultETy = cast<ShapedType>(op.getType()).getElementType();
858
859 Type accETy = op.getAccType();
860 ShapedType accTy = resultTy.clone(accETy);
861
862 auto dynamicDimsOr =
863 checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
864 if (!dynamicDimsOr.has_value())
865 return failure();
866 SmallVector<Value> dynamicDims = *dynamicDimsOr;
867
868 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
869 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
870 if (failed(maybeIZp))
871 return rewriter.notifyMatchFailure(
872 op, "input zero point could not be statically determined");
873 if (failed(maybeOZp))
874 return rewriter.notifyMatchFailure(
875 op, "output zero point could not be statically determined");
876
877 const int64_t inputZpVal = *maybeIZp;
878 const int64_t outputZpVal = *maybeOZp;
879
880 // Apply padding as necessary.
881 llvm::SmallVector<int64_t> pad;
882 pad.resize(2, 0);
883 llvm::append_range(pad, op.getPad());
884 pad.resize(pad.size() + 2, 0);
885 TypedAttr padAttr = rewriter.getZeroAttr(inElementTy);
886 // Unsupported element type
887 if (!padAttr)
888 return failure();
889 Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
890
891 auto initialAttr = rewriter.getZeroAttr(accETy);
892 Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
893
894 ArrayRef<int64_t> kernel = op.getKernel();
895 ArrayRef<int64_t> stride = op.getStride();
896
897 Attribute strideAttr = rewriter.getI64VectorAttr(stride);
898 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
899
900 // Create the linalg op that performs pooling.
901 Value poolEmptyTensor = tensor::EmptyOp::create(
902 rewriter, loc, accTy.getShape(), accETy, dynamicDims);
903
904 Value filledEmptyTensor =
905 linalg::FillOp::create(rewriter, loc, ValueRange{initialValue},
906 ValueRange{poolEmptyTensor})
907 .result();
908
909 Value fakeWindowDims =
910 tensor::EmptyOp::create(rewriter, loc, kernel, accETy);
911
912 // Sum across the pooled region.
913 Value poolingOp = linalg::PoolingNhwcSumOp::create(
914 rewriter, loc, ArrayRef<Type>{accTy},
915 ValueRange{paddedInput, fakeWindowDims},
916 filledEmptyTensor, strideAttr, dilationAttr)
917 .getResult(0);
918
919 // Normalize the summed value by the number of elements grouped in each
920 // pool.
921 Value iH = tensor::DimOp::create(rewriter, loc, poolingOp, 1);
922 Value iW = tensor::DimOp::create(rewriter, loc, poolingOp, 2);
923
924 auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
925 iH = arith::SubIOp::create(rewriter, loc, iH, one);
926 iW = arith::SubIOp::create(rewriter, loc, iW, one);
927
928 Value genericEmptyTensor = tensor::EmptyOp::create(
929 rewriter, loc, resultTy.getShape(), resultETy, dynamicDims);
930
931 auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
932 auto genericOp = linalg::GenericOp::create(
933 rewriter, loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
934 ValueRange{genericEmptyTensor},
935 ArrayRef<AffineMap>({affineMap, affineMap}),
936 getNParallelLoopsAttrs(resultTy.getRank()),
937 [&](OpBuilder &b, Location loc, ValueRange args) {
938 auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
939
940 // Determines what the portion of valid input is covered by the
941 // kernel.
942 auto padFn = [&](Value valid, Value pos, int64_t pad) -> Value {
943 if (pad == 0)
944 return valid;
945
946 auto padVal = arith::ConstantIndexOp::create(rewriter, loc, pad);
947 Value dpos = arith::SubIOp::create(rewriter, loc, pos, padVal);
948
949 Value offset = arith::MinSIOp::create(rewriter, loc, dpos, zero);
950 return arith::AddIOp::create(rewriter, loc, valid, offset)
951 ->getResult(0);
952 };
953
954 auto coverageFn = [&](int64_t i, Value isize) -> Value {
955 Value strideVal =
956 arith::ConstantIndexOp::create(rewriter, loc, stride[i - 1]);
957 Value val =
958 arith::ConstantIndexOp::create(rewriter, loc, kernel[i - 1]);
959
960 // Find the position relative to the input tensor's ends.
961 Value left = linalg::IndexOp::create(rewriter, loc, i);
962 Value right = arith::SubIOp::create(rewriter, loc, isize, left);
963 left = arith::MulIOp::create(rewriter, loc, left, strideVal);
964 right = arith::MulIOp::create(rewriter, loc, right, strideVal);
965
966 // Determine how much padding was included.
967 val = padFn(val, left, pad[i * 2]);
968 val = padFn(val, right, pad[i * 2 + 1]);
969 return arith::MaxSIOp::create(rewriter, loc, one, val);
970 };
971
972 // Compute the indices from either end.
973 Value kH3 = coverageFn(1, iH);
974 Value kW3 = coverageFn(2, iW);
975
976 // Compute the total number of elements and normalize.
977 auto count = arith::IndexCastOp::create(
978 rewriter, loc, rewriter.getI32Type(),
979 arith::MulIOp::create(rewriter, loc, kH3, kW3));
980
981 // Divide by the number of summed values. For floats this is just
982 // a div however for quantized values input normalization had
983 // to be applied.
984 Value poolVal = args[0];
985 if (isa<FloatType>(accETy)) {
986 auto countF = arith::SIToFPOp::create(rewriter, loc, accETy, count);
987 poolVal = arith::DivFOp::create(rewriter, loc, poolVal, countF)
988 ->getResult(0);
989 if (accETy.getIntOrFloatBitWidth() >
990 resultETy.getIntOrFloatBitWidth())
991 poolVal =
992 arith::TruncFOp::create(rewriter, loc, resultETy, poolVal);
993 } else {
994
995 // If we have quantization information we need to apply an offset
996 // for the input zp value.
997 if (inputZpVal != 0) {
998 auto inputZp = arith::ConstantOp::create(
999 rewriter, loc, b.getIntegerAttr(accETy, inputZpVal));
1000 Value offset =
1001 arith::MulIOp::create(rewriter, loc, accETy, count, inputZp);
1002 poolVal =
1003 arith::SubIOp::create(rewriter, loc, accETy, poolVal, offset);
1004 }
1005
1006 // Compute: k = 32 - count_leading_zeros(value - 1)
1007 Value one32 = arith::ConstantOp::create(
1008 rewriter, loc, rewriter.getI32IntegerAttr(1));
1009 Value thirtyTwo32 = arith::ConstantOp::create(
1010 rewriter, loc, rewriter.getI32IntegerAttr(32));
1011
1012 Value countSubOne =
1013 arith::SubIOp::create(rewriter, loc, count, one32);
1014 Value leadingZeros =
1015 math::CountLeadingZerosOp::create(rewriter, loc, countSubOne);
1016 Value k =
1017 arith::SubIOp::create(rewriter, loc, thirtyTwo32, leadingZeros);
1018
1019 // Compute: numerator = ((1 << 30) + 1) << k
1020 Value k64 =
1021 arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), k);
1022 Value thirtyShiftPlusOne = arith::ConstantOp::create(
1023 rewriter, loc, rewriter.getI64IntegerAttr((1 << 30) + 1));
1024 Value numerator =
1025 arith::ShLIOp::create(rewriter, loc, thirtyShiftPlusOne, k64);
1026
1027 // Compute: scale.multiplier = numerator / value;
1028 Value count64 = arith::ExtUIOp::create(
1029 rewriter, loc, rewriter.getI64Type(), count);
1030 Value multiplier =
1031 arith::DivUIOp::create(rewriter, loc, numerator, count64);
1032 multiplier = arith::TruncIOp::create(
1033 rewriter, loc, rewriter.getI32Type(), multiplier);
1034
1035 // Compute: scale.shift = 30 + k
1036 Value k8 =
1037 arith::TruncIOp::create(rewriter, loc, rewriter.getI8Type(), k);
1038 Value thirty8 = arith::ConstantOp::create(
1039 rewriter, loc, rewriter.getI8IntegerAttr(30));
1040 Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
1041
1042 auto roundingAttr = RoundingModeAttr::get(
1043 rewriter.getContext(), RoundingMode::SINGLE_ROUND);
1044
1045 auto scaled = tosa::ApplyScaleOp::create(
1046 rewriter, loc, rewriter.getI32Type(), poolVal,
1047 multiplier, shift, roundingAttr)
1048 .getResult();
1049
1050 // If we have quantization information we need to apply output
1051 // zeropoint.
1052 if (outputZpVal != 0) {
1053 auto outputZp = arith::ConstantOp::create(
1054 rewriter, loc,
1055 b.getIntegerAttr(scaled.getType(), outputZpVal));
1056 scaled = arith::AddIOp::create(rewriter, loc, scaled, outputZp)
1057 .getResult();
1058 }
1059
1060 // Apply Clip.
1061 int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
1062
1064 rewriter, loc, accETy,
1065 APInt::getSignedMinValue(outBitwidth).getSExtValue());
1067 rewriter, loc, accETy,
1068 APInt::getSignedMaxValue(outBitwidth).getSExtValue());
1069 auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
1070 /*isUnsigned=*/false);
1071
1072 poolVal = clamp;
1073 // Convert type.
1074 if (resultETy != clamp.getType()) {
1075 poolVal =
1076 arith::TruncIOp::create(rewriter, loc, resultETy, poolVal);
1077 }
1078 }
1079
1080 linalg::YieldOp::create(rewriter, loc, poolVal);
1081 });
1082
1083 rewriter.replaceOp(op, genericOp.getResult(0));
1084 return success();
1085 }
1086};
1087
1088class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1089public:
1090 using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
1091
1092 LogicalResult matchAndRewrite(tosa::TransposeOp op,
1093 PatternRewriter &rewriter) const final {
1094 const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
1095
1096 Location loc = op.getLoc();
1097 // The verifier should have made sure we have a valid TOSA permutation
1098 // tensor. isPermutationVector doesn't actually check the TOSA perms we
1099 // expect.
1100 SmallVector<OpFoldResult> inputSizes =
1101 tensor::getMixedSizes(rewriter, loc, op.getInput1());
1102 auto permutedSizes =
1103 applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms);
1104
1105 auto permutedInit =
1106 tensor::EmptyOp::create(rewriter, loc, permutedSizes,
1107 op.getInput1().getType().getElementType());
1108 rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
1109 op, op.getInput1(), permutedInit,
1110 llvm::to_vector(llvm::map_range(
1111 constantPerms, [](int32_t v) -> int64_t { return v; })));
1112 return success();
1113 }
1114};
1115} // namespace
1116
1118 const TypeConverter &converter, RewritePatternSet *patterns,
1120 if (options.preferConv2DKernelLayoutHWCF) {
1121 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
1122 linalg::Conv2DNhwcHwcfQOp>>(
1123 patterns->getContext());
1124 } else {
1125 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
1126 linalg::Conv2DNhwcFhwcQOp>>(
1127 patterns->getContext());
1128 }
1129 patterns->add<
1130 // clang-format off
1131 ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
1132 DepthwiseConvConverter,
1133 MatMulConverter,
1134 AvgPool2dConverter,
1135 TransposeConverter
1136 >(patterns->getContext());
1137
1138 patterns->add<
1139 MaxPool2dConverter
1140 >(converter, patterns->getContext());
1141 // clang-format on
1142}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source, Value result)
static mlir::Value applyPad(Location loc, Value input, ArrayRef< int64_t > pad, TypedAttr padAttr, OpBuilder &rewriter)
static void createDepthwiseConvCollapseMap(int64_t outputRank, SmallVector< ReassociationExprs, 4 > &reassociationMap, OpBuilder &rewriter)
static mlir::Value linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef< AffineMap > indexingMaps)
static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim, int64_t padBeforeAttr, int64_t padAfterAttr, Value kernelDim, int64_t strideAttr, int64_t dilationAttr, OpBuilder &rewriter)
static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter, Location loc, Value source, Value result)
static mlir::Value reifyConstantDim(int64_t attr, ImplicitLocOpBuilder &builder)
static SmallVector< Value > inferDynamicDimsForConv(Location loc, Value input, Value weight, ShapedType resultTy, ArrayRef< int64_t > padAttr, ArrayRef< int64_t > strideAttr, ArrayRef< int64_t > dilationAttr, ArrayRef< int64_t > inputSizeDims, ArrayRef< int64_t > kernelSizeDims, OpBuilder &rewriter)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:387
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:372
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
MLIRContext * getContext() const
Definition Builders.h:56
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:258
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
SmallVector< T > applyTOSAPermutation(ArrayRef< T > input, ArrayRef< int32_t > perms)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
void populateTosaToLinalgNamedConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options)
Populates conversion passes from TOSA dialect to Linalg named operations.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...