MLIR

Multi-Level IR Compiler Framework

Tensor Operator Set Architecture (TOSA) Dialect

Rationale 

The MLIR TOSA dialect implements the TOSA specification. This document describes the decision process for how TOSA expresses operators in high level dialects.

TOSA was developed after parallel efforts to rationalize the top-down picture from multiple high-level frameworks, as well as a bottom-up view of different hardware target concerns (CPU, GPU and NPU), and reflects a set of choices that attempt to manage both sets of requirements.

TOSA and Tensor Level Expressiveness 

TOSA endeavors to provide an operator set that tries to fulfil the following expressiveness goals at the tensor level of abstraction :

Complete 

This is driven by the top-down perspective, needing to express as much of multiple high level frameworks fully in TOSA, as possible. This was originally done from an operator frequency analysis done upon dozens of high level networks in different frameworks, to select the most frequently occurring ones and establish a common set of tensor-level operators that could express them.

TOSA categorizes its operator set into classes and attempts to address major functional operations at the tensor level, including compute, reduction, elementwise transformations, comparison and control flow.

Minimal 

This takes the bottom-up approach - keep the TOSA operator set minimal in order to bound the design of hardware, operator kernels, code generation strategies and associated considerations that effect the executability of TOSA content.

In this regard TOSA seeks to avoid creating compound operators, instead leaving it to compiler backend to fuse multiple TOSA ops if required. This choice also benefits the numerical precision goal, since it is easier to fuse the numerical functionality of successive operators, than to split the numerical functionality of a compound operator.

Numerical Precision 

TOSA began as a means to address operator-level numerical precision for code generation and hardware development. It therefore incorporates precision detail into the operator set.

In this regard, TOSA operators are best understood as a combination of the visible quantization information embedded within an operation, together with the functional information about how that information is used, as described in the specification of the operation.

TOSA Operator Rationale 

The general basis of selection of the operator set that constitutes TOSA is described in the TOSA specification document under Section 1.3 Operator Selection. Explanation of the thinking behind some operators is listed here:

COND_IF and WHILE_LOOP 

Several neural networks express conditional control flow at the tensor level. A survey of multiple high level frameworks indicated that conditional if and a loop construct are common in all major frameworks, with some variation. Since TOSA endeavors to be complete in expressing tensor level functionality including control flow, it implements these constructs.

The COND_IF and WHILE_LOOP operators implement such structured control flow forms and should be lowerable to corresponding ops in the scf dialect. Since the dialect seeks to remain isomorphic with an external, serialized form, the decision was to keep these ops in the dialect (as opposed to deferring completely to scf), and this may be re-evaluated if this turns out to not yield the expected value.

Using TOSA In A Compiler 

The TOSA specification describes each operator in functional detail. It is expected that compilers that use TOSA will use its builders to construct the operators so that the quantization information for the operator is correctly generated.

The functional steps described in the pseudocode of the specification enables the construction of code generation for that operation, or decisions on the design of underlying hardware. The functional pseudocode also describes how the quantization parameters are utilized within the operation.

Quantization Parameters in Ops vs Tensors 

TOSA uses the quantization parameters embedded in the input and output tensors to construct the quantization attributes that sit within the operator. Once these attributes are constructed, the quantization information within the tensors are no longer necessary for code generation.

This enables the tensors to be subsequently interpreted simply as contiguous buffers containing raw data, with no ‘meta information’ in the form of the quantization_type. Precision related manipulation of the input or output are instead described by the operator itself which describes, for example, when the zero point is applied, or when the scale multiplication is done.

However, TOSA does not eliminate the existing MLIR QuantOps quantization type information within the tensors; this leaves the choice of how to handle quantization information, to later backend code generation steps.

Maintaining the ability to overlap these different representations of quantization parameters (i.e. tensor-carried vs op-carried) is an important capability when considering progressive lowering between uses that expect one scheme vs the other.

Operation definitions 

tosa.abs (mlir::tosa::AbsOp) 

Elementwise abs op

Elementwise absolute value operation

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.add (mlir::tosa::AddOp) 

Elementwise addition operator

Elementwise addition of input1 and input2. Axis of size 1 will be broadcast, as necessary.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.apply_scale (mlir::tosa::ApplyScaleOp) 

Rescale scalar operator for Tosa tensor operators

Applies rescaling for fixed point values. This behavior is replicated in multiple quantized operations (mul, convolution, rescale, matmul, pooling).

The commonplace implementation is to use i64 operations to avoid integer overflow with target specific implementations can use native operations to avoid wider than necessary types.

Traits: AlwaysSpeculatableImplTrait, Elementwise, Scalarizable, Tensorizable, Vectorizable

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface, VectorUnrollOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
double_round::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
valuesignless-integer-like
multipliersignless-integer-like
shiftsignless-integer-8-bit-like

Results: 

ResultDescription
outputsignless-integer-like

tosa.argmax (mlir::tosa::ArgMaxOp) 

Perform argmax on the input.

This returns the index with the largest value across the given axis of the input tensor.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr64-bit signless integer attribute

Operands: 

OperandDescription
inputunranked tensor of number values or 1D/2D/3D/4D tensor of number values

Results: 

ResultDescription
outputunranked tensor of number values or 0D/1D/2D/3D/4D tensor of number values

tosa.arithmetic_right_shift (mlir::tosa::ArithmeticRightShiftOp) 

Elementwise Arithmetic Right Shift

Elementwise arithmetic right shift of input1 by the amount specified in input2. Axis of size 1 will be broadcast, as necessary.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
round::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.avg_pool2d (mlir::tosa::AvgPool2dOp) 

Performs max pooling on the input.

This performs an average pooling over the given input tensor. A sliding window of size given by is passed over the input tensor, with the mean value being placed in the output tensor.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
kernel::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
stride::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
pad::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 4 elements
acc_type::mlir::TypeAttrtype attribute of 32-bit signless integer or 32-bit signed integer or 16-bit float or 32-bit float
quantization_infomlir::tosa::UnaryOpQuantizationAttrAttribute for UnaryOp quantization information.

Operands: 

OperandDescription
inputunranked tensor of number values or 4D tensor of number values

Results: 

ResultDescription
outputunranked tensor of number values or 4D tensor of number values

tosa.bitwise_and (mlir::tosa::BitwiseAndOp) 

Bitwise AND operator

Elementwise bitwise AND of input1 and input2. Axis of size 1 will be broadcast as necessary.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.bitwise_not (mlir::tosa::BitwiseNotOp) 

Bitwise NOT operator

Elementwise bitwise NOT of input tensor.

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.bitwise_or (mlir::tosa::BitwiseOrOp) 

Bitwise OR operator

Elementwise bitwise OR of input1 and input2. Axis of size 1 will be broadcast as necessary.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.bitwise_xor (mlir::tosa::BitwiseXorOp) 

Bitwise XOR operator

Elementwise bitwise XOR of input1 and input2. Axis of size 1 will be broadcast as necessary.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.cast (mlir::tosa::CastOp) 

Cast operation

Performs a set of permissible cast operations

Mode Input Output
signed 8 to bool int8 Boolean
signed 16 to bool int16 Boolean
signed 32 to bool int32 Boolean
bool to 8 Boolean int8
bool to 16 Boolean int16
bool to 32 Boolean int32
signed 8 to signed 16 int8 int16
signed 8 to signed 32 int8 int32
signed 16 to signed 8 int16 int8
signed 16 to signed 32 int16 int32
signed 32 to signed 8 int32 int8
signed 32 to signed 16 int32 int16
float to signed 8 float int8
float to signed 16 float int16
signed 8 to float int8 float
signed 16 to float int16 float
float 32 to float 64 float32 float64
float 64 to float 32 float64 float32

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputtensor of number_plus_f64 values

Results: 

ResultDescription
outputtensor of number_plus_f64 values

tosa.ceil (mlir::tosa::CeilOp) 

Elementwise ceil op

Elementwise ceiling operation

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.clamp (mlir::tosa::ClampOp) 

Computes clamp(features, min, max).

Clamp to an arbitrary minimum and maximum value. Maximum and minimum values are specified as values in the range of the input type. No zero point subtraction is done to the values, thus to clamp to the zero point value, the zero point itself should be supplied as the minimum value.

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
min_int::mlir::IntegerAttr64-bit signless integer attribute
max_int::mlir::IntegerAttr64-bit signless integer attribute
min_fp::mlir::FloatAttr32-bit float attribute
max_fp::mlir::FloatAttr32-bit float attribute

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.clz (mlir::tosa::ClzOp) 

Elementwise count leading zero op

Elementwise count leading zeros operation

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.concat (mlir::tosa::ConcatOp) 

Concatenates tensors along one dimension.

Concatenate a variadic amount of tensors along a given axis. No data conversion happens during a concat operation.

Traits: AlwaysSpeculatableImplTrait, InferTensorType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr64-bit signless integer attribute

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.const (mlir::tosa::ConstOp) 

Constant op.

A node containing constant data for use as the input to an operation. May hold data in any of the supported data formats.

Traits: AlwaysSpeculatableImplTrait, ConstantLike, FirstAttrDerivedResultType

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
value::mlir::ElementsAttrconstant vector/tensor attribute

Results: 

ResultDescription
outputtensor of number_plus_f64 values

tosa.conv2d (mlir::tosa::Conv2DOp) 

2D Convolution Operator

Performs a 2D convolution over the given tensor input, using the weight tensor.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
pad::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 4 elements
stride::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
dilation::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
quantization_infomlir::tosa::ConvOpQuantizationAttrAttribute for Conv type op quantization information.

Operands: 

OperandDescription
inputunranked tensor of number values or 4D tensor of number values
weightunranked tensor of number values or 4D tensor of number values
biasunranked tensor of number values or 1D tensor of number values

Results: 

ResultDescription
outputunranked tensor of number values or 4D tensor of number values

tosa.conv3d (mlir::tosa::Conv3DOp) 

3D Convolution operator

Performs a 3D convolution over the given input tensor.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
pad::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 6 elements
stride::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 3 elements
dilation::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 3 elements
quantization_infomlir::tosa::ConvOpQuantizationAttrAttribute for Conv type op quantization information.

Operands: 

OperandDescription
inputunranked tensor of number values or 5D tensor of number values
weightunranked tensor of number values or 5D tensor of number values
biasunranked tensor of number values or 1D tensor of number values

Results: 

ResultDescription
outputunranked tensor of number values or 5D tensor of number values

tosa.custom (mlir::tosa::CustomOp) 

Custom operator wrapper for Tosa

Hardware implementing TOSA may choose to add additional custom operators that are not expressed in the existing TOSA operations. These operators are not expected to be portable across TOSA implementations. The input and output signatures must be expressed in the corresponding TOSA node.

identifier is a string that tells the backend which custom operator is being called.

config is a string identifier which can help avoid name collisions on the identifier field.

implementation_attrs is a string which is a backend and identifier specific set of attributes to the custom operator.

inputs is the set of tensor inputs to the custom operator.

`outputs is the list of tensors returned by the operator. The number of operators is backend specific.

Interfaces: TosaOpInterface

Attributes: 

AttributeMLIR TypeDescription
identifier::mlir::StringAttrstring attribute
config::mlir::StringAttrstring attribute
implementation_attrs::mlir::StringAttrstring attribute

Operands: 

OperandDescription
inputstensor of number values

Results: 

ResultDescription
outputstensor of number values

tosa.depthwise_conv2d (mlir::tosa::DepthwiseConv2DOp) 

Depthwise 2D Convolution operator

Performs 2D convolutions separately over each channel of the given tensor input, using the weight tensor.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
pad::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 4 elements
stride::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
dilation::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
quantization_infomlir::tosa::ConvOpQuantizationAttrAttribute for Conv type op quantization information.

Operands: 

OperandDescription
inputunranked tensor of number values or 4D tensor of number values
weightunranked tensor of number values or 4D tensor of number values
biasunranked tensor of number values or 1D tensor of number values

Results: 

ResultDescription
outputunranked tensor of number values or 4D tensor of number values

tosa.div (mlir::tosa::DivOp) 

Integer divide operator

Elementwise integer divide operator of input1 by input2. Axis of size 1 will be broadcast, as necessary.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of 32-bit signless integer values
input2tensor of 32-bit signless integer values

Results: 

ResultDescription
outputtensor of 32-bit signless integer values

tosa.equal (mlir::tosa::EqualOp) 

Returns the truth value of (x == y) element-wise.

Elementwise comparison operation

Traits: AlwaysSpeculatableImplTrait, Commutative, InferTensorType, ResultsBroadcastableShape, SameOperandsElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of 1-bit signless integer values

tosa.erf (mlir::tosa::ErfOp) 

Computes gauss error function of input

Gauss error function: $ erf(x) = \frac{2}{\sqrt(\pi)} \int_{0}^{x} e^{-t^2} ,dt $ For quantized integer data types, the TABLE operator should be used instead with the following definition. The erf_table has 513 entries each of 16-bit/8-bit precision and covering the input range -4.0 to +4.0 in steps of 1/64.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.exp (mlir::tosa::ExpOp) 

Elementwise exp op

Elementwise e to the x operation

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.fft2d (mlir::tosa::FFT2dOp) 

Performs FFT2D operation on the input.

Performs a batched complex 2D Fast Fourier Transform over the input. The complex input values are constructed from the corresponding values in the input_real and input_imag tensors. The resulting values in the output are split into the output_real and output_imag tensors. No normalization is applied on either the forward or inverse versions of the operation.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
inverse::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
input_realunranked tensor of number values or 3D tensor of number values
input_imagunranked tensor of number values or 3D tensor of number values

Results: 

ResultDescription
output_realunranked tensor of number values or 3D tensor of number values
output_imagunranked tensor of number values or 3D tensor of number values

tosa.floor (mlir::tosa::FloorOp) 

Elementwise floor op

Elementwise floor operation

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.fully_connected (mlir::tosa::FullyConnectedOp) 

Fully Connected operator

Performs a fully connected network.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
quantization_infomlir::tosa::ConvOpQuantizationAttrAttribute for Conv type op quantization information.

Operands: 

OperandDescription
inputunranked tensor of number values or 2D tensor of number values
weightunranked tensor of number values or 2D tensor of number values
biasunranked tensor of number values or 1D tensor of number values

Results: 

ResultDescription
outputunranked tensor of number values or 2D tensor of number values

tosa.gather (mlir::tosa::GatherOp) 

Gather operation,

Generate a tensor for which each element in the output is a slice of the values tensor based on the value of indices.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
valuesunranked tensor of number values or 3D tensor of number values
indices2D tensor of 32-bit signless integer values

Results: 

ResultDescription
outputunranked tensor of number values or 3D tensor of number values

tosa.greater_equal (mlir::tosa::GreaterEqualOp) 

Returns the truth value of (x >= y) element-wise.

Elementwise comparison operation

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of 1-bit signless integer values

tosa.greater (mlir::tosa::GreaterOp) 

Returns the truth value of (x > y) element-wise.

Elementwise greater than comparison operation

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of 1-bit signless integer values

tosa.identity (mlir::tosa::IdentityOp) 

Identity operator

Returns a tensor with the same shape, size, type and content as the input.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.cond_if (mlir::tosa::IfOp) 

Conditional if operator

Evaluates a Boolean condition and then takes one of two distinct execution paths. This implements the semantic If-then-else structure.

Traits: RecursiveMemoryEffects, SingleBlockImplicitTerminator

Interfaces: InferShapedTypeOpInterface, TosaOpInterface

Operands: 

OperandDescription
condtensor of 1-bit signless integer values
inputstensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.log (mlir::tosa::LogOp) 

Elementwise log op

Elementwise natural logarithm operation

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.logical_and (mlir::tosa::LogicalAndOp) 

Returns the truth value of x AND y element-wise.

Elementwise logical AND of input1 and input2. Axis of size 1 will be broadcast, as necessary.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of 1-bit signless integer values
input2tensor of 1-bit signless integer values

Results: 

ResultDescription
ztensor of 1-bit signless integer values

tosa.logical_left_shift (mlir::tosa::LogicalLeftShiftOp) 

Elementwise Logical Left Shift

Elementwise left shift of input1 and input2. Axis of size 1 will be broadcast, as necessary.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.logical_not (mlir::tosa::LogicalNotOp) 

Returns the truth value of NOT x element-wise.

Elementwise logical NOT of input.

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of 1-bit signless integer values

Results: 

ResultDescription
outputtensor of 1-bit signless integer values

tosa.logical_or (mlir::tosa::LogicalOrOp) 

Returns the truth value of x OR y element-wise.

Elementwise logical OR of input1 and input2. Axis of size 1 will be broadcast as necessary.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of 1-bit signless integer values
input2tensor of 1-bit signless integer values

Results: 

ResultDescription
ztensor of 1-bit signless integer values

tosa.logical_right_shift (mlir::tosa::LogicalRightShiftOp) 

Elementwise Logical Right Shift

Elementwise logical right shift of input1 by the amount specified in input2. Axis of size 1 will be broadcast, as necessary.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.logical_xor (mlir::tosa::LogicalXorOp) 

Returns the truth value of x XOR y element-wise.

Elementwise logical XOR of input1 and input2. Axis of size 1 will be broadcast as necessary.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of 1-bit signless integer values
input2tensor of 1-bit signless integer values

Results: 

ResultDescription
ztensor of 1-bit signless integer values

tosa.matmul (mlir::tosa::MatMulOp) 

Matrix multiplication with bias

Performs a two dimensional matrix multiplication. This allows both inputs to be activations, rather than reserving weights as an attribute in the FULLY_CONNECTED operator.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
quantization_infomlir::tosa::MatMulOpQuantizationAttrAttribute for MatMulOp quantization information.

Operands: 

OperandDescription
aunranked tensor of number values or 3D tensor of number values
bunranked tensor of number values or 3D tensor of number values

Results: 

ResultDescription
cunranked tensor of number values or 3D tensor of number values

tosa.max_pool2d (mlir::tosa::MaxPool2dOp) 

Performs max pooling on the input.

This performs a max pooling over the given input tensor. A sliding window of size given by is passed over the input tensor, with the maximum value being placed in the output tensor.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
kernel::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
stride::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
pad::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 4 elements

Operands: 

OperandDescription
inputunranked tensor of number values or 4D tensor of number values

Results: 

ResultDescription
outputunranked tensor of number values or 4D tensor of number values

tosa.maximum (mlir::tosa::MaximumOp) 

Elementwise Maximum

Elementwise max of input1 and input2. Axis of size 1 will be broadcast, as necessary.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.minimum (mlir::tosa::MinimumOp) 

Elementwise Minimum

Elementwise minimum of input1 and input2. Axis of size 1 will be broadcast, as necessary.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.mul (mlir::tosa::MulOp) 

Multiplication operator

Elementwise multiplication (Hadamard product) of input1 and input2. Axis of size 1 will be broadcast, as necessary. i8/i16 input type can be promoted to i32 result type.

Traits: AlwaysSpeculatableImplTrait, Commutative, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
shift::mlir::IntegerAttr32-bit signless integer attribute

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.negate (mlir::tosa::NegateOp) 

Elementwise negate op

Elementwise negation operation

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
quantization_infomlir::tosa::UnaryOpQuantizationAttrAttribute for UnaryOp quantization information.

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.pad (mlir::tosa::PadOp) 

Pads a tensor with value specified.

The tosa.pad operation pads a tensor along borders of each dimension with pad_const (defaults to zero), given a padding configuration padding specifying low and high values along the dimensions.

Example:

%0 = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
"tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, tensor<2x2xi32>)  -> (tensor<4x9xf32>)

Example 2:

%0 = arith.constant dense<[[-1, 2], [3, 4]]> : tensor<2x2xi32>
"tosa.pad"(%arg0, %0)  : (tensor<1x2xf32>, tensor<2x2xi32>)  -> (tensor<?x9xf32>)

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
quantization_infomlir::tosa::PadOpQuantizationAttrAttribute for PadOp quantization information.

Operands: 

OperandDescription
input1ranked tensor of number values
paddingtensor of 32-bit signless integer or 64-bit signless integer values
pad_const0D tensor of number values

Results: 

ResultDescription
outputranked tensor of number values

tosa.pow (mlir::tosa::PowOp) 

Computes the power of one value to another.

Elementwise input1 raised to the power of input2. Axis of size 1 will be broadcast, as necessary.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
ztensor of number values

tosa.rfft2d (mlir::tosa::RFFT2dOp) 

Performs RFFT2D operation on the input.

Performs a batched 2D real-valued Fast Fourier Transform over the input where the input tensor consists of real values producing complex valued output. The complex output values will be split into the output_real and output_imag tensor arguments. RFFT2D takes advantage of Hermitian symmetry to only calculate the first half of the final output axis. Imaginary values with locations (0,0), (0,W/2), (H/2,0) and (H/2,W/2) are zero.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputunranked tensor of number values or 3D tensor of number values

Results: 

ResultDescription
output_realunranked tensor of number values or 3D tensor of number values
output_imagunranked tensor of number values or 3D tensor of number values

tosa.reciprocal (mlir::tosa::ReciprocalOp) 

Elementwise reciprocal op

Elementwise reciprocal operation. For integer operation, a TABLE should be used with the appropriate ranges.

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.reduce_all (mlir::tosa::ReduceAllOp) 

Reduce All operator

Reduce a tensor along the given axis with a logical AND operation

Traits: AlwaysSpeculatableImplTrait, InferTensorType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr64-bit signless integer attribute

Operands: 

OperandDescription
inputunranked tensor of number values or 1D/2D/3D/4D tensor of number values

Results: 

ResultDescription
outputunranked tensor of number values or 1D/2D/3D/4D tensor of number values

tosa.reduce_any (mlir::tosa::ReduceAnyOp) 

Reduce Any operator

Reduce a tensor along the given axis with a logical OR operation

Traits: AlwaysSpeculatableImplTrait, InferTensorType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr64-bit signless integer attribute

Operands: 

OperandDescription
inputunranked tensor of number values or 1D/2D/3D/4D tensor of number values

Results: 

ResultDescription
outputunranked tensor of number values or 1D/2D/3D/4D tensor of number values

tosa.reduce_max (mlir::tosa::ReduceMaxOp) 

Reduce Max operator

Reduce a tensor along the given axis with a maximum operation

Traits: AlwaysSpeculatableImplTrait, InferTensorType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr64-bit signless integer attribute

Operands: 

OperandDescription
inputunranked tensor of number values or 1D/2D/3D/4D tensor of number values

Results: 

ResultDescription
outputunranked tensor of number values or 1D/2D/3D/4D tensor of number values

tosa.reduce_min (mlir::tosa::ReduceMinOp) 

Reduce Min operator

Reduce a tensor along the given axis with a minimum operation

Traits: AlwaysSpeculatableImplTrait, InferTensorType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr64-bit signless integer attribute

Operands: 

OperandDescription
inputunranked tensor of number values or 1D/2D/3D/4D tensor of number values

Results: 

ResultDescription
outputunranked tensor of number values or 1D/2D/3D/4D tensor of number values

tosa.reduce_prod (mlir::tosa::ReduceProdOp) 

Reduce Prod operator

Reduce a tensor along the given axis by computing the product of the axis.

Traits: AlwaysSpeculatableImplTrait, InferTensorType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr64-bit signless integer attribute

Operands: 

OperandDescription
inputunranked tensor of number values or 1D/2D/3D/4D tensor of number values

Results: 

ResultDescription
outputunranked tensor of number values or 1D/2D/3D/4D tensor of number values

tosa.reduce_sum (mlir::tosa::ReduceSumOp) 

Reduce Sum operator

Reduce a tensor along the given axis by computing the sum of the axis.

Traits: AlwaysSpeculatableImplTrait, InferTensorType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr64-bit signless integer attribute

Operands: 

OperandDescription
inputunranked tensor of number values or 1D/2D/3D/4D tensor of number values

Results: 

ResultDescription
outputunranked tensor of number values or 1D/2D/3D/4D tensor of number values

tosa.rescale (mlir::tosa::RescaleOp) 

Tosa rescale operator

Rescale quantized values into a new domain. Supported rescalings are: Mode Input Output signed 8 to 8 int8 int8 signed 8 to 16 int8 int16 signed 8 to 32 int8 int32 signed 16 to 8 int16 int8 signed 16 to 16 int16 int16 signed 16 to 32 int16 int32 signed 32 to 8 int32 int8 signed 32 to 16 int32 int16 signed 32 to 32 int32 int32 signed 48 to 8 int48 int8 signed 48 to 16 int48 int16 signed 48 to 32 int48 int32 unsigned 8 to signed 8 uint8 int8 signed 8 to unsigned 8 int8 uint8

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
input_zp::mlir::IntegerAttr32-bit signless integer attribute
output_zp::mlir::IntegerAttr32-bit signless integer attribute
multiplier::mlir::DenseI32ArrayAttri32 dense array attribute
shift::mlir::DenseI32ArrayAttri32 dense array attribute
scale32::mlir::BoolAttrbool attribute
double_round::mlir::BoolAttrbool attribute
per_channel::mlir::BoolAttrbool attribute

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.reshape (mlir::tosa::ReshapeOp) 

Reshape operator

Returns a tensor with the same type/values as the input, with a new shape specified by the shape argument. Reshape may operate on tensors of any rank. No data conversion happens during a reshape operation.

Traits: AlwaysSpeculatableImplTrait, InferTensorType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
new_shape::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputranked tensor of number values

tosa.resize (mlir::tosa::ResizeOp) 

Resize operation, supports various resize/upsample modes

Resizes a tensor. Resize is only allowed in the H and W dimensions. In expected use, The height dimension is scaled by factor (scale_y_n/scale_y_d). And the width dimension is scaled by factor (scale_x_n/scale_x_d). Thus the output dimensions can be derived from the input dimensions by inverting the scale. And the [order_y, border_x] values adjust the output size to allow fractional sampling beyond integer input position (IH-1,IW-1).

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
scale::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 4 elements
offset::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
border::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
mode::mlir::StringAttrSupported resize/upsampling strategies

Operands: 

OperandDescription
inputunranked tensor of number values or 4D tensor of number values

Results: 

ResultDescription
outputunranked tensor of number values or 4D tensor of number values

tosa.reverse (mlir::tosa::ReverseOp) 

Reverse operator

Returns a tensor with the same type/values as the input, with the data reversed along the given axis. No data conversion happens during a reverse operation.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr64-bit signless integer attribute

Operands: 

OperandDescription
inputunranked tensor of number values or 1D/2D/3D/4D tensor of number values

Results: 

ResultDescription
outputunranked tensor of number values or 1D/2D/3D/4D tensor of number values

tosa.rsqrt (mlir::tosa::RsqrtOp) 

Elementwise 1/sqrt op

Elementwise reciprocal square root operation. For integer operation, a TABLE should be used with the appropriate ranges.

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.scatter (mlir::tosa::ScatterOp) 

Scatter operation,

The values_out tensor is set to the values_in tensor with data modified as follows: data from the input tensor is inserted at the positions specified by the indices tensor.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
values_inunranked tensor of number values or 3D tensor of number values
indices2D tensor of 32-bit signless integer values
inputunranked tensor of number values or 3D tensor of number values

Results: 

ResultDescription
values_outunranked tensor of number values or 3D tensor of number values

tosa.select (mlir::tosa::SelectOp) 

Elementwise select operator

Elementwise select of the output based on a condition.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
predtensor of 1-bit signless integer values
on_truetensor of number values
on_falsetensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.sigmoid (mlir::tosa::SigmoidOp) 

Computes elementwise sigmoid of input.

Sigmoid function: output = 1 / (1 + exp(-input)) For quantized integer data types, the TABLE operator should be used instead with the following definition. The sigmoid table has 513 entries each of 16-bit precision and covering the input range -16.0 to +16.0 in steps of 1/16.

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.slice (mlir::tosa::SliceOp) 

Slice operator

Extracts a slice of the input1 on the given axis, beginning at the start coordinates, and extending for size elements in each direction. No data conversion happens during a slice operation.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
start::mlir::DenseI64ArrayAttri64 dense array attribute
size::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
inputunranked tensor of number values or 1D/2D/3D/4D/5D/6D tensor of number values

Results: 

ResultDescription
outputunranked tensor of number values or 1D/2D/3D/4D/5D/6D tensor of number values

tosa.sub (mlir::tosa::SubOp) 

Elementwise subtraction operator

Elementwise subtraction of input1 and input2. Axis of size 1 will be broadcast as necessary.

Traits: AlwaysSpeculatableImplTrait, ResultsBroadcastableShape, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1tensor of number values
input2tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.table (mlir::tosa::TableOp) 

Table lookup op

Interpolated table lookup operation. Input values are scaled to create a fixed-point 9.7 value. The high 9 bits are used to index into the table. The fractional bits are used to interpolate based on the looked up value and the index+1 value in the table. The TABLE operator then returns a 16.7 interpolated value. Note that there must be 513 values to handle the full range of inputs.

The TABLE operator is expected to be used as follows:

  • A RESCALE node is expected before the TABLE operator to scale the input to a full int16_t range for the table lookup
  • If an int16_t result is required then follow the TABLE operator with a RESCALE with a right shift of 7
  • If an int8_t result is required then follow the TABLE operator with a RESCALE with a right shift of 15

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputtensor of number values
tableunranked tensor of number values or 1D tensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.tanh (mlir::tosa::TanhOp) 

Computes elementwise hyperbolic tangent of input

Parameterized hyperbolic tangent. For quantized integer data types, the TABLE operator should be used instead with the following definition. The tanh_table has 513 entries each of 16-bit precision and covering the input range -8.0 to +8.0 in steps of 1/32.

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputtensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.tile (mlir::tosa::TileOp) 

Tile operator

Replicates input 0 multiplies times along each dimension.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
multiples::mlir::DenseI64ArrayAttri64 dense array attribute

Operands: 

OperandDescription
input1unranked tensor of number values or 1D/2D/3D/4D tensor of number values

Results: 

ResultDescription
outputunranked tensor of number values or 1D/2D/3D/4D tensor of number values

tosa.transpose_conv2d (mlir::tosa::TransposeConv2DOp) 

Transpose 2D Convolution operator.

Performs a 2D transposed convolution over the given tensor input, using the weights tensor.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Attributes: 

AttributeMLIR TypeDescription
out_pad::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 4 elements
stride::mlir::DenseI64ArrayAttri64 dense array attribute with exactly 2 elements
out_shape::mlir::DenseI64ArrayAttri64 dense array attribute with at least 4 elements
quantization_infomlir::tosa::ConvOpQuantizationAttrAttribute for Conv type op quantization information.

Operands: 

OperandDescription
inputunranked tensor of number values or 4D tensor of number values
filterunranked tensor of number values or 4D tensor of number values
biasunranked tensor of number values or 1D tensor of number values

Results: 

ResultDescription
outputunranked tensor of number values or 4D tensor of number values

tosa.transpose (mlir::tosa::TransposeOp) 

Transpose operator

Permutes the dimensions based on perm.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, InferShapedTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
input1unranked tensor of number values or 1D/2D/3D/4D/5D/6D tensor of number values
permstensor of 32-bit signless integer or 64-bit signless integer values

Results: 

ResultDescription
outputunranked tensor of number values or 1D/2D/3D/4D/5D/6D tensor of number values

tosa.while_loop (mlir::tosa::WhileOp) 

output = input; While (Cond(output)) {output = Body(output)}

Generates and evaluates a Bool condition and either executes a loop body or exits to another control point. This action is performed repeatedly after updating and re-evaluating the Boolean condition every iteration. This implements the semantic foreach or while iterative loop structure.

Traits: RecursiveMemoryEffects, SingleBlockImplicitTerminator

Interfaces: InferShapedTypeOpInterface, LoopLikeOpInterface, TosaOpInterface

Operands: 

OperandDescription
inputstensor of number values

Results: 

ResultDescription
outputtensor of number values

tosa.yield (mlir::tosa::YieldOp) 

yield operator

return operation within the conditional and body of structured control flow. Operation takes variadic operands but produces no results of its own.

Traits: AlwaysSpeculatableImplTrait, Terminator

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), TosaOpInterface

Effects: MemoryEffects::Effect{}

Operands: 

OperandDescription
inputstensor of number values