MLIR 23.0.0git
TosaToSPIRVTosaOps.cpp
Go to the documentation of this file.
1//===- TosaToSPIRVTosaOps.cpp - TOSA to SPIR-V Graph/TOSA patterns --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert TOSA IR to SPIR-V Graph/TOSA.
10//
11//===----------------------------------------------------------------------===//
12
17
18#define DEBUG_TYPE "tosa-to-spirv-tosa-ops-pattern"
19
20namespace mlir::tosa {
21namespace {
22
23constexpr unsigned maxConcatOpInputs = 64;
24
25template <typename OpAdaptor>
26spirv::TosaExtNaNPropagationModeType getNanMode(OpAdaptor adaptor) {
27 return static_cast<spirv::TosaExtNaNPropagationModeType>(
28 adaptor.getNanMode());
29}
30
31template <typename OpAdaptor>
32spirv::TosaExtResizeModeType getResizeMode(OpAdaptor adaptor) {
33 return static_cast<spirv::TosaExtResizeModeType>(adaptor.getMode());
34}
35
36template <typename OpAdaptor>
37spirv::TosaExtRoundingModeType getRoundingMode(OpAdaptor adaptor) {
38 return static_cast<spirv::TosaExtRoundingModeType>(adaptor.getRoundingMode());
39}
40
41spirv::TosaExtAccType getAccType(Type accType) {
42 if (accType.isInteger(32))
43 return spirv::TosaExtAccType::INT32;
44 else if (accType.isF16())
45 return spirv::TosaExtAccType::FP16;
46 else if (accType.isF32())
47 return spirv::TosaExtAccType::FP32;
48 else if (accType.isInteger(48))
49 return spirv::TosaExtAccType::INT48;
50 llvm_unreachable("unknown accumulator type");
51}
52
53DenseIntElementsAttr getI32TensorArmAttr(ArrayRef<int32_t> values,
54 ConversionPatternRewriter &rewriter) {
56 spirv::TensorArmType::get(static_cast<int64_t>(values.size()),
57 IntegerType::get(rewriter.getContext(), 32)),
58 values);
59}
60
61// TOSA stores many integer array attributes as i64 in MLIR, while the
62// SPIR-V TOSA extended instruction set models the same attributes as i32.
63DenseIntElementsAttr getI32TensorArmAttr(ArrayRef<int64_t> values,
64 ConversionPatternRewriter &rewriter) {
65 SmallVector<int32_t> i32Values(values.begin(), values.end());
66 return getI32TensorArmAttr(i32Values, rewriter);
67}
68
69FailureOr<DenseElementsAttr>
70convertDenseElementsAttr(DenseElementsAttr values, ShapedType convertedType) {
71 Type convertedElementType = convertedType.getElementType();
72 if (values.getElementType() == convertedElementType)
73 return values.reshape(convertedType);
74
75 // Constant attributes still have the source TOSA element type. Rebuild them
76 // for the converted SPIR-V tensor type, including integer type-converter
77 // changes such as index to i32, i4 to i8, and i48 to i64.
78 auto integerType = dyn_cast<IntegerType>(convertedElementType);
79 if (!integerType)
80 return failure();
81
82 // The SPIR-V ARM tensor type does not represent scalar shape constants
83 // directly, so model tensor<0xindex> as a one-element i32 tensor.
84 if (values.empty() && values.getElementType().isIndex())
85 return DenseIntElementsAttr::get(convertedType, {1});
86
87 DenseElementsAttr convertedValues =
88 values.mapValues(integerType, [&](const APInt &value) {
89 return value.sextOrTrunc(integerType.getWidth());
90 });
91 return convertedValues.reshape(convertedType);
92}
93
94// Split a large concat into smaller concat operations so the generated SPIR-V
95// instructions stay below the binary operand count limit.
96LogicalResult splitConcat(tosa::ConcatOp op, Type resultType, int32_t axis,
97 ValueRange inputs,
98 ConversionPatternRewriter &rewriter) {
99 auto resultTensorType = dyn_cast<spirv::TensorArmType>(resultType);
100 if (!resultTensorType)
101 return rewriter.notifyMatchFailure(op, "expected tensor result type");
102 if (!resultTensorType.hasRank())
103 return rewriter.notifyMatchFailure(op,
104 "expected ranked tensor result type");
105
106 SmallVector<Value> concatInputs;
107 SmallVector<int64_t> concatShape(resultTensorType.getShape());
108 concatShape[axis] = 0;
109
110 for (auto [index, input] : llvm::enumerate(inputs)) {
111 auto inputType = dyn_cast<spirv::TensorArmType>(input.getType());
112 if (!inputType)
113 return rewriter.notifyMatchFailure(op, "expected tensor input type");
114 if (!inputType.hasRank())
115 return rewriter.notifyMatchFailure(op,
116 "expected ranked tensor input type");
117
118 int64_t inputAxisDim = inputType.getShape()[axis];
119 if (ShapedType::isDynamic(inputAxisDim) ||
120 ShapedType::isDynamic(concatShape[axis]))
121 concatShape[axis] = ShapedType::kDynamic;
122 else
123 concatShape[axis] += inputAxisDim;
124
125 concatInputs.push_back(input);
126 if (concatInputs.size() != maxConcatOpInputs || index == inputs.size() - 1)
127 continue;
128
129 Type concatType = spirv::TensorArmType::get(
130 concatShape, resultTensorType.getElementType());
131 auto concat = spirv::TosaConcatOp::create(rewriter, op.getLoc(), concatType,
132 axis, concatInputs);
133 concatInputs.clear();
134 concatInputs.push_back(concat.getOutput());
135 }
136
137 rewriter.replaceOpWithNewOp<spirv::TosaConcatOp>(op, resultType, axis,
138 concatInputs);
139 return success();
140}
141
142template <typename SourceOp, auto Replace>
143struct TosaOpConvert final : public OpConversionPattern<SourceOp> {
144 using OpConversionPattern<SourceOp>::OpConversionPattern;
145
146 LogicalResult
147 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
148 ConversionPatternRewriter &rewriter) const override {
149 Type type = this->getTypeConverter()->convertType(op.getType());
150 if (!type)
151 return rewriter.notifyMatchFailure(op, "type conversion failed");
152 return Replace(op, adaptor, type, rewriter);
153 }
154};
155
156template <typename SourceOp, auto Replace>
157struct TosaMultiResultOpConvert final : public OpConversionPattern<SourceOp> {
158 using OpConversionPattern<SourceOp>::OpConversionPattern;
159
160 LogicalResult
161 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
162 ConversionPatternRewriter &rewriter) const override {
163 SmallVector<Type> types;
164 if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(),
165 types)))
166 return rewriter.notifyMatchFailure(op, "type conversion failed");
167 return Replace(op, adaptor, types, rewriter);
168 }
169};
170
171template <typename SourceOp, typename TargetOp>
172LogicalResult replaceUnaryInput1(SourceOp op,
173 typename SourceOp::Adaptor adaptor, Type type,
174 ConversionPatternRewriter &rewriter) {
175 rewriter.replaceOpWithNewOp<TargetOp>(op, type, adaptor.getInput1());
176 return success();
177}
178
179template <typename SourceOp, typename TargetOp>
180LogicalResult replaceUnaryInput(SourceOp op, typename SourceOp::Adaptor adaptor,
181 Type type,
182 ConversionPatternRewriter &rewriter) {
183 rewriter.replaceOpWithNewOp<TargetOp>(op, type, adaptor.getInput());
184 return success();
185}
186
187template <typename SourceOp, typename TargetOp>
188LogicalResult
189replaceBinaryElementwise(SourceOp op, typename SourceOp::Adaptor adaptor,
190 Type type, ConversionPatternRewriter &rewriter) {
191 rewriter.replaceOpWithNewOp<TargetOp>(op, type, adaptor.getInput1(),
192 adaptor.getInput2());
193 return success();
194}
195
196template <typename SourceOp, typename TargetOp>
197LogicalResult
198replaceBinaryNanModeElementwise(SourceOp op, typename SourceOp::Adaptor adaptor,
199 Type type,
200 ConversionPatternRewriter &rewriter) {
201 rewriter.replaceOpWithNewOp<TargetOp>(
202 op, type, getNanMode(adaptor), adaptor.getInput1(), adaptor.getInput2());
203 return success();
204}
205
206template <typename SourceOp, typename TargetOp>
207LogicalResult replaceReduction(SourceOp op, typename SourceOp::Adaptor adaptor,
208 Type type, ConversionPatternRewriter &rewriter) {
209 rewriter.replaceOpWithNewOp<TargetOp>(op, type, adaptor.getAxis(),
210 adaptor.getInput());
211 return success();
212}
213
214template <typename SourceOp, typename TargetOp>
215LogicalResult
216replaceNanModeReduction(SourceOp op, typename SourceOp::Adaptor adaptor,
217 Type type, ConversionPatternRewriter &rewriter) {
218 rewriter.replaceOpWithNewOp<TargetOp>(
219 op, type, adaptor.getAxis(), getNanMode(adaptor), adaptor.getInput());
220 return success();
221}
222
223LogicalResult replaceAvgPool2d(tosa::AvgPool2dOp op,
224 tosa::AvgPool2dOpAdaptor adaptor, Type type,
225 ConversionPatternRewriter &rewriter) {
226 rewriter.replaceOpWithNewOp<spirv::TosaAvgPool2DOp>(
227 op, type, getI32TensorArmAttr(adaptor.getKernel(), rewriter),
228 getI32TensorArmAttr(adaptor.getStride(), rewriter),
229 getI32TensorArmAttr(adaptor.getPad(), rewriter),
230 getAccType(adaptor.getAccType()), adaptor.getInput(),
231 adaptor.getInputZp(), adaptor.getOutputZp());
232 return success();
233}
234
235template <typename SourceOp, typename TargetOp>
236LogicalResult replaceConvolution(SourceOp op,
237 typename SourceOp::Adaptor adaptor, Type type,
238 ConversionPatternRewriter &rewriter) {
239 rewriter.replaceOpWithNewOp<TargetOp>(
240 op, type, getI32TensorArmAttr(adaptor.getPad(), rewriter),
241 getI32TensorArmAttr(adaptor.getStride(), rewriter),
242 getI32TensorArmAttr(adaptor.getDilation(), rewriter),
243 getAccType(adaptor.getAccType()), adaptor.getLocalBound(),
244 adaptor.getInput(), adaptor.getWeight(), adaptor.getBias(),
245 adaptor.getInputZp(), adaptor.getWeightZp());
246 return success();
247}
248
249LogicalResult replaceFFT2d(tosa::FFT2dOp op, tosa::FFT2dOpAdaptor adaptor,
250 ArrayRef<Type> types,
251 ConversionPatternRewriter &rewriter) {
252 auto structType = spirv::StructType::get(types);
253 auto result = spirv::TosaFFT2DOp::create(
254 rewriter, op.getLoc(), structType, adaptor.getInverse(),
255 adaptor.getLocalBound(), adaptor.getInputReal(), adaptor.getInputImag());
256 auto outputReal =
257 spirv::CompositeExtractOp::create(rewriter, op.getLoc(), result, {0});
258 auto outputImag =
259 spirv::CompositeExtractOp::create(rewriter, op.getLoc(), result, {1});
260 rewriter.replaceOp(op, {outputReal, outputImag});
261 return success();
262}
263
264LogicalResult replaceMatMul(tosa::MatMulOp op, tosa::MatMulOpAdaptor adaptor,
265 Type type, ConversionPatternRewriter &rewriter) {
266 rewriter.replaceOpWithNewOp<spirv::TosaMatMulOp>(
267 op, type, adaptor.getA(), adaptor.getB(), adaptor.getAZp(),
268 adaptor.getBZp());
269 return success();
270}
271
272LogicalResult replaceMaxPool2d(tosa::MaxPool2dOp op,
273 tosa::MaxPool2dOpAdaptor adaptor, Type type,
274 ConversionPatternRewriter &rewriter) {
275 rewriter.replaceOpWithNewOp<spirv::TosaMaxPool2DOp>(
276 op, type, getI32TensorArmAttr(adaptor.getKernel(), rewriter),
277 getI32TensorArmAttr(adaptor.getStride(), rewriter),
278 getI32TensorArmAttr(adaptor.getPad(), rewriter), getNanMode(adaptor),
279 adaptor.getInput());
280 return success();
281}
282
283LogicalResult replaceRFFT2d(tosa::RFFT2dOp op, tosa::RFFT2dOpAdaptor adaptor,
284 ArrayRef<Type> types,
285 ConversionPatternRewriter &rewriter) {
286 auto structType = spirv::StructType::get(types);
287 auto result = spirv::TosaRFFT2DOp::create(rewriter, op.getLoc(), structType,
288 adaptor.getLocalBound(),
289 adaptor.getInputReal());
290 auto outputReal =
291 spirv::CompositeExtractOp::create(rewriter, op.getLoc(), result, {0});
292 auto outputImag =
293 spirv::CompositeExtractOp::create(rewriter, op.getLoc(), result, {1});
294 rewriter.replaceOp(op, {outputReal, outputImag});
295 return success();
296}
297
298LogicalResult replaceTransposeConv2d(tosa::TransposeConv2DOp op,
299 tosa::TransposeConv2DOpAdaptor adaptor,
300 Type type,
301 ConversionPatternRewriter &rewriter) {
302 rewriter.replaceOpWithNewOp<spirv::TosaTransposeConv2DOp>(
303 op, type, getI32TensorArmAttr(adaptor.getOutPad(), rewriter),
304 getI32TensorArmAttr(adaptor.getStride(), rewriter),
305 getAccType(adaptor.getAccType()), adaptor.getLocalBound(),
306 adaptor.getInput(), adaptor.getWeight(), adaptor.getBias(),
307 adaptor.getInputZp(), adaptor.getWeightZp());
308 return success();
309}
310
311LogicalResult replaceClamp(tosa::ClampOp op, tosa::ClampOpAdaptor adaptor,
312 Type type, ConversionPatternRewriter &rewriter) {
313 rewriter.replaceOpWithNewOp<spirv::TosaClampOp>(
314 op, type, adaptor.getMinVal(), adaptor.getMaxVal(), getNanMode(adaptor),
315 adaptor.getInput());
316 return success();
317}
318
319LogicalResult
320replaceArithmeticRightShift(tosa::ArithmeticRightShiftOp op,
321 tosa::ArithmeticRightShiftOpAdaptor adaptor,
322 Type type, ConversionPatternRewriter &rewriter) {
323 rewriter.replaceOpWithNewOp<spirv::TosaArithmeticRightShiftOp>(
324 op, type, adaptor.getRound(), adaptor.getInput1(), adaptor.getInput2());
325 return success();
326}
327
328LogicalResult replaceMul(tosa::MulOp op, tosa::MulOpAdaptor adaptor, Type type,
329 ConversionPatternRewriter &rewriter) {
330 rewriter.replaceOpWithNewOp<spirv::TosaMulOp>(
331 op, type, adaptor.getInput1(), adaptor.getInput2(), adaptor.getShift());
332 return success();
333}
334
335LogicalResult replaceTable(tosa::TableOp op, tosa::TableOpAdaptor adaptor,
336 Type type, ConversionPatternRewriter &rewriter) {
337 rewriter.replaceOpWithNewOp<spirv::TosaTableOp>(op, type, adaptor.getInput1(),
338 adaptor.getTable());
339 return success();
340}
341
342LogicalResult replaceNegate(tosa::NegateOp op, tosa::NegateOpAdaptor adaptor,
343 Type type, ConversionPatternRewriter &rewriter) {
344 rewriter.replaceOpWithNewOp<spirv::TosaNegateOp>(
345 op, type, adaptor.getInput1(), adaptor.getInput1Zp(),
346 adaptor.getOutputZp());
347 return success();
348}
349
350LogicalResult replaceSelect(tosa::SelectOp op, tosa::SelectOpAdaptor adaptor,
351 Type type, ConversionPatternRewriter &rewriter) {
352 rewriter.replaceOpWithNewOp<spirv::TosaSelectOp>(
353 op, type, adaptor.getInput1(), adaptor.getInput2(), adaptor.getInput3());
354 return success();
355}
356
357LogicalResult replaceConcat(tosa::ConcatOp op, tosa::ConcatOpAdaptor adaptor,
358 Type type, ConversionPatternRewriter &rewriter) {
359 // Large TOSA concats can produce SPIR-V instructions with too many
360 // operands and fail validation. Split them into conservative 64-input
361 // chunks to keep the generated SPIR-V valid.
362 if (adaptor.getInput1().size() > maxConcatOpInputs)
363 return splitConcat(op, type, adaptor.getAxis(), adaptor.getInput1(),
364 rewriter);
365
366 rewriter.replaceOpWithNewOp<spirv::TosaConcatOp>(op, type, adaptor.getAxis(),
367 adaptor.getInput1());
368 return success();
369}
370
371LogicalResult replacePad(tosa::PadOp op, tosa::PadOpAdaptor adaptor, Type type,
372 ConversionPatternRewriter &rewriter) {
373 rewriter.replaceOpWithNewOp<spirv::TosaPadOp>(op, type, adaptor.getInput1(),
374 adaptor.getPadding(),
375 adaptor.getPadConst());
376 return success();
377}
378
379LogicalResult replaceReshape(tosa::ReshapeOp op, tosa::ReshapeOpAdaptor adaptor,
380 Type type, ConversionPatternRewriter &rewriter) {
381 rewriter.replaceOpWithNewOp<spirv::TosaReshapeOp>(
382 op, type, adaptor.getInput1(), adaptor.getShape());
383 return success();
384}
385
386LogicalResult replaceReverse(tosa::ReverseOp op, tosa::ReverseOpAdaptor adaptor,
387 Type type, ConversionPatternRewriter &rewriter) {
388 rewriter.replaceOpWithNewOp<spirv::TosaReverseOp>(op, type, adaptor.getAxis(),
389 adaptor.getInput1());
390 return success();
391}
392
393LogicalResult replaceSlice(tosa::SliceOp op, tosa::SliceOpAdaptor adaptor,
394 Type type, ConversionPatternRewriter &rewriter) {
395 rewriter.replaceOpWithNewOp<spirv::TosaSliceOp>(
396 op, type, adaptor.getInput1(), adaptor.getStart(), adaptor.getSize());
397 return success();
398}
399
400LogicalResult replaceTile(tosa::TileOp op, tosa::TileOpAdaptor adaptor,
401 Type type, ConversionPatternRewriter &rewriter) {
402 rewriter.replaceOpWithNewOp<spirv::TosaTileOp>(op, type, adaptor.getInput1(),
403 adaptor.getMultiples());
404 return success();
405}
406
407LogicalResult replaceTranspose(tosa::TransposeOp op,
408 tosa::TransposeOpAdaptor adaptor, Type type,
409 ConversionPatternRewriter &rewriter) {
410 DenseIntElementsAttr perms =
411 getI32TensorArmAttr(adaptor.getPerms(), rewriter);
412 rewriter.replaceOpWithNewOp<spirv::TosaTransposeOp>(op, type, perms,
413 adaptor.getInput1());
414 return success();
415}
416
417LogicalResult replaceGather(tosa::GatherOp op, tosa::GatherOpAdaptor adaptor,
418 Type type, ConversionPatternRewriter &rewriter) {
419 rewriter.replaceOpWithNewOp<spirv::TosaGatherOp>(
420 op, type, adaptor.getValues(), adaptor.getIndices());
421 return success();
422}
423
424LogicalResult replaceScatter(tosa::ScatterOp op, tosa::ScatterOpAdaptor adaptor,
425 Type type, ConversionPatternRewriter &rewriter) {
426 rewriter.replaceOpWithNewOp<spirv::TosaScatterOp>(
427 op, type, adaptor.getValuesIn(), adaptor.getIndices(),
428 adaptor.getInput());
429 return success();
430}
431
432LogicalResult replaceResize(tosa::ResizeOp op, tosa::ResizeOpAdaptor adaptor,
433 Type type, ConversionPatternRewriter &rewriter) {
434 rewriter.replaceOpWithNewOp<spirv::TosaResizeOp>(
435 op, type, getResizeMode(adaptor), adaptor.getInput(), adaptor.getScale(),
436 adaptor.getOffset(), adaptor.getBorder());
437 return success();
438}
439
440LogicalResult replaceRescale(tosa::RescaleOp op, tosa::RescaleOpAdaptor adaptor,
441 Type type, ConversionPatternRewriter &rewriter) {
442 rewriter.replaceOpWithNewOp<spirv::TosaRescaleOp>(
443 op, type, adaptor.getScale32(), getRoundingMode(adaptor),
444 adaptor.getPerChannel(), adaptor.getInputUnsigned(),
445 adaptor.getOutputUnsigned(), adaptor.getInput(), adaptor.getMultiplier(),
446 adaptor.getShift(), adaptor.getInputZp(), adaptor.getOutputZp());
447 return success();
448}
449
450template <typename SourceOp>
451LogicalResult replaceConstant(SourceOp op, typename SourceOp::Adaptor adaptor,
452 Type type, ConversionPatternRewriter &rewriter) {
453 if (auto graphConstantId = op->template getAttrOfType<IntegerAttr>(
455 rewriter.replaceOpWithNewOp<spirv::GraphConstantARMOp>(op, type,
456 graphConstantId);
457 return success();
458 }
459
460 auto convertedType = dyn_cast<ShapedType>(type);
461 auto values = dyn_cast<DenseElementsAttr>(adaptor.getValues());
462 if (!convertedType || !values)
463 return failure();
464
465 FailureOr<DenseElementsAttr> convertedValues =
466 convertDenseElementsAttr(values, convertedType);
467 if (failed(convertedValues))
468 return failure();
469
470 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, type, *convertedValues);
471 return success();
472}
473
474LogicalResult replaceIdentity(tosa::IdentityOp op,
475 tosa::IdentityOpAdaptor adaptor, Type type,
476 ConversionPatternRewriter &rewriter) {
477 rewriter.replaceOp(op, adaptor.getInput1());
478 return success();
479}
480
481} // namespace
482
484 SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
485 patterns.add<
486 TosaOpConvert<tosa::ArgMaxOp, replaceNanModeReduction<
487 tosa::ArgMaxOp, spirv::TosaArgMaxOp>>,
488 TosaOpConvert<tosa::AvgPool2dOp, replaceAvgPool2d>,
489 TosaOpConvert<tosa::Conv2DOp,
490 replaceConvolution<tosa::Conv2DOp, spirv::TosaConv2DOp>>,
491 TosaOpConvert<tosa::Conv3DOp,
492 replaceConvolution<tosa::Conv3DOp, spirv::TosaConv3DOp>>,
493 TosaOpConvert<tosa::DepthwiseConv2DOp,
494 replaceConvolution<tosa::DepthwiseConv2DOp,
495 spirv::TosaDepthwiseConv2DOp>>,
496 TosaMultiResultOpConvert<tosa::FFT2dOp, replaceFFT2d>,
497 TosaOpConvert<tosa::MatMulOp, replaceMatMul>,
498 TosaOpConvert<tosa::MaxPool2dOp, replaceMaxPool2d>,
499 TosaMultiResultOpConvert<tosa::RFFT2dOp, replaceRFFT2d>,
500 TosaOpConvert<tosa::TransposeConv2DOp, replaceTransposeConv2d>,
501 TosaOpConvert<tosa::ClampOp, replaceClamp>,
502 TosaOpConvert<tosa::ErfOp,
503 replaceUnaryInput<tosa::ErfOp, spirv::TosaErfOp>>,
504 TosaOpConvert<tosa::SigmoidOp,
505 replaceUnaryInput<tosa::SigmoidOp, spirv::TosaSigmoidOp>>,
506 TosaOpConvert<tosa::TanhOp,
507 replaceUnaryInput<tosa::TanhOp, spirv::TosaTanhOp>>,
508 TosaOpConvert<tosa::AddOp,
509 replaceBinaryElementwise<tosa::AddOp, spirv::TosaAddOp>>,
510 TosaOpConvert<tosa::ArithmeticRightShiftOp, replaceArithmeticRightShift>,
511 TosaOpConvert<tosa::BitwiseAndOp,
512 replaceBinaryElementwise<tosa::BitwiseAndOp,
513 spirv::TosaBitwiseAndOp>>,
514 TosaOpConvert<
515 tosa::BitwiseOrOp,
516 replaceBinaryElementwise<tosa::BitwiseOrOp, spirv::TosaBitwiseOrOp>>,
517 TosaOpConvert<tosa::BitwiseXorOp,
518 replaceBinaryElementwise<tosa::BitwiseXorOp,
519 spirv::TosaBitwiseXorOp>>,
520 TosaOpConvert<tosa::IntDivOp, replaceBinaryElementwise<
521 tosa::IntDivOp, spirv::TosaIntDivOp>>,
522 TosaOpConvert<tosa::LogicalAndOp,
523 replaceBinaryElementwise<tosa::LogicalAndOp,
524 spirv::TosaLogicalAndOp>>,
525 TosaOpConvert<tosa::LogicalLeftShiftOp,
526 replaceBinaryElementwise<tosa::LogicalLeftShiftOp,
527 spirv::TosaLogicalLeftShiftOp>>,
528 TosaOpConvert<tosa::LogicalRightShiftOp,
529 replaceBinaryElementwise<tosa::LogicalRightShiftOp,
530 spirv::TosaLogicalRightShiftOp>>,
531 TosaOpConvert<
532 tosa::LogicalOrOp,
533 replaceBinaryElementwise<tosa::LogicalOrOp, spirv::TosaLogicalOrOp>>,
534 TosaOpConvert<tosa::LogicalXorOp,
535 replaceBinaryElementwise<tosa::LogicalXorOp,
536 spirv::TosaLogicalXorOp>>,
537 TosaOpConvert<tosa::MaximumOp,
538 replaceBinaryNanModeElementwise<tosa::MaximumOp,
539 spirv::TosaMaximumOp>>,
540 TosaOpConvert<tosa::MinimumOp,
541 replaceBinaryNanModeElementwise<tosa::MinimumOp,
542 spirv::TosaMinimumOp>>,
543 TosaOpConvert<tosa::MulOp, replaceMul>,
544 TosaOpConvert<tosa::PowOp,
545 replaceBinaryElementwise<tosa::PowOp, spirv::TosaPowOp>>,
546 TosaOpConvert<tosa::SubOp,
547 replaceBinaryElementwise<tosa::SubOp, spirv::TosaSubOp>>,
548 TosaOpConvert<tosa::TableOp, replaceTable>,
549 TosaOpConvert<tosa::AbsOp,
550 replaceUnaryInput1<tosa::AbsOp, spirv::TosaAbsOp>>,
551 TosaOpConvert<
552 tosa::BitwiseNotOp,
553 replaceUnaryInput1<tosa::BitwiseNotOp, spirv::TosaBitwiseNotOp>>,
554 TosaOpConvert<tosa::CeilOp,
555 replaceUnaryInput1<tosa::CeilOp, spirv::TosaCeilOp>>,
556 TosaOpConvert<tosa::ClzOp,
557 replaceUnaryInput1<tosa::ClzOp, spirv::TosaClzOp>>,
558 TosaOpConvert<tosa::CosOp,
559 replaceUnaryInput1<tosa::CosOp, spirv::TosaCosOp>>,
560 TosaOpConvert<tosa::ExpOp,
561 replaceUnaryInput1<tosa::ExpOp, spirv::TosaExpOp>>,
562 TosaOpConvert<tosa::FloorOp,
563 replaceUnaryInput1<tosa::FloorOp, spirv::TosaFloorOp>>,
564 TosaOpConvert<tosa::LogOp,
565 replaceUnaryInput1<tosa::LogOp, spirv::TosaLogOp>>,
566 TosaOpConvert<
567 tosa::LogicalNotOp,
568 replaceUnaryInput1<tosa::LogicalNotOp, spirv::TosaLogicalNotOp>>,
569 TosaOpConvert<tosa::NegateOp, replaceNegate>,
570 TosaOpConvert<
571 tosa::ReciprocalOp,
572 replaceUnaryInput1<tosa::ReciprocalOp, spirv::TosaReciprocalOp>>,
573 TosaOpConvert<tosa::RsqrtOp,
574 replaceUnaryInput1<tosa::RsqrtOp, spirv::TosaRsqrtOp>>,
575 TosaOpConvert<tosa::SinOp,
576 replaceUnaryInput1<tosa::SinOp, spirv::TosaSinOp>>,
577 TosaOpConvert<tosa::SelectOp, replaceSelect>,
578 TosaOpConvert<tosa::EqualOp, replaceBinaryElementwise<
579 tosa::EqualOp, spirv::TosaEqualOp>>,
580 TosaOpConvert<
581 tosa::GreaterOp,
582 replaceBinaryElementwise<tosa::GreaterOp, spirv::TosaGreaterOp>>,
583 TosaOpConvert<tosa::GreaterEqualOp,
584 replaceBinaryElementwise<tosa::GreaterEqualOp,
585 spirv::TosaGreaterEqualOp>>,
586 TosaOpConvert<
587 tosa::ReduceAllOp,
588 replaceReduction<tosa::ReduceAllOp, spirv::TosaReduceAllOp>>,
589 TosaOpConvert<
590 tosa::ReduceAnyOp,
591 replaceReduction<tosa::ReduceAnyOp, spirv::TosaReduceAnyOp>>,
592 TosaOpConvert<
593 tosa::ReduceMaxOp,
594 replaceNanModeReduction<tosa::ReduceMaxOp, spirv::TosaReduceMaxOp>>,
595 TosaOpConvert<
596 tosa::ReduceMinOp,
597 replaceNanModeReduction<tosa::ReduceMinOp, spirv::TosaReduceMinOp>>,
598 TosaOpConvert<
599 tosa::ReduceProductOp,
600 replaceReduction<tosa::ReduceProductOp, spirv::TosaReduceProductOp>>,
601 TosaOpConvert<
602 tosa::ReduceSumOp,
603 replaceReduction<tosa::ReduceSumOp, spirv::TosaReduceSumOp>>,
604 TosaOpConvert<tosa::ConcatOp, replaceConcat>,
605 TosaOpConvert<tosa::PadOp, replacePad>,
606 TosaOpConvert<tosa::ReshapeOp, replaceReshape>,
607 TosaOpConvert<tosa::ReverseOp, replaceReverse>,
608 TosaOpConvert<tosa::SliceOp, replaceSlice>,
609 TosaOpConvert<tosa::TileOp, replaceTile>,
610 TosaOpConvert<tosa::TransposeOp, replaceTranspose>,
611 TosaOpConvert<tosa::GatherOp, replaceGather>,
612 TosaOpConvert<tosa::ScatterOp, replaceScatter>,
613 TosaOpConvert<tosa::ResizeOp, replaceResize>,
614 TosaOpConvert<tosa::CastOp,
615 replaceUnaryInput<tosa::CastOp, spirv::TosaCastOp>>,
616 TosaOpConvert<tosa::RescaleOp, replaceRescale>,
617 TosaOpConvert<tosa::ConstOp, replaceConstant<tosa::ConstOp>>,
618 TosaOpConvert<tosa::ConstShapeOp, replaceConstant<tosa::ConstShapeOp>>,
619 TosaOpConvert<tosa::IdentityOp, replaceIdentity>>(typeConverter,
620 patterns.getContext());
621}
622
623} // namespace mlir::tosa
return success()
static llvm::Constant * convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, llvm::Type *llvmType, const ModuleTranslation &moduleTranslation)
Convert a dense elements attribute to an LLVM IR constant using its raw data storage if possible.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
DenseElementsAttr mapValues(Type newElementType, function_ref< APInt(const APInt &)> mapping) const
Generates a new DenseElementsAttr by mapping each value attribute, and constructing the DenseElements...
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
constexpr llvm::StringLiteral graphARMGraphConstantIdAttrName
void populateTosaToSPIRVTosaOpsConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)