MLIR

Multi-Level IR Compiler Framework

Tensor Operator Set Architecture (TOSA) Dialect

Rationale 

The MLIR TOSA dialect implements the TOSA specification. This document describes the decision process for how TOSA expresses operators in high level dialects.

TOSA was developed after parallel efforts to rationalize the top-down picture from multiple high-level frameworks, as well as a bottom-up view of different hardware target concerns (CPU, GPU and NPU), and reflects a set of choices that attempt to manage both sets of requirements.

TOSA and Tensor Level Expressiveness 

TOSA endeavors to provide an operator set that tries to fulfil the following expressiveness goals at the tensor level of abstraction :

Complete 

This is driven by the top-down perspective, needing to express as much of multiple high level frameworks fully in TOSA, as possible. This was originally done from an operator frequency analysis done upon dozens of high level networks in different frameworks, to select the most frequently occurring ones and establish a common set of tensor-level operators that could express them.

TOSA categorizes its operator set into classes and attempts to address major functional operations at the tensor level, including compute, reduction, elementwise transformations, comparison and control flow.

Minimal 

This takes the bottom-up approach - keep the TOSA operator set minimal in order to bound the design of hardware, operator kernels, code generation strategies and associated considerations that effect the executability of TOSA content.

In this regard TOSA seeks to avoid creating compound operators, instead leaving it to compiler backend to fuse multiple TOSA ops if required. This choice also benefits the numerical precision goal, since it is easier to fuse the numerical functionality of successive operators, than to split the numerical functionality of a compound operator.

Numerical Precision 

TOSA began as a means to address operator-level numerical precision for code generation and hardware development. It therefore incorporates precision detail into the operator set.

In this regard, TOSA operators are best understood as a combination of the visible quantization information embedded within an operation, together with the functional information about how that information is used, as described in the specification of the operation.

TOSA Operator Rationale 

The general basis of selection of the operator set that constitutes TOSA is described in the TOSA specification document under Section 1.3 Operator Selection. Explanation of the thinking behind some operators is listed here:

COND_IF and WHILE_LOOP 

Several neural networks express conditional control flow at the tensor level. A survey of multiple high level frameworks indicated that conditional if and a loop construct are common in all major frameworks, with some variation. Since TOSA endeavors to be complete in expressing tensor level functionality including control flow, it implements these constructs.

The COND_IF and WHILE_LOOP operators implement such structured control flow forms and should be lowerable to corresponding ops in the scf dialect. Since the dialect seeks to remain isomorphic with an external, serialized form, the decision was to keep these ops in the dialect (as opposed to deferring completely to scf), and this may be re-evaluated if this turns out to not yield the expected value.

Using TOSA In A Compiler 

The TOSA specification describes each operator in functional detail. It is expected that compilers that use TOSA will use its builders to construct the operators so that the quantization information for the operator is correctly generated.

The functional steps described in the pseudocode of the specification enables the construction of code generation for that operation, or decisions on the design of underlying hardware. The functional pseudocode also describes how the quantization parameters are utilized within the operation.

Quantization Parameters in Ops vs Tensors 

TOSA uses the quantization parameters embedded in the input and output tensors to construct the quantization attributes that sit within the operator. Once these attributes are constructed, the quantization information within the tensors are no longer necessary for code generation.

This enables the tensors to be subsequently interpreted simply as contiguous buffers containing raw data, with no ‘meta information’ in the form of the quantization_type. Precision related manipulation of the input or output are instead described by the operator itself which describes, for example, when the zero point is applied, or when the scale multiplication is done.

However, TOSA does not eliminate the existing MLIR QuantOps quantization type information within the tensors; this leaves the choice of how to handle quantization information, to later backend code generation steps.

Maintaining the ability to overlap these different representations of quantization parameters (i.e. tensor-carried vs op-carried) is an important capability when considering progressive lowering between uses that expect one scheme vs the other.

Operation definitions 

source

tosa.abs (mlir::tosa::AbsOp) 

Elementwise abs op

Syntax:

operation ::= `tosa.abs` operands attr-dict `:` functional-type(operands, results)

Elementwise absolute value operation

Example:

%out = tosa.abs(%in) : (tensor<21x3xf32>) -> tensor<21x3xf32>

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.add (mlir::tosa::AddOp) 

Elementwise addition operator

Syntax:

operation ::= `tosa.add` operands attr-dict `:` functional-type(operands, results)

Elementwise addition of input1 and input2. Axis of size 1 will be broadcast, as necessary.

Example:

// Elementwise addition.
%out = tosa.add %in1, %in2 : tensor<12x6xf32>, tensor<12x6xf32> -> tensor<12x6xf32>

// Elementwise addition with broadcasting.
%out = tosa.add %in1, %in2 : tensor<12x6xsi32>, tensor<1x1xsi32> -> tensor<12x6xsi32>

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.apply_scale (mlir::tosa::ApplyScaleOp) 

Rescale scalar operator for Tosa tensor operators

Syntax:

operation ::= `tosa.apply_scale` operands attr-dict `:` functional-type(operands, results)

Applies rescaling for fixed point values. This behavior is replicated in multiple quantized operations (mul, convolution, rescale, matmul, pooling).

The commonplace implementation is to use i64 operations to avoid integer overflow with target specific implementations can use native operations to avoid wider than necessary types.

Traits: AlwaysSpeculatableImplTrait, Elementwise, Scalarizable, Tensorizable, Vectorizable

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface, VectorUnrollOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
double_round::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
valuesignless-integer-like
multipliersignless-integer-like
shiftsignless-integer-8-bit-like

Results: 

ResultDescription
outputsignless-integer-like

tosa.argmax (mlir::tosa::ArgMaxOp) 

Perform argmax on the input.

Syntax:

operation ::= `tosa.argmax` operands attr-dict `:` functional-type(operands, results)

This returns the index with the largest value across the given axis of the input tensor.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.arithmetic_right_shift (mlir::tosa::ArithmeticRightShiftOp) 

Elementwise Arithmetic Right Shift

Syntax:

operation ::= `tosa.arithmetic_right_shift` operands attr-dict `:` functional-type(operands, results)

Elementwise arithmetic right shift of input1 by the amount specified in input2. Axis of size 1 will be broadcast, as necessary.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
round::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.avg_pool2d (mlir::tosa::AvgPool2dOp) 

Performs max pooling on the input.

Syntax:

operation ::= `tosa.avg_pool2d` operands attr-dict `:` functional-type(operands, results)

This performs an average pooling over the given input tensor. A sliding window of size given by is passed over the input tensor, with the mean value being placed in the output tensor.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
kernel::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
stride::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
pad::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 4 elements
acc_type::mlir::TypeAttrtype attribute of 32-bit signless integer or 32-bit signed integer or 16-bit float or 32-bit float
quantization_infomlir::tosa::UnaryOpQuantizationAttrAttribute for UnaryOp quantization information.

Operands: 

OperandDescription
input4-d tensor

Results: 

ResultDescription
output4-d tensor

tosa.bitwise_and (mlir::tosa::BitwiseAndOp) 

Bitwise AND operator

Syntax:

operation ::= `tosa.bitwise_and` operands attr-dict `:` functional-type(operands, results)

Elementwise bitwise AND of input1 and input2. Axis of size 1 will be broadcast as necessary.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.bitwise_not (mlir::tosa::BitwiseNotOp) 

Bitwise NOT operator

Syntax:

operation ::= `tosa.bitwise_not` operands attr-dict `:` functional-type(operands, results)

Elementwise bitwise NOT of input tensor.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.bitwise_or (mlir::tosa::BitwiseOrOp) 

Bitwise OR operator

Syntax:

operation ::= `tosa.bitwise_or` operands attr-dict `:` functional-type(operands, results)

Elementwise bitwise OR of input1 and input2. Axis of size 1 will be broadcast as necessary.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.bitwise_xor (mlir::tosa::BitwiseXorOp) 

Bitwise XOR operator

Syntax:

operation ::= `tosa.bitwise_xor` operands attr-dict `:` functional-type(operands, results)

Elementwise bitwise XOR of input1 and input2. Axis of size 1 will be broadcast as necessary.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.cast (mlir::tosa::CastOp) 

Cast operation

Syntax:

operation ::= `tosa.cast` operands attr-dict `:` functional-type(operands, results)

Performs a set of permissible cast operations

ModeInputOutput
signed 8 to boolint8Boolean
signed 16 to boolint16Boolean
signed 32 to boolint32Boolean
bool to 8Booleanint8
bool to 16Booleanint16
bool to 32Booleanint32
signed 8 to signed 16int8int16
signed 8 to signed 32int8int32
signed 16 to signed 8int16int8
signed 16 to signed 32int16int32
signed 32 to signed 8int32int8
signed 32 to signed 16int32int16
float to signed 8floatint8
float to signed 16floatint16
signed 8 to floatint8float
signed 16 to floatint16float
float 32 to float 64float32float64
float 64 to float 32float64float32

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputtensor of number_plus_f64 values

Results: 

ResultDescription
outputtensor of number_plus_f64 values

tosa.ceil (mlir::tosa::CeilOp) 

Elementwise ceil op

Syntax:

operation ::= `tosa.ceil` operands attr-dict `:` functional-type(operands, results)

Elementwise ceiling operation

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.clamp (mlir::tosa::ClampOp) 

Computes clamp(features, min, max).

Syntax:

operation ::= `tosa.clamp` operands attr-dict `:` functional-type(operands, results)

Clamp to an arbitrary minimum and maximum value. Maximum and minimum values are specified as values in the range of the input type. No zero point subtraction is done to the values, thus to clamp to the zero point value, the zero point itself should be supplied as the minimum value.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
min_int::mlir::IntegerAttr64-bit signless integer attribute
max_int::mlir::IntegerAttr64-bit signless integer attribute
min_fp::mlir::FloatAttrarbitrary float attribute
max_fp::mlir::FloatAttrarbitrary float attribute

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.clz (mlir::tosa::ClzOp) 

Elementwise count leading zero op

Syntax:

operation ::= `tosa.clz` operands attr-dict `:` functional-type(operands, results)

Elementwise count leading zeros operation

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.concat (mlir::tosa::ConcatOp) 

Concatenates tensors along one dimension.

Syntax:

operation ::= `tosa.concat` operands attr-dict `:` functional-type(operands, results)

Concatenate a variadic amount of tensors along a given axis. No data conversion happens during a concat operation.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, InferTensorType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute

Operands: 

OperandDescription
input1variadic of tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.const (mlir::tosa::ConstOp) 

Constant op.

A node containing constant data for use as the input to an operation. May hold data in any of the supported data formats.

Example:

// Generic form
%out = "tosa.const"() {value = dense<0> : tensor<2x3xi32>} : () -> tensor<2x3xi32>

Traits: AlwaysSpeculatableImplTrait, ConstantLike, FirstAttrDerivedResultType

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
value::mlir::ElementsAttrconstant vector/tensor attribute

Results: 

ResultDescription
outputtensor of number_plus_f64 values

tosa.conv2d (mlir::tosa::Conv2DOp) 

2D Convolution Operator

Syntax:

operation ::= `tosa.conv2d` operands attr-dict `:` functional-type(operands, results)

Performs a 2D convolution over the given tensor input, using the weight tensor.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
pad::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 4 elements
stride::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
dilation::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
quantization_infomlir::tosa::ConvOpQuantizationAttrAttribute for Conv type op quantization information.
local_bound::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
input4-d tensor
weight4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or 32-bit float or 16-bit float or bfloat16 type values
bias1-d tensor

Results: 

ResultDescription
output4-d tensor

tosa.conv3d (mlir::tosa::Conv3DOp) 

3D Convolution operator

Syntax:

operation ::= `tosa.conv3d` operands attr-dict `:` functional-type(operands, results)

Performs a 3D convolution over the given input tensor.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
pad::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 6 elements
stride::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 3 elements
dilation::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 3 elements
quantization_infomlir::tosa::ConvOpQuantizationAttrAttribute for Conv type op quantization information.
local_bound::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
input5-d tensor
weight5D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or 32-bit float or 16-bit float or bfloat16 type values
bias1-d tensor

Results: 

ResultDescription
output5-d tensor

tosa.cos (mlir::tosa::CosOp) 

Elementwise cos op

Syntax:

operation ::= `tosa.cos` operands attr-dict `:` functional-type(operands, results)

Elementwise cosine operation for values given in radians.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputtensor of 32-bit float or 16-bit float or bfloat16 type values

Results: 

ResultDescription
outputtensor of 32-bit float or 16-bit float or bfloat16 type values

tosa.custom (mlir::tosa::CustomOp) 

Custom operator wrapper for Tosa

Syntax:

operation ::= `tosa.custom` operands attr-dict `:` functional-type(operands, results)

Hardware implementing TOSA may choose to add additional custom operators that are not expressed in the existing TOSA operations. These operators are not expected to be portable across TOSA implementations. The input and output signatures must be expressed in the corresponding TOSA node.

operator_name is a string that tells the backend which custom operator is being called.

domain_name is a string identifier which can help avoid name collisions on the identifier field.

implementation_attrs is a string which is a backend and identifier specific set of attributes to the custom operator.

inputs is the set of tensor inputs to the custom operator.

`outputs is the list of tensors returned by the operator. The number of operators is backend specific.

Example:

%out = tosa.custom %in {domain_name = "tosa_mlir_test", operator_name =
       "custom_test", implementation_attrs = ""}: (tensor<10xi32>) ->
       (tensor<10xi32>)

Interfaces: TosaOpInterface

Attributes: 

AttributeMLIR TypeDescription
operator_name::mlir::StringAttrstring attribute
domain_name::mlir::StringAttrstring attribute
implementation_attrs::mlir::StringAttrstring attribute

Operands: 

OperandDescription
inputsvariadic of tensor of number values

Results: 

ResultDescription
outputsvariadic of tensor of number values

tosa.depthwise_conv2d (mlir::tosa::DepthwiseConv2DOp) 

Depthwise 2D Convolution operator

Syntax:

operation ::= `tosa.depthwise_conv2d` operands attr-dict `:` functional-type(operands, results)

Performs 2D convolutions separately over each channel of the given tensor input, using the weight tensor.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
pad::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 4 elements
stride::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
dilation::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
quantization_infomlir::tosa::ConvOpQuantizationAttrAttribute for Conv type op quantization information.
local_bound::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
input4-d tensor
weight4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or 32-bit float or 16-bit float or bfloat16 type values
bias1-d tensor

Results: 

ResultDescription
output4-d tensor

tosa.div (mlir::tosa::DivOp) 

Integer divide operator

Syntax:

operation ::= `tosa.div` operands attr-dict `:` functional-type(operands, results)

Elementwise integer divide operator of input1 by input2. Axis of size 1 will be broadcast, as necessary.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of 32-bit signless integer values
input2tensor of 32-bit signless integer values

Results: 

ResultDescription
outputtensor of 32-bit signless integer values

tosa.equal (mlir::tosa::EqualOp) 

Returns the truth value of (x == y) element-wise.

Syntax:

operation ::= `tosa.equal` operands attr-dict `:` functional-type(operands, results)

Elementwise comparison operation

Traits: AlwaysSpeculatableImplTrait, Commutative, InferTensorType, ResultsBroadcastableShape, SameOperandsElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of 1-bit signless integer values

tosa.erf (mlir::tosa::ErfOp) 

Computes gauss error function of input

Syntax:

operation ::= `tosa.erf` operands attr-dict `:` functional-type(operands, results)

Gauss error function: $ erf(x) = \frac{2}{\sqrt(\pi)} \int_{0}^{x} e^{-t^2} ,dt $ For quantized integer data types, the TABLE operator should be used instead with the following definition. The erf_table has 513 entries each of 16-bit/8-bit precision and covering the input range -4.0 to +4.0 in steps of 1/64.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.exp (mlir::tosa::ExpOp) 

Elementwise exp op

Syntax:

operation ::= `tosa.exp` operands attr-dict `:` functional-type(operands, results)

Elementwise e to the x operation

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.fft2d (mlir::tosa::FFT2dOp) 

Performs FFT2D operation on the input.

Syntax:

operation ::= `tosa.fft2d` $input_real `,` $input_imag attr-dict `:` `(` type($input_real) `,`
              type($input_imag) `)` `->` `(` type($output_real) `,` type($output_imag) `)`

Performs a batched complex 2D Fast Fourier Transform over the input. The complex input values are constructed from the corresponding values in the input_real and input_imag tensors. The resulting values in the output are split into the output_real and output_imag tensors. No normalization is applied on either the forward or inverse versions of the operation.

Example:

 %out_real, %out_imag = tosa.fft2d %in_real, %in_imag : (tensor<8x9xf32>, tensor<8x9xf32>) -> (tensor<8x9xf32>, tensor<8x9xf32>)

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
inverse::mlir::BoolAttrbool attribute
local_bound::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
input_real3-d tensor
input_imag3-d tensor

Results: 

ResultDescription
output_real3-d tensor
output_imag3-d tensor

tosa.floor (mlir::tosa::FloorOp) 

Elementwise floor op

Syntax:

operation ::= `tosa.floor` operands attr-dict `:` functional-type(operands, results)

Elementwise floor operation

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.fully_connected (mlir::tosa::FullyConnectedOp) 

Fully Connected operator

Syntax:

operation ::= `tosa.fully_connected` operands attr-dict `:` functional-type(operands, results)

Performs a fully connected network.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
quantization_infomlir::tosa::ConvOpQuantizationAttrAttribute for Conv type op quantization information.

Operands: 

OperandDescription
input2-d tensor
weight2D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or 32-bit float or 16-bit float or bfloat16 type values
bias1-d tensor

Results: 

ResultDescription
output2-d tensor

tosa.gather (mlir::tosa::GatherOp) 

Gather operation,

Syntax:

operation ::= `tosa.gather` operands attr-dict `:` functional-type(operands, results)

Generate a tensor for which each element in the output is a slice of the values tensor based on the value of indices.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
values3-d tensor
indices2D tensor of 32-bit signless integer values

Results: 

ResultDescription
output3-d tensor

tosa.greater_equal (mlir::tosa::GreaterEqualOp) 

Returns the truth value of (x >= y) element-wise.

Syntax:

operation ::= `tosa.greater_equal` operands attr-dict `:` functional-type(operands, results)

Elementwise comparison operation

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of 1-bit signless integer values

tosa.greater (mlir::tosa::GreaterOp) 

Returns the truth value of (x > y) element-wise.

Syntax:

operation ::= `tosa.greater` operands attr-dict `:` functional-type(operands, results)

Elementwise greater than comparison operation

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of 1-bit signless integer values

tosa.identity (mlir::tosa::IdentityOp) 

Identity operator

Syntax:

operation ::= `tosa.identity` operands attr-dict `:` functional-type(operands, results)

Returns a tensor with the same shape, size, type and content as the input.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.cond_if (mlir::tosa::IfOp) 

Conditional if operator

Evaluates a Boolean condition and then takes one of two distinct execution paths. This implements the semantic If-then-else structure.

Traits: InferShapedTypeOpAdaptor, RecursiveMemoryEffects, SingleBlockImplicitTerminator<YieldOp>, SingleBlock

Interfaces: InferShapedTypeOpInterface, TosaOpInterface

Operands: 

OperandDescription
condtensor of 1-bit signless integer values
inputsvariadic of tensor of number values

Results: 

ResultDescription
outputvariadic of tensor of number values

tosa.log (mlir::tosa::LogOp) 

Elementwise log op

Syntax:

operation ::= `tosa.log` operands attr-dict `:` functional-type(operands, results)

Elementwise natural logarithm operation

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.logical_and (mlir::tosa::LogicalAndOp) 

Returns the truth value of x AND y element-wise.

Syntax:

operation ::= `tosa.logical_and` operands attr-dict `:` functional-type(operands, results)

Elementwise logical AND of input1 and input2. Axis of size 1 will be broadcast, as necessary.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of 1-bit signless integer values
input2tensor of 1-bit signless integer values

Results: 

ResultDescription
ztensor of 1-bit signless integer values

tosa.logical_left_shift (mlir::tosa::LogicalLeftShiftOp) 

Elementwise Logical Left Shift

Syntax:

operation ::= `tosa.logical_left_shift` operands attr-dict `:` functional-type(operands, results)

Elementwise left shift of input1 and input2. Axis of size 1 will be broadcast, as necessary.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.logical_not (mlir::tosa::LogicalNotOp) 

Returns the truth value of NOT x element-wise.

Syntax:

operation ::= `tosa.logical_not` operands attr-dict `:` functional-type(operands, results)

Elementwise logical NOT of input.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of 1-bit signless integer values

Results: 

ResultDescription
outputtensor of 1-bit signless integer values

tosa.logical_or (mlir::tosa::LogicalOrOp) 

Returns the truth value of x OR y element-wise.

Syntax:

operation ::= `tosa.logical_or` operands attr-dict `:` functional-type(operands, results)

Elementwise logical OR of input1 and input2. Axis of size 1 will be broadcast as necessary.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of 1-bit signless integer values
input2tensor of 1-bit signless integer values

Results: 

ResultDescription
ztensor of 1-bit signless integer values

tosa.logical_right_shift (mlir::tosa::LogicalRightShiftOp) 

Elementwise Logical Right Shift

Syntax:

operation ::= `tosa.logical_right_shift` operands attr-dict `:` functional-type(operands, results)

Elementwise logical right shift of input1 by the amount specified in input2. Axis of size 1 will be broadcast, as necessary.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.logical_xor (mlir::tosa::LogicalXorOp) 

Returns the truth value of x XOR y element-wise.

Syntax:

operation ::= `tosa.logical_xor` operands attr-dict `:` functional-type(operands, results)

Elementwise logical XOR of input1 and input2. Axis of size 1 will be broadcast as necessary.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of 1-bit signless integer values
input2tensor of 1-bit signless integer values

Results: 

ResultDescription
ztensor of 1-bit signless integer values

tosa.matmul (mlir::tosa::MatMulOp) 

Matrix multiplication with bias

Syntax:

operation ::= `tosa.matmul` operands attr-dict `:` functional-type(operands, results)

Performs a two dimensional matrix multiplication. This allows both inputs to be activations, rather than reserving weights as an attribute in the FULLY_CONNECTED operator.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
quantization_infomlir::tosa::MatMulOpQuantizationAttrAttribute for MatMulOp quantization information.

Operands: 

OperandDescription
a3-d tensor
b3-d tensor

Results: 

ResultDescription
c3-d tensor

tosa.max_pool2d (mlir::tosa::MaxPool2dOp) 

Performs max pooling on the input.

Syntax:

operation ::= `tosa.max_pool2d` operands attr-dict `:` functional-type(operands, results)

This performs a max pooling over the given input tensor. A sliding window of size given by is passed over the input tensor, with the maximum value being placed in the output tensor.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
kernel::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
stride::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
pad::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 4 elements

Operands: 

OperandDescription
input4-d tensor

Results: 

ResultDescription
output4-d tensor

tosa.maximum (mlir::tosa::MaximumOp) 

Elementwise Maximum

Syntax:

operation ::= `tosa.maximum` operands attr-dict `:` functional-type(operands, results)

Elementwise max of input1 and input2. Axis of size 1 will be broadcast, as necessary.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.minimum (mlir::tosa::MinimumOp) 

Elementwise Minimum

Syntax:

operation ::= `tosa.minimum` operands attr-dict `:` functional-type(operands, results)

Elementwise minimum of input1 and input2. Axis of size 1 will be broadcast, as necessary.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.mul (mlir::tosa::MulOp) 

Multiplication operator

Syntax:

operation ::= `tosa.mul` operands attr-dict `:` functional-type(operands, results)

Elementwise multiplication (Hadamard product) of input1 and input2. Axis of size 1 will be broadcast, as necessary. i8/i16 input type can be promoted to i32 result type.

Traits: AlwaysSpeculatableImplTrait, Commutative, MulOperandsAndResultElementType, ResultsBroadcastableShape

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
shift::mlir::IntegerAttr8-bit signless integer attribute

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.negate (mlir::tosa::NegateOp) 

Elementwise negate op

Syntax:

operation ::= `tosa.negate` operands attr-dict `:` functional-type(operands, results)

Elementwise negation operation

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
quantization_infomlir::tosa::UnaryOpQuantizationAttrAttribute for UnaryOp quantization information.

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.pad (mlir::tosa::PadOp) 

Pads a tensor with value specified.

Syntax:

operation ::= `tosa.pad` operands attr-dict `:` functional-type(operands, results)

The tosa.pad operation pads a tensor along borders of each dimension with pad_const (defaults to zero), given a padding configuration padding specifying low and high values along the dimensions.

Example:

%0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<2x2xi32>)  -> (tensor<4x9xf32>)

Example 2:

%0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
tosa.pad %arg0, %0 : (tensor<1x2xf32>, tensor<2x2xi32>)  -> (tensor<?x9xf32>)

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
quantization_infomlir::tosa::PadOpQuantizationAttrAttribute for PadOp quantization information.

Operands: 

OperandDescription
input1ranked tensor of number values
paddingtensor of 32-bit signless integer or 64-bit signless integer values
pad_const0D tensor of number values

Results: 

ResultDescription
outputranked tensor of number values

tosa.pow (mlir::tosa::PowOp) 

Computes the power of one value to another.

Syntax:

operation ::= `tosa.pow` operands attr-dict `:` functional-type(operands, results)

Elementwise input1 raised to the power of input2. Axis of size 1 will be broadcast, as necessary.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
ztensor of number values

tosa.rfft2d (mlir::tosa::RFFT2dOp) 

Performs RFFT2D operation on the input.

Syntax:

operation ::= `tosa.rfft2d` $input attr-dict `:` `(` type($input) `)` `->` `(` type($output_real) `,` type($output_imag) `)`

Performs a batched 2D real-valued Fast Fourier Transform over the input where the input tensor consists of real values producing complex valued output. The complex output values will be split into the output_real and output_imag tensor arguments. RFFT2D takes advantage of Hermitian symmetry to only calculate the first half of the final output axis. Imaginary values with locations (0,0), (0,W/2), (H/2,0) and (H/2,W/2) are zero.

Example:

 %real, %imag = tosa.rfft2d %in : (tensor<8x16xf32>) -> (tensor<8x9xf32>, tensor<8x9xf32>)

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
local_bound::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
input3-d tensor

Results: 

ResultDescription
output_real3-d tensor
output_imag3-d tensor

tosa.reciprocal (mlir::tosa::ReciprocalOp) 

Elementwise reciprocal op

Syntax:

operation ::= `tosa.reciprocal` operands attr-dict `:` functional-type(operands, results)

Elementwise reciprocal operation. For integer operation, a TABLE should be used with the appropriate ranges.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.reduce_all (mlir::tosa::ReduceAllOp) 

Reduce All operator

Syntax:

operation ::= `tosa.reduce_all` operands attr-dict `:` functional-type(operands, results)

Reduce a tensor along the given axis with a logical AND operation

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, InferTensorType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.reduce_any (mlir::tosa::ReduceAnyOp) 

Reduce Any operator

Syntax:

operation ::= `tosa.reduce_any` operands attr-dict `:` functional-type(operands, results)

Reduce a tensor along the given axis with a logical OR operation

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, InferTensorType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.reduce_max (mlir::tosa::ReduceMaxOp) 

Reduce Max operator

Syntax:

operation ::= `tosa.reduce_max` operands attr-dict `:` functional-type(operands, results)

Reduce a tensor along the given axis with a maximum operation

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, InferTensorType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.reduce_min (mlir::tosa::ReduceMinOp) 

Reduce Min operator

Syntax:

operation ::= `tosa.reduce_min` operands attr-dict `:` functional-type(operands, results)

Reduce a tensor along the given axis with a minimum operation

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, InferTensorType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.reduce_prod (mlir::tosa::ReduceProdOp) 

Reduce Prod operator

Syntax:

operation ::= `tosa.reduce_prod` operands attr-dict `:` functional-type(operands, results)

Reduce a tensor along the given axis by computing the product of the axis.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, InferTensorType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.reduce_sum (mlir::tosa::ReduceSumOp) 

Reduce Sum operator

Syntax:

operation ::= `tosa.reduce_sum` operands attr-dict `:` functional-type(operands, results)

Reduce a tensor along the given axis by computing the sum of the axis.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, InferTensorType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.rescale (mlir::tosa::RescaleOp) 

Tosa rescale operator

Syntax:

operation ::= `tosa.rescale` operands attr-dict `:` functional-type(operands, results)

Rescale quantized values into a new domain. Supported rescalings are: Mode Input Output signed 8 to 8 int8 int8 signed 8 to 16 int8 int16 signed 8 to 32 int8 int32 signed 16 to 8 int16 int8 signed 16 to 16 int16 int16 signed 16 to 32 int16 int32 signed 32 to 8 int32 int8 signed 32 to 16 int32 int16 signed 32 to 32 int32 int32 signed 48 to 8 int48 int8 signed 48 to 16 int48 int16 signed 48 to 32 int48 int32 unsigned 8 to signed 8 uint8 int8 signed 8 to unsigned 8 int8 uint8

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
input_zp::mlir::IntegerAttr32-bit signless integer attribute
output_zp::mlir::IntegerAttr32-bit signless integer attribute
multiplier::mlir::DenseI32ArrayAttri32 dense array attribute
shift::mlir::DenseI8ArrayAttri8 dense array attribute
scale32::mlir::BoolAttrbool attribute
double_round::mlir::BoolAttrbool attribute
per_channel::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.reshape (mlir::tosa::ReshapeOp) 

Reshape operator

Syntax:

operation ::= `tosa.reshape` operands attr-dict `:` functional-type(operands, results)

Returns a tensor with the same type/values as the input, with a new shape specified by the shape argument. Reshape may operate on tensors of any rank. No data conversion happens during a reshape operation.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor, InferTensorType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
new_shape::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputranked tensor of number values

tosa.resize (mlir::tosa::ResizeOp) 

Resize operation, supports various resize/upsample modes

Syntax:

operation ::= `tosa.resize` operands attr-dict `:` functional-type(operands, results)

Resizes a tensor. Resize is only allowed in the H and W dimensions. In expected use, The height dimension is scaled by factor (scale_y_n/scale_y_d). And the width dimension is scaled by factor (scale_x_n/scale_x_d). Thus the output dimensions can be derived from the input dimensions by inverting the scale. And the [order_y, border_x] values adjust the output size to allow fractional sampling beyond integer input position (IH-1,IW-1).

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
scale::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 4 elements
offset::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
border::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
mode::mlir::StringAttrSupported resize/upsampling strategies

Operands: 

OperandDescription
input4-d tensor

Results: 

ResultDescription
output4-d tensor

tosa.reverse (mlir::tosa::ReverseOp) 

Reverse operator

Syntax:

operation ::= `tosa.reverse` operands attr-dict `:` functional-type(operands, results)

Returns a tensor with the same type/values as the input, with the data reversed along the given axis. No data conversion happens during a reverse operation.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.rsqrt (mlir::tosa::RsqrtOp) 

Elementwise 1/sqrt op

Syntax:

operation ::= `tosa.rsqrt` operands attr-dict `:` functional-type(operands, results)

Elementwise reciprocal square root operation. For integer operation, a TABLE should be used with the appropriate ranges.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.scatter (mlir::tosa::ScatterOp) 

Scatter operation,

Syntax:

operation ::= `tosa.scatter` operands attr-dict `:` functional-type(operands, results)

The values_out tensor is set to the values_in tensor with data modified as follows: data from the input tensor is inserted at the positions specified by the indices tensor.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
values_in3-d tensor
indices2D tensor of 32-bit signless integer values
input3-d tensor

Results: 

ResultDescription
values_out3-d tensor

tosa.select (mlir::tosa::SelectOp) 

Elementwise select operator

Syntax:

operation ::= `tosa.select` operands attr-dict `:` `(` type($pred) `,` type($on_true) `,` type($on_false)
              `)` `->` type($output)

Elementwise select of the output based on a condition.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
predtensor of 1-bit signless integer values
on_truetensor of number values
on_falsetensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.sigmoid (mlir::tosa::SigmoidOp) 

Computes elementwise sigmoid of input.

Syntax:

operation ::= `tosa.sigmoid` operands attr-dict `:` functional-type(operands, results)

Sigmoid function: output = 1 / (1 + exp(-input)) For quantized integer data types, the TABLE operator should be used instead with the following definition. The sigmoid table has 513 entries each of 16-bit precision and covering the input range -16.0 to +16.0 in steps of 1/16.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.sin (mlir::tosa::SinOp) 

Elementwise sin op

Syntax:

operation ::= `tosa.sin` operands attr-dict `:` functional-type(operands, results)

Elementwise sine operation for values given in radians.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputtensor of 32-bit float or 16-bit float or bfloat16 type values

Results: 

ResultDescription
outputtensor of 32-bit float or 16-bit float or bfloat16 type values

tosa.slice (mlir::tosa::SliceOp) 

Slice operator

Syntax:

operation ::= `tosa.slice` operands attr-dict `:` functional-type(operands, results)

Extracts a slice of the input1 on the given axis, beginning at the start coordinates, and extending for size elements in each direction. No data conversion happens during a slice operation.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
start::mlir::DenseI64ArrayAttri64 dense array attribute
size::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.sub (mlir::tosa::SubOp) 

Elementwise subtraction operator

Syntax:

operation ::= `tosa.sub` operands attr-dict `:` functional-type(operands, results)

Elementwise subtraction of input1 and input2. Axis of size 1 will be broadcast as necessary.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.table (mlir::tosa::TableOp) 

Table lookup op

Syntax:

operation ::= `tosa.table` $input `,` $table attr-dict `:` `(` type($input) `,` type($table) `)` `->` type($output)

Interpolated table lookup operation. Input values are scaled to create a fixed-point 9.7 value. The high 9 bits are used to index into the table. The fractional bits are used to interpolate based on the looked up value and the index+1 value in the table. The TABLE operator then returns a 16.7 interpolated value. Note that there must be 513 values to handle the full range of inputs.

The TABLE operator is expected to be used as follows:

  • A RESCALE node is expected before the TABLE operator to scale the input to a full int16_t range for the table lookup
  • If an int16_t result is required then follow the TABLE operator with a RESCALE with a right shift of 7
  • If an int8_t result is required then follow the TABLE operator with a RESCALE with a right shift of 15

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputtensor of number values
table1-d tensor

Results: 

ResultDescription
outputtensor of number values

tosa.tanh (mlir::tosa::TanhOp) 

Computes elementwise hyperbolic tangent of input

Syntax:

operation ::= `tosa.tanh` operands attr-dict `:` functional-type(operands, results)

Parameterized hyperbolic tangent. For quantized integer data types, the TABLE operator should be used instead with the following definition. The tanh_table has 513 entries each of 16-bit precision and covering the input range -8.0 to +8.0 in steps of 1/32.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.tile (mlir::tosa::TileOp) 

Tile operator

Syntax:

operation ::= `tosa.tile` operands attr-dict `:` functional-type(operands, results)

Replicates input 0 multiplies times along each dimension.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
multiples::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.transpose_conv2d (mlir::tosa::TransposeConv2DOp) 

Transpose 2D Convolution operator.

Syntax:

operation ::= `tosa.transpose_conv2d` operands attr-dict `:` functional-type(operands, results)

Performs a 2D transposed convolution over the given tensor input, using the weights tensor.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
out_pad::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 4 elements
stride::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
out_shape::mlir::DenseI64ArrayAttri64 dense array attribute with at least 4 elements
quantization_infomlir::tosa::ConvOpQuantizationAttrAttribute for Conv type op quantization information.
local_bound::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
input4-d tensor
filter4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or 32-bit float or 16-bit float or bfloat16 type values
bias1-d tensor

Results: 

ResultDescription
output4-d tensor

tosa.transpose (mlir::tosa::TransposeOp) 

Transpose operator

Syntax:

operation ::= `tosa.transpose` operands attr-dict `:` functional-type(operands, results)

Permutes the dimensions based on perm.

Traits: AlwaysSpeculatableImplTrait, InferShapedTypeOpAdaptor

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
permstensor of 32-bit signless integer or 64-bit signless integer values

Results: 

ResultDescription
outputtensor of number values

tosa.variable (mlir::tosa::VariableOp) 

Defines a variable

Syntax:

operation ::= `tosa.variable` $name
              attr-dict
              custom<TypeOrAttr>($type, $initial_value)

Defines a new TOSA variable. This is a mutable value. Modifications are expressed using read/write semantics.

Interfaces: TosaOpInterface

Attributes: 

AttributeMLIR TypeDescription
name::mlir::StringAttrstring attribute
type::mlir::TypeAttrany type attribute
initial_value::mlir::Attributeany attribute

tosa.variable.read (mlir::tosa::VariableReadOp) 

Read_buffer operator

Syntax:

operation ::= `tosa.variable.read` $name attr-dict `:` type($value)

Reads the value from a pseudo-buffer resource holding a mutable tensor.

Interfaces: TosaOpInterface

Attributes: 

AttributeMLIR TypeDescription
name::mlir::StringAttrstring attribute

Results: 

ResultDescription
valueany type

tosa.variable.write (mlir::tosa::VariableWriteOp) 

Write_buffer operator

Syntax:

operation ::= `tosa.variable.write` $name attr-dict `,` $value `:` type($value)

Assigns a value to pseudo-buffer resource holding a mutable tensor.

Interfaces: TosaOpInterface

Attributes: 

AttributeMLIR TypeDescription
name::mlir::StringAttrstring attribute

Operands: 

OperandDescription
valueany type

tosa.while_loop (mlir::tosa::WhileOp) 

Output = input; While (Cond(output)) {output = Body(output)}

Generates and evaluates a Bool condition and either executes a loop body or exits to another control point. This action is performed repeatedly after updating and re-evaluating the Boolean condition every iteration. This implements the semantic foreach or while iterative loop structure.

Traits: InferShapedTypeOpAdaptor, RecursiveMemoryEffects, SingleBlockImplicitTerminator<YieldOp>, SingleBlock

Interfaces: InferShapedTypeOpInterface, LoopLikeOpInterface, TosaOpInterface

Operands: 

OperandDescription
inputsvariadic of tensor of number values

Results: 

ResultDescription
outputvariadic of tensor of number values

tosa.yield (mlir::tosa::YieldOp) 

Yield operator

Syntax:

operation ::= `tosa.yield` $inputs attr-dict `:` type($inputs)

return operation within the conditional and body of structured control flow. Operation takes variadic operands but produces no results of its own.

Traits: AlwaysSpeculatableImplTrait, Terminator

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputsvariadic of tensor of number values