Processing math: 100%

MLIR

Multi-Level IR Compiler Framework

Tensor Operator Set Architecture (TOSA) Dialect

Rationale 

The MLIR TOSA dialect implements the TOSA specification. This document describes the decision process for how TOSA expresses operators in high level dialects.

TOSA was developed after parallel efforts to rationalize the top-down picture from multiple high-level frameworks, as well as a bottom-up view of different hardware target concerns (CPU, GPU and NPU), and reflects a set of choices that attempt to manage both sets of requirements.

TOSA and Tensor Level Expressiveness 

TOSA endeavors to provide an operator set that tries to fulfil the following expressiveness goals at the tensor level of abstraction :

Complete 

This is driven by the top-down perspective, needing to express as much of multiple high level frameworks fully in TOSA, as possible. This was originally done from an operator frequency analysis done upon dozens of high level networks in different frameworks, to select the most frequently occurring ones and establish a common set of tensor-level operators that could express them.

TOSA categorizes its operator set into classes and attempts to address major functional operations at the tensor level, including compute, reduction, elementwise transformations, comparison and control flow.

Minimal 

This takes the bottom-up approach - keep the TOSA operator set minimal in order to bound the design of hardware, operator kernels, code generation strategies and associated considerations that effect the executability of TOSA content.

In this regard TOSA seeks to avoid creating compound operators, instead leaving it to compiler backend to fuse multiple TOSA ops if required. This choice also benefits the numerical precision goal, since it is easier to fuse the numerical functionality of successive operators, than to split the numerical functionality of a compound operator.

Numerical Precision 

TOSA began as a means to address operator-level numerical precision for code generation and hardware development. It therefore incorporates precision detail into the operator set.

In this regard, TOSA operators are best understood as a combination of the visible quantization information embedded within an operation, together with the functional information about how that information is used, as described in the specification of the operation.

TOSA Operator Rationale 

The general basis of selection of the operator set that constitutes TOSA is described in the TOSA specification document under Section 1.3 Operator Selection. Explanation of the thinking behind some operators is listed here:

COND_IF and WHILE_LOOP 

Several neural networks express conditional control flow at the tensor level. A survey of multiple high level frameworks indicated that conditional if and a loop construct are common in all major frameworks, with some variation. Since TOSA endeavors to be complete in expressing tensor level functionality including control flow, it implements these constructs.

The COND_IF and WHILE_LOOP operators implement such structured control flow forms and should be lowerable to corresponding ops in the scf dialect. Since the dialect seeks to remain isomorphic with an external, serialized form, the decision was to keep these ops in the dialect (as opposed to deferring completely to scf), and this may be re-evaluated if this turns out to not yield the expected value.

Using TOSA In A Compiler 

The TOSA specification describes each operator in functional detail. It is expected that compilers that use TOSA will use its builders to construct the operators so that the quantization information for the operator is correctly generated.

The functional steps described in the pseudocode of the specification enables the construction of code generation for that operation, or decisions on the design of underlying hardware. The functional pseudocode also describes how the quantization parameters are utilized within the operation.

Quantization Parameters in Ops vs Tensors 

TOSA uses the quantization parameters embedded in the input and output tensors to construct the quantization attributes that sit within the operator. Once these attributes are constructed, the quantization information within the tensors are no longer necessary for code generation.

This enables the tensors to be subsequently interpreted simply as contiguous buffers containing raw data, with no ‘meta information’ in the form of the quantization_type. Precision related manipulation of the input or output are instead described by the operator itself which describes, for example, when the zero point is applied, or when the scale multiplication is done.

However, TOSA does not eliminate the existing MLIR QuantOps quantization type information within the tensors; this leaves the choice of how to handle quantization information, to later backend code generation steps.

Maintaining the ability to overlap these different representations of quantization parameters (i.e. tensor-carried vs op-carried) is an important capability when considering progressive lowering between uses that expect one scheme vs the other.

Operation definitions 

source

tosa.abs (mlir::tosa::AbsOp) 

Elementwise abs operator.

Syntax:

operation ::= `tosa.abs` operands attr-dict `:` functional-type(operands, results)

Elementwise absolute value operation.

Example:

%out = tosa.abs(%in) : (tensor<21x3xf32>) -> tensor<21x3xf32>

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, SameOperandsAndResultShape, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.add (mlir::tosa::AddOp) 

Elementwise addition operator.

Syntax:

operation ::= `tosa.add` operands attr-dict `:` functional-type(operands, results)

Elementwise addition of input1 and input2. Axis of size 1 will be broadcast, as necessary. Rank of input tensors must match.

Example:

// Elementwise addition.
%out = tosa.add %in1, %in2 : tensor<12x6xf32>, tensor<12x6xf32> -> tensor<12x6xf32>

// Elementwise addition with broadcasting.
%out = tosa.add %in1, %in2 : tensor<12x6xsi32>, tensor<1x1xsi32> -> tensor<12x6xsi32>

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values
input2tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.apply_scale (mlir::tosa::ApplyScaleOp) 

Rescale scalar operator for Tosa tensor operators

Syntax:

operation ::= `tosa.apply_scale` operands attr-dict `:` functional-type(operands, results)

Applies rescaling for fixed point values. This behavior is replicated in multiple quantized operations (mul, convolution, rescale, matmul, pooling).

The commonplace implementation is to use i64 operations to avoid integer overflow with target specific implementations can use native operations to avoid wider than necessary types.

Traits: AlwaysSpeculatableImplTrait, Elementwise, Scalarizable, Tensorizable, TosaResolvableShapeOperands, Vectorizable

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface, VectorUnrollOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
rounding_mode::mlir::StringAttrSupported rounding modes

Operands: 

OperandDescription
valuesignless-integer-like
multipliersignless-integer-like
shiftsignless-integer-8-bit-like

Results: 

ResultDescription
outputsignless-integer-like

tosa.argmax (mlir::tosa::ArgMaxOp) 

Perform argmax on the input.

Syntax:

operation ::= `tosa.argmax` operands attr-dict `:` functional-type(operands, results)

This returns the index with the largest value across the given axis of the input tensor. If multiple locations have equal values, returns the first match along the search axis.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute
nan_mode::mlir::StringAttrSupported NaN propagation strategies

Operands: 

OperandDescription
inputtosa-conformant tensor of at least rank 1

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.arithmetic_right_shift (mlir::tosa::ArithmeticRightShiftOp) 

Elementwise Arithmetic Right Shift.

Syntax:

operation ::= `tosa.arithmetic_right_shift` operands attr-dict `:` functional-type(operands, results)

Elementwise arithmetic right shift of input1 by the amount specified in input2. Axis of size 1 will be broadcast, as necessary. Rank of input tensors must match.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
round::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
input1tosa-conformant tensor of number values
input2tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.avg_pool2d (mlir::tosa::AvgPool2dOp) 

Performs average pooling on the input.

Syntax:

operation ::= `tosa.avg_pool2d` operands attr-dict `:` functional-type(operands, results)

This performs an average pooling over the given input tensor. A sliding window of size given by is passed over the input tensor, with the mean value being placed in the output tensor. When calculating the average, only the number of valid input tensor values, but not padding, are used to calculate the divisor.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
kernel::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
stride::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
pad::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 4 elements
acc_type::mlir::TypeAttrtype attribute of 32-bit signless integer or 48-bit signless integer or 16-bit float or 32-bit float

Operands: 

OperandDescription
input4-d tosa-conformant tensor
input_zptosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values
output_zptosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values

Results: 

ResultDescription
output4-d tosa-conformant tensor

tosa.bitwise_and (mlir::tosa::BitwiseAndOp) 

Bitwise AND operator.

Syntax:

operation ::= `tosa.bitwise_and` operands attr-dict `:` functional-type(operands, results)

Elementwise bitwise AND of input1 and input2. Axis of size 1 will be broadcast as necessary. Rank of input tensors must match.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values
input2tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.bitwise_not (mlir::tosa::BitwiseNotOp) 

Bitwise NOT operator.

Syntax:

operation ::= `tosa.bitwise_not` operands attr-dict `:` functional-type(operands, results)

Elementwise bitwise NOT of input tensor.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, SameOperandsAndResultShape, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.bitwise_or (mlir::tosa::BitwiseOrOp) 

Bitwise OR operator.

Syntax:

operation ::= `tosa.bitwise_or` operands attr-dict `:` functional-type(operands, results)

Elementwise bitwise OR of input1 and input2. Axis of size 1 will be broadcast as necessary. Rank of input tensors must match.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values
input2tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.bitwise_xor (mlir::tosa::BitwiseXorOp) 

Bitwise XOR operator.

Syntax:

operation ::= `tosa.bitwise_xor` operands attr-dict `:` functional-type(operands, results)

Elementwise bitwise XOR of input1 and input2. Axis of size 1 will be broadcast as necessary. Rank of input tensors must match.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values
input2tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.cast (mlir::tosa::CastOp) 

Cast operation.

Syntax:

operation ::= `tosa.cast` operands attr-dict `:` functional-type(operands, results)

Casts a tensor from one data type to another.

  • This table is showing the supported conversions from the TOSA Specification.
  • The MLIR dialect here can be used to represent other conversions.
ModeInputOutput
fp16 to fp32float16float32
fp16 to int 16float16int16
fp16 to int 32float16int32
fp16 to int 8float16int8
fp32 to fp16float32float16
fp32 to int 16float32int16
fp32 to int 32float32int32
fp32 to int 8float32int8
int 16 to fp16int16float16
int 16 to fp32int16float32
int 32 to fp16int32float16
int 32 to fp32int32float32
int 8 to fp16int8float16
int 8 to fp32int8float32
bool to int 16Booleanint16
bool to int 32Booleanint32
bool to int 8Booleanint8
int 16 to boolint16Boolean
int 16 to int 32int16int32
int 16 to int 8int16int8
int 32 to boolint32Boolean
int 32 to int 16int32int16
int 32 to int 8int32int8
int 8 to boolint8Boolean
int 8 to int 16int8int16
int 8 to int 32int8int32
bf16 to fp32bf16float32
bf16 to int 16bf16int16
bf16 to int 32bf16int32
bf16 to int 8bf16int8
fp32 to bf16float32bf16
int 16 to bf16int16bf16
int 32 to bf16int32bf16
int 8 to bf16int8bf16
bf16 to fp8e4m3bf16fp8e4m3
fp8e4m3 to bf16fp8e4m3bf16
bf16 to fp8e5m2bf16fp8e5m2
fp8e5m2 to bf16fp8e5m2bf16
fp16 to fp8e4m3float16fp8e4m3
fp32 to fp8e4m3float32fp8e4m3
fp8e4m3 to fp16fp8e4m3float16
fp8e4m3 to fp32fp8e4m3float32
fp16 to fp8e5m2float16fp8e5m2
fp32 to fp8e5m2float32fp8e5m2
fp8e5m2 to fp16fp8e5m2float16
fp8e5m2 to fp32fp8e5m2float32

Traits: AlwaysSpeculatableImplTrait, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputtosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.ceil (mlir::tosa::CeilOp) 

Elementwise ceil operator.

Syntax:

operation ::= `tosa.ceil` operands attr-dict `:` functional-type(operands, results)

Elementwise ceiling operation.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, SameOperandsAndResultShape, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.clamp (mlir::tosa::ClampOp) 

Computes clamp(features, min, max).

Syntax:

operation ::= `tosa.clamp` operands attr-dict `:` functional-type(operands, results)

Clamp to an arbitrary minimum and maximum value. Maximum and minimum values are specified as values in the range of the input type. No zero point subtraction is done to the values, thus to clamp to the zero point value, the zero point itself should be supplied as the minimum value.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, SameOperandsAndResultShape, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
min_val::mlir::Attributearbitrary integer attribute or arbitrary float attribute
max_val::mlir::Attributearbitrary integer attribute or arbitrary float attribute
nan_mode::mlir::StringAttrSupported NaN propagation strategies

Operands: 

OperandDescription
inputtosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.clz (mlir::tosa::ClzOp) 

Elementwise count leading zero operator.

Syntax:

operation ::= `tosa.clz` operands attr-dict `:` functional-type(operands, results)

Elementwise count leading zeros operation.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, SameOperandsAndResultShape, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.concat (mlir::tosa::ConcatOp) 

Concatenates tensors along one dimension.

Syntax:

operation ::= `tosa.concat` operands attr-dict `:` functional-type(operands, results)

Concatenate a list of tensors along a given axis. No data conversion happens during a concat operation.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, InferTensorType, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute

Operands: 

OperandDescription
input1variadic of tosa-conformant tensor of at least rank 1

Results: 

ResultDescription
outputtosa-conformant tensor of at least rank 1

tosa.const (mlir::tosa::ConstOp) 

Constant operator.

A node containing constant data for use as the input to an operation. May hold data in any of the supported data formats.

Example:

// Generic form
%out = "tosa.const"() {values = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32>

Traits: AlwaysSpeculatableImplTrait, ConstantLike, FirstAttrDerivedResultType, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
values::mlir::ElementsAttrconstant vector/tensor attribute

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.const_shape (mlir::tosa::ConstShapeOp) 

Constant Shape operator.

Syntax:

operation ::= `tosa.const_shape` operands attr-dict `:` functional-type(operands, results)

A node containing a constant shape.

Example:

// Generic form
%out = "tosa.const_shape"() {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>

Traits: AlwaysSpeculatableImplTrait, ConstantLike, TosaResolvableShapeOperands, TosaShapeOperator

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
values::mlir::DenseIntElementsAttrindex elements attribute

Results: 

ResultDescription
outputShape with static rank and Index element type

tosa.conv2d (mlir::tosa::Conv2DOp) 

2D Convolution operator.

Syntax:

operation ::= `tosa.conv2d` operands attr-dict `:` functional-type(operands, results)

Performs a 2D convolution over the given tensor input, using the weight tensor. Implementations may choose to skip calculation of multiplies in the padding area.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, SameVariadicOperandSize, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
pad::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 4 elements
stride::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
dilation::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
acc_type::mlir::TypeAttrtype attribute of 32-bit signless integer or 48-bit signless integer or 16-bit float or 32-bit float
local_bound::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
input4-d tosa-conformant tensor
weight4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values
bias1-d tosa-conformant tensor
input_zptosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values
weight_zptosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values

Results: 

ResultDescription
output4-d tosa-conformant tensor

tosa.conv3d (mlir::tosa::Conv3DOp) 

3D Convolution operator.

Syntax:

operation ::= `tosa.conv3d` operands attr-dict `:` functional-type(operands, results)

Performs a 3D convolution over the given input tensor. Implementations may choose to skip calculation of multiplies in the padding area.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, SameVariadicOperandSize, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
pad::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 6 elements
stride::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 3 elements
dilation::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 3 elements
acc_type::mlir::TypeAttrtype attribute of 32-bit signless integer or 48-bit signless integer or 16-bit float or 32-bit float
local_bound::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
input5-d tosa-conformant tensor
weight5D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values
bias1-d tosa-conformant tensor
input_zptosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values
weight_zptosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values

Results: 

ResultDescription
output5-d tosa-conformant tensor

tosa.cos (mlir::tosa::CosOp) 

Elementwise cos operator.

Syntax:

operation ::= `tosa.cos` operands attr-dict `:` functional-type(operands, results)

Elementwise cosine operation for values given in radians.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, SameOperandsAndResultShape, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of floating-point values

Results: 

ResultDescription
outputtosa-conformant tensor of floating-point values

tosa.custom (mlir::tosa::CustomOp) 

Custom operator wrapper for Tosa

Syntax:

operation ::= `tosa.custom` operands attr-dict `:` functional-type(operands, results)

Hardware implementing TOSA may choose to add additional custom operators that are not expressed in the existing TOSA operations. These operators are not expected to be portable across TOSA implementations. The input and output signatures must be expressed in the corresponding TOSA node.

operator_name is a string that tells the backend which custom operator is being called.

domain_name is a string identifier which can help avoid name collisions on the identifier field.

implementation_attrs is a string which is a backend and identifier specific set of attributes to the custom operator.

input_list is the set of tensor inputs to the custom operator.

output_list is the list of tensors returned by the operator. The number of operators is backend specific.

Example:

%out = tosa.custom %in {domain_name = "tosa_mlir_test", operator_name =
       "custom_test", implementation_attrs = ""}: (tensor<10xi32>) ->
       (tensor<10xi32>)

Traits: TosaResolvableShapeOperands

Interfaces: QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Attributes: 

AttributeMLIR TypeDescription
operator_name::mlir::StringAttrstring attribute
domain_name::mlir::StringAttrstring attribute
implementation_attrs::mlir::StringAttrstring attribute

Operands: 

OperandDescription
input_listvariadic of tosa-conformant tensor of number values

Results: 

ResultDescription
output_listvariadic of tosa-conformant tensor of number values

tosa.depthwise_conv2d (mlir::tosa::DepthwiseConv2DOp) 

Depthwise 2D Convolution operator.

Syntax:

operation ::= `tosa.depthwise_conv2d` operands attr-dict `:` functional-type(operands, results)

Performs 2D convolutions separately over each channel of the given tensor input, using the weight tensor. Implementations may choose to skip calculation of multiplies in the padding area.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, SameVariadicOperandSize, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
pad::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 4 elements
stride::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
dilation::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
acc_type::mlir::TypeAttrtype attribute of 32-bit signless integer or 48-bit signless integer or 16-bit float or 32-bit float
local_bound::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
input4-d tosa-conformant tensor
weight4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values
bias1-d tosa-conformant tensor
input_zptosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values
weight_zptosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values

Results: 

ResultDescription
output4-d tosa-conformant tensor

tosa.equal (mlir::tosa::EqualOp) 

Returns the truth value of (input1 == input2) element-wise.

Syntax:

operation ::= `tosa.equal` operands attr-dict `:` functional-type(operands, results)

Elementwise comparison operation.

Traits: AlwaysSpeculatableImplTrait, Commutative, InferTensorType, ResultsBroadcastableShape, SameOperandsAndResultRank, SameOperandsElementType, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values
input2tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of 1-bit signless integer values

tosa.erf (mlir::tosa::ErfOp) 

Computes gauss error function of input.

Syntax:

operation ::= `tosa.erf` operands attr-dict `:` functional-type(operands, results)

Gauss error function: erf(x)=2πx0et2dt For quantized integer data types, the TABLE operator should be used instead with the following definition. The ERF table has 513 entries each of 16-bit precision and covering the input range -4.0 to +4.0 in steps of 1/64.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, SameOperandsAndResultShape, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputtosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.exp (mlir::tosa::ExpOp) 

Elementwise exp operator.

Syntax:

operation ::= `tosa.exp` operands attr-dict `:` functional-type(operands, results)

Elementwise e to the x operation

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, SameOperandsAndResultShape, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.fft2d (mlir::tosa::FFT2dOp) 

Performs FFT2D operation on the input.

Syntax:

operation ::= `tosa.fft2d` $input_real `,` $input_imag attr-dict `:` `(` type($input_real) `,`
              type($input_imag) `)` `->` `(` type($output_real) `,` type($output_imag) `)`

Performs a batched complex 2D Fast Fourier Transform over the input. The complex input values are constructed from the corresponding values in the input_real and input_imag tensors. The resulting values in the output are split into the output_real and output_imag tensors. No normalization is applied on either the forward or inverse versions of the operation.

Example:

 %out_real, %out_imag = tosa.fft2d %in_real, %in_imag : (tensor<8x9xf32>, tensor<8x9xf32>) -> (tensor<8x9xf32>, tensor<8x9xf32>)

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, ResultsAreFloatLike, SameOperandsAndResultElementType, SameOperandsAndResultShape, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
inverse::mlir::BoolAttrbool attribute
local_bound::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
input_real3-d tosa-conformant tensor
input_imag3-d tosa-conformant tensor

Results: 

ResultDescription
output_real3-d tosa-conformant tensor
output_imag3-d tosa-conformant tensor

tosa.floor (mlir::tosa::FloorOp) 

Elementwise floor operator.

Syntax:

operation ::= `tosa.floor` operands attr-dict `:` functional-type(operands, results)

Elementwise floor operation.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, SameOperandsAndResultShape, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.gather (mlir::tosa::GatherOp) 

Gather operation.

Syntax:

operation ::= `tosa.gather` operands attr-dict `:` functional-type(operands, results)

Generate a tensor for which each element in the output is a subtensor of the values tensor based on the indices. N is the number of batches, W the number of indices in each batch, K the range of each index and C the number data channels for each index.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
values3-d tosa-conformant tensor
indices2D tensor of 32-bit signless integer values

Results: 

ResultDescription
output3-d tosa-conformant tensor

tosa.greater_equal (mlir::tosa::GreaterEqualOp) 

Returns the truth value of (input1 >= input2) element-wise.

Syntax:

operation ::= `tosa.greater_equal` operands attr-dict `:` functional-type(operands, results)

Elementwise comparison operation.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultRank, SameOperandsElementType, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values
input2tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of 1-bit signless integer values

tosa.greater (mlir::tosa::GreaterOp) 

Returns the truth value of (input1 > input2) element-wise.

Syntax:

operation ::= `tosa.greater` operands attr-dict `:` functional-type(operands, results)

Elementwise greater than comparison operation.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultRank, SameOperandsElementType, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values
input2tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of 1-bit signless integer values

tosa.identity (mlir::tosa::IdentityOp) 

Identity operator.

Syntax:

operation ::= `tosa.identity` operands attr-dict `:` functional-type(operands, results)

Returns a tensor with the same shape, type, and contents as the input.

Traits: AlwaysSpeculatableImplTrait, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.cond_if (mlir::tosa::IfOp) 

Conditional if operator.

Evaluates a Boolean condition and then takes one of two distinct execution paths. This implements the semantic If-then-else structure.

Traits: InferShapedTypeOpAdaptor, RecursiveMemoryEffects, SingleBlockImplicitTerminator<YieldOp>, SingleBlock, TosaResolvableShapeOperands

Interfaces: InferShapedTypeOpInterface, QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Operands: 

OperandDescription
conditiontosa-conformant tensor of 1-bit signless integer values
input_listvariadic of tosa-conformant tensor of number values

Results: 

ResultDescription
output_listvariadic of tosa-conformant tensor of number values

tosa.int_div (mlir::tosa::IntDivOp) 

Integer divide operator.

Syntax:

operation ::= `tosa.int_div` operands attr-dict `:` functional-type(operands, results)

Elementwise integer divide of input1 by input2. Axis of size 1 will be broadcast as necessary. Rank of input tensors must match. The result of the divide is truncated towards zero. Expected use is for operations on non-scaled integers. Floating point divide should use RECIPROCAL and MUL. Quantized integer divide should use TABLE (for 1/x) and MUL.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of 32-bit signless integer values
input2tosa-conformant tensor of 32-bit signless integer values

Results: 

ResultDescription
outputtosa-conformant tensor of 32-bit signless integer values

tosa.log (mlir::tosa::LogOp) 

Elementwise log operator.

Syntax:

operation ::= `tosa.log` operands attr-dict `:` functional-type(operands, results)

Elementwise natural logarithm operation

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, SameOperandsAndResultShape, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.logical_and (mlir::tosa::LogicalAndOp) 

Returns the truth value of input1 AND input2 element-wise.

Syntax:

operation ::= `tosa.logical_and` operands attr-dict `:` functional-type(operands, results)

Elementwise logical AND of input1 and input2. Axis of size 1 will be broadcast, as necessary. Rank of input tensors must match.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of 1-bit signless integer values
input2tosa-conformant tensor of 1-bit signless integer values

Results: 

ResultDescription
outputtosa-conformant tensor of 1-bit signless integer values

tosa.logical_left_shift (mlir::tosa::LogicalLeftShiftOp) 

Elementwise Logical Left Shift.

Syntax:

operation ::= `tosa.logical_left_shift` operands attr-dict `:` functional-type(operands, results)

Elementwise logical left-shift of input1 by the amount specified in input2. Axis of size 1 will be broadcast, as necessary. Rank of input tensors must match.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values
input2tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.logical_not (mlir::tosa::LogicalNotOp) 

Returns the truth value of NOT input1 element-wise.

Syntax:

operation ::= `tosa.logical_not` operands attr-dict `:` functional-type(operands, results)

Elementwise logical NOT of input.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, SameOperandsAndResultShape, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of 1-bit signless integer values

Results: 

ResultDescription
outputtosa-conformant tensor of 1-bit signless integer values

tosa.logical_or (mlir::tosa::LogicalOrOp) 

Returns the truth value of x OR y element-wise.

Syntax:

operation ::= `tosa.logical_or` operands attr-dict `:` functional-type(operands, results)

Elementwise logical OR of input1 and input2. Axis of size 1 will be broadcast as necessary. Rank of input tensors must match.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of 1-bit signless integer values
input2tosa-conformant tensor of 1-bit signless integer values

Results: 

ResultDescription
outputtosa-conformant tensor of 1-bit signless integer values

tosa.logical_right_shift (mlir::tosa::LogicalRightShiftOp) 

Elementwise Logical Right Shift.

Syntax:

operation ::= `tosa.logical_right_shift` operands attr-dict `:` functional-type(operands, results)

Elementwise logical right shift of input1 by the amount specified in input2. Axis of size 1 will be broadcast, as necessary. Rank of input tensors must match.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values
input2tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.logical_xor (mlir::tosa::LogicalXorOp) 

Returns the truth value of input1 XOR input2 element-wise.

Syntax:

operation ::= `tosa.logical_xor` operands attr-dict `:` functional-type(operands, results)

Elementwise logical XOR of input1 and input2. Axis of size 1 will be broadcast as necessary. Rank of input tensors must match.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of 1-bit signless integer values
input2tosa-conformant tensor of 1-bit signless integer values

Results: 

ResultDescription
outputtosa-conformant tensor of 1-bit signless integer values

tosa.matmul (mlir::tosa::MatMulOp) 

Matrix multiplication operator.

Syntax:

operation ::= `tosa.matmul` operands attr-dict `:` functional-type(operands, results)

Performs two dimensional matrix multiplications.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
a3-d tosa-conformant tensor
b3-d tosa-conformant tensor
a_zptosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values
b_zptosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values

Results: 

ResultDescription
output3-d tosa-conformant tensor

tosa.max_pool2d (mlir::tosa::MaxPool2dOp) 

Performs max pooling on the input.

Syntax:

operation ::= `tosa.max_pool2d` operands attr-dict `:` functional-type(operands, results)

This performs a max pooling over the given input tensor. A sliding window of size given by is passed over the input tensor, with the maximum value being placed in the output tensor.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
kernel::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
stride::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
pad::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 4 elements
nan_mode::mlir::StringAttrSupported NaN propagation strategies

Operands: 

OperandDescription
input4-d tosa-conformant tensor

Results: 

ResultDescription
output4-d tosa-conformant tensor

tosa.maximum (mlir::tosa::MaximumOp) 

Elementwise Maximum.

Syntax:

operation ::= `tosa.maximum` operands attr-dict `:` functional-type(operands, results)

Elementwise max of input1 and input2. Axis of size 1 will be broadcast, as necessary. Rank of input tensors must match.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
nan_mode::mlir::StringAttrSupported NaN propagation strategies

Operands: 

OperandDescription
input1tosa-conformant tensor of number values
input2tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.minimum (mlir::tosa::MinimumOp) 

Elementwise Minimum.

Syntax:

operation ::= `tosa.minimum` operands attr-dict `:` functional-type(operands, results)

Elementwise minimum of input1 and input2. Axis of size 1 will be broadcast, as necessary. Rank of input tensors must match.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
nan_mode::mlir::StringAttrSupported NaN propagation strategies

Operands: 

OperandDescription
input1tosa-conformant tensor of number values
input2tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.mul (mlir::tosa::MulOp) 

Multiplication operator.

Syntax:

operation ::= `tosa.mul` operands attr-dict `:` functional-type(operands, results)

Elementwise multiplication (Hadamard product) of input1 and input2. Axis of size 1 will be broadcast, as necessary. Rank of input tensors must match.

Traits: AlwaysSpeculatableImplTrait, Commutative, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values
input2tosa-conformant tensor of number values
shifttosa-conformant scalar tensor of 8-bit signless integer values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.negate (mlir::tosa::NegateOp) 

Elementwise negate operator.

Syntax:

operation ::= `tosa.negate` operands attr-dict `:` functional-type(operands, results)

Elementwise negation operation.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values
input1_zptosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values
output_zptosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.pad (mlir::tosa::PadOp) 

Pads a tensor with value specified.

Syntax:

operation ::= `tosa.pad` operands attr-dict `:` functional-type(operands, results)

Pads a tensor along the borders of each dimension with a supplied value. Returns a new tensor with the padding included. The pad_const value includes the zero point if the tensor uses a zero point.

Example:

%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
%padding = tosa.const_shape {values = dense<[1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
tosa.pad %arg0, %padding, %pad_const: (tensor<1x2xf32>, !tosa.shape<4>, tensor<1xf32>)  -> (tensor<4x9xf32>)

Example 2:

%pad_const = "tosa.const"() {values = dense<3.14> : tensor<1xf32>} : () -> tensor<1xf32>
%padding = tosa.const_shape {values = dense<[-1, 2, 3, 4]> : tensor<4xindex> } : () -> !tosa.shape<4>
tosa.pad %arg0, %padding, %pad_const : (tensor<1x2xf32>, !tosa.shape<4>, tensor<1xf32>)  -> (tensor<?x9xf32>)

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of at least rank 1
paddingShape with static rank and Index element type
pad_consttosa-conformant scalar tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of at least rank 1

tosa.pow (mlir::tosa::PowOp) 

Computes the power of one value to another.

Syntax:

operation ::= `tosa.pow` operands attr-dict `:` functional-type(operands, results)

Elementwise input1 value raised to the power of input2. Axis of size 1 will be broadcast, as necessary. Rank of input tensors must match.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values
input2tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.rfft2d (mlir::tosa::RFFT2dOp) 

Performs RFFT2D operation on the input.

Syntax:

operation ::= `tosa.rfft2d` $input_real attr-dict `:` `(` type($input_real) `)` `->` `(` type($output_real) `,` type($output_imag) `)`

Performs a batched 2D real-valued Fast Fourier Transform over the input where the input tensor consists of real values producing complex valued output. The complex output values will be split into the output_real and output_imag tensor arguments. RFFT2D takes advantage of Hermitian symmetry to only calculate the first half of the final output axis. Implementations may choose to skip calculation of the imaginary values at (0,0), (0,W/2), (H/2,0), and (H/2, W/2). If the calculation is skipped, the result at that location must be zero.

Example:

 %real, %imag = tosa.rfft2d %in : (tensor<8x16xf32>) -> (tensor<8x9xf32>, tensor<8x9xf32>)

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, ResultsAreFloatLike, SameOperandsAndResultElementType, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
local_bound::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
input_real3-d tosa-conformant tensor

Results: 

ResultDescription
output_real3-d tosa-conformant tensor
output_imag3-d tosa-conformant tensor

tosa.reciprocal (mlir::tosa::ReciprocalOp) 

Elementwise reciprocal operator.

Syntax:

operation ::= `tosa.reciprocal` operands attr-dict `:` functional-type(operands, results)

Elementwise reciprocal operation. For integer operation, a TABLE should be used with the appropriate ranges.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, SameOperandsAndResultShape, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.reduce_all (mlir::tosa::ReduceAllOp) 

Reduce All operator.

Syntax:

operation ::= `tosa.reduce_all` operands attr-dict `:` functional-type(operands, results)

Reduce a tensor along the given axis with a logical AND operation.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, InferTensorType, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute

Operands: 

OperandDescription
inputtosa-conformant tensor of at least rank 1

Results: 

ResultDescription
outputtosa-conformant tensor of at least rank 1

tosa.reduce_any (mlir::tosa::ReduceAnyOp) 

Reduce Any operator.

Syntax:

operation ::= `tosa.reduce_any` operands attr-dict `:` functional-type(operands, results)

Reduce a tensor along the given axis with a logical OR operation.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, InferTensorType, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute

Operands: 

OperandDescription
inputtosa-conformant tensor of at least rank 1

Results: 

ResultDescription
outputtosa-conformant tensor of at least rank 1

tosa.reduce_max (mlir::tosa::ReduceMaxOp) 

Reduce Max operator.

Syntax:

operation ::= `tosa.reduce_max` operands attr-dict `:` functional-type(operands, results)

Reduce a tensor along the given axis with a maximum operation.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, InferTensorType, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute
nan_mode::mlir::StringAttrSupported NaN propagation strategies

Operands: 

OperandDescription
inputtosa-conformant tensor of at least rank 1

Results: 

ResultDescription
outputtosa-conformant tensor of at least rank 1

tosa.reduce_min (mlir::tosa::ReduceMinOp) 

Reduce Min operator.

Syntax:

operation ::= `tosa.reduce_min` operands attr-dict `:` functional-type(operands, results)

Reduce a tensor along the given axis with a minimum operation.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, InferTensorType, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute
nan_mode::mlir::StringAttrSupported NaN propagation strategies

Operands: 

OperandDescription
inputtosa-conformant tensor of at least rank 1

Results: 

ResultDescription
outputtosa-conformant tensor of at least rank 1

tosa.reduce_product (mlir::tosa::ReduceProductOp) 

Reduce Product operator.

Syntax:

operation ::= `tosa.reduce_product` operands attr-dict `:` functional-type(operands, results)

Reduce a tensor along the given axis by computing the product of the axis.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, InferTensorType, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute

Operands: 

OperandDescription
inputtosa-conformant tensor of at least rank 1

Results: 

ResultDescription
outputtosa-conformant tensor of at least rank 1

tosa.reduce_sum (mlir::tosa::ReduceSumOp) 

Reduce Sum operator.

Syntax:

operation ::= `tosa.reduce_sum` operands attr-dict `:` functional-type(operands, results)

Reduce a tensor along the given axis by computing the sum of the axis.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, InferTensorType, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute

Operands: 

OperandDescription
inputtosa-conformant tensor of at least rank 1

Results: 

ResultDescription
outputtosa-conformant tensor of at least rank 1

tosa.rescale (mlir::tosa::RescaleOp) 

Tosa rescale operator.

Syntax:

operation ::= `tosa.rescale` operands attr-dict `:` functional-type(operands, results)

Rescale quantized values into a new domain. Supported rescalings are:

ModeInputOutputUnsigned inputUnsigned output
signed 16 to 16int16int16falsefalse
signed 16 to 32int16int32falsefalse
signed 16 to 8int16int8falsefalse
signed 32 to 16int32int16falsefalse
signed 32 to 32int32int32falsefalse
signed 32 to 8int32int8falsefalse
signed 8 to 16int8int16falsefalse
signed 8 to 32int8int32falsefalse
signed 8 to 8int8int8falsefalse
signed 48 to 16int48int16falsefalse
signed 48 to 32int48int32falsefalse
signed 48 to 8int48int8falsefalse
unsigned 8 to signed 8uint8int8truefalse
signed 8 to unsigned 8int8uint8falsetrue

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
scale32::mlir::BoolAttrbool attribute
rounding_mode::mlir::StringAttrSupported rounding modes
per_channel::mlir::BoolAttrbool attribute
input_unsigned::mlir::BoolAttrbool attribute
output_unsigned::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
inputtosa-conformant tensor of number values
multiplier1D tensor of 16-bit signless integer or 32-bit signless integer values
shift1D tensor of 8-bit signless integer values
input_zptosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values
output_zptosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.reshape (mlir::tosa::ReshapeOp) 

Reshape operator.

Syntax:

operation ::= `tosa.reshape` operands attr-dict `:` functional-type(operands, results)

Returns a tensor with the same type/values as the input, with a new shape specified by the shape argument. Reshape may operate on tensors of any rank. No data conversion happens during a reshape operation.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, InferTensorType, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values
shapeShape with static rank and Index element type

Results: 

ResultDescription
outputtosa-conformant ranked tensor of number values

tosa.resize (mlir::tosa::ResizeOp) 

Resize operation, supports various resize/upsample modes.

Syntax:

operation ::= `tosa.resize` operands attr-dict `:` functional-type(operands, results)

Resizes a tensor. Resize is only allowed in the H and W dimensions.

The height dimension is scaled by factor (scale_y_n/scale_y_d). The width dimension is scaled by factor (scale_x_n/scale_x_d).

The NEAREST_NEIGHBOR mode returns the value of the input tensor closest to the calculated sample position for both floating-point and integer data formats.

Floating-point BILINEAR mode returns a bilinearly interpolated output value based on the four closest input sample positions.

For integer BILINEAR interpolation mode, the output value must be scaled by 1/(scale_y_n * scale_x_n) in a following operation to complete the interpolation (for example with a RESCALE operator).

The output dimensions can be derived from the input dimensions by inverting the scale as described in the pseudocode. The [border_y, border_x] values adjust the output size to allow fractional sampling beyond integer input position (IH - 1,IW - 1).

The limit MAX_SCALE is applied to each scale ratio after reduction of the ratio. Individual scale numerator and denominator values are allowed to be larger than MAX_SCALE.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
mode::mlir::StringAttrSupported resize/upsampling strategies

Operands: 

OperandDescription
input4-d tosa-conformant tensor
scaleTosa shape type of rank 4
offsetTosa shape type of rank 2
borderTosa shape type of rank 2

Results: 

ResultDescription
output4-d tosa-conformant tensor

tosa.reverse (mlir::tosa::ReverseOp) 

Reverse operator.

Syntax:

operation ::= `tosa.reverse` operands attr-dict `:` functional-type(operands, results)

Returns a tensor with the same type/values as the input, with the data reversed along the given axis. No data conversion happens during a reverse operation.

Traits: AlwaysSpeculatableImplTrait, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute

Operands: 

OperandDescription
input1tosa-conformant tensor of at least rank 1

Results: 

ResultDescription
outputtosa-conformant tensor of at least rank 1

tosa.rsqrt (mlir::tosa::RsqrtOp) 

Elementwise 1/sqrt operator.

Syntax:

operation ::= `tosa.rsqrt` operands attr-dict `:` functional-type(operands, results)

Elementwise reciprocal square root operation. For integer operation, a TABLE should be used with the appropriate ranges.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, SameOperandsAndResultShape, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.scatter (mlir::tosa::ScatterOp) 

Scatter operation.

Syntax:

operation ::= `tosa.scatter` operands attr-dict `:` functional-type(operands, results)

The values_out tensor is set to the values_in tensor with data modified as follows: data from the input tensor is inserted at the positions specified by the indices tensor. N is the number of batches, W the number of indices in each batch, K the range of each index and C the number data channels for each index. It is not permitted to repeat the same output index within a single SCATTER operation and so each output index occurs at most once. It follows that K >= W. In use cases that require multiple updates to the same output position, these must be decomposed into multiple SCATTER operations.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
values_in3-d tosa-conformant tensor
indices2D tensor of 32-bit signless integer values
input3-d tosa-conformant tensor

Results: 

ResultDescription
values_out3-d tosa-conformant tensor

tosa.select (mlir::tosa::SelectOp) 

Elementwise select operator.

Syntax:

operation ::= `tosa.select` operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3)
              `)` `->` type($output)

Elementwise select of the output based on a condition.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultRank, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of 1-bit signless integer values
input2tosa-conformant tensor of number values
input3tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.sigmoid (mlir::tosa::SigmoidOp) 

Computes elementwise sigmoid of input.

Syntax:

operation ::= `tosa.sigmoid` operands attr-dict `:` functional-type(operands, results)

Applies the sigmoid logistic function to each element of the input tensor: sigmoid(x)=11+ex.

For quantized integer data types, the TABLE operator should be used instead. Each implementation may choose an appropriate TABLE given the scale and zero point of the input data. Eight or sixteen bit precision tables may be used based on the input tensor to the sigmoid function. The sigmoid table has 513 entries each of 16-bit precision and covering the input range -16.0 to +16.0 in steps of 1/16.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, SameOperandsAndResultShape, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputtosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.sin (mlir::tosa::SinOp) 

Elementwise sin operator.

Syntax:

operation ::= `tosa.sin` operands attr-dict `:` functional-type(operands, results)

Elementwise sine operation for values given in radians.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, SameOperandsAndResultShape, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of floating-point values

Results: 

ResultDescription
outputtosa-conformant tensor of floating-point values

tosa.slice (mlir::tosa::SliceOp) 

Slice operator.

Syntax:

operation ::= `tosa.slice` operands attr-dict `:` functional-type(operands, results)

Extracts a slice of input1, beginning at the start coordinates, and extending for size elements in each direction. No data conversion happens during a slice operation.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of at least rank 1
startShape with static rank and Index element type
sizeShape with static rank and Index element type

Results: 

ResultDescription
outputtosa-conformant tensor of at least rank 1

tosa.sub (mlir::tosa::SubOp) 

Elementwise subtraction operator.

Syntax:

operation ::= `tosa.sub` operands attr-dict `:` functional-type(operands, results)

Elementwise subtraction of input1 and input2. Axis of size 1 will be broadcast as necessary. Rank of input tensors must match.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values
input2tosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.table (mlir::tosa::TableOp) 

Table lookup operator.

Syntax:

operation ::= `tosa.table` $input1 `,` $table attr-dict `:` `(` type($input1) `,` type($table) `)` `->` type($output)

Table lookup operation. For int8_t TABLE operation, perform a 256 entry table lookup returning an int8_t value. For int16_t tables, the int16_t input is treated as a fixed-point 9.7 value. The most significant 9 bits are used to index into the table. The fractional 7 bits are used to interpolate based on table[index] and table[index+1]. For int16_t inputs, the TABLE operator returns a 16.7 interpolated value in an int32_t. This value can then be input to the RESCALE operator to scale to the required output data type. Note that int16_t table has 513 values to handle table[index+1] when index=511.

An int16_t to int16_t table lookup can be constructed in TOSA as follows:

  • Use the TABLE operator to produce a fixed point 16.7 interpolated result
  • Use RESCALE (in_t=int32_t, out_t=int16_t, scale=1«14, shift=21) to scale the output to int16_t range (or alternate scale as required)

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of number values
table1-d tosa-conformant tensor

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.tanh (mlir::tosa::TanhOp) 

Computes elementwise hyperbolic tangent of input.

Syntax:

operation ::= `tosa.tanh` operands attr-dict `:` functional-type(operands, results)

Parameterized hyperbolic tangent: tanh(x)=1e2x1+e2x.

For quantized integer data types, the TABLE operator should be used instead. Each implementation may choose an appropriate TABLE given the scale and zero point of the input data. Eight or sixteen bit precision tables may be used based on the input tensor to the tanh function.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType, SameOperandsAndResultRank, SameOperandsAndResultShape, TosaElementwiseOperator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputtosa-conformant tensor of number values

Results: 

ResultDescription
outputtosa-conformant tensor of number values

tosa.tile (mlir::tosa::TileOp) 

Tile operator.

Syntax:

operation ::= `tosa.tile` operands attr-dict `:` functional-type(operands, results)

Replicates input1 multiples times along each dimension.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tosa-conformant tensor of at least rank 1
multiplesShape with static rank and Index element type

Results: 

ResultDescription
outputtosa-conformant tensor of at least rank 1

tosa.transpose_conv2d (mlir::tosa::TransposeConv2DOp) 

Transpose 2D Convolution operator.

Syntax:

operation ::= `tosa.transpose_conv2d` operands attr-dict `:` functional-type(operands, results)

Performs a 2D transposed convolution over the given tensor input, using the weights tensor. Implementations may choose to skip calculation of multiplies by zero at fractional input positions.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, SameVariadicOperandSize, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
out_pad::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 4 elements
stride::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
acc_type::mlir::TypeAttrtype attribute of 32-bit signless integer or 48-bit signless integer or 16-bit float or 32-bit float
local_bound::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
input4-d tosa-conformant tensor
weight4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values
bias1-d tosa-conformant tensor
input_zptosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values
weight_zptosa-conformant scalar tensor of unsigned integer or signless integer or floating-point values

Results: 

ResultDescription
output4-d tosa-conformant tensor

tosa.transpose (mlir::tosa::TransposeOp) 

Transpose operator.

Syntax:

operation ::= `tosa.transpose` operands attr-dict `:` functional-type(operands, results)

Permutes the dimensions of the input tensor input1 based on the perms argument. Each value in the perms list must be a valid dimension of the input tensor and may not be repeated.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, ReifyRankedShapedTypeOpInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
perms::mlir::DenseI32ArrayAttri32 dense array attribute

Operands: 

OperandDescription
input1tosa-conformant tensor of at least rank 1

Results: 

ResultDescription
outputtosa-conformant tensor of at least rank 1

tosa.variable (mlir::tosa::VariableOp) 

Defines a variable

Syntax:

operation ::= `tosa.variable` $name
              attr-dict
              custom<TypeOrAttr>($type, $initial_value)

Defines a new TOSA variable. This is a mutable value. Modifications are expressed using read/write semantics.

Traits: TosaResolvableShapeOperands

Interfaces: QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Attributes: 

AttributeMLIR TypeDescription
name::mlir::StringAttrstring attribute
type::mlir::TypeAttrany type attribute
initial_value::mlir::Attributeany attribute

tosa.variable.read (mlir::tosa::VariableReadOp) 

Read_buffer operator

Syntax:

operation ::= `tosa.variable.read` $name attr-dict `:` type($value)

Reads the value from a pseudo-buffer resource holding a mutable tensor.

Traits: TosaResolvableShapeOperands

Interfaces: QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Attributes: 

AttributeMLIR TypeDescription
name::mlir::StringAttrstring attribute

Results: 

ResultDescription
valueany type

tosa.variable.write (mlir::tosa::VariableWriteOp) 

Write_buffer operator

Syntax:

operation ::= `tosa.variable.write` $name attr-dict `,` $value `:` type($value)

Assigns a value to pseudo-buffer resource holding a mutable tensor.

Traits: TosaResolvableShapeOperands

Interfaces: QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Attributes: 

AttributeMLIR TypeDescription
name::mlir::StringAttrstring attribute

Operands: 

OperandDescription
valueany type

tosa.while_loop (mlir::tosa::WhileOp) 

Output = input; While (Cond(output)) {output = Body(output)}

Generates and evaluates a Boolean condition and either executes a loop body or exits the loop. This action is performed repeatedly after updating and re-evaluating the Boolean condition every iteration. This implements the semantic foreach or while iterative loop structure.

Traits: InferShapedTypeOpAdaptor, RecursiveMemoryEffects, SingleBlockImplicitTerminator<YieldOp>, SingleBlock, TosaResolvableShapeOperands

Interfaces: InferShapedTypeOpInterface, LoopLikeOpInterface, QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Operands: 

OperandDescription
input_listvariadic of tosa-conformant tensor of number values

Results: 

ResultDescription
output_listvariadic of tosa-conformant tensor of number values

tosa.yield (mlir::tosa::YieldOp) 

Yield operator

Syntax:

operation ::= `tosa.yield` $inputs attr-dict `:` type($inputs)

return operation within the conditional and body of structured control flow. Operation takes variadic operands but produces no results of its own.

Traits: AlwaysSpeculatableImplTrait, Terminator, TosaResolvableShapeOperands

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), QueryExtensionInterface, QueryProfileInterface, TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputsvariadic of tosa-conformant tensor of number values