18#define DEBUG_TYPE "tosa-to-spirv-tosa-ops-pattern"
23constexpr unsigned maxConcatOpInputs = 64;
25template <
typename OpAdaptor>
26spirv::TosaExtNaNPropagationModeType getNanMode(OpAdaptor adaptor) {
27 return static_cast<spirv::TosaExtNaNPropagationModeType
>(
28 adaptor.getNanMode());
31template <
typename OpAdaptor>
32spirv::TosaExtResizeModeType getResizeMode(OpAdaptor adaptor) {
33 return static_cast<spirv::TosaExtResizeModeType
>(adaptor.getMode());
36template <
typename OpAdaptor>
37spirv::TosaExtRoundingModeType getRoundingMode(OpAdaptor adaptor) {
38 return static_cast<spirv::TosaExtRoundingModeType
>(adaptor.getRoundingMode());
41spirv::TosaExtAccType getAccType(Type accType) {
42 if (accType.isInteger(32))
43 return spirv::TosaExtAccType::INT32;
44 else if (accType.isF16())
45 return spirv::TosaExtAccType::FP16;
46 else if (accType.isF32())
47 return spirv::TosaExtAccType::FP32;
48 else if (accType.isInteger(48))
49 return spirv::TosaExtAccType::INT48;
50 llvm_unreachable(
"unknown accumulator type");
53DenseIntElementsAttr getI32TensorArmAttr(ArrayRef<int32_t> values,
54 ConversionPatternRewriter &rewriter) {
57 IntegerType::get(rewriter.getContext(), 32)),
63DenseIntElementsAttr getI32TensorArmAttr(ArrayRef<int64_t> values,
64 ConversionPatternRewriter &rewriter) {
65 SmallVector<int32_t> i32Values(values.begin(), values.end());
66 return getI32TensorArmAttr(i32Values, rewriter);
69FailureOr<DenseElementsAttr>
71 Type convertedElementType = convertedType.getElementType();
72 if (values.getElementType() == convertedElementType)
73 return values.reshape(convertedType);
78 auto integerType = dyn_cast<IntegerType>(convertedElementType);
84 if (values.empty() && values.getElementType().isIndex())
87 DenseElementsAttr convertedValues =
88 values.
mapValues(integerType, [&](
const APInt &value) {
89 return value.sextOrTrunc(integerType.getWidth());
91 return convertedValues.
reshape(convertedType);
96LogicalResult splitConcat(tosa::ConcatOp op, Type resultType, int32_t axis,
98 ConversionPatternRewriter &rewriter) {
99 auto resultTensorType = dyn_cast<spirv::TensorArmType>(resultType);
100 if (!resultTensorType)
101 return rewriter.notifyMatchFailure(op,
"expected tensor result type");
102 if (!resultTensorType.hasRank())
103 return rewriter.notifyMatchFailure(op,
104 "expected ranked tensor result type");
106 SmallVector<Value> concatInputs;
107 SmallVector<int64_t> concatShape(resultTensorType.getShape());
108 concatShape[axis] = 0;
110 for (
auto [index, input] : llvm::enumerate(inputs)) {
111 auto inputType = dyn_cast<spirv::TensorArmType>(input.getType());
113 return rewriter.notifyMatchFailure(op,
"expected tensor input type");
114 if (!inputType.hasRank())
115 return rewriter.notifyMatchFailure(op,
116 "expected ranked tensor input type");
118 int64_t inputAxisDim = inputType.getShape()[axis];
119 if (ShapedType::isDynamic(inputAxisDim) ||
120 ShapedType::isDynamic(concatShape[axis]))
121 concatShape[axis] = ShapedType::kDynamic;
123 concatShape[axis] += inputAxisDim;
125 concatInputs.push_back(input);
126 if (concatInputs.size() != maxConcatOpInputs || index == inputs.size() - 1)
130 concatShape, resultTensorType.getElementType());
131 auto concat = spirv::TosaConcatOp::create(rewriter, op.getLoc(), concatType,
133 concatInputs.clear();
134 concatInputs.push_back(concat.getOutput());
137 rewriter.replaceOpWithNewOp<spirv::TosaConcatOp>(op, resultType, axis,
142template <
typename SourceOp, auto Replace>
143struct TosaOpConvert final :
public OpConversionPattern<SourceOp> {
144 using OpConversionPattern<SourceOp>::OpConversionPattern;
147 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
148 ConversionPatternRewriter &rewriter)
const override {
149 Type type = this->getTypeConverter()->convertType(op.getType());
151 return rewriter.notifyMatchFailure(op,
"type conversion failed");
152 return Replace(op, adaptor, type, rewriter);
156template <
typename SourceOp, auto Replace>
157struct TosaMultiResultOpConvert final :
public OpConversionPattern<SourceOp> {
158 using OpConversionPattern<SourceOp>::OpConversionPattern;
161 matchAndRewrite(SourceOp op,
typename SourceOp::Adaptor adaptor,
162 ConversionPatternRewriter &rewriter)
const override {
163 SmallVector<Type> types;
164 if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(),
166 return rewriter.notifyMatchFailure(op,
"type conversion failed");
167 return Replace(op, adaptor, types, rewriter);
171template <
typename SourceOp,
typename TargetOp>
172LogicalResult replaceUnaryInput1(SourceOp op,
173 typename SourceOp::Adaptor adaptor, Type type,
174 ConversionPatternRewriter &rewriter) {
175 rewriter.replaceOpWithNewOp<TargetOp>(op, type, adaptor.getInput1());
179template <
typename SourceOp,
typename TargetOp>
180LogicalResult replaceUnaryInput(SourceOp op,
typename SourceOp::Adaptor adaptor,
182 ConversionPatternRewriter &rewriter) {
183 rewriter.replaceOpWithNewOp<TargetOp>(op, type, adaptor.getInput());
187template <
typename SourceOp,
typename TargetOp>
189replaceBinaryElementwise(SourceOp op,
typename SourceOp::Adaptor adaptor,
190 Type type, ConversionPatternRewriter &rewriter) {
191 rewriter.replaceOpWithNewOp<TargetOp>(op, type, adaptor.getInput1(),
192 adaptor.getInput2());
196template <
typename SourceOp,
typename TargetOp>
198replaceBinaryNanModeElementwise(SourceOp op,
typename SourceOp::Adaptor adaptor,
200 ConversionPatternRewriter &rewriter) {
201 rewriter.replaceOpWithNewOp<TargetOp>(
202 op, type, getNanMode(adaptor), adaptor.getInput1(), adaptor.getInput2());
206template <
typename SourceOp,
typename TargetOp>
207LogicalResult replaceReduction(SourceOp op,
typename SourceOp::Adaptor adaptor,
208 Type type, ConversionPatternRewriter &rewriter) {
209 rewriter.replaceOpWithNewOp<TargetOp>(op, type, adaptor.getAxis(),
214template <
typename SourceOp,
typename TargetOp>
216replaceNanModeReduction(SourceOp op,
typename SourceOp::Adaptor adaptor,
217 Type type, ConversionPatternRewriter &rewriter) {
218 rewriter.replaceOpWithNewOp<TargetOp>(
219 op, type, adaptor.getAxis(), getNanMode(adaptor), adaptor.getInput());
223LogicalResult replaceAvgPool2d(tosa::AvgPool2dOp op,
224 tosa::AvgPool2dOpAdaptor adaptor, Type type,
225 ConversionPatternRewriter &rewriter) {
226 rewriter.replaceOpWithNewOp<spirv::TosaAvgPool2DOp>(
227 op, type, getI32TensorArmAttr(adaptor.getKernel(), rewriter),
228 getI32TensorArmAttr(adaptor.getStride(), rewriter),
229 getI32TensorArmAttr(adaptor.getPad(), rewriter),
230 getAccType(adaptor.getAccType()), adaptor.getInput(),
231 adaptor.getInputZp(), adaptor.getOutputZp());
235template <
typename SourceOp,
typename TargetOp>
236LogicalResult replaceConvolution(SourceOp op,
237 typename SourceOp::Adaptor adaptor, Type type,
238 ConversionPatternRewriter &rewriter) {
239 rewriter.replaceOpWithNewOp<TargetOp>(
240 op, type, getI32TensorArmAttr(adaptor.getPad(), rewriter),
241 getI32TensorArmAttr(adaptor.getStride(), rewriter),
242 getI32TensorArmAttr(adaptor.getDilation(), rewriter),
243 getAccType(adaptor.getAccType()), adaptor.getLocalBound(),
244 adaptor.getInput(), adaptor.getWeight(), adaptor.getBias(),
245 adaptor.getInputZp(), adaptor.getWeightZp());
249LogicalResult replaceFFT2d(tosa::FFT2dOp op, tosa::FFT2dOpAdaptor adaptor,
250 ArrayRef<Type> types,
251 ConversionPatternRewriter &rewriter) {
253 auto result = spirv::TosaFFT2DOp::create(
254 rewriter, op.getLoc(), structType, adaptor.getInverse(),
255 adaptor.getLocalBound(), adaptor.getInputReal(), adaptor.getInputImag());
257 spirv::CompositeExtractOp::create(rewriter, op.getLoc(),
result, {0});
259 spirv::CompositeExtractOp::create(rewriter, op.getLoc(),
result, {1});
260 rewriter.replaceOp(op, {outputReal, outputImag});
264LogicalResult replaceMatMul(tosa::MatMulOp op, tosa::MatMulOpAdaptor adaptor,
265 Type type, ConversionPatternRewriter &rewriter) {
266 rewriter.replaceOpWithNewOp<spirv::TosaMatMulOp>(
267 op, type, adaptor.getA(), adaptor.getB(), adaptor.getAZp(),
272LogicalResult replaceMaxPool2d(tosa::MaxPool2dOp op,
273 tosa::MaxPool2dOpAdaptor adaptor, Type type,
274 ConversionPatternRewriter &rewriter) {
275 rewriter.replaceOpWithNewOp<spirv::TosaMaxPool2DOp>(
276 op, type, getI32TensorArmAttr(adaptor.getKernel(), rewriter),
277 getI32TensorArmAttr(adaptor.getStride(), rewriter),
278 getI32TensorArmAttr(adaptor.getPad(), rewriter), getNanMode(adaptor),
283LogicalResult replaceRFFT2d(tosa::RFFT2dOp op, tosa::RFFT2dOpAdaptor adaptor,
284 ArrayRef<Type> types,
285 ConversionPatternRewriter &rewriter) {
287 auto result = spirv::TosaRFFT2DOp::create(rewriter, op.getLoc(), structType,
288 adaptor.getLocalBound(),
289 adaptor.getInputReal());
291 spirv::CompositeExtractOp::create(rewriter, op.getLoc(),
result, {0});
293 spirv::CompositeExtractOp::create(rewriter, op.getLoc(),
result, {1});
294 rewriter.replaceOp(op, {outputReal, outputImag});
298LogicalResult replaceTransposeConv2d(tosa::TransposeConv2DOp op,
299 tosa::TransposeConv2DOpAdaptor adaptor,
301 ConversionPatternRewriter &rewriter) {
302 rewriter.replaceOpWithNewOp<spirv::TosaTransposeConv2DOp>(
303 op, type, getI32TensorArmAttr(adaptor.getOutPad(), rewriter),
304 getI32TensorArmAttr(adaptor.getStride(), rewriter),
305 getAccType(adaptor.getAccType()), adaptor.getLocalBound(),
306 adaptor.getInput(), adaptor.getWeight(), adaptor.getBias(),
307 adaptor.getInputZp(), adaptor.getWeightZp());
311LogicalResult replaceClamp(tosa::ClampOp op, tosa::ClampOpAdaptor adaptor,
312 Type type, ConversionPatternRewriter &rewriter) {
313 rewriter.replaceOpWithNewOp<spirv::TosaClampOp>(
314 op, type, adaptor.getMinVal(), adaptor.getMaxVal(), getNanMode(adaptor),
320replaceArithmeticRightShift(tosa::ArithmeticRightShiftOp op,
321 tosa::ArithmeticRightShiftOpAdaptor adaptor,
322 Type type, ConversionPatternRewriter &rewriter) {
323 rewriter.replaceOpWithNewOp<spirv::TosaArithmeticRightShiftOp>(
324 op, type, adaptor.getRound(), adaptor.getInput1(), adaptor.getInput2());
328LogicalResult replaceMul(tosa::MulOp op, tosa::MulOpAdaptor adaptor, Type type,
329 ConversionPatternRewriter &rewriter) {
330 rewriter.replaceOpWithNewOp<spirv::TosaMulOp>(
331 op, type, adaptor.getInput1(), adaptor.getInput2(), adaptor.getShift());
335LogicalResult replaceTable(tosa::TableOp op, tosa::TableOpAdaptor adaptor,
336 Type type, ConversionPatternRewriter &rewriter) {
337 rewriter.replaceOpWithNewOp<spirv::TosaTableOp>(op, type, adaptor.getInput1(),
342LogicalResult replaceNegate(tosa::NegateOp op, tosa::NegateOpAdaptor adaptor,
343 Type type, ConversionPatternRewriter &rewriter) {
344 rewriter.replaceOpWithNewOp<spirv::TosaNegateOp>(
345 op, type, adaptor.getInput1(), adaptor.getInput1Zp(),
346 adaptor.getOutputZp());
350LogicalResult replaceSelect(tosa::SelectOp op, tosa::SelectOpAdaptor adaptor,
351 Type type, ConversionPatternRewriter &rewriter) {
352 rewriter.replaceOpWithNewOp<spirv::TosaSelectOp>(
353 op, type, adaptor.getInput1(), adaptor.getInput2(), adaptor.getInput3());
357LogicalResult replaceConcat(tosa::ConcatOp op, tosa::ConcatOpAdaptor adaptor,
358 Type type, ConversionPatternRewriter &rewriter) {
362 if (adaptor.getInput1().size() > maxConcatOpInputs)
363 return splitConcat(op, type, adaptor.getAxis(), adaptor.getInput1(),
366 rewriter.replaceOpWithNewOp<spirv::TosaConcatOp>(op, type, adaptor.getAxis(),
367 adaptor.getInput1());
371LogicalResult replacePad(tosa::PadOp op, tosa::PadOpAdaptor adaptor, Type type,
372 ConversionPatternRewriter &rewriter) {
373 rewriter.replaceOpWithNewOp<spirv::TosaPadOp>(op, type, adaptor.getInput1(),
374 adaptor.getPadding(),
375 adaptor.getPadConst());
379LogicalResult replaceReshape(tosa::ReshapeOp op, tosa::ReshapeOpAdaptor adaptor,
380 Type type, ConversionPatternRewriter &rewriter) {
381 rewriter.replaceOpWithNewOp<spirv::TosaReshapeOp>(
382 op, type, adaptor.getInput1(), adaptor.getShape());
386LogicalResult replaceReverse(tosa::ReverseOp op, tosa::ReverseOpAdaptor adaptor,
387 Type type, ConversionPatternRewriter &rewriter) {
388 rewriter.replaceOpWithNewOp<spirv::TosaReverseOp>(op, type, adaptor.getAxis(),
389 adaptor.getInput1());
393LogicalResult replaceSlice(tosa::SliceOp op, tosa::SliceOpAdaptor adaptor,
394 Type type, ConversionPatternRewriter &rewriter) {
395 rewriter.replaceOpWithNewOp<spirv::TosaSliceOp>(
396 op, type, adaptor.getInput1(), adaptor.getStart(), adaptor.getSize());
400LogicalResult replaceTile(tosa::TileOp op, tosa::TileOpAdaptor adaptor,
401 Type type, ConversionPatternRewriter &rewriter) {
402 rewriter.replaceOpWithNewOp<spirv::TosaTileOp>(op, type, adaptor.getInput1(),
403 adaptor.getMultiples());
407LogicalResult replaceTranspose(tosa::TransposeOp op,
408 tosa::TransposeOpAdaptor adaptor, Type type,
409 ConversionPatternRewriter &rewriter) {
410 DenseIntElementsAttr perms =
411 getI32TensorArmAttr(adaptor.getPerms(), rewriter);
412 rewriter.replaceOpWithNewOp<spirv::TosaTransposeOp>(op, type, perms,
413 adaptor.getInput1());
417LogicalResult replaceGather(tosa::GatherOp op, tosa::GatherOpAdaptor adaptor,
418 Type type, ConversionPatternRewriter &rewriter) {
419 rewriter.replaceOpWithNewOp<spirv::TosaGatherOp>(
420 op, type, adaptor.getValues(), adaptor.getIndices());
424LogicalResult replaceScatter(tosa::ScatterOp op, tosa::ScatterOpAdaptor adaptor,
425 Type type, ConversionPatternRewriter &rewriter) {
426 rewriter.replaceOpWithNewOp<spirv::TosaScatterOp>(
427 op, type, adaptor.getValuesIn(), adaptor.getIndices(),
432LogicalResult replaceResize(tosa::ResizeOp op, tosa::ResizeOpAdaptor adaptor,
433 Type type, ConversionPatternRewriter &rewriter) {
434 rewriter.replaceOpWithNewOp<spirv::TosaResizeOp>(
435 op, type, getResizeMode(adaptor), adaptor.getInput(), adaptor.getScale(),
436 adaptor.getOffset(), adaptor.getBorder());
440LogicalResult replaceRescale(tosa::RescaleOp op, tosa::RescaleOpAdaptor adaptor,
441 Type type, ConversionPatternRewriter &rewriter) {
442 rewriter.replaceOpWithNewOp<spirv::TosaRescaleOp>(
443 op, type, adaptor.getScale32(), getRoundingMode(adaptor),
444 adaptor.getPerChannel(), adaptor.getInputUnsigned(),
445 adaptor.getOutputUnsigned(), adaptor.getInput(), adaptor.getMultiplier(),
446 adaptor.getShift(), adaptor.getInputZp(), adaptor.getOutputZp());
450template <
typename SourceOp>
451LogicalResult replaceConstant(SourceOp op,
typename SourceOp::Adaptor adaptor,
452 Type type, ConversionPatternRewriter &rewriter) {
453 if (
auto graphConstantId = op->template getAttrOfType<IntegerAttr>(
455 rewriter.replaceOpWithNewOp<spirv::GraphConstantARMOp>(op, type,
460 auto convertedType = dyn_cast<ShapedType>(type);
461 auto values = dyn_cast<DenseElementsAttr>(adaptor.getValues());
462 if (!convertedType || !values)
465 FailureOr<DenseElementsAttr> convertedValues =
467 if (failed(convertedValues))
470 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, type, *convertedValues);
474LogicalResult replaceIdentity(tosa::IdentityOp op,
475 tosa::IdentityOpAdaptor adaptor, Type type,
476 ConversionPatternRewriter &rewriter) {
477 rewriter.replaceOp(op, adaptor.getInput1());
486 TosaOpConvert<tosa::ArgMaxOp, replaceNanModeReduction<
487 tosa::ArgMaxOp, spirv::TosaArgMaxOp>>,
488 TosaOpConvert<tosa::AvgPool2dOp, replaceAvgPool2d>,
489 TosaOpConvert<tosa::Conv2DOp,
490 replaceConvolution<tosa::Conv2DOp, spirv::TosaConv2DOp>>,
491 TosaOpConvert<tosa::Conv3DOp,
492 replaceConvolution<tosa::Conv3DOp, spirv::TosaConv3DOp>>,
493 TosaOpConvert<tosa::DepthwiseConv2DOp,
494 replaceConvolution<tosa::DepthwiseConv2DOp,
495 spirv::TosaDepthwiseConv2DOp>>,
496 TosaMultiResultOpConvert<tosa::FFT2dOp, replaceFFT2d>,
497 TosaOpConvert<tosa::MatMulOp, replaceMatMul>,
498 TosaOpConvert<tosa::MaxPool2dOp, replaceMaxPool2d>,
499 TosaMultiResultOpConvert<tosa::RFFT2dOp, replaceRFFT2d>,
500 TosaOpConvert<tosa::TransposeConv2DOp, replaceTransposeConv2d>,
501 TosaOpConvert<tosa::ClampOp, replaceClamp>,
502 TosaOpConvert<tosa::ErfOp,
503 replaceUnaryInput<tosa::ErfOp, spirv::TosaErfOp>>,
504 TosaOpConvert<tosa::SigmoidOp,
505 replaceUnaryInput<tosa::SigmoidOp, spirv::TosaSigmoidOp>>,
506 TosaOpConvert<tosa::TanhOp,
507 replaceUnaryInput<tosa::TanhOp, spirv::TosaTanhOp>>,
508 TosaOpConvert<tosa::AddOp,
509 replaceBinaryElementwise<tosa::AddOp, spirv::TosaAddOp>>,
510 TosaOpConvert<tosa::ArithmeticRightShiftOp, replaceArithmeticRightShift>,
511 TosaOpConvert<tosa::BitwiseAndOp,
512 replaceBinaryElementwise<tosa::BitwiseAndOp,
513 spirv::TosaBitwiseAndOp>>,
516 replaceBinaryElementwise<tosa::BitwiseOrOp, spirv::TosaBitwiseOrOp>>,
517 TosaOpConvert<tosa::BitwiseXorOp,
518 replaceBinaryElementwise<tosa::BitwiseXorOp,
519 spirv::TosaBitwiseXorOp>>,
520 TosaOpConvert<tosa::IntDivOp, replaceBinaryElementwise<
521 tosa::IntDivOp, spirv::TosaIntDivOp>>,
522 TosaOpConvert<tosa::LogicalAndOp,
523 replaceBinaryElementwise<tosa::LogicalAndOp,
524 spirv::TosaLogicalAndOp>>,
525 TosaOpConvert<tosa::LogicalLeftShiftOp,
526 replaceBinaryElementwise<tosa::LogicalLeftShiftOp,
527 spirv::TosaLogicalLeftShiftOp>>,
528 TosaOpConvert<tosa::LogicalRightShiftOp,
529 replaceBinaryElementwise<tosa::LogicalRightShiftOp,
530 spirv::TosaLogicalRightShiftOp>>,
533 replaceBinaryElementwise<tosa::LogicalOrOp, spirv::TosaLogicalOrOp>>,
534 TosaOpConvert<tosa::LogicalXorOp,
535 replaceBinaryElementwise<tosa::LogicalXorOp,
536 spirv::TosaLogicalXorOp>>,
537 TosaOpConvert<tosa::MaximumOp,
538 replaceBinaryNanModeElementwise<tosa::MaximumOp,
539 spirv::TosaMaximumOp>>,
540 TosaOpConvert<tosa::MinimumOp,
541 replaceBinaryNanModeElementwise<tosa::MinimumOp,
542 spirv::TosaMinimumOp>>,
543 TosaOpConvert<tosa::MulOp, replaceMul>,
544 TosaOpConvert<tosa::PowOp,
545 replaceBinaryElementwise<tosa::PowOp, spirv::TosaPowOp>>,
546 TosaOpConvert<tosa::SubOp,
547 replaceBinaryElementwise<tosa::SubOp, spirv::TosaSubOp>>,
548 TosaOpConvert<tosa::TableOp, replaceTable>,
549 TosaOpConvert<tosa::AbsOp,
550 replaceUnaryInput1<tosa::AbsOp, spirv::TosaAbsOp>>,
553 replaceUnaryInput1<tosa::BitwiseNotOp, spirv::TosaBitwiseNotOp>>,
554 TosaOpConvert<tosa::CeilOp,
555 replaceUnaryInput1<tosa::CeilOp, spirv::TosaCeilOp>>,
556 TosaOpConvert<tosa::ClzOp,
557 replaceUnaryInput1<tosa::ClzOp, spirv::TosaClzOp>>,
558 TosaOpConvert<tosa::CosOp,
559 replaceUnaryInput1<tosa::CosOp, spirv::TosaCosOp>>,
560 TosaOpConvert<tosa::ExpOp,
561 replaceUnaryInput1<tosa::ExpOp, spirv::TosaExpOp>>,
562 TosaOpConvert<tosa::FloorOp,
563 replaceUnaryInput1<tosa::FloorOp, spirv::TosaFloorOp>>,
564 TosaOpConvert<tosa::LogOp,
565 replaceUnaryInput1<tosa::LogOp, spirv::TosaLogOp>>,
568 replaceUnaryInput1<tosa::LogicalNotOp, spirv::TosaLogicalNotOp>>,
569 TosaOpConvert<tosa::NegateOp, replaceNegate>,
572 replaceUnaryInput1<tosa::ReciprocalOp, spirv::TosaReciprocalOp>>,
573 TosaOpConvert<tosa::RsqrtOp,
574 replaceUnaryInput1<tosa::RsqrtOp, spirv::TosaRsqrtOp>>,
575 TosaOpConvert<tosa::SinOp,
576 replaceUnaryInput1<tosa::SinOp, spirv::TosaSinOp>>,
577 TosaOpConvert<tosa::SelectOp, replaceSelect>,
578 TosaOpConvert<tosa::EqualOp, replaceBinaryElementwise<
579 tosa::EqualOp, spirv::TosaEqualOp>>,
582 replaceBinaryElementwise<tosa::GreaterOp, spirv::TosaGreaterOp>>,
583 TosaOpConvert<tosa::GreaterEqualOp,
584 replaceBinaryElementwise<tosa::GreaterEqualOp,
585 spirv::TosaGreaterEqualOp>>,
588 replaceReduction<tosa::ReduceAllOp, spirv::TosaReduceAllOp>>,
591 replaceReduction<tosa::ReduceAnyOp, spirv::TosaReduceAnyOp>>,
594 replaceNanModeReduction<tosa::ReduceMaxOp, spirv::TosaReduceMaxOp>>,
597 replaceNanModeReduction<tosa::ReduceMinOp, spirv::TosaReduceMinOp>>,
599 tosa::ReduceProductOp,
600 replaceReduction<tosa::ReduceProductOp, spirv::TosaReduceProductOp>>,
603 replaceReduction<tosa::ReduceSumOp, spirv::TosaReduceSumOp>>,
604 TosaOpConvert<tosa::ConcatOp, replaceConcat>,
605 TosaOpConvert<tosa::PadOp, replacePad>,
606 TosaOpConvert<tosa::ReshapeOp, replaceReshape>,
607 TosaOpConvert<tosa::ReverseOp, replaceReverse>,
608 TosaOpConvert<tosa::SliceOp, replaceSlice>,
609 TosaOpConvert<tosa::TileOp, replaceTile>,
610 TosaOpConvert<tosa::TransposeOp, replaceTranspose>,
611 TosaOpConvert<tosa::GatherOp, replaceGather>,
612 TosaOpConvert<tosa::ScatterOp, replaceScatter>,
613 TosaOpConvert<tosa::ResizeOp, replaceResize>,
614 TosaOpConvert<tosa::CastOp,
615 replaceUnaryInput<tosa::CastOp, spirv::TosaCastOp>>,
616 TosaOpConvert<tosa::RescaleOp, replaceRescale>,
617 TosaOpConvert<tosa::ConstOp, replaceConstant<tosa::ConstOp>>,
618 TosaOpConvert<tosa::ConstShapeOp, replaceConstant<tosa::ConstShapeOp>>,
619 TosaOpConvert<tosa::IdentityOp, replaceIdentity>>(typeConverter,
static llvm::Constant * convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr, llvm::Type *llvmType, const ModuleTranslation &moduleTranslation)
Convert a dense elements attribute to an LLVM IR constant using its raw data storage if possible.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
DenseElementsAttr mapValues(Type newElementType, function_ref< APInt(const APInt &)> mapping) const
Generates a new DenseElementsAttr by mapping each value attribute, and constructing the DenseElements...
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Type conversion from builtin types to SPIR-V types for shader interface.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
constexpr llvm::StringLiteral graphARMGraphConstantIdAttrName
void populateTosaToSPIRVTosaOpsConversionPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)