mlir.dialects.linalg.opdsl.lang.comprehension ============================================= .. py:module:: mlir.dialects.linalg.opdsl.lang.comprehension .. autoapi-nested-parse:: Model classes representing a tensor comprehension. These classes model the language more at an AST level as evaluated. Reasoning about it typically involves processing this form into config objects that represent actual op definitions (i.e. YAML). Attributes ---------- .. autoapisummary:: mlir.dialects.linalg.opdsl.lang.comprehension.ContractionOpInterface mlir.dialects.linalg.opdsl.lang.comprehension.ConvolutionOpInterface mlir.dialects.linalg.opdsl.lang.comprehension.FillOpInterface mlir.dialects.linalg.opdsl.lang.comprehension.Canonicalizer Classes ------- .. autoapisummary:: mlir.dialects.linalg.opdsl.lang.comprehension.TensorExpression mlir.dialects.linalg.opdsl.lang.comprehension.TensorUse mlir.dialects.linalg.opdsl.lang.comprehension.TensorFn mlir.dialects.linalg.opdsl.lang.comprehension.TensorReduceFn mlir.dialects.linalg.opdsl.lang.comprehension.const mlir.dialects.linalg.opdsl.lang.comprehension.index mlir.dialects.linalg.opdsl.lang.comprehension.FunctionKind mlir.dialects.linalg.opdsl.lang.comprehension.UnaryFnType mlir.dialects.linalg.opdsl.lang.comprehension.UnaryFn mlir.dialects.linalg.opdsl.lang.comprehension.BinaryFnType mlir.dialects.linalg.opdsl.lang.comprehension.BinaryFn mlir.dialects.linalg.opdsl.lang.comprehension.TernaryFnType mlir.dialects.linalg.opdsl.lang.comprehension.TernaryFn mlir.dialects.linalg.opdsl.lang.comprehension.TypeFnType mlir.dialects.linalg.opdsl.lang.comprehension.TypeFn mlir.dialects.linalg.opdsl.lang.comprehension.ReduceFnUse mlir.dialects.linalg.opdsl.lang.comprehension.ReduceFnType mlir.dialects.linalg.opdsl.lang.comprehension.ReduceFn mlir.dialects.linalg.opdsl.lang.comprehension.OperandKind mlir.dialects.linalg.opdsl.lang.comprehension.OperandDef mlir.dialects.linalg.opdsl.lang.comprehension.TensorDef mlir.dialects.linalg.opdsl.lang.comprehension.ScalarDef mlir.dialects.linalg.opdsl.lang.comprehension.IndexAttrDef mlir.dialects.linalg.opdsl.lang.comprehension.UnaryFnAttrDef mlir.dialects.linalg.opdsl.lang.comprehension.BinaryFnAttrDef mlir.dialects.linalg.opdsl.lang.comprehension.TernaryFnAttrDef mlir.dialects.linalg.opdsl.lang.comprehension.TypeFnAttrDef mlir.dialects.linalg.opdsl.lang.comprehension.Comprehension mlir.dialects.linalg.opdsl.lang.comprehension.OpInterfaceDef mlir.dialects.linalg.opdsl.lang.comprehension.OpDefinitionDef mlir.dialects.linalg.opdsl.lang.comprehension.OpMetadataDef mlir.dialects.linalg.opdsl.lang.comprehension.LinalgOpDef Module Contents --------------- .. py:class:: TensorExpression An expression that can appear on the RHS of a comprehension. .. py:method:: to_scalar_expression() -> mlir.dialects.linalg.opdsl.lang.scalar_expr.ScalarExpression :abstractmethod: .. py:method:: visit_tensor_exprs(callback: mlir.dialects.linalg.opdsl.lang.scalar_expr.Callable[[TensorExpression], None]) Visits all tensor expression reachable by the expression. .. py:method:: collect_dim_uses(uses: mlir.dialects.linalg.opdsl.lang.scalar_expr.Set[mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef]) Collects all DimDefs reachable through this expression. .. py:method:: collect_tensor_uses(uses: mlir.dialects.linalg.opdsl.lang.scalar_expr.Set[TensorUse]) Collects all TensorUses reachable through this expression. .. py:method:: collect_indices(indices: mlir.dialects.linalg.opdsl.lang.scalar_expr.Set[index]) Collects all index accesses reachable through this expression. .. py:method:: collect_scalar_uses(uses: mlir.dialects.linalg.opdsl.lang.scalar_expr.Set[ScalarDef]) Collects all ScalarDefs reachable through this expression. .. py:method:: __add__(rhs: TensorExpression) -> TensorExpression .. py:method:: __mul__(rhs) -> TensorExpression .. py:method:: __sub__(rhs) -> TensorExpression .. py:method:: __truediv__(rhs) -> TensorExpression .. py:method:: __hash__() .. py:class:: TensorUse(operand_def: OperandDef, indices: mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[mlir.dialects.linalg.opdsl.lang.scalar_expr.AffineExprDef]) Bases: :py:obj:`TensorExpression` A used tensor represented by its (tensor_name, indices). Note that forming a comprehension via direct assignment is performed through **setitem** on the TensorDef level. However, performing a reduction with compound ops (+=, *=, etc) is done by doing a: TensorDef.**getitem** TensorUse.**iadd** TensorDef.**setitem** .. py:attribute:: operand_def .. py:attribute:: indices .. py:method:: to_scalar_expression() -> mlir.dialects.linalg.opdsl.lang.scalar_expr.ScalarExpression .. py:property:: tensor_name :type: str .. py:method:: _compute_reduce_dims(rhs: TensorExpression) -> mlir.dialects.linalg.opdsl.lang.scalar_expr.Set[mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef] .. py:method:: __iadd__(rhs: TensorExpression) -> TensorReduceFn .. py:method:: __repr__() .. py:class:: TensorFn(kind: FunctionKind, name: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[str], operand_def: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[OperandDef], type_var: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[mlir.dialects.linalg.opdsl.lang.types.TypeVar], args: mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[TensorExpression]) Bases: :py:obj:`TensorExpression` Application of a tensor function. .. py:attribute:: name .. py:attribute:: kind .. py:attribute:: operand_def .. py:attribute:: type_var .. py:attribute:: args .. py:method:: to_scalar_expression() -> mlir.dialects.linalg.opdsl.lang.scalar_expr.ScalarExpression .. py:method:: visit_tensor_exprs(callback: mlir.dialects.linalg.opdsl.lang.scalar_expr.Callable[[TensorExpression], None]) Visits all tensor expression reachable by the expression. .. py:method:: __repr__() .. py:class:: TensorReduceFn(reduce_use: ReduceFnUse, args: mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[TensorExpression]) Bases: :py:obj:`TensorExpression` Application of a reduction function. This captures the lhs (initial value) separately from the rhs. .. py:attribute:: reduce_use .. py:attribute:: lhs :type: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[TensorUse] :value: None .. py:attribute:: args .. py:method:: to_scalar_expression() -> mlir.dialects.linalg.opdsl.lang.scalar_expr.ScalarExpression .. py:method:: visit_tensor_exprs(callback: mlir.dialects.linalg.opdsl.lang.scalar_expr.Callable[[TensorExpression], None]) Visits all tensor expression reachable by the expression. .. py:method:: __repr__() .. py:class:: const(value: mlir.dialects.linalg.opdsl.lang.scalar_expr.Any) Bases: :py:obj:`TensorExpression` Returns the given constant floating point or integer value. .. py:method:: to_scalar_expression() -> mlir.dialects.linalg.opdsl.lang.scalar_expr.ScalarExpression .. py:method:: __repr__() .. py:class:: index(dim: mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef) Bases: :py:obj:`TensorExpression` Returns the iteration index for a given dimension name. Resolves the given dimension name to obtain its position in the iteration domain of the operation. .. py:attribute:: dim_def .. py:attribute:: dim :value: -1 .. py:method:: resolve_dimension_name(affine_state: mlir.dialects.linalg.opdsl.lang.scalar_expr.AffineBuildState) .. py:method:: to_scalar_expression() -> mlir.dialects.linalg.opdsl.lang.scalar_expr.ScalarExpression .. py:method:: __repr__() .. py:class:: FunctionKind Bases: :py:obj:`mlir.dialects.linalg.opdsl.lang.types.Enum` Generic enumeration. Derive from this class to define new enumerations. .. py:attribute:: UNARY :value: 0 .. py:attribute:: BINARY :value: 1 .. py:attribute:: TERNARY :value: 2 .. py:attribute:: TYPE :value: 3 .. py:class:: UnaryFnType(fn_name: str) Unary function. A unary function takes one tensor expression and returns the function evaluation result. .. py:attribute:: fn_name .. py:method:: __call__(arg: TensorExpression) -> TensorFn .. py:method:: __repr__() .. py:class:: UnaryFn Unary function namespace. .. py:attribute:: exp .. py:attribute:: log .. py:attribute:: abs .. py:attribute:: ceil .. py:attribute:: floor .. py:attribute:: negf .. py:attribute:: reciprocal .. py:attribute:: round .. py:attribute:: sqrt .. py:attribute:: rsqrt .. py:attribute:: square .. py:attribute:: tanh .. py:attribute:: erf .. py:class:: BinaryFnType(fn_name: str) Binary function. A binary function takes two tensor expressions and returns the function evaluation result. .. py:attribute:: fn_name .. py:method:: __call__(arg0: TensorExpression, arg1: TensorExpression) -> TensorFn .. py:method:: __repr__() .. py:class:: BinaryFn Binary function namespace. As the integer types are signless, signedness is implement by different functions that treat integers as signed or unsigned values. Examples: * max -> ``arith.MaxSIOp`` * max_unsigned -> ``arith.MaxUIOp`` .. py:attribute:: add .. py:attribute:: sub .. py:attribute:: mul .. py:attribute:: div .. py:attribute:: div_unsigned .. py:attribute:: max_signed .. py:attribute:: min_signed .. py:attribute:: max_unsigned .. py:attribute:: min_unsigned .. py:attribute:: powf .. py:class:: TernaryFnType(fn_name: str) Ternary function. A ternary function takes three tensor expressions and returns the function evaluation result. .. py:attribute:: fn_name .. py:method:: __call__(arg0: TensorExpression, arg1: TensorExpression, arg2: TensorExpression) -> TensorFn .. py:method:: __repr__() .. py:class:: TernaryFn Ternary function namespace. .. py:attribute:: select .. py:class:: TypeFnType(fn_name: str) Type conversion function. A type conversion function takes a target type and a tensor expression and returns the casted tensor expression. .. py:attribute:: fn_name .. py:method:: __call__(type_var: mlir.dialects.linalg.opdsl.lang.types.TypeVar, arg: TensorExpression) -> TensorFn .. py:method:: __repr__() .. py:class:: TypeFn Type conversion function namespace. As the integer types are signless, signedness is implement by different cast functions that treat integers as signed (``cast_signed``) or unsigned (``cast_unsigned``) values. Examples: * cast_signed(I32 -> I64) -> ``arith.ExtSIOp`` * cast_unsigned(I32 -> I64) -> ``arith.ExtUIOp`` .. py:attribute:: cast_signed .. py:attribute:: cast_unsigned .. py:class:: ReduceFnUse(binary_fn: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[BinaryFnType], binary_attr: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[BinaryFnAttrDef], *reduce_dims: mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef) Reduction function use. A reduction use specifies the reduction function and dimensions. .. py:attribute:: binary_fn .. py:attribute:: binary_attr .. py:attribute:: reduce_dims :value: () .. py:method:: __call__(*args: TensorExpression) -> TensorReduceFn .. py:method:: __repr__() .. py:class:: ReduceFnType(binary_fn: BinaryFnType) Reduction function. A binary function that reduces its RHS into its LHS. .. py:attribute:: binary_fn .. py:method:: __getitem__(reduce_dims: mlir.dialects.linalg.opdsl.lang.scalar_expr.Tuple[mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef]) -> ReduceFnUse .. py:method:: __repr__() .. py:class:: ReduceFn .. py:attribute:: add .. py:attribute:: mul .. py:attribute:: max_signed .. py:attribute:: min_signed .. py:attribute:: max_unsigned .. py:attribute:: min_unsigned .. py:class:: OperandKind Bases: :py:obj:`mlir.dialects.linalg.opdsl.lang.types.Enum` Generic enumeration. Derive from this class to define new enumerations. .. py:attribute:: INPUT_TENSOR :value: 0 .. py:attribute:: SCALAR :value: 1 .. py:attribute:: OUTPUT_TENSOR :value: 2 .. py:attribute:: INDEX_ATTR :value: 3 .. py:attribute:: UNARY_FN_ATTR :value: 4 .. py:attribute:: BINARY_FN_ATTR :value: 5 .. py:attribute:: TERNARY_FN_ATTR :value: 6 .. py:attribute:: TYPE_FN_ATTR :value: 7 .. py:class:: OperandDef(kind: OperandKind, type_var: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[mlir.dialects.linalg.opdsl.lang.types.TypeVar] = None, size_exprs: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[mlir.dialects.linalg.opdsl.lang.scalar_expr.AffineExprDef]] = None, index_dims: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef]] = None, default_indices: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[int]] = None, default_fn: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[str] = None) Definition of an operand passed to an operation. Keep the meta information of Tensor, Scalar, and Attribute operands and provide the shared registration functionality. .. py:attribute:: owner :type: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[LinalgOpDef] :value: None .. py:attribute:: type_var :value: None .. py:attribute:: size_exprs :value: None .. py:attribute:: index_dims :value: None .. py:attribute:: default_indices :value: None .. py:attribute:: default_fn :value: None .. py:attribute:: kind .. py:attribute:: name :type: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[str] :value: None .. py:attribute:: registered_index :type: int :value: -1 .. py:method:: attach(index: int, name: str, owner: LinalgOpDef) .. py:method:: is_input() -> bool .. py:method:: is_tensor() -> bool .. py:method:: is_attribute() -> bool .. py:method:: __hash__() .. py:method:: __repr__() .. py:class:: TensorDef(type_var: mlir.dialects.linalg.opdsl.lang.types.TypeVar, *shape: mlir.dialects.linalg.opdsl.lang.scalar_expr.AffineExprDef, index_dims: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef]] = None, output: bool = False) Tensor operand definition. Tensor operands are indexed using the associated indexing_map when forwarded to the body of the structured op. A unique name identifies the tensor operands and an index determines their position in the operation's parameter list. A tensor definition takes type, a shape, and an optional flag to mark output tensors. Additionally, a tuple of index dimensions may be used to map the tensor to the loop dimensions of the operation. This mapping is needed to compute the indexing map of shape-only tensors that have no uses. .. py:attribute:: operand_def .. py:method:: __getitem__(dims: mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[mlir.dialects.linalg.opdsl.lang.scalar_expr.AffineExprDef]) -> TensorUse .. py:method:: __setitem__(dims: mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[mlir.dialects.linalg.opdsl.lang.scalar_expr.AffineExprDef], value: TensorExpression) Creates a new 1:1 comprehension by binding this tensor to an expression. Note that due to the way assignment works in Python, we have to capture direct assignment as a setitem on the TensorDef. .. py:class:: ScalarDef(type_var: mlir.dialects.linalg.opdsl.lang.types.TypeVar) Bases: :py:obj:`TensorExpression` Scalar operand definition. Scalar operands are forwarded to the body of the structured op as they are. A unique name identifies the scalars and an index determines their position in the operation's parameter list. .. py:attribute:: operand_def .. py:property:: scalar_name :type: str .. py:method:: to_scalar_expression() -> mlir.dialects.linalg.opdsl.lang.scalar_expr.ScalarExpression .. py:class:: IndexAttrDef(*sizes: mlir.dialects.linalg.opdsl.lang.scalar_expr.SymbolDef, default: mlir.dialects.linalg.opdsl.lang.scalar_expr.Sequence[int]) Index attribute definition. Index attributes provide a way to define and set symbols that can be used in indexing expressions. Every attribute specifies a tuple of symbols that at compile-time are replaced by integer values as well as their default values. .. py:attribute:: operand_def .. py:class:: UnaryFnAttrDef(default: UnaryFnType) Unary function attribute definition. Unary function attributes provide a way to make the arithmetic computation parametrizable. Every attribute specifies a default unary function that may be overwritten at operation instantiation time. .. py:attribute:: operand_def .. py:method:: __call__(arg: TensorExpression) -> TensorFn .. py:class:: BinaryFnAttrDef(default: BinaryFnType) Binary function attribute definition. Binary function attributes provide a way to make the arithmetic computation parametrizable. Every attribute specifies a default binary function that may be overwritten at operation instantiation time. .. py:attribute:: operand_def .. py:method:: __call__(arg0: TensorExpression, arg1: TensorExpression) -> TensorFn .. py:method:: __getitem__(reduce_dims: mlir.dialects.linalg.opdsl.lang.scalar_expr.Tuple[mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef]) -> ReduceFnUse .. py:class:: TernaryFnAttrDef(default: TernaryFnType) Ternary function attribute definition. Ternary function attributes provide a way to make the arithmetic computation parametrizable. Every attribute specifies a default Ternary function that may be overwritten at operation instantiation time. .. py:attribute:: operand_def .. py:method:: __call__(arg0: TensorExpression, arg1: TensorExpression) -> TensorFn .. py:method:: __getitem__(reduce_dims: mlir.dialects.linalg.opdsl.lang.scalar_expr.Tuple[mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef]) -> ReduceFnUse .. py:class:: TypeFnAttrDef(default: TypeFnType) Type conversion function attribute definition. Type conversion function attributes provide a way to make type conversions parameterizable. Every attribute specifies a default type conversion function that may be overwritten at operation instantiation time. .. py:attribute:: operand_def .. py:method:: __call__(type_var: mlir.dialects.linalg.opdsl.lang.types.TypeVar, arg: TensorExpression) -> TensorFn .. py:class:: Comprehension(*bindings: mlir.dialects.linalg.opdsl.lang.scalar_expr.Tuple[TensorUse, TensorExpression]) Represents a single comprehension. .. py:attribute:: definitions :value: [] .. py:attribute:: values :value: [] .. py:property:: all_reduction_dims :type: mlir.dialects.linalg.opdsl.lang.scalar_expr.Set[mlir.dialects.linalg.opdsl.lang.scalar_expr.Tuple[mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef, Ellipsis]] Gets the reduction dims for the comprehension or None. .. py:method:: __repr__() .. py:class:: OpInterfaceDef(cpp_name: str) An interface that an op implements. .. py:attribute:: cpp_name .. py:data:: ContractionOpInterface .. py:data:: ConvolutionOpInterface .. py:data:: FillOpInterface .. py:class:: OpDefinitionDef(def_name: str) A method that an op implements. .. py:attribute:: def_name .. py:data:: Canonicalizer .. py:class:: OpMetadataDef(name: str, cpp_class_name: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[str], doc: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[str]) Bases: :py:obj:`mlir.dialects.linalg.opdsl.lang.yaml_helper.YAMLObject` Metadata about the op (generally not behavior impacting). .. py:attribute:: yaml_tag :value: '!LinalgOpMetadata' .. py:attribute:: name .. py:attribute:: cpp_class_name .. py:attribute:: doc .. py:attribute:: implements :type: mlir.dialects.linalg.opdsl.lang.scalar_expr.List[OpInterfaceDef] :value: [] .. py:attribute:: defines :type: mlir.dialects.linalg.opdsl.lang.scalar_expr.List[OpDefinitionsDef] :value: [] .. py:method:: to_yaml_custom_dict() .. py:class:: LinalgOpDef(name: str, cpp_class_name: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[str] = None, doc: mlir.dialects.linalg.opdsl.lang.scalar_expr.Optional[str] = None) Definition of a linalg op. .. py:attribute:: metadata .. py:attribute:: registered_operands :type: mlir.dialects.linalg.opdsl.lang.types.Dict[str, OperandDef] .. py:attribute:: domain :type: mlir.dialects.linalg.opdsl.lang.scalar_expr.List[mlir.dialects.linalg.opdsl.lang.scalar_expr.DimDef] :value: [] .. py:attribute:: comprehensions :type: mlir.dialects.linalg.opdsl.lang.scalar_expr.List[Comprehension] :value: [] .. py:attribute:: _affine_state .. py:method:: add_operand(name: str, operand: OperandDef) Registers an operand. .. py:method:: __repr__()