MLIR 23.0.0git
TosaToLinalgNamed.cpp
Go to the documentation of this file.
1//===- TosaToLinalgNamed.cpp - Lowering Tosa to Linalg Named Ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Linalg named ops.
10//
11//===----------------------------------------------------------------------===//
12
23#include "llvm/ADT/SmallVectorExtras.h"
24
25#include <type_traits>
26
27using namespace mlir;
28using namespace mlir::tosa;
29
31 TypedAttr padAttr, OpBuilder &rewriter) {
32 // Input should be padded only if necessary.
33 if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
34 return input;
35
36 ShapedType inputTy = cast<ShapedType>(input.getType());
37 Type inputETy = inputTy.getElementType();
38 auto inputShape = inputTy.getShape();
39
40 assert((inputShape.size() * 2) == pad.size());
41
42 SmallVector<int64_t, 4> paddedShape;
45 for (size_t i : llvm::seq(inputShape.size())) {
46 auto lowPad = pad[i * 2];
47 auto highPad = pad[i * 2 + 1];
48 if (ShapedType::isDynamic(inputShape[i]))
49 paddedShape.push_back(inputShape[i]);
50 else
51 paddedShape.push_back(inputShape[i] + highPad + lowPad);
52 lowIndices.push_back(rewriter.getIndexAttr(lowPad));
53 highIndices.push_back(rewriter.getIndexAttr(highPad));
54 }
55
56 Value padValue = arith::ConstantOp::create(rewriter, loc, padAttr);
57
58 return tensor::PadOp::create(rewriter, loc,
59 RankedTensorType::get(paddedShape, inputETy),
60 input, lowIndices, highIndices, padValue);
61}
62
63static mlir::Value
65 Value conv, Value result,
66 ArrayRef<AffineMap> indexingMaps) {
67 ShapedType resultTy = cast<ShapedType>(conv.getType());
68 return linalg::GenericOp::create(
69 rewriter, loc, resultTy, ValueRange({bias, conv}), result,
70 indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
71 [](OpBuilder &builder, Location loc, ValueRange args) {
72 Value biasVal = args[0];
73 Type resType = args[1].getType();
74 if (resType != biasVal.getType()) {
75 biasVal =
76 arith::ExtSIOp::create(builder, loc, resType, biasVal);
77 }
78 Value added =
79 arith::AddIOp::create(builder, loc, biasVal, args[1]);
80 linalg::YieldOp::create(builder, loc, added);
81 })
82 .getResult(0);
83}
84
85// Construct the affine map that a linalg generic would use to broadcast the
86// source tensor into the shape of the result tensor.
88 Value result) {
89 ShapedType resultTy = cast<ShapedType>(result.getType());
90 ShapedType sourceTy = cast<ShapedType>(source.getType());
91 const int64_t resultRank = resultTy.getRank();
92 const int64_t sourceRank = sourceTy.getRank();
93
94 // The source tensor is broadcast to all the outer dimensions of the
95 // result tensor.
96 SmallVector<AffineExpr> sourceDims;
97 // In the case of a rank one source tensor with a single element TOSA
98 // specifies that the value be broadcast meaning we need an edge case for a
99 // constant map.
100 assert(sourceTy.hasStaticShape() &&
101 "Dynamic broadcasting shapes not supported!");
102 if (sourceRank == 1 && sourceTy.getDimSize(0) == 1) {
103 sourceDims.push_back(rewriter.getAffineConstantExpr(0));
104 } else {
105 for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
106 auto expr = rewriter.getAffineDimExpr(dim + resultRank - sourceRank);
107 sourceDims.push_back(expr);
108 }
109 }
110
111 return AffineMap::get(/*dimCount=*/resultRank,
112 /*symbolCount=*/0, sourceDims, rewriter.getContext());
113}
114
115// Broadcast the source value to all the outer dimensions of the result value.
116// If required, the element type is expanded using an arith.extsi or arith.extf
117// operation as appropriate.
119 Location loc, Value source,
120 Value result) {
121 ShapedType resultTy = cast<ShapedType>(result.getType());
122 const int64_t resultRank = resultTy.getRank();
123 // Creating maps for the input and output of the broacast-like generic op.
124 SmallVector<AffineMap, 2> indexingMaps;
125 indexingMaps.push_back(getBroadcastingMap(rewriter, source, result));
126 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
127
128 // Build the broadcast-like operation as a linalg.generic.
129 return linalg::GenericOp::create(
130 rewriter, loc, resultTy, ValueRange({source}), result,
131 indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
132 [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
133 Value biasVal = args[0];
134 Type resType = args[1].getType();
135 if (resType != biasVal.getType()) {
136 biasVal =
137 resultTy.getElementType().isFloat()
138 ? arith::ExtFOp::create(builder, loc, resType, biasVal)
139 .getResult()
140 : arith::ExtSIOp::create(builder, loc, resType,
141 biasVal)
142 .getResult();
143 }
144 linalg::YieldOp::create(builder, loc, biasVal);
145 })
146 .getResult(0);
147}
148
150 ImplicitLocOpBuilder &builder) {
151 return arith::ConstantIndexOp::create(builder, attr);
152}
153
154// Calculating the output width/height using the formula:
155// H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
156// W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
157
159 int64_t padBeforeAttr,
160 int64_t padAfterAttr, Value kernelDim,
161 int64_t strideAttr,
162 int64_t dilationAttr,
163 OpBuilder &rewriter) {
164 ImplicitLocOpBuilder builder(loc, rewriter);
165 auto one = arith::ConstantOp::create(rewriter, loc,
166 IntegerAttr::get(inputDim.getType(), 1));
167 Value padBefore = reifyConstantDim(padBeforeAttr, builder);
168 Value paddedBefore = arith::AddIOp::create(builder, inputDim, padBefore);
169 Value padAfter = reifyConstantDim(padAfterAttr, builder);
170 Value paddedAfter = arith::AddIOp::create(builder, paddedBefore, padAfter);
171
172 Value subOne = arith::SubIOp::create(builder, kernelDim, one);
173 Value dilation = reifyConstantDim(dilationAttr, builder);
174 Value dilated = arith::MulIOp::create(builder, dilation, subOne);
175 Value addOne = arith::AddIOp::create(builder, dilated, one);
176
177 Value subtract = arith::SubIOp::create(builder, paddedAfter, addOne);
178 Value stride = reifyConstantDim(strideAttr, builder);
179 Value divide = arith::DivUIOp::create(builder, subtract, stride);
180 return arith::AddIOp::create(builder, divide, one);
181}
182
183// Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
185 Location loc, Value input, Value weight, ShapedType resultTy,
186 ArrayRef<int64_t> padAttr, ArrayRef<int64_t> strideAttr,
187 ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
188 ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
189 ShapedType inputTy = cast<ShapedType>(input.getType());
190 int64_t inputRank = inputTy.getRank();
191
192 SmallVector<Value> dynDims;
193 dynDims.resize(resultTy.getRank());
194
195 for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
196 int64_t inputDim = inputSizeDims[i];
197 int64_t kernelDim = kernelSizeDims[i];
198 if (resultTy.isDynamicDim(inputDim)) {
199 auto padTop = padAttr[i * 2];
200 auto padBottom = padAttr[i * 2 + 1];
201 auto stride = strideAttr[i];
202 auto dilation = dilationAttr[i];
203 Value initDynDim = tensor::DimOp::create(rewriter, loc, input, inputDim);
204 Value kernelDynDim =
205 tensor::DimOp::create(rewriter, loc, weight, kernelDim);
206 // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
207 dynDims[inputDim] =
208 getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom,
209 kernelDynDim, stride, dilation, rewriter);
210 }
211 }
212
213 // Get the batch/channels dimensions.
214 for (int i = 0; i < inputRank; i++) {
215 if (resultTy.isDynamicDim(i) && !dynDims[i])
216 dynDims[i] = tensor::DimOp::create(rewriter, loc, input, i);
217 }
218
219 SmallVector<Value> filteredDims = condenseValues(dynDims);
220 return filteredDims;
221}
222
223// Creates a map to collapse the last dimension of the Depthwise convolution op
224// due to a shape mismatch
226 int64_t outputRank, SmallVector<ReassociationExprs, 4> &reassociationMap,
227 OpBuilder &rewriter) {
228 reassociationMap.resize(outputRank);
229 for (int i = 0; i < outputRank; i++) {
230 reassociationMap[i].push_back(rewriter.getAffineDimExpr(i));
231 }
232 reassociationMap[outputRank - 1].push_back(
233 rewriter.getAffineDimExpr(outputRank));
234}
235
236namespace {
237
238template <typename TosaConvOp, typename LinalgConvOp, typename LinalgConvQOp>
239class ConvConverter : public OpConversionPattern<TosaConvOp> {
240public:
241 using OpConversionPattern<TosaConvOp>::OpConversionPattern;
242 LogicalResult
243 matchAndRewrite(TosaConvOp op, typename TosaConvOp::Adaptor adaptor,
244 ConversionPatternRewriter &rewriter) const final {
245 Location loc = op->getLoc();
246 Value input = op->getOperand(0);
247 Value weight = op->getOperand(1);
248 Value bias = op->getOperand(2);
249
250 ShapedType inputTy = cast<ShapedType>(input.getType());
251 ShapedType weightTy = cast<ShapedType>(weight.getType());
252 ShapedType biasTy = cast<ShapedType>(bias.getType());
253 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
254
255 Type inputETy = inputTy.getElementType();
256
257 DenseI64ArrayAttr padAttr = op.getPadAttr();
258 DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
259 DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
260
261 Type accETy = op.getAccType();
262 Type accTy = RankedTensorType::get(resultTy.getShape(), accETy);
263
264 // Get and verify zero points.
265 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
266 if (failed(maybeIZp))
267 return rewriter.notifyMatchFailure(
268 op, "input zero point cannot be statically determined");
269
270 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
271 if (failed(maybeWZp))
272 return rewriter.notifyMatchFailure(
273 op, "weight zero point cannot be statically determined");
274
275 const int64_t inputZpVal = *maybeIZp;
276 const int64_t weightZpVal = *maybeWZp;
277
278 if (op.verifyInputZeroPoint(inputZpVal).failed())
279 return rewriter.notifyMatchFailure(
280 op, "input zero point must be zero for non-int8 integer types");
281
282 if (op.verifyWeightZeroPoint(weightZpVal).failed())
283 return rewriter.notifyMatchFailure(
284 op, "weight zero point must be zero for non-int8 integer types");
285
286 bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
287
288 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
289 return rewriter.notifyMatchFailure(
290 op, "tosa.conv ops require static shapes for weight and bias");
291
292 if (inputETy.isUnsignedInteger())
293 return rewriter.notifyMatchFailure(
294 op, "tosa.conv ops does not support unsigned integer input");
295
296 llvm::SmallVector<int64_t> inputSizeDims;
297 llvm::SmallVector<int64_t> kernelSizeDims;
298 for (int i = 1; i < resultTy.getRank() - 1; i++) {
299 inputSizeDims.push_back(i);
300 kernelSizeDims.push_back(i);
301 }
302
303 SmallVector<Value> filteredDims = inferDynamicDimsForConv(
304 loc, input, weight, resultTy, padAttr.asArrayRef(),
305 strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
306 inputSizeDims, kernelSizeDims, rewriter);
307
308 auto weightShape = weightTy.getShape();
309
310 // Apply padding as necessary.
311 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
312 if (hasZp) {
313 int64_t intMin =
314 APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
315 .getSExtValue();
316 int64_t intMax =
317 APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
318 .getSExtValue();
319
320 if (inputZpVal < intMin || inputZpVal > intMax)
321 return rewriter.notifyMatchFailure(
322 op, "tosa.conv op quantization has zp outside of input range");
323
324 zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
325 }
326
327 llvm::SmallVector<int64_t> pad;
328 pad.resize(2, 0);
329 llvm::append_range(pad, padAttr.asArrayRef());
330 pad.resize(pad.size() + 2, 0);
331 input = applyPad(loc, input, pad, zeroAttr, rewriter);
332
333 if (4 == inputTy.getRank()) {
334 // For 2D convolutions, we need to check if the target convolution op
335 // wants a HWCF kernel layout.
336 bool wantHwcf =
337 hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
338 : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
339 if (wantHwcf) {
340 // Transpose the kernel to match dimension ordering of the linalg
341 // convolution operation.
342 // TODO(suderman): See if this can be efficiently folded - check whether
343 // the input is used anywhere else, if not fold the constant.
344 SmallVector<int32_t> weightPerm;
345 for (int i = 1; i < resultTy.getRank(); i++)
346 weightPerm.push_back(i);
347 weightPerm.push_back(0);
348
349 SmallVector<int64_t> newWeightShape;
350 for (auto dim : weightPerm)
351 newWeightShape.push_back(weightShape[dim]);
352 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
353 Type newWeightTy =
354 RankedTensorType::get(newWeightShape, weightTy.getElementType());
355 weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
356 weightPermAttr);
357 }
358 }
359
360 // For Conv3D transpose the kernel to match dimension ordering of the linalg
361 // convolution operation. Conv2D has a 1-1 mapping in linalg so better to
362 // map directly and then transpose later if desired.
363 if (5 == inputTy.getRank()) {
364 // TODO(suderman): See if this can be efficiently folded - check whether
365 // the input is used anywhere else, if not fold the constant.
366 SmallVector<int32_t> weightPerm;
367 for (int i = 1; i < resultTy.getRank(); i++)
368 weightPerm.push_back(i);
369 weightPerm.push_back(0);
370
371 SmallVector<int64_t> newWeightShape;
372 for (auto dim : weightPerm)
373 newWeightShape.push_back(weightShape[dim]);
374 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
375 Type newWeightTy =
376 RankedTensorType::get(newWeightShape, weightTy.getElementType());
377 weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
378 weightPermAttr);
379 }
380
381 // Extract the attributes for convolution.
382 ArrayRef<int64_t> stride = strideTosaAttr;
383 ArrayRef<int64_t> dilation = dilationTosaAttr;
384
385 // Create the convolution op.
386 auto strideAttr = rewriter.getI64TensorAttr(stride);
387 auto dilationAttr = rewriter.getI64TensorAttr(dilation);
388
389 Value biasEmptyTensor = tensor::EmptyOp::create(
390 rewriter, loc, resultTy.getShape(), accETy, filteredDims);
391
392 Value broadcastBias =
393 linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor);
394
395 if (hasZp) {
396 auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
397 auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
398
399 auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
400 auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp);
401
402 Value conv = LinalgConvQOp::create(
403 rewriter, loc, resultTy,
404 ValueRange{input, weight, iZpVal, kZpVal},
405 ValueRange{broadcastBias}, strideAttr, dilationAttr)
406 ->getResult(0);
407
408 rewriter.replaceOp(op, conv);
409 return success();
410 }
411
412 Value conv = LinalgConvOp::create(
413 rewriter, loc, accTy, ValueRange{input, weight},
414 ValueRange{broadcastBias}, strideAttr, dilationAttr)
415 ->getResult(0);
416
417 // We may need to truncate back to the result type if the accumulator was
418 // wider than the result.
419 if (resultTy != accTy)
420 conv = tosa::CastOp::create(rewriter, loc, resultTy, conv);
421
422 rewriter.replaceOp(op, conv);
423 return success();
424 }
425};
426
427class DepthwiseConvConverter
428 : public OpConversionPattern<tosa::DepthwiseConv2DOp> {
429public:
430 using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern;
431 LogicalResult
432 matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
433 ConversionPatternRewriter &rewriter) const final {
434 Location loc = op->getLoc();
435 Value input = op->getOperand(0);
436 Value weight = op->getOperand(1);
437 Value bias = op->getOperand(2);
438
439 ShapedType inputTy = cast<ShapedType>(input.getType());
440 ShapedType weightTy = cast<ShapedType>(weight.getType());
441 ShapedType biasTy = cast<ShapedType>(bias.getType());
442 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
443 int64_t resultRank = resultTy.getRank();
444
445 Type inputETy = inputTy.getElementType();
446 Type resultETy = resultTy.getElementType();
447
448 auto padAttr = cast<DenseI64ArrayAttr>(op->getAttr("pad"));
449 auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride"));
450 auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation"));
451
452 Type accETy = op.getAccType();
453
454 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
455 return rewriter.notifyMatchFailure(
456 op, "tosa.depthwise_conv ops require static shapes");
457
458 // Compute output dynamic dims
459 SmallVector<Value> filteredDims = inferDynamicDimsForConv(
460 loc, input, weight, resultTy, padAttr.asArrayRef(),
461 strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
462 /*inputSizeDims=*/{1, 2},
463 /*kernelSizeDims=*/{0, 1}, rewriter);
464
465 // Get and verify zero points.
466
467 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
468 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
469 if (failed(maybeIZp))
470 return rewriter.notifyMatchFailure(
471 op, "input zero point cannot be statically determined");
472 if (failed(maybeWZp))
473 return rewriter.notifyMatchFailure(
474 op, "weight zero point cannot be statically determined");
475
476 const int64_t inputZpVal = *maybeIZp;
477 const int64_t weightZpVal = *maybeWZp;
478
479 if (op.verifyInputZeroPoint(inputZpVal).failed())
480 return rewriter.notifyMatchFailure(
481 op, "input zero point must be zero for non-int8 integer types");
482
483 if (op.verifyWeightZeroPoint(weightZpVal).failed())
484 return rewriter.notifyMatchFailure(
485 op, "weight zero point must be zero for non-int8 integer types");
486
487 bool hasNullZps = (inputZpVal == 0) && (weightZpVal == 0);
488 auto weightShape = weightTy.getShape();
489 auto resultShape = resultTy.getShape();
490
491 // Apply padding as necessary.
492 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
493 if (!hasNullZps) {
494 int64_t intMin =
495 APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
496 .getSExtValue();
497 int64_t intMax =
498 APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
499 .getSExtValue();
500
501 if (inputZpVal < intMin || inputZpVal > intMax)
502 return rewriter.notifyMatchFailure(
503 op, "tosa.depthwise_conv op quantization has zp outside of input "
504 "range");
505
506 zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
507 }
508
509 llvm::SmallVector<int64_t> pad;
510 pad.resize(2, 0);
511 llvm::append_range(pad, padAttr.asArrayRef());
512 pad.resize(pad.size() + 2, 0);
513
514 input = applyPad(loc, input, pad, zeroAttr, rewriter);
515
516 // Extract the attributes for convolution.
517 ArrayRef<int64_t> stride = strideTosaAttr;
518 ArrayRef<int64_t> dilation = dilationTosaAttr;
519
520 // Create the convolution op.
521 auto strideAttr = rewriter.getI64TensorAttr(stride);
522 auto dilationAttr = rewriter.getI64TensorAttr(dilation);
523 ShapedType linalgConvTy =
524 RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
525 weightShape[2], weightShape[3]},
526 accETy);
527
528 auto resultZeroAttr = rewriter.getZeroAttr(accETy);
529 Value emptyTensor = tensor::EmptyOp::create(
530 rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims);
531 Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr);
532 Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
533 ValueRange{emptyTensor})
534 .result();
535
536 Value biasEmptyTensor = tensor::EmptyOp::create(
537 rewriter, loc, resultTy.getShape(), resultETy, filteredDims);
538
539 // Broadcast the initial value to the output tensor before convolving.
540 SmallVector<AffineMap, 4> indexingMaps;
541 indexingMaps.push_back(getBroadcastingMap(rewriter, bias, biasEmptyTensor));
542 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
543 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
544
545 if (hasNullZps) {
546 Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create(
547 rewriter, loc, linalgConvTy, ValueRange{input, weight},
548 ValueRange{zeroTensor}, strideAttr, dilationAttr)
549 .getResult(0);
550
551 // We may need to truncate back to the result type if the accumulator was
552 // wider than the result.
553 if (accETy != resultETy)
554 conv = tosa::CastOp::create(
555 rewriter, loc,
556 RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(),
557 resultETy),
558 conv);
559
560 SmallVector<ReassociationExprs, 4> reassociationMap;
561 createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
562 Value convReshape = tensor::CollapseShapeOp::create(
563 rewriter, loc, resultTy, conv, reassociationMap);
564
565 Value result =
566 linalg::GenericOp::create(
567 rewriter, loc, resultTy, ValueRange({bias, convReshape}),
568 biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultRank),
569 [&](OpBuilder &nestedBuilder, Location nestedLoc,
570 ValueRange args) {
571 Value added;
572 if (llvm::isa<FloatType>(inputETy))
573 added = arith::AddFOp::create(nestedBuilder, loc, args[0],
574 args[1]);
575 else
576 added = arith::AddIOp::create(nestedBuilder, loc, args[0],
577 args[1]);
578 linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
579 })
580 .getResult(0);
581 rewriter.replaceOp(op, result);
582 } else {
583 IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
584 IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
585 auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
586 auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp);
587 Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create(
588 rewriter, loc, linalgConvTy,
589 ValueRange{input, weight, iZpVal, kZpVal},
590 ValueRange{zeroTensor}, strideAttr, dilationAttr)
591 .getResult(0);
592 SmallVector<ReassociationExprs, 4> reassociationMap;
593 createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
594 Value convReshape = tensor::CollapseShapeOp::create(
595 rewriter, loc, resultTy, conv, reassociationMap);
597 rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps);
598 rewriter.replaceOp(op, result);
599 }
600 return success();
601 }
602};
603
604class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
605public:
606 using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
607 LogicalResult
608 matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
609 ConversionPatternRewriter &rewriter) const final {
610 Location loc = op.getLoc();
611
612 auto outputTy = cast<ShapedType>(op.getType());
613 auto outputElementTy = outputTy.getElementType();
614
615 SmallVector<Value> dynDims;
616 dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
617
618 if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
619 dynDims[0] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 0);
620 }
621
622 if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
623 dynDims[1] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 1);
624 }
625
626 if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
627 dynDims[2] = tensor::DimOp::create(rewriter, loc, op->getOperand(1), 2);
628 }
629
630 SmallVector<Value> filteredDims = condenseValues(dynDims);
631
632 auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
633 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
634 auto emptyTensor =
635 tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(),
636 outputTy.getElementType(), filteredDims);
637 Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
638 ValueRange{emptyTensor})
639 .result();
640
641 FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
642 FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
643 if (failed(maybeAZp))
644 return rewriter.notifyMatchFailure(
645 op, "input a zero point cannot be statically determined");
646 if (failed(maybeBZp))
647 return rewriter.notifyMatchFailure(
648 op, "input b zero point cannot be statically determined");
649
650 const int64_t aZpVal = *maybeAZp;
651 const int64_t bZpVal = *maybeBZp;
652
653 if (op.verifyAZeroPoint(aZpVal).failed())
654 return rewriter.notifyMatchFailure(
655 op, "input a zero point must be zero for non-int8 integer types");
656
657 if (op.verifyBZeroPoint(bZpVal).failed())
658 return rewriter.notifyMatchFailure(
659 op, "input b zero point must be zero for non-int8 integer types");
660
661 if (aZpVal == 0 && bZpVal == 0) {
662 rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
663 op, TypeRange{op.getType()},
664 ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
665 return success();
666 }
667
668 auto aZp = arith::ConstantOp::create(rewriter, loc,
669 rewriter.getI32IntegerAttr(aZpVal));
670 auto bZp = arith::ConstantOp::create(rewriter, loc,
671 rewriter.getI32IntegerAttr(bZpVal));
672 rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
673 op, TypeRange{op.getType()},
674 ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
675
676 return success();
677 }
678};
679
680class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
681public:
682 using OpConversionPattern::OpConversionPattern;
683
684 // Compute the dynamic output sizes of the maxpool operation.
685 static SmallVector<Value>
686 computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor,
687 ConversionPatternRewriter &rewriter) {
688 TensorType resultTy = op.getType();
689 Location loc = op.getLoc();
690
691 Value input = adaptor.getInput();
692 ArrayRef<int64_t> kernel = op.getKernel();
693 ArrayRef<int64_t> pad = op.getPad();
694 ArrayRef<int64_t> stride = op.getStride();
695
696 SmallVector<Value> dynamicDims;
697
698 // Batch dimension
699 if (resultTy.isDynamicDim(0))
700 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
701
702 // Height/width dimensions
703 for (int64_t dim : {1, 2}) {
704 if (!resultTy.isDynamicDim(dim))
705 continue;
706
707 // Index into the attribute arrays
708 int64_t index = dim - 1;
709
710 // Input height/width
711 Value ihw = tensor::DimOp::create(rewriter, loc, input, dim);
712
713 // Kernel height/width
714 Value khw = arith::ConstantIndexOp::create(rewriter, loc, kernel[index]);
715
716 // Output height/width
717 Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2],
718 pad[index * 2 + 1], khw, stride[index],
719 /*dilationAttr=*/1, rewriter);
720 dynamicDims.push_back(ohw);
721 }
722
723 // Channel dimension
724 if (resultTy.isDynamicDim(3))
725 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 3));
726
727 return dynamicDims;
728 }
729
730 LogicalResult
731 matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
732 ConversionPatternRewriter &rewriter) const final {
733 Location loc = op.getLoc();
734 Value input = adaptor.getInput();
735 ShapedType inputTy = cast<ShapedType>(input.getType());
736
737 bool isUnsigned = op.getType().getElementType().isUnsignedInteger();
738 ShapedType resultTy =
739 getTypeConverter()->convertType<ShapedType>(op.getType());
740 if (!resultTy)
741 return rewriter.notifyMatchFailure(op, "failed to convert type");
742 Type resultETy = inputTy.getElementType();
743
744 SmallVector<Value> dynamicDims =
745 computeDynamicOutputSizes(op, adaptor, rewriter);
746
747 // Determine what the initial value needs to be for the max pool op.
748 TypedAttr initialAttr;
749 if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16())
750 initialAttr = rewriter.getFloatAttr(
751 resultETy, APFloat::getLargest(
752 cast<FloatType>(resultETy).getFloatSemantics(), true));
753
754 else if (isUnsigned)
755 initialAttr = rewriter.getIntegerAttr(
756 resultETy, APInt::getZero(resultETy.getIntOrFloatBitWidth()));
757 else if (isa<IntegerType>(resultETy))
758 initialAttr = rewriter.getIntegerAttr(
759 resultETy,
760 APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
761
762 if (!initialAttr)
763 return rewriter.notifyMatchFailure(
764 op, "Unsupported initial value for tosa.maxpool_2d op");
765
766 // Apply padding as necessary.
767 llvm::SmallVector<int64_t> pad;
768 pad.resize(2, 0);
769 llvm::append_range(pad, op.getPad());
770 pad.resize(pad.size() + 2, 0);
771
772 Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
773
774 Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
775
776 ArrayRef<int64_t> kernel = op.getKernel();
777 ArrayRef<int64_t> stride = op.getStride();
778
779 Attribute strideAttr = rewriter.getI64VectorAttr(stride);
780 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
781
782 // Create the linalg op that performs pooling.
783 Value emptyTensor =
784 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
785 resultTy.getElementType(), dynamicDims);
786
787 Value filledEmptyTensor =
788 linalg::FillOp::create(rewriter, loc, initialValue, emptyTensor)
789 .result();
790
791 Value fakeWindowDims =
792 tensor::EmptyOp::create(rewriter, loc, kernel, resultETy);
793
794 if (isUnsigned) {
795 rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
796 op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
797 filledEmptyTensor, strideAttr, dilationAttr);
798 return llvm::success();
799 }
800
801 auto resultOp = linalg::PoolingNhwcMaxOp::create(
802 rewriter, op->getLoc(), ArrayRef<Type>{resultTy},
803 ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
804 dilationAttr);
805
806 NanPropagationMode nanMode = op.getNanMode();
807 rewriter.replaceOp(op, resultOp);
808
809 // NaN propagation has no meaning for non floating point types.
810 if (!isa<FloatType>(getElementTypeOrSelf(inputTy)))
811 return success();
812
813 // "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
814 // compare and select materialization is required.
815 //
816 // In the case of "IGNORE" we need to insert a compare and select. Since
817 // we've already produced a named op we will just take its body and modify
818 // it to include the appropriate checks. If the current value is NaN the
819 // old value of pool will be taken otherwise we use the result.
820 if (nanMode == NanPropagationMode::IGNORE) {
821 auto genericOp = linalg::GenericOp::create(
822 rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
823 resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
824 resultOp.getIteratorTypesArray(),
825 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
826 IRMapping map;
827 auto oldBlock = resultOp.getRegion().begin();
828 auto oldArgs = oldBlock->getArguments();
829 auto &oldMaxOp = *resultOp.getBlock()->begin();
830 map.map(oldArgs, blockArgs);
831 auto *newOp = opBuilder.clone(oldMaxOp, map);
832 Value isNaN =
833 arith::CmpFOp::create(opBuilder, loc, arith::CmpFPredicate::UNO,
834 blockArgs.front(), blockArgs.front());
835 auto selectOp = arith::SelectOp::create(
836 opBuilder, loc, isNaN, blockArgs.back(), newOp->getResult(0));
837 linalg::YieldOp::create(opBuilder, loc, selectOp.getResult());
838 });
839 rewriter.replaceOp(resultOp, genericOp);
840 }
841
842 return success();
843 }
844};
845
846class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
847public:
848 using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern;
849
850 LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
851 PatternRewriter &rewriter) const final {
852 Location loc = op.getLoc();
853 Value input = op.getInput();
854 ShapedType inputTy = cast<ShapedType>(input.getType());
855 Type inElementTy = inputTy.getElementType();
856
857 ShapedType resultTy = cast<ShapedType>(op.getType());
858 Type resultETy = cast<ShapedType>(op.getType()).getElementType();
859
860 Type accETy = op.getAccType();
861 ShapedType accTy = resultTy.clone(accETy);
862
863 auto dynamicDimsOr =
864 checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
865 if (!dynamicDimsOr.has_value())
866 return failure();
867 SmallVector<Value> dynamicDims = *dynamicDimsOr;
868
869 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
870 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
871 if (failed(maybeIZp))
872 return rewriter.notifyMatchFailure(
873 op, "input zero point could not be statically determined");
874 if (failed(maybeOZp))
875 return rewriter.notifyMatchFailure(
876 op, "output zero point could not be statically determined");
877
878 const int64_t inputZpVal = *maybeIZp;
879 const int64_t outputZpVal = *maybeOZp;
880
881 // Apply padding as necessary.
882 llvm::SmallVector<int64_t> pad;
883 pad.resize(2, 0);
884 llvm::append_range(pad, op.getPad());
885 pad.resize(pad.size() + 2, 0);
886 TypedAttr padAttr = rewriter.getZeroAttr(inElementTy);
887 // Unsupported element type
888 if (!padAttr)
889 return failure();
890 Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
891
892 auto initialAttr = rewriter.getZeroAttr(accETy);
893 Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
894
895 ArrayRef<int64_t> kernel = op.getKernel();
896 ArrayRef<int64_t> stride = op.getStride();
897
898 Attribute strideAttr = rewriter.getI64VectorAttr(stride);
899 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
900
901 // Create the linalg op that performs pooling.
902 Value poolEmptyTensor = tensor::EmptyOp::create(
903 rewriter, loc, accTy.getShape(), accETy, dynamicDims);
904
905 Value filledEmptyTensor =
906 linalg::FillOp::create(rewriter, loc, ValueRange{initialValue},
907 ValueRange{poolEmptyTensor})
908 .result();
909
910 Value fakeWindowDims =
911 tensor::EmptyOp::create(rewriter, loc, kernel, accETy);
912
913 // Sum across the pooled region.
914 Value poolingOp = linalg::PoolingNhwcSumOp::create(
915 rewriter, loc, ArrayRef<Type>{accTy},
916 ValueRange{paddedInput, fakeWindowDims},
917 filledEmptyTensor, strideAttr, dilationAttr)
918 .getResult(0);
919
920 // Normalize the summed value by the number of elements grouped in each
921 // pool.
922 Value iH = tensor::DimOp::create(rewriter, loc, poolingOp, 1);
923 Value iW = tensor::DimOp::create(rewriter, loc, poolingOp, 2);
924
925 auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
926 iH = arith::SubIOp::create(rewriter, loc, iH, one);
927 iW = arith::SubIOp::create(rewriter, loc, iW, one);
928
929 Value genericEmptyTensor = tensor::EmptyOp::create(
930 rewriter, loc, resultTy.getShape(), resultETy, dynamicDims);
931
932 auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
933 auto genericOp = linalg::GenericOp::create(
934 rewriter, loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
935 ValueRange{genericEmptyTensor},
936 ArrayRef<AffineMap>({affineMap, affineMap}),
937 getNParallelLoopsAttrs(resultTy.getRank()),
938 [&](OpBuilder &b, Location loc, ValueRange args) {
939 auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
940
941 // Determines what the portion of valid input is covered by the
942 // kernel.
943 auto padFn = [&](Value valid, Value pos, int64_t pad) -> Value {
944 if (pad == 0)
945 return valid;
946
947 auto padVal = arith::ConstantIndexOp::create(rewriter, loc, pad);
948 Value dpos = arith::SubIOp::create(rewriter, loc, pos, padVal);
949
950 Value offset = arith::MinSIOp::create(rewriter, loc, dpos, zero);
951 return arith::AddIOp::create(rewriter, loc, valid, offset)
952 ->getResult(0);
953 };
954
955 auto coverageFn = [&](int64_t i, Value isize) -> Value {
956 Value strideVal =
957 arith::ConstantIndexOp::create(rewriter, loc, stride[i - 1]);
958 Value val =
959 arith::ConstantIndexOp::create(rewriter, loc, kernel[i - 1]);
960
961 // Find the position relative to the input tensor's ends.
962 Value left = linalg::IndexOp::create(rewriter, loc, i);
963 Value right = arith::SubIOp::create(rewriter, loc, isize, left);
964 left = arith::MulIOp::create(rewriter, loc, left, strideVal);
965 right = arith::MulIOp::create(rewriter, loc, right, strideVal);
966
967 // Determine how much padding was included.
968 val = padFn(val, left, pad[i * 2]);
969 val = padFn(val, right, pad[i * 2 + 1]);
970 return arith::MaxSIOp::create(rewriter, loc, one, val);
971 };
972
973 // Compute the indices from either end.
974 Value kH3 = coverageFn(1, iH);
975 Value kW3 = coverageFn(2, iW);
976
977 // Compute the total number of elements and normalize.
978 auto count = arith::IndexCastOp::create(
979 rewriter, loc, rewriter.getI32Type(),
980 arith::MulIOp::create(rewriter, loc, kH3, kW3));
981
982 // Divide by the number of summed values. For floats this is just
983 // a div however for quantized values input normalization had
984 // to be applied.
985 Value poolVal = args[0];
986 if (isa<FloatType>(accETy)) {
987 auto countF = arith::SIToFPOp::create(rewriter, loc, accETy, count);
988 poolVal = arith::DivFOp::create(rewriter, loc, poolVal, countF)
989 ->getResult(0);
990 if (accETy.getIntOrFloatBitWidth() >
991 resultETy.getIntOrFloatBitWidth())
992 poolVal =
993 arith::TruncFOp::create(rewriter, loc, resultETy, poolVal);
994 } else {
995
996 // If we have quantization information we need to apply an offset
997 // for the input zp value.
998 if (inputZpVal != 0) {
999 auto inputZp = arith::ConstantOp::create(
1000 rewriter, loc, b.getIntegerAttr(accETy, inputZpVal));
1001 Value offset =
1002 arith::MulIOp::create(rewriter, loc, accETy, count, inputZp);
1003 poolVal =
1004 arith::SubIOp::create(rewriter, loc, accETy, poolVal, offset);
1005 }
1006
1007 // Compute: k = 32 - count_leading_zeros(value - 1)
1008 Value one32 = arith::ConstantOp::create(
1009 rewriter, loc, rewriter.getI32IntegerAttr(1));
1010 Value thirtyTwo32 = arith::ConstantOp::create(
1011 rewriter, loc, rewriter.getI32IntegerAttr(32));
1012
1013 Value countSubOne =
1014 arith::SubIOp::create(rewriter, loc, count, one32);
1015 Value leadingZeros =
1016 math::CountLeadingZerosOp::create(rewriter, loc, countSubOne);
1017 Value k =
1018 arith::SubIOp::create(rewriter, loc, thirtyTwo32, leadingZeros);
1019
1020 // Compute: numerator = ((1 << 30) + 1) << k
1021 Value k64 =
1022 arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), k);
1023 Value thirtyShiftPlusOne = arith::ConstantOp::create(
1024 rewriter, loc, rewriter.getI64IntegerAttr((1 << 30) + 1));
1025 Value numerator =
1026 arith::ShLIOp::create(rewriter, loc, thirtyShiftPlusOne, k64);
1027
1028 // Compute: scale.multiplier = numerator / value;
1029 Value count64 = arith::ExtUIOp::create(
1030 rewriter, loc, rewriter.getI64Type(), count);
1031 Value multiplier =
1032 arith::DivUIOp::create(rewriter, loc, numerator, count64);
1033 multiplier = arith::TruncIOp::create(
1034 rewriter, loc, rewriter.getI32Type(), multiplier);
1035
1036 // Compute: scale.shift = 30 + k
1037 Value k8 =
1038 arith::TruncIOp::create(rewriter, loc, rewriter.getI8Type(), k);
1039 Value thirty8 = arith::ConstantOp::create(
1040 rewriter, loc, rewriter.getI8IntegerAttr(30));
1041 Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
1042
1043 auto roundingAttr = RoundingModeAttr::get(
1044 rewriter.getContext(), RoundingMode::SINGLE_ROUND);
1045
1046 auto scaled = tosa::ApplyScaleOp::create(
1047 rewriter, loc, rewriter.getI32Type(), poolVal,
1048 multiplier, shift, roundingAttr)
1049 .getResult();
1050
1051 // If we have quantization information we need to apply output
1052 // zeropoint.
1053 if (outputZpVal != 0) {
1054 auto outputZp = arith::ConstantOp::create(
1055 rewriter, loc,
1056 b.getIntegerAttr(scaled.getType(), outputZpVal));
1057 scaled = arith::AddIOp::create(rewriter, loc, scaled, outputZp)
1058 .getResult();
1059 }
1060
1061 // Apply Clip.
1062 int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
1063
1065 rewriter, loc, accETy,
1066 APInt::getSignedMinValue(outBitwidth).getSExtValue());
1068 rewriter, loc, accETy,
1069 APInt::getSignedMaxValue(outBitwidth).getSExtValue());
1070 auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
1071 /*isUnsigned=*/false);
1072
1073 poolVal = clamp;
1074 // Convert type.
1075 if (resultETy != clamp.getType()) {
1076 poolVal =
1077 arith::TruncIOp::create(rewriter, loc, resultETy, poolVal);
1078 }
1079 }
1080
1081 linalg::YieldOp::create(rewriter, loc, poolVal);
1082 });
1083
1084 rewriter.replaceOp(op, genericOp.getResult(0));
1085 return success();
1086 }
1087};
1088
1089class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1090public:
1091 using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
1092
1093 LogicalResult matchAndRewrite(tosa::TransposeOp op,
1094 PatternRewriter &rewriter) const final {
1095 const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
1096
1097 Location loc = op.getLoc();
1098 // The verifier should have made sure we have a valid TOSA permutation
1099 // tensor. isPermutationVector doesn't actually check the TOSA perms we
1100 // expect.
1101 SmallVector<OpFoldResult> inputSizes =
1102 tensor::getMixedSizes(rewriter, loc, op.getInput1());
1103 auto permutedSizes =
1104 applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms);
1105
1106 auto permutedInit =
1107 tensor::EmptyOp::create(rewriter, loc, permutedSizes,
1108 op.getInput1().getType().getElementType());
1109 rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
1110 op, op.getInput1(), permutedInit,
1111 llvm::map_to_vector(constantPerms,
1112 [](int32_t v) -> int64_t { return v; }));
1113 return success();
1114 }
1115};
1116} // namespace
1117
1119 const TypeConverter &converter, RewritePatternSet *patterns,
1120 const TosaToLinalgNamedOptions &options) {
1121 if (options.preferConv2DKernelLayoutHWCF) {
1122 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
1123 linalg::Conv2DNhwcHwcfQOp>>(
1124 patterns->getContext());
1125 } else {
1126 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
1127 linalg::Conv2DNhwcFhwcQOp>>(
1128 patterns->getContext());
1129 }
1130 patterns->add<
1131 // clang-format off
1132 ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
1133 DepthwiseConvConverter,
1134 MatMulConverter,
1135 AvgPool2dConverter,
1136 TransposeConverter
1137 >(patterns->getContext());
1138
1139 patterns->add<
1140 MaxPool2dConverter
1141 >(converter, patterns->getContext());
1142 // clang-format on
1143}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source, Value result)
static mlir::Value applyPad(Location loc, Value input, ArrayRef< int64_t > pad, TypedAttr padAttr, OpBuilder &rewriter)
static void createDepthwiseConvCollapseMap(int64_t outputRank, SmallVector< ReassociationExprs, 4 > &reassociationMap, OpBuilder &rewriter)
static mlir::Value linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef< AffineMap > indexingMaps)
static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim, int64_t padBeforeAttr, int64_t padAfterAttr, Value kernelDim, int64_t strideAttr, int64_t dilationAttr, OpBuilder &rewriter)
static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter, Location loc, Value source, Value result)
static mlir::Value reifyConstantDim(int64_t attr, ImplicitLocOpBuilder &builder)
static SmallVector< Value > inferDynamicDimsForConv(Location loc, Value input, Value weight, ShapedType resultTy, ArrayRef< int64_t > padAttr, ArrayRef< int64_t > strideAttr, ArrayRef< int64_t > dilationAttr, ArrayRef< int64_t > inputSizeDims, ArrayRef< int64_t > kernelSizeDims, OpBuilder &rewriter)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:391
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:376
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:368
MLIRContext * getContext() const
Definition Builders.h:56
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:262
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:68
SmallVector< T > applyTOSAPermutation(ArrayRef< T > input, ArrayRef< int32_t > perms)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
void populateTosaToLinalgNamedConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options)
Populates conversion passes from TOSA dialect to Linalg named operations.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...