mlir.dialects.linalg.opdsl.lang.comprehension¶
Model classes representing a tensor comprehension.
These classes model the language more at an AST level as evaluated. Reasoning about it typically involves processing this form into config objects that represent actual op definitions (i.e. YAML).
Attributes¶
Classes¶
An expression that can appear on the RHS of a comprehension. |
|
A used tensor represented by its (tensor_name, indices). |
|
Application of a tensor function. |
|
Application of a reduction function. |
|
Returns the given constant floating point or integer value. |
|
Returns the iteration index for a given dimension name. |
|
Generic enumeration. |
|
Unary function. |
|
Unary function namespace. |
|
Binary function. |
|
Binary function namespace. |
|
Ternary function. |
|
Ternary function namespace. |
|
Type conversion function. |
|
Type conversion function namespace. |
|
Reduction function use. |
|
Reduction function. |
|
Generic enumeration. |
|
Definition of an operand passed to an operation. |
|
Tensor operand definition. |
|
Scalar operand definition. |
|
Index attribute definition. |
|
Unary function attribute definition. |
|
Binary function attribute definition. |
|
Ternary function attribute definition. |
|
Type conversion function attribute definition. |
|
Represents a single comprehension. |
|
An interface that an op implements. |
|
A method that an op implements. |
|
Metadata about the op (generally not behavior impacting). |
|
Definition of a linalg op. |
Module Contents¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.TensorExpression¶
An expression that can appear on the RHS of a comprehension.
- abstract to_scalar_expression() mlir.dialects.linalg.opdsl.lang.scalar_expr.ScalarExpression¶
- visit_tensor_exprs(callback: mlir.dialects.linalg.opdsl.lang.scalar_expr.Callable[[TensorExpression], None])¶
Visits all tensor expression reachable by the expression.
- collect_dim_uses(uses: mlir.dialects.linalg.opdsl.lang.scalar_expr.Set[mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef])¶
Collects all DimDefs reachable through this expression.
- collect_tensor_uses(uses: mlir.dialects.linalg.opdsl.lang.scalar_expr.Set[TensorUse])¶
Collects all TensorUses reachable through this expression.
- collect_indices(indices: mlir.dialects.linalg.opdsl.lang.scalar_expr.Set[index])¶
Collects all index accesses reachable through this expression.
- collect_scalar_uses(uses: mlir.dialects.linalg.opdsl.lang.scalar_expr.Set[ScalarDef])¶
Collects all ScalarDefs reachable through this expression.
- __add__(rhs: TensorExpression) TensorExpression¶
- __mul__(rhs) TensorExpression¶
- __sub__(rhs) TensorExpression¶
- __truediv__(rhs) TensorExpression¶
- __hash__()¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.TensorUse(operand_def: OperandDef, indices: mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[mlir.dialects.linalg.opdsl.lang.scalar_expr.AffineExprDef])¶
Bases:
TensorExpressionA used tensor represented by its (tensor_name, indices).
Note that forming a comprehension via direct assignment is performed through setitem on the TensorDef level. However, performing a reduction with compound ops (+=, =, etc) is done by doing a: TensorDef.**getitem* TensorUse.**iadd** TensorDef.**setitem**
- operand_def¶
- indices¶
- to_scalar_expression() mlir.dialects.linalg.opdsl.lang.scalar_expr.ScalarExpression¶
- property tensor_name: str¶
- _compute_reduce_dims(rhs: TensorExpression) mlir.dialects.linalg.opdsl.lang.scalar_expr.Set[mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef]¶
- __iadd__(rhs: TensorExpression) TensorReduceFn¶
- __repr__()¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.TensorFn(kind: FunctionKind, name: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[str], operand_def: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[OperandDef], type_var: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[mlir.dialects.linalg.opdsl.lang.types.TypeVar], args: mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[TensorExpression])¶
Bases:
TensorExpressionApplication of a tensor function.
- name¶
- kind¶
- operand_def¶
- type_var¶
- args¶
- to_scalar_expression() mlir.dialects.linalg.opdsl.lang.scalar_expr.ScalarExpression¶
- visit_tensor_exprs(callback: mlir.dialects.linalg.opdsl.lang.scalar_expr.Callable[[TensorExpression], None])¶
Visits all tensor expression reachable by the expression.
- __repr__()¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.TensorReduceFn(reduce_use: ReduceFnUse, args: mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[TensorExpression])¶
Bases:
TensorExpressionApplication of a reduction function.
This captures the lhs (initial value) separately from the rhs.
- reduce_use¶
- args¶
- to_scalar_expression() mlir.dialects.linalg.opdsl.lang.scalar_expr.ScalarExpression¶
- visit_tensor_exprs(callback: mlir.dialects.linalg.opdsl.lang.scalar_expr.Callable[[TensorExpression], None])¶
Visits all tensor expression reachable by the expression.
- __repr__()¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.const(value: mlir.dialects.linalg.opdsl.lang.scalar_expr.Any)¶
Bases:
TensorExpressionReturns the given constant floating point or integer value.
- to_scalar_expression() mlir.dialects.linalg.opdsl.lang.scalar_expr.ScalarExpression¶
- __repr__()¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.index(dim: mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef)¶
Bases:
TensorExpressionReturns the iteration index for a given dimension name.
Resolves the given dimension name to obtain its position in the iteration domain of the operation.
- dim_def¶
- dim = -1¶
- resolve_dimension_name(affine_state: mlir.dialects.linalg.opdsl.lang.scalar_expr.AffineBuildState)¶
- to_scalar_expression() mlir.dialects.linalg.opdsl.lang.scalar_expr.ScalarExpression¶
- __repr__()¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.FunctionKind¶
Bases:
mlir.dialects.linalg.opdsl.lang.types.EnumGeneric enumeration.
Derive from this class to define new enumerations.
- UNARY = 0¶
- BINARY = 1¶
- TERNARY = 2¶
- TYPE = 3¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.UnaryFnType(fn_name: str)¶
Unary function.
A unary function takes one tensor expression and returns the function evaluation result.
- fn_name¶
- __call__(arg: TensorExpression) TensorFn¶
- __repr__()¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.UnaryFn¶
Unary function namespace.
- exp¶
- log¶
- abs¶
- ceil¶
- floor¶
- negf¶
- reciprocal¶
- round¶
- sqrt¶
- rsqrt¶
- square¶
- tanh¶
- erf¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.BinaryFnType(fn_name: str)¶
Binary function.
A binary function takes two tensor expressions and returns the function evaluation result.
- fn_name¶
- __call__(arg0: TensorExpression, arg1: TensorExpression) TensorFn¶
- __repr__()¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.BinaryFn¶
Binary function namespace.
As the integer types are signless, signedness is implement by different functions that treat integers as signed or unsigned values.
Examples:
max ->
arith.MaxSIOpmax_unsigned ->
arith.MaxUIOp
- add¶
- sub¶
- mul¶
- div¶
- div_unsigned¶
- max_signed¶
- min_signed¶
- max_unsigned¶
- min_unsigned¶
- powf¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.TernaryFnType(fn_name: str)¶
Ternary function.
A ternary function takes three tensor expressions and returns the function evaluation result.
- fn_name¶
- __call__(arg0: TensorExpression, arg1: TensorExpression, arg2: TensorExpression) TensorFn¶
- __repr__()¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.TypeFnType(fn_name: str)¶
Type conversion function.
A type conversion function takes a target type and a tensor expression and returns the casted tensor expression.
- fn_name¶
- __call__(type_var: mlir.dialects.linalg.opdsl.lang.types.TypeVar, arg: TensorExpression) TensorFn¶
- __repr__()¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.TypeFn¶
Type conversion function namespace.
As the integer types are signless, signedness is implement by different cast functions that treat integers as signed (
cast_signed) or unsigned (cast_unsigned) values.Examples:
cast_signed(I32 -> I64) ->
arith.ExtSIOpcast_unsigned(I32 -> I64) ->
arith.ExtUIOp
- cast_signed¶
- cast_unsigned¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.ReduceFnUse(binary_fn: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[BinaryFnType], binary_attr: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[BinaryFnAttrDef], *reduce_dims: mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef)¶
Reduction function use.
A reduction use specifies the reduction function and dimensions.
- binary_fn¶
- binary_attr¶
- reduce_dims = ()¶
- __call__(*args: TensorExpression) TensorReduceFn¶
- __repr__()¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.ReduceFnType(binary_fn: BinaryFnType)¶
Reduction function.
A binary function that reduces its RHS into its LHS.
- binary_fn¶
- __getitem__(reduce_dims: mlir.dialects.linalg.opdsl.lang.scalar_expr.Tuple[mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef]) ReduceFnUse¶
- __repr__()¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.ReduceFn¶
- add¶
- mul¶
- max_signed¶
- min_signed¶
- max_unsigned¶
- min_unsigned¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.OperandKind¶
Bases:
mlir.dialects.linalg.opdsl.lang.types.EnumGeneric enumeration.
Derive from this class to define new enumerations.
- INPUT_TENSOR = 0¶
- SCALAR = 1¶
- OUTPUT_TENSOR = 2¶
- INDEX_ATTR = 3¶
- UNARY_FN_ATTR = 4¶
- BINARY_FN_ATTR = 5¶
- TERNARY_FN_ATTR = 6¶
- TYPE_FN_ATTR = 7¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.OperandDef(kind: OperandKind, type_var: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[mlir.dialects.linalg.opdsl.lang.types.TypeVar] = None, size_exprs: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[mlir.dialects.linalg.opdsl.lang.scalar_expr.AffineExprDef]] = None, index_dims: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef]] = None, default_indices: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[int]] = None, default_fn: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[str] = None)¶
Definition of an operand passed to an operation.
Keep the meta information of Tensor, Scalar, and Attribute operands and provide the shared registration functionality.
- owner: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[LinalgOpDef] = None¶
- type_var = None¶
- size_exprs = None¶
- index_dims = None¶
- default_indices = None¶
- default_fn = None¶
- kind¶
- name: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[str] = None¶
- registered_index: int = -1¶
- attach(index: int, name: str, owner: LinalgOpDef)¶
- is_input() bool¶
- is_tensor() bool¶
- is_attribute() bool¶
- __hash__()¶
- __repr__()¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.TensorDef(type_var: mlir.dialects.linalg.opdsl.lang.types.TypeVar, *shape: mlir.dialects.linalg.opdsl.lang.scalar_expr.AffineExprDef, index_dims: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef]] = None, output: bool = False)¶
Tensor operand definition.
Tensor operands are indexed using the associated indexing_map when forwarded to the body of the structured op. A unique name identifies the tensor operands and an index determines their position in the operation’s parameter list. A tensor definition takes type, a shape, and an optional flag to mark output tensors. Additionally, a tuple of index dimensions may be used to map the tensor to the loop dimensions of the operation. This mapping is needed to compute the indexing map of shape-only tensors that have no uses.
- operand_def¶
- __getitem__(dims: mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[mlir.dialects.linalg.opdsl.lang.scalar_expr.AffineExprDef]) TensorUse¶
- __setitem__(dims: mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[mlir.dialects.linalg.opdsl.lang.scalar_expr.AffineExprDef], value: TensorExpression)¶
Creates a new 1:1 comprehension by binding this tensor to an expression.
Note that due to the way assignment works in Python, we have to capture direct assignment as a setitem on the TensorDef.
- class mlir.dialects.linalg.opdsl.lang.comprehension.ScalarDef(type_var: mlir.dialects.linalg.opdsl.lang.types.TypeVar)¶
Bases:
TensorExpressionScalar operand definition.
Scalar operands are forwarded to the body of the structured op as they are. A unique name identifies the scalars and an index determines their position in the operation’s parameter list.
- operand_def¶
- property scalar_name: str¶
- to_scalar_expression() mlir.dialects.linalg.opdsl.lang.scalar_expr.ScalarExpression¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.IndexAttrDef(*sizes: mlir.dialects.linalg.opdsl.lang.scalar_expr.SymbolDef, default: mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[int])¶
Index attribute definition.
Index attributes provide a way to define and set symbols that can be used in indexing expressions. Every attribute specifies a tuple of symbols that at compile-time are replaced by integer values as well as their default values.
- operand_def¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.UnaryFnAttrDef(default: UnaryFnType)¶
Unary function attribute definition.
Unary function attributes provide a way to make the arithmetic computation parametrizable. Every attribute specifies a default unary function that may be overwritten at operation instantiation time.
- operand_def¶
- __call__(arg: TensorExpression) TensorFn¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.BinaryFnAttrDef(default: BinaryFnType)¶
Binary function attribute definition.
Binary function attributes provide a way to make the arithmetic computation parametrizable. Every attribute specifies a default binary function that may be overwritten at operation instantiation time.
- operand_def¶
- __call__(arg0: TensorExpression, arg1: TensorExpression) TensorFn¶
- __getitem__(reduce_dims: mlir.dialects.linalg.opdsl.lang.scalar_expr.Tuple[mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef]) ReduceFnUse¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.TernaryFnAttrDef(default: TernaryFnType)¶
Ternary function attribute definition.
Ternary function attributes provide a way to make the arithmetic computation parametrizable. Every attribute specifies a default Ternary function that may be overwritten at operation instantiation time.
- operand_def¶
- __call__(arg0: TensorExpression, arg1: TensorExpression) TensorFn¶
- __getitem__(reduce_dims: mlir.dialects.linalg.opdsl.lang.scalar_expr.Tuple[mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef]) ReduceFnUse¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.TypeFnAttrDef(default: TypeFnType)¶
Type conversion function attribute definition.
Type conversion function attributes provide a way to make type conversions parameterizable. Every attribute specifies a default type conversion function that may be overwritten at operation instantiation time.
- operand_def¶
- __call__(type_var: mlir.dialects.linalg.opdsl.lang.types.TypeVar, arg: TensorExpression) TensorFn¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.Comprehension(*bindings: mlir.dialects.linalg.opdsl.lang.scalar_expr.Tuple[TensorUse, TensorExpression])¶
Represents a single comprehension.
- definitions = []¶
- values = []¶
- property all_reduction_dims: mlir.dialects.linalg.opdsl.lang.scalar_expr.Set[mlir.dialects.linalg.opdsl.lang.scalar_expr.Tuple[mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef, Ellipsis]]¶
Gets the reduction dims for the comprehension or None.
- __repr__()¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.OpInterfaceDef(cpp_name: str)¶
An interface that an op implements.
- cpp_name¶
- mlir.dialects.linalg.opdsl.lang.comprehension.ContractionOpInterface¶
- mlir.dialects.linalg.opdsl.lang.comprehension.ConvolutionOpInterface¶
- mlir.dialects.linalg.opdsl.lang.comprehension.FillOpInterface¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.OpDefinitionDef(def_name: str)¶
A method that an op implements.
- def_name¶
- mlir.dialects.linalg.opdsl.lang.comprehension.Canonicalizer¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.OpMetadataDef(name: str, cpp_class_name: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[str], doc: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[str])¶
Bases:
mlir.dialects.linalg.opdsl.lang.yaml_helper.YAMLObjectMetadata about the op (generally not behavior impacting).
- yaml_tag = '!LinalgOpMetadata'¶
- name¶
- cpp_class_name¶
- doc¶
- implements: mlir.dialects.linalg.opdsl.lang.scalar_expr.List[OpInterfaceDef] = []¶
- defines: mlir.dialects.linalg.opdsl.lang.scalar_expr.List[OpDefinitionsDef] = []¶
- to_yaml_custom_dict()¶
- class mlir.dialects.linalg.opdsl.lang.comprehension.LinalgOpDef(name: str, cpp_class_name: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[str] = None, doc: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[str] = None)¶
Definition of a linalg op.
- metadata¶
- registered_operands: mlir.dialects.linalg.opdsl.lang.types.Dict[str, OperandDef]¶
- domain: mlir.dialects.linalg.opdsl.lang.scalar_expr.List[mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef] = []¶
- comprehensions: mlir.dialects.linalg.opdsl.lang.scalar_expr.List[Comprehension] = []¶
- _affine_state¶
- add_operand(name: str, operand: OperandDef)¶
Registers an operand.
- __repr__()¶