MLIR 23.0.0git
TosaToLinalgNamed.cpp
Go to the documentation of this file.
1//===- TosaToLinalgNamed.cpp - Lowering Tosa to Linalg Named Ops ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// These rewriters lower from the Tosa to the Linalg named ops.
10//
11//===----------------------------------------------------------------------===//
12
23#include "llvm/ADT/SmallVectorExtras.h"
24
25#include <type_traits>
26
27using namespace mlir;
28using namespace mlir::tosa;
29
31 TypedAttr padAttr, OpBuilder &rewriter) {
32 // Input should be padded only if necessary.
33 if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
34 return input;
35
36 ShapedType inputTy = cast<ShapedType>(input.getType());
37 Type inputETy = inputTy.getElementType();
38 auto inputShape = inputTy.getShape();
39
40 assert((inputShape.size() * 2) == pad.size());
41
42 SmallVector<int64_t, 4> paddedShape;
45 for (size_t i : llvm::seq(inputShape.size())) {
46 auto lowPad = pad[i * 2];
47 auto highPad = pad[i * 2 + 1];
48 if (ShapedType::isDynamic(inputShape[i]))
49 paddedShape.push_back(inputShape[i]);
50 else
51 paddedShape.push_back(inputShape[i] + highPad + lowPad);
52 lowIndices.push_back(rewriter.getIndexAttr(lowPad));
53 highIndices.push_back(rewriter.getIndexAttr(highPad));
54 }
55
56 Value padValue = arith::ConstantOp::create(rewriter, loc, padAttr);
57
58 return tensor::PadOp::create(rewriter, loc,
59 RankedTensorType::get(paddedShape, inputETy),
60 input, lowIndices, highIndices, padValue);
61}
62
63static mlir::Value
65 Value conv, Value result,
66 ArrayRef<AffineMap> indexingMaps) {
67 ShapedType resultTy = cast<ShapedType>(conv.getType());
68 return linalg::GenericOp::create(
69 rewriter, loc, resultTy, ValueRange({bias, conv}), result,
70 indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
71 [](OpBuilder &builder, Location loc, ValueRange args) {
72 Value biasVal = args[0];
73 Type resType = args[1].getType();
74 if (resType != biasVal.getType()) {
75 biasVal =
76 arith::ExtSIOp::create(builder, loc, resType, biasVal);
77 }
78 Value added =
79 arith::AddIOp::create(builder, loc, biasVal, args[1]);
80 linalg::YieldOp::create(builder, loc, added);
81 })
82 .getResult(0);
83}
84
85// Construct the affine map that a linalg generic would use to broadcast the
86// source tensor into the shape of the result tensor.
88 Value result) {
89 ShapedType resultTy = cast<ShapedType>(result.getType());
90 ShapedType sourceTy = cast<ShapedType>(source.getType());
91 const int64_t resultRank = resultTy.getRank();
92 const int64_t sourceRank = sourceTy.getRank();
93
94 // The source tensor is broadcast to all the outer dimensions of the
95 // result tensor.
96 SmallVector<AffineExpr> sourceDims;
97 // In the case of a rank one source tensor with a single element TOSA
98 // specifies that the value be broadcast meaning we need an edge case for a
99 // constant map.
100 assert(sourceTy.hasStaticShape() &&
101 "Dynamic broadcasting shapes not supported!");
102 if (sourceRank == 1 && sourceTy.getDimSize(0) == 1) {
103 sourceDims.push_back(rewriter.getAffineConstantExpr(0));
104 } else {
105 for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
106 auto expr = rewriter.getAffineDimExpr(dim + resultRank - sourceRank);
107 sourceDims.push_back(expr);
108 }
109 }
110
111 return AffineMap::get(/*dimCount=*/resultRank,
112 /*symbolCount=*/0, sourceDims, rewriter.getContext());
113}
114
115// Broadcast the source value to all the outer dimensions of the result value.
116// If required, the element type is expanded using an arith.extsi or arith.extf
117// operation as appropriate.
119 Location loc, Value source,
120 Value result) {
121 ShapedType resultTy = cast<ShapedType>(result.getType());
122 const int64_t resultRank = resultTy.getRank();
123 // Creating maps for the input and output of the broacast-like generic op.
124 SmallVector<AffineMap, 2> indexingMaps;
125 indexingMaps.push_back(getBroadcastingMap(rewriter, source, result));
126 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
127
128 // Build the broadcast-like operation as a linalg.generic.
129 return linalg::GenericOp::create(
130 rewriter, loc, resultTy, ValueRange({source}), result,
131 indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
132 [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
133 Value biasVal = args[0];
134 Type resType = args[1].getType();
135 if (resType != biasVal.getType()) {
136 biasVal =
137 resultTy.getElementType().isFloat()
138 ? arith::ExtFOp::create(builder, loc, resType, biasVal)
139 .getResult()
140 : arith::ExtSIOp::create(builder, loc, resType,
141 biasVal)
142 .getResult();
143 }
144 linalg::YieldOp::create(builder, loc, biasVal);
145 })
146 .getResult(0);
147}
148
150 ImplicitLocOpBuilder &builder) {
151 return arith::ConstantIndexOp::create(builder, attr);
152}
153
154// Calculating the output width/height using the formula:
155// H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
156// W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
157
159 int64_t padBeforeAttr,
160 int64_t padAfterAttr, Value kernelDim,
161 int64_t strideAttr,
162 int64_t dilationAttr,
163 OpBuilder &rewriter) {
164 ImplicitLocOpBuilder builder(loc, rewriter);
165 auto one = arith::ConstantOp::create(rewriter, loc,
166 IntegerAttr::get(inputDim.getType(), 1));
167 Value padBefore = reifyConstantDim(padBeforeAttr, builder);
168 Value paddedBefore = arith::AddIOp::create(builder, inputDim, padBefore);
169 Value padAfter = reifyConstantDim(padAfterAttr, builder);
170 Value paddedAfter = arith::AddIOp::create(builder, paddedBefore, padAfter);
171
172 Value subOne = arith::SubIOp::create(builder, kernelDim, one);
173 Value dilation = reifyConstantDim(dilationAttr, builder);
174 Value dilated = arith::MulIOp::create(builder, dilation, subOne);
175 Value addOne = arith::AddIOp::create(builder, dilated, one);
176
177 Value subtract = arith::SubIOp::create(builder, paddedAfter, addOne);
178 Value stride = reifyConstantDim(strideAttr, builder);
179 Value divide = arith::DivUIOp::create(builder, subtract, stride);
180 return arith::AddIOp::create(builder, divide, one);
181}
182
183// Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D
185 Location loc, Value input, Value weight, ShapedType resultTy,
186 ArrayRef<int64_t> padAttr, ArrayRef<int64_t> strideAttr,
187 ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
188 ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
189 ShapedType inputTy = cast<ShapedType>(input.getType());
190 int64_t inputRank = inputTy.getRank();
191
192 SmallVector<Value> dynDims;
193 dynDims.resize(resultTy.getRank());
194
195 for (uint32_t i = 0, s = inputSizeDims.size(); i < s; ++i) {
196 int64_t inputDim = inputSizeDims[i];
197 int64_t kernelDim = kernelSizeDims[i];
198 if (resultTy.isDynamicDim(inputDim)) {
199 auto padTop = padAttr[i * 2];
200 auto padBottom = padAttr[i * 2 + 1];
201 auto stride = strideAttr[i];
202 auto dilation = dilationAttr[i];
203 Value initDynDim = tensor::DimOp::create(rewriter, loc, input, inputDim);
204 Value kernelDynDim =
205 tensor::DimOp::create(rewriter, loc, weight, kernelDim);
206 // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
207 dynDims[inputDim] =
208 getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom,
209 kernelDynDim, stride, dilation, rewriter);
210 }
211 }
212
213 // Get the batch/channels dimensions.
214 for (int i = 0; i < inputRank; i++) {
215 if (resultTy.isDynamicDim(i) && !dynDims[i])
216 dynDims[i] = tensor::DimOp::create(rewriter, loc, input, i);
217 }
218
219 SmallVector<Value> filteredDims = condenseValues(dynDims);
220 return filteredDims;
221}
222
223// Creates a map to collapse the last dimension of the Depthwise convolution op
224// due to a shape mismatch
226 int64_t outputRank, SmallVector<ReassociationExprs, 4> &reassociationMap,
227 OpBuilder &rewriter) {
228 reassociationMap.resize(outputRank);
229 for (int i = 0; i < outputRank; i++) {
230 reassociationMap[i].push_back(rewriter.getAffineDimExpr(i));
231 }
232 reassociationMap[outputRank - 1].push_back(
233 rewriter.getAffineDimExpr(outputRank));
234}
235
236namespace {
237
238template <typename TosaConvOp, typename LinalgConvOp, typename LinalgConvQOp>
239class ConvConverter : public OpConversionPattern<TosaConvOp> {
240public:
241 using OpConversionPattern<TosaConvOp>::OpConversionPattern;
242 LogicalResult
243 matchAndRewrite(TosaConvOp op, typename TosaConvOp::Adaptor adaptor,
244 ConversionPatternRewriter &rewriter) const final {
245 Location loc = op->getLoc();
246 Value input = op->getOperand(0);
247 Value weight = op->getOperand(1);
248 Value bias = op->getOperand(2);
249
250 ShapedType inputTy = cast<ShapedType>(input.getType());
251 ShapedType weightTy = cast<ShapedType>(weight.getType());
252 ShapedType biasTy = cast<ShapedType>(bias.getType());
253 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
254
255 Type inputETy = inputTy.getElementType();
256
257 DenseI64ArrayAttr padAttr = op.getPadAttr();
258 DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
259 DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
260
261 Type accETy = op.getAccType();
262 Type accTy = RankedTensorType::get(resultTy.getShape(), accETy);
263
264 // Get and verify zero points.
265 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
266 if (failed(maybeIZp))
267 return rewriter.notifyMatchFailure(
268 op, "input zero point cannot be statically determined");
269
270 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
271 if (failed(maybeWZp))
272 return rewriter.notifyMatchFailure(
273 op, "weight zero point cannot be statically determined");
274
275 const int64_t inputZpVal = *maybeIZp;
276 const int64_t weightZpVal = *maybeWZp;
277
278 if (op.verifyInputZeroPoint(inputZpVal).failed())
279 return rewriter.notifyMatchFailure(
280 op, "input zero point must be zero for non-int8 integer types");
281
282 if (op.verifyWeightZeroPoint(weightZpVal).failed())
283 return rewriter.notifyMatchFailure(
284 op, "weight zero point must be zero for non-int8 integer types");
285
286 bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
287
288 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
289 return rewriter.notifyMatchFailure(
290 op, "tosa.conv ops require static shapes for weight and bias");
291
292 if (inputETy.isUnsignedInteger())
293 return rewriter.notifyMatchFailure(
294 op, "tosa.conv ops does not support unsigned integer input");
295
296 llvm::SmallVector<int64_t> inputSizeDims;
297 llvm::SmallVector<int64_t> kernelSizeDims;
298 for (int i = 1; i < resultTy.getRank() - 1; i++) {
299 inputSizeDims.push_back(i);
300 kernelSizeDims.push_back(i);
301 }
302
303 SmallVector<Value> filteredDims = inferDynamicDimsForConv(
304 loc, input, weight, resultTy, padAttr.asArrayRef(),
305 strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
306 inputSizeDims, kernelSizeDims, rewriter);
307
308 auto weightShape = weightTy.getShape();
309
310 // Apply padding as necessary.
311 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
312 if (hasZp) {
313 int64_t intMin =
314 APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
315 .getSExtValue();
316 int64_t intMax =
317 APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
318 .getSExtValue();
319
320 if (inputZpVal < intMin || inputZpVal > intMax)
321 return rewriter.notifyMatchFailure(
322 op, "tosa.conv op quantization has zp outside of input range");
323
324 zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
325 }
326
327 llvm::SmallVector<int64_t> pad;
328 pad.resize(2, 0);
329 llvm::append_range(pad, padAttr.asArrayRef());
330 pad.resize(pad.size() + 2, 0);
331 input = applyPad(loc, input, pad, zeroAttr, rewriter);
332
333 if (4 == inputTy.getRank()) {
334 // For 2D convolutions, we need to check if the target convolution op
335 // wants a HWCF kernel layout.
336 bool wantHwcf =
337 hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
338 : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
339 if (wantHwcf) {
340 // Transpose the kernel to match dimension ordering of the linalg
341 // convolution operation.
342 // TODO(suderman): See if this can be efficiently folded - check whether
343 // the input is used anywhere else, if not fold the constant.
344 SmallVector<int32_t> weightPerm;
345 for (int i = 1; i < resultTy.getRank(); i++)
346 weightPerm.push_back(i);
347 weightPerm.push_back(0);
348
349 SmallVector<int64_t> newWeightShape;
350 for (auto dim : weightPerm)
351 newWeightShape.push_back(weightShape[dim]);
352 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
353 Type newWeightTy =
354 RankedTensorType::get(newWeightShape, weightTy.getElementType());
355 weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
356 weightPermAttr);
357 }
358 }
359
360 // For Conv3D transpose the kernel to match dimension ordering of the linalg
361 // convolution operation. Conv2D has a 1-1 mapping in linalg so better to
362 // map directly and then transpose later if desired.
363 if (5 == inputTy.getRank()) {
364 // TODO(suderman): See if this can be efficiently folded - check whether
365 // the input is used anywhere else, if not fold the constant.
366 SmallVector<int32_t> weightPerm;
367 for (int i = 1; i < resultTy.getRank(); i++)
368 weightPerm.push_back(i);
369 weightPerm.push_back(0);
370
371 SmallVector<int64_t> newWeightShape;
372 for (auto dim : weightPerm)
373 newWeightShape.push_back(weightShape[dim]);
374 auto weightPermAttr = rewriter.getDenseI32ArrayAttr(weightPerm);
375 Type newWeightTy =
376 RankedTensorType::get(newWeightShape, weightTy.getElementType());
377 weight = tosa::TransposeOp::create(rewriter, loc, newWeightTy, weight,
378 weightPermAttr);
379 }
380
381 // Extract the attributes for convolution.
382 ArrayRef<int64_t> stride = strideTosaAttr;
383 ArrayRef<int64_t> dilation = dilationTosaAttr;
384
385 // Create the convolution op.
386 auto strideAttr = rewriter.getI64TensorAttr(stride);
387 auto dilationAttr = rewriter.getI64TensorAttr(dilation);
388
389 Value biasEmptyTensor = tensor::EmptyOp::create(
390 rewriter, loc, resultTy.getShape(), accETy, filteredDims);
391
392 Value broadcastBias =
393 linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor);
394
395 if (hasZp) {
396 auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
397 auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
398
399 auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
400 auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp);
401
402 Value conv = LinalgConvQOp::create(
403 rewriter, loc, resultTy,
404 ValueRange{input, weight, iZpVal, kZpVal},
405 ValueRange{broadcastBias}, strideAttr, dilationAttr)
406 ->getResult(0);
407
408 rewriter.replaceOp(op, conv);
409 return success();
410 }
411
412 Value conv = LinalgConvOp::create(
413 rewriter, loc, accTy, ValueRange{input, weight},
414 ValueRange{broadcastBias}, strideAttr, dilationAttr)
415 ->getResult(0);
416
417 // We may need to truncate back to the result type if the accumulator was
418 // wider than the result.
419 if (resultTy != accTy)
420 conv = tosa::CastOp::create(rewriter, loc, resultTy, conv);
421
422 rewriter.replaceOp(op, conv);
423 return success();
424 }
425};
426
427class DepthwiseConvConverter
428 : public OpConversionPattern<tosa::DepthwiseConv2DOp> {
429public:
430 using OpConversionPattern<tosa::DepthwiseConv2DOp>::OpConversionPattern;
431 LogicalResult
432 matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor,
433 ConversionPatternRewriter &rewriter) const final {
434 Location loc = op->getLoc();
435 Value input = op->getOperand(0);
436 Value weight = op->getOperand(1);
437 Value bias = op->getOperand(2);
438
439 ShapedType inputTy = cast<ShapedType>(input.getType());
440 ShapedType weightTy = cast<ShapedType>(weight.getType());
441 ShapedType biasTy = cast<ShapedType>(bias.getType());
442 ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
443 int64_t resultRank = resultTy.getRank();
444
445 Type inputETy = inputTy.getElementType();
446 Type resultETy = resultTy.getElementType();
447
448 auto padAttr = cast<DenseI64ArrayAttr>(op->getAttr("pad"));
449 auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride"));
450 auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation"));
451
452 Type accETy = op.getAccType();
453
454 if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
455 return rewriter.notifyMatchFailure(
456 op, "tosa.depthwise_conv ops require static shapes");
457
458 // Compute output dynamic dims
459 SmallVector<Value> filteredDims = inferDynamicDimsForConv(
460 loc, input, weight, resultTy, padAttr.asArrayRef(),
461 strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
462 /*inputSizeDims=*/{1, 2},
463 /*kernelSizeDims=*/{0, 1}, rewriter);
464
465 // Get and verify zero points.
466
467 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
468 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
469 if (failed(maybeIZp))
470 return rewriter.notifyMatchFailure(
471 op, "input zero point cannot be statically determined");
472 if (failed(maybeWZp))
473 return rewriter.notifyMatchFailure(
474 op, "weight zero point cannot be statically determined");
475
476 const int64_t inputZpVal = *maybeIZp;
477 const int64_t weightZpVal = *maybeWZp;
478
479 if (op.verifyInputZeroPoint(inputZpVal).failed())
480 return rewriter.notifyMatchFailure(
481 op, "input zero point must be zero for non-int8 integer types");
482
483 if (op.verifyWeightZeroPoint(weightZpVal).failed())
484 return rewriter.notifyMatchFailure(
485 op, "weight zero point must be zero for non-int8 integer types");
486
487 bool hasNullZps = (inputZpVal == 0) && (weightZpVal == 0);
488 auto weightShape = weightTy.getShape();
489 auto resultShape = resultTy.getShape();
490
491 // Apply padding as necessary.
492 TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
493 if (!hasNullZps) {
494 int64_t intMin =
495 APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
496 .getSExtValue();
497 int64_t intMax =
498 APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
499 .getSExtValue();
500
501 if (inputZpVal < intMin || inputZpVal > intMax)
502 return rewriter.notifyMatchFailure(
503 op, "tosa.depthwise_conv op quantization has zp outside of input "
504 "range");
505
506 zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
507 }
508
509 llvm::SmallVector<int64_t> pad;
510 pad.resize(2, 0);
511 llvm::append_range(pad, padAttr.asArrayRef());
512 pad.resize(pad.size() + 2, 0);
513
514 input = applyPad(loc, input, pad, zeroAttr, rewriter);
515
516 // Extract the attributes for convolution.
517 ArrayRef<int64_t> stride = strideTosaAttr;
518 ArrayRef<int64_t> dilation = dilationTosaAttr;
519
520 // Create the convolution op.
521 auto strideAttr = rewriter.getI64TensorAttr(stride);
522 auto dilationAttr = rewriter.getI64TensorAttr(dilation);
523 ShapedType linalgConvTy =
524 RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
525 weightShape[2], weightShape[3]},
526 accETy);
527
528 auto resultZeroAttr = rewriter.getZeroAttr(accETy);
529 Value emptyTensor = tensor::EmptyOp::create(
530 rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims);
531 Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr);
532 Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
533 ValueRange{emptyTensor})
534 .result();
535
536 Value biasEmptyTensor = tensor::EmptyOp::create(
537 rewriter, loc, resultTy.getShape(), resultETy, filteredDims);
538
539 // Broadcast the initial value to the output tensor before convolving.
540 SmallVector<AffineMap, 4> indexingMaps;
541 indexingMaps.push_back(getBroadcastingMap(rewriter, bias, biasEmptyTensor));
542 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
543 indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
544
545 if (hasNullZps) {
546 Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create(
547 rewriter, loc, linalgConvTy, ValueRange{input, weight},
548 ValueRange{zeroTensor}, strideAttr, dilationAttr)
549 .getResult(0);
550
551 // We may need to truncate back to the result type if the accumulator was
552 // wider than the result.
553 if (accETy != resultETy)
554 conv = tosa::CastOp::create(
555 rewriter, loc,
556 RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(),
557 resultETy),
558 conv);
559
560 SmallVector<ReassociationExprs, 4> reassociationMap;
561 createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
562 Value convReshape = tensor::CollapseShapeOp::create(
563 rewriter, loc, resultTy, conv, reassociationMap);
564
565 Value result =
566 linalg::GenericOp::create(
567 rewriter, loc, resultTy, ValueRange({bias, convReshape}),
568 biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultRank),
569 [&](OpBuilder &nestedBuilder, Location nestedLoc,
570 ValueRange args) {
571 Value added;
572 if (llvm::isa<FloatType>(inputETy))
573 added = arith::AddFOp::create(nestedBuilder, loc, args[0],
574 args[1]);
575 else
576 added = arith::AddIOp::create(nestedBuilder, loc, args[0],
577 args[1]);
578 linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
579 })
580 .getResult(0);
581 rewriter.replaceOp(op, result);
582 } else {
583 IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
584 IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
585 auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
586 auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp);
587 Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create(
588 rewriter, loc, linalgConvTy,
589 ValueRange{input, weight, iZpVal, kZpVal},
590 ValueRange{zeroTensor}, strideAttr, dilationAttr)
591 .getResult(0);
592 SmallVector<ReassociationExprs, 4> reassociationMap;
593 createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
594 Value convReshape = tensor::CollapseShapeOp::create(
595 rewriter, loc, resultTy, conv, reassociationMap);
597 rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps);
598 rewriter.replaceOp(op, result);
599 }
600 return success();
601 }
602};
603
604class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
605public:
606 using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
607 LogicalResult
608 matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
609 ConversionPatternRewriter &rewriter) const final {
610 Location loc = op.getLoc();
611
612 auto outputTy = cast<ShapedType>(op.getType());
613 auto outputElementTy = outputTy.getElementType();
614
615 SmallVector<Value> dynDims;
616 dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
617
618 if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
619 dynDims[0] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 0);
620 }
621
622 if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
623 dynDims[1] = tensor::DimOp::create(rewriter, loc, op->getOperand(0), 1);
624 }
625
626 if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
627 dynDims[2] = tensor::DimOp::create(rewriter, loc, op->getOperand(1), 2);
628 }
629
630 SmallVector<Value> filteredDims = condenseValues(dynDims);
631
632 auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
633 Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr);
634 auto emptyTensor =
635 tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(),
636 outputTy.getElementType(), filteredDims);
637 Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
638 ValueRange{emptyTensor})
639 .result();
640
641 FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
642 FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
643 if (failed(maybeAZp))
644 return rewriter.notifyMatchFailure(
645 op, "input a zero point cannot be statically determined");
646 if (failed(maybeBZp))
647 return rewriter.notifyMatchFailure(
648 op, "input b zero point cannot be statically determined");
649
650 const int64_t aZpVal = *maybeAZp;
651 const int64_t bZpVal = *maybeBZp;
652
653 if (op.verifyAZeroPoint(aZpVal).failed())
654 return rewriter.notifyMatchFailure(
655 op, "input a zero point must be zero for non-int8 integer types");
656
657 if (op.verifyBZeroPoint(bZpVal).failed())
658 return rewriter.notifyMatchFailure(
659 op, "input b zero point must be zero for non-int8 integer types");
660
661 if (aZpVal == 0 && bZpVal == 0) {
662 rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
663 op, TypeRange{op.getType()},
664 ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
665 return success();
666 }
667
668 auto aZp = arith::ConstantOp::create(rewriter, loc,
669 rewriter.getI32IntegerAttr(aZpVal));
670 auto bZp = arith::ConstantOp::create(rewriter, loc,
671 rewriter.getI32IntegerAttr(bZpVal));
672 rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
673 op, TypeRange{op.getType()},
674 ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
675
676 return success();
677 }
678};
679
680class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
681public:
682 using OpConversionPattern::OpConversionPattern;
683
684 // Compute the dynamic output sizes of the maxpool operation.
685 static SmallVector<Value>
686 computeDynamicOutputSizes(tosa::MaxPool2dOp op, OpAdaptor adaptor,
687 ConversionPatternRewriter &rewriter) {
688 TensorType resultTy = op.getType();
689 Location loc = op.getLoc();
690
691 Value input = adaptor.getInput();
692 ArrayRef<int64_t> kernel = op.getKernel();
693 ArrayRef<int64_t> pad = op.getPad();
694 ArrayRef<int64_t> stride = op.getStride();
695
696 SmallVector<Value> dynamicDims;
697
698 // Batch dimension
699 if (resultTy.isDynamicDim(0))
700 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0));
701
702 // Height/width dimensions
703 for (int64_t dim : {1, 2}) {
704 if (!resultTy.isDynamicDim(dim))
705 continue;
706
707 // Index into the attribute arrays
708 int64_t index = dim - 1;
709
710 // Input height/width
711 Value ihw = tensor::DimOp::create(rewriter, loc, input, dim);
712
713 // Kernel height/width
714 Value khw = arith::ConstantIndexOp::create(rewriter, loc, kernel[index]);
715
716 // Output height/width
717 Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2],
718 pad[index * 2 + 1], khw, stride[index],
719 /*dilationAttr=*/1, rewriter);
720 dynamicDims.push_back(ohw);
721 }
722
723 // Channel dimension
724 if (resultTy.isDynamicDim(3))
725 dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 3));
726
727 return dynamicDims;
728 }
729
730 LogicalResult
731 matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor,
732 ConversionPatternRewriter &rewriter) const final {
733 Location loc = op.getLoc();
734 Value input = adaptor.getInput();
735 ShapedType inputTy = cast<ShapedType>(input.getType());
736
737 bool isUnsigned = op.getType().getElementType().isUnsignedInteger();
738 ShapedType resultTy =
739 getTypeConverter()->convertType<ShapedType>(op.getType());
740 if (!resultTy)
741 return rewriter.notifyMatchFailure(op, "failed to convert type");
742 Type resultETy = inputTy.getElementType();
743
744 SmallVector<Value> dynamicDims =
745 computeDynamicOutputSizes(op, adaptor, rewriter);
746
747 // Determine what the initial value needs to be for the max pool op.
748 TypedAttr initialAttr;
749 if (resultETy.isF32() || resultETy.isBF16() || resultETy.isF16())
750 initialAttr = rewriter.getFloatAttr(
751 resultETy, APFloat::getLargest(
752 cast<FloatType>(resultETy).getFloatSemantics(), true));
753
754 else if (isUnsigned)
755 initialAttr = rewriter.getIntegerAttr(
756 resultETy, APInt::getZero(resultETy.getIntOrFloatBitWidth()));
757 else if (isa<IntegerType>(resultETy))
758 initialAttr = rewriter.getIntegerAttr(
759 resultETy,
760 APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth()));
761
762 if (!initialAttr)
763 return rewriter.notifyMatchFailure(
764 op, "Unsupported initial value for tosa.maxpool_2d op");
765
766 // Apply padding as necessary.
767 llvm::SmallVector<int64_t> pad;
768 pad.resize(2, 0);
769 llvm::append_range(pad, op.getPad());
770 pad.resize(pad.size() + 2, 0);
771
772 Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
773
774 Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
775
776 ArrayRef<int64_t> kernel = op.getKernel();
777 ArrayRef<int64_t> stride = op.getStride();
778
779 Attribute strideAttr = rewriter.getI64VectorAttr(stride);
780 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
781
782 // Create the linalg op that performs pooling.
783 Value emptyTensor =
784 tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(),
785 resultTy.getElementType(), dynamicDims);
786
787 Value filledEmptyTensor =
788 linalg::FillOp::create(rewriter, loc, initialValue, emptyTensor)
789 .result();
790
791 Value fakeWindowDims =
792 tensor::EmptyOp::create(rewriter, loc, kernel, resultETy);
793
794 if (isUnsigned) {
795 rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(
796 op, ArrayRef<Type>{resultTy}, ValueRange{paddedInput, fakeWindowDims},
797 filledEmptyTensor, strideAttr, dilationAttr);
798 return llvm::success();
799 }
800
801 auto resultOp = linalg::PoolingNhwcMaxOp::create(
802 rewriter, op->getLoc(), ArrayRef<Type>{resultTy},
803 ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
804 dilationAttr);
805
806 NanPropagationMode nanMode = op.getNanMode();
807 rewriter.replaceOp(op, resultOp);
808
809 // NaN propagation has no meaning for non floating point types.
810 if (!isa<FloatType>(getElementTypeOrSelf(inputTy)))
811 return success();
812
813 // "PROPAGATE" mode matches the behaviour of the LinAlg named op, so no
814 // compare and select materialization is required.
815 //
816 // In the case of "IGNORE" we need to insert a compare and select. Since
817 // we've already produced a named op we will just take its body and modify
818 // it to include the appropriate checks. If the current value is NaN the
819 // old value of pool will be taken otherwise we use the result.
820 if (nanMode == NanPropagationMode::IGNORE) {
821 auto genericOp = linalg::GenericOp::create(
822 rewriter, loc, resultOp.getType(0), resultOp.getInputs(),
823 resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
824 resultOp.getIteratorTypesArray(),
825 [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
826 IRMapping map;
827 auto oldBlock = resultOp.getRegion().begin();
828 auto oldArgs = oldBlock->getArguments();
829 auto &oldMaxOp = *resultOp.getBlock()->begin();
830 map.map(oldArgs, blockArgs);
831 auto *newOp = opBuilder.clone(oldMaxOp, map);
832 Value isNaN =
833 arith::CmpFOp::create(opBuilder, loc, arith::CmpFPredicate::UNO,
834 blockArgs.front(), blockArgs.front());
835 auto selectOp = arith::SelectOp::create(
836 opBuilder, loc, isNaN, blockArgs.back(), newOp->getResult(0));
837 linalg::YieldOp::create(opBuilder, loc, selectOp.getResult());
838 });
839 rewriter.replaceOp(resultOp, genericOp);
840 }
841
842 return success();
843 }
844};
845
846class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
847public:
848 using OpRewritePattern<tosa::AvgPool2dOp>::OpRewritePattern;
849
850 LogicalResult matchAndRewrite(tosa::AvgPool2dOp op,
851 PatternRewriter &rewriter) const final {
852 Location loc = op.getLoc();
853 Value input = op.getInput();
854 ShapedType inputTy = cast<ShapedType>(input.getType());
855 Type inElementTy = inputTy.getElementType();
856
857 ShapedType resultTy = cast<ShapedType>(op.getType());
858 Type resultETy = cast<ShapedType>(op.getType()).getElementType();
859
860 Type accETy = op.getAccType();
861 ShapedType accTy = resultTy.clone(accETy);
862
863 auto dynamicDimsOr =
864 checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
865 if (!dynamicDimsOr.has_value())
866 return failure();
867 SmallVector<Value> dynamicDims = *dynamicDimsOr;
868
869 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
870 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
871 if (failed(maybeIZp))
872 return rewriter.notifyMatchFailure(
873 op, "input zero point could not be statically determined");
874 if (failed(maybeOZp))
875 return rewriter.notifyMatchFailure(
876 op, "output zero point could not be statically determined");
877
878 const int64_t inputZpVal = *maybeIZp;
879 const int64_t outputZpVal = *maybeOZp;
880
881 // Apply padding as necessary.
882 llvm::SmallVector<int64_t> pad;
883 pad.resize(2, 0);
884 llvm::append_range(pad, op.getPad());
885 pad.resize(pad.size() + 2, 0);
886 TypedAttr padAttr = rewriter.getZeroAttr(inElementTy);
887 // Unsupported element type
888 if (!padAttr)
889 return failure();
890 Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
891
892 auto initialAttr = rewriter.getZeroAttr(accETy);
893 Value initialValue = arith::ConstantOp::create(rewriter, loc, initialAttr);
894
895 ArrayRef<int64_t> kernel = op.getKernel();
896 ArrayRef<int64_t> stride = op.getStride();
897
898 Attribute strideAttr = rewriter.getI64VectorAttr(stride);
899 Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
900
901 // Create the linalg op that performs pooling.
902 Value poolEmptyTensor = tensor::EmptyOp::create(
903 rewriter, loc, accTy.getShape(), accETy, dynamicDims);
904
905 Value filledEmptyTensor =
906 linalg::FillOp::create(rewriter, loc, ValueRange{initialValue},
907 ValueRange{poolEmptyTensor})
908 .result();
909
910 Value fakeWindowDims =
911 tensor::EmptyOp::create(rewriter, loc, kernel, accETy);
912
913 // Sum across the pooled region.
914 Value poolingOp = linalg::PoolingNhwcSumOp::create(
915 rewriter, loc, ArrayRef<Type>{accTy},
916 ValueRange{paddedInput, fakeWindowDims},
917 filledEmptyTensor, strideAttr, dilationAttr)
918 .getResult(0);
919
920 // Normalize the summed value by the number of elements grouped in each
921 // pool.
922 Value iH = tensor::DimOp::create(rewriter, loc, poolingOp, 1);
923 Value iW = tensor::DimOp::create(rewriter, loc, poolingOp, 2);
924
925 auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);
926 iH = arith::SubIOp::create(rewriter, loc, iH, one);
927 iW = arith::SubIOp::create(rewriter, loc, iW, one);
928
929 Value genericEmptyTensor = tensor::EmptyOp::create(
930 rewriter, loc, resultTy.getShape(), resultETy, dynamicDims);
931
932 auto affineMap = rewriter.getMultiDimIdentityMap(resultTy.getRank());
933 auto genericOp = linalg::GenericOp::create(
934 rewriter, loc, ArrayRef<Type>({resultTy}), ValueRange{poolingOp},
935 ValueRange{genericEmptyTensor},
936 ArrayRef<AffineMap>({affineMap, affineMap}),
937 getNParallelLoopsAttrs(resultTy.getRank()),
938 [&](OpBuilder &b, Location loc, ValueRange args) {
939 auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
940
941 // Determines what the portion of valid input is covered by the
942 // kernel.
943 auto padFn = [&](Value valid, Value pos, int64_t pad) -> Value {
944 if (pad == 0)
945 return valid;
946
947 auto padVal = arith::ConstantIndexOp::create(rewriter, loc, pad);
948 Value dpos = arith::SubIOp::create(rewriter, loc, pos, padVal);
949
950 Value offset = arith::MinSIOp::create(rewriter, loc, dpos, zero);
951 return arith::AddIOp::create(rewriter, loc, valid, offset)
952 ->getResult(0);
953 };
954
955 auto coverageFn = [&](int64_t i, Value isize) -> Value {
956 Value strideVal =
957 arith::ConstantIndexOp::create(rewriter, loc, stride[i - 1]);
958 Value val =
959 arith::ConstantIndexOp::create(rewriter, loc, kernel[i - 1]);
960
961 // Find the position relative to the input tensor's ends.
962 Value left = linalg::IndexOp::create(rewriter, loc, i);
963 Value right = arith::SubIOp::create(rewriter, loc, isize, left);
964 left = arith::MulIOp::create(rewriter, loc, left, strideVal);
965 right = arith::MulIOp::create(rewriter, loc, right, strideVal);
966
967 // Determine how much padding was included.
968 val = padFn(val, left, pad[i * 2]);
969 val = padFn(val, right, pad[i * 2 + 1]);
970 return arith::MaxSIOp::create(rewriter, loc, one, val);
971 };
972
973 // Compute the indices from either end.
974 Value kH3 = coverageFn(1, iH);
975 Value kW3 = coverageFn(2, iW);
976
977 // Compute the total number of elements and normalize.
978 auto count = arith::IndexCastOp::create(
979 rewriter, loc, rewriter.getI32Type(),
980 arith::MulIOp::create(rewriter, loc, kH3, kW3));
981
982 // Divide by the number of summed values. For floats this is just
983 // a div however for quantized values input normalization had
984 // to be applied.
985 Value poolVal = args[0];
986 if (isa<FloatType>(accETy)) {
987 auto countF = arith::SIToFPOp::create(rewriter, loc, accETy, count);
988 poolVal = arith::DivFOp::create(rewriter, loc, poolVal, countF)
989 ->getResult(0);
990 if (accETy.getIntOrFloatBitWidth() >
991 resultETy.getIntOrFloatBitWidth())
992 poolVal =
993 arith::TruncFOp::create(rewriter, loc, resultETy, poolVal);
994 } else {
995
996 // If we have quantization information we need to apply an offset
997 // for the input zp value.
998 if (inputZpVal != 0) {
999 auto inputZp = arith::ConstantOp::create(
1000 rewriter, loc, b.getIntegerAttr(accETy, inputZpVal));
1001 Value offset =
1002 arith::MulIOp::create(rewriter, loc, accETy, count, inputZp);
1003 poolVal =
1004 arith::SubIOp::create(rewriter, loc, accETy, poolVal, offset);
1005 }
1006
1007 // Compute: k = 32 - count_leading_zeros(value - 1)
1008 Value one32 = arith::ConstantOp::create(
1009 rewriter, loc, rewriter.getI32IntegerAttr(1));
1010 Value thirtyTwo32 = arith::ConstantOp::create(
1011 rewriter, loc, rewriter.getI32IntegerAttr(32));
1012
1013 Value countSubOne =
1014 arith::SubIOp::create(rewriter, loc, count, one32);
1015 Value leadingZeros =
1016 math::CountLeadingZerosOp::create(rewriter, loc, countSubOne);
1017 Value k =
1018 arith::SubIOp::create(rewriter, loc, thirtyTwo32, leadingZeros);
1019
1020 // Compute: numerator = ((1 << 30) + 1) << k
1021 Value k64 =
1022 arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), k);
1023 Value thirtyShiftPlusOne = arith::ConstantOp::create(
1024 rewriter, loc, rewriter.getI64IntegerAttr((1 << 30) + 1));
1025 Value numerator =
1026 arith::ShLIOp::create(rewriter, loc, thirtyShiftPlusOne, k64);
1027
1028 // Compute: scale.multiplier = numerator / value;
1029 Value count64 = arith::ExtUIOp::create(
1030 rewriter, loc, rewriter.getI64Type(), count);
1031 Value multiplier =
1032 arith::DivUIOp::create(rewriter, loc, numerator, count64);
1033 multiplier = arith::TruncIOp::create(
1034 rewriter, loc, rewriter.getI32Type(), multiplier);
1035
1036 // Compute: scale.shift = 30 + k
1037 Value k8 =
1038 arith::TruncIOp::create(rewriter, loc, rewriter.getI8Type(), k);
1039 Value thirty8 = arith::ConstantOp::create(
1040 rewriter, loc, rewriter.getI8IntegerAttr(30));
1041 Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
1042
1043 auto roundingAttr = RoundingModeAttr::get(
1044 rewriter.getContext(), RoundingMode::SINGLE_ROUND);
1045
1046 auto scaled = tosa::ApplyScaleOp::create(
1047 rewriter, loc, rewriter.getI32Type(), poolVal,
1048 multiplier, shift, roundingAttr)
1049 .getResult();
1050
1051 // If we have quantization information we need to apply output
1052 // zeropoint.
1053 if (outputZpVal != 0) {
1054 auto outputZp = arith::ConstantOp::create(
1055 rewriter, loc,
1056 b.getIntegerAttr(scaled.getType(), outputZpVal));
1057 scaled = arith::AddIOp::create(rewriter, loc, scaled, outputZp)
1058 .getResult();
1059 }
1060
1061 // Apply Clip.
1062 int64_t outBitwidth = resultETy.getIntOrFloatBitWidth();
1063
1065 rewriter, loc, accETy,
1066 APInt::getSignedMinValue(outBitwidth).getSExtValue());
1068 rewriter, loc, accETy,
1069 APInt::getSignedMaxValue(outBitwidth).getSExtValue());
1070 auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
1071 /*isUnsigned=*/false);
1072
1073 poolVal = clamp;
1074 // Convert type.
1075 if (resultETy != clamp.getType()) {
1076 poolVal =
1077 arith::TruncIOp::create(rewriter, loc, resultETy, poolVal);
1078 }
1079 }
1080
1081 linalg::YieldOp::create(rewriter, loc, poolVal);
1082 });
1083
1084 rewriter.replaceOp(op, genericOp.getResult(0));
1085 return success();
1086 }
1087};
1088
1089class TransposeConverter : public OpRewritePattern<tosa::TransposeOp> {
1090public:
1091 using OpRewritePattern<tosa::TransposeOp>::OpRewritePattern;
1092
1093 LogicalResult matchAndRewrite(tosa::TransposeOp op,
1094 PatternRewriter &rewriter) const final {
1095 const llvm::ArrayRef<int32_t> constantPerms = op.getPerms();
1096
1097 Location loc = op.getLoc();
1098 // The verifier should have made sure we have a valid TOSA permutation
1099 // tensor. isPermutationVector doesn't actually check the TOSA perms we
1100 // expect.
1101 SmallVector<OpFoldResult> inputSizes =
1102 tensor::getMixedSizes(rewriter, loc, op.getInput1());
1103 auto permutedSizes =
1104 applyTOSAPermutation<OpFoldResult>(inputSizes, constantPerms);
1105
1106 auto permutedInit =
1107 tensor::EmptyOp::create(rewriter, loc, permutedSizes,
1108 op.getInput1().getType().getElementType());
1109 rewriter.replaceOpWithNewOp<linalg::TransposeOp>(
1110 op, op.getInput1(), permutedInit,
1111 llvm::map_to_vector(constantPerms,
1112 [](int32_t v) -> int64_t { return v; }));
1113 return success();
1114 }
1115};
1116} // namespace
1117
1119 const TypeConverter &converter, RewritePatternSet *patterns,
1120 const TosaToLinalgNamedOptions &options) {
1121 if (options.preferConv2DKernelLayoutHWCF) {
1122 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp,
1123 linalg::Conv2DNhwcHwcfQOp>>(
1124 patterns->getContext());
1125 } else {
1126 patterns->add<ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp,
1127 linalg::Conv2DNhwcFhwcQOp>>(
1128 patterns->getContext());
1129 }
1130 patterns->add<
1131 // clang-format off
1132 ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
1133 DepthwiseConvConverter,
1134 MatMulConverter,
1135 AvgPool2dConverter,
1136 TransposeConverter
1137 >(patterns->getContext());
1138
1139 patterns->add<
1140 MaxPool2dConverter
1141 >(converter, patterns->getContext());
1142 // clang-format on
1143}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source, Value result)
static mlir::Value applyPad(Location loc, Value input, ArrayRef< int64_t > pad, TypedAttr padAttr, OpBuilder &rewriter)
static void createDepthwiseConvCollapseMap(int64_t outputRank, SmallVector< ReassociationExprs, 4 > &reassociationMap, OpBuilder &rewriter)
static mlir::Value linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef< AffineMap > indexingMaps)
static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim, int64_t padBeforeAttr, int64_t padAfterAttr, Value kernelDim, int64_t strideAttr, int64_t dilationAttr, OpBuilder &rewriter)
static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter, Location loc, Value source, Value result)
static mlir::Value reifyConstantDim(int64_t attr, ImplicitLocOpBuilder &builder)
static SmallVector< Value > inferDynamicDimsForConv(Location loc, Value input, Value weight, ShapedType resultTy, ArrayRef< int64_t > padAttr, ArrayRef< int64_t > strideAttr, ArrayRef< int64_t > dilationAttr, ArrayRef< int64_t > inputSizeDims, ArrayRef< int64_t > kernelSizeDims, OpBuilder &rewriter)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition Builders.cpp:387
AffineExpr getAffineConstantExpr(int64_t constant)
Definition Builders.cpp:372
AffineExpr getAffineDimExpr(unsigned position)
Definition Builders.cpp:364
MLIRContext * getContext() const
Definition Builders.h:56
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:362
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition ArithOps.cpp:261
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:68
SmallVector< T > applyTOSAPermutation(ArrayRef< T > input, ArrayRef< int32_t > perms)
SmallVector< utils::IteratorType > getNParallelLoopsAttrs(unsigned nParallelLoops)
SmallVector< Value > condenseValues(const SmallVector< Value > &values)
std::optional< SmallVector< Value > > checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op, ArrayRef< Value > params)
Value clampIntHelper(Location loc, Value arg, Value min, Value max, OpBuilder &rewriter, bool isUnsigned)
void populateTosaToLinalgNamedConversionPatterns(const TypeConverter &converter, RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options)
Populates conversion passes from TOSA dialect to Linalg named operations.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...