MLIR

Multi-Level IR Compiler Framework

Passes

This document describes the available MLIR passes and their contracts.

General Transformation Passes 

-canonicalize: Canonicalize operations 

This pass performs various types of canonicalizations over a set of operations by iteratively applying the canonicalization patterns of all loaded dialects until either a fixpoint is reached or the maximum number of iterations/rewrites is exhausted. Canonicalization is best-effort and does not guarantee that the entire IR is in a canonical form after running this pass. See Operation Canonicalization for more details.

Options 

-top-down         : Seed the worklist in general top-down order
-region-simplify  : Perform control flow optimizations to the region tree
-max-iterations   : Max. iterations between applying patterns / simplifying regions
-max-num-rewrites : Max. number of pattern rewrites within an iteration
-test-convergence : Test only: Fail pass on non-convergence to detect cyclic pattern
-disable-patterns : Labels of patterns that should be filtered out during application
-enable-patterns  : Labels of patterns that should be used during application, all other patterns are filtered out

-control-flow-sink: Sink operations into conditional blocks 

This pass implements control-flow sink on operations that implement RegionBranchOpInterface by moving dominating operations whose only uses are in a conditionally-executed regions into those regions so that executions paths where their results are not needed do not perform unnecessary computations.

This is similar (but opposite) to loop-invariant code motion, which hoists operations out of regions executed more than once. The implementation of control-flow sink uses a simple and conversative cost model: operations are never duplicated and are only moved into singly-executed regions.

It is recommended to run canonicalization first to remove unreachable blocks: ops in unreachable blocks may prevent other operations from being sunk as they may contain uses of their results

Statistics 

num-sunk : Number of operations sunk

-cse: Eliminate common sub-expressions 

This pass implements a generalized algorithm for common sub-expression elimination. This pass relies on information provided by the Memory SideEffect interface to identify when it is safe to eliminate operations. See Common subexpression elimination for more general details on this optimization.

Statistics 

num-cse'd : Number of operations CSE'd
num-dce'd : Number of operations DCE'd

-generate-runtime-verification: Generate additional runtime op verification checks 

This pass generates op-specific runtime checks using the RuntimeVerifiableOpInterface. It can be run for debugging purposes after passes that are suspected to introduce faulty IR.

-inline: Inline function calls 

Options 

-default-pipeline : The default optimizer pipeline used for callables
-op-pipelines     : Callable operation specific optimizer pipelines (in the form of `dialect.op(pipeline)`)
-max-iterations   : Maximum number of iterations when inlining within an SCC

-loop-invariant-code-motion: Hoist loop invariant instructions outside of the loop 

-mem2reg: Promotes memory slots into values. 

This pass removes loads out of and stores into a memory slot, and turns them into direct uses of SSA values. This is done generically using the PromoteAllocationOpInterface, PromoteOpInterface and PromoteMemOpInterface interfaces.

This pass will attempt to compute which definitions of the content of the memory slot reach operations that use the memory slot pointer. It will rewire or remove operations that use the slot pointer so they no longer use it. If any of this is not possible, the IR will be left without mutation.

This pass only supports unstructured control-flow. Promotion of operations within subregions will not happen.

Options 

-region-simplify : Perform control flow optimizations to the region tree

Statistics 

promoted slots : Total amount of memory slot promoted
new block args : Total amount of new block argument inserted in blocks

-print-ir: Print IR on the debug stream 

Print the entire IR on the debug stream. This is meant for debugging purposes to inspect the IR at a specific point in the pipeline.

Options 

-label : Label

-print-op-stats: Print statistics of operations 

Options 

-json : print the stats as JSON

-sccp: Sparse Conditional Constant Propagation 

This pass implements a general algorithm for sparse conditional constant propagation. This algorithm detects values that are known to be constant and optimistically propagates this throughout the IR. Any values proven to be constant are replaced, and removed if possible.

This implementation is based on the algorithm described by Wegman and Zadeck in “Constant Propagation with Conditional Branches” (1991).

-snapshot-op-locations: Generate new locations from the current IR 

This pass allows for generating new locations from the IR during any stage of compilation, by snapshotting the IR to a file and using that file to generate new locations for the operations.

Depending on the value of the tag option, different resulting locations may be generated:

  • If unset, the original location of the operation is replaced.

Example:

// old:
... loc("original_source.cpp":1:1)

// new:
... loc("snapshot_source.mlir":10:10)
  • If set, the new location is fused with the original location in the form of a Name Location with the specified tag.

Example:

// old:
... loc("original_source.cpp":1:1)

// new:
... loc(fused["original_source.cpp":1:1, "snapshot"("snapshot_source.mlir":10:10)])

Options 

-filename : The filename to print the generated IR
-tag      : A tag to use when fusing the new locations with the original. If unset, the locations are replaced.

-sroa: Scalar Replacement of Aggregates 

Scalar Replacement of Aggregates. Replaces allocations of aggregates into independant allocations of its elements.

Allocators must implement DestructurableAllocationOpInterface to provide the list of memory slots for which destructuring should be attempted.

This pass will only be applied if all accessors of the aggregate implement the DestructurableAccessorOpInterface. If the accessors provide a view into the struct, users of the view must ensure it is used in a type-safe manner and within bounds by implementing TypeSafeOpInterface.

Statistics 

destructured slots        : Total amount of memory slots destructured
slots with memory benefit : Total amount of memory slots in which the destructured size was smaller than the total size after eliminating unused fields
max subelement number     : Maximal number of sub-elements a successfully destructured slot initially had

-strip-debuginfo: Strip debug info from all operations 

This pass strips the IR of any location information, by replacing all operation locations with unknown.

-symbol-dce: Eliminate dead symbols 

This pass deletes all symbols that are found to be unreachable. This is done by computing the set of operations that are known to be live, propagating that liveness to other symbols, and then deleting all symbols that are not within this live set. Live symbols are those that have a visibility that extends beyond the IR, e.g. public, or those that are referenced by live symbols or other non-Symbol operations.

For example, consider the following input:

func.func private @dead_private_function()
func.func private @live_private_function()

// Note: The `public` isn't necessary here, as this is the default.
func.func public @public_function() {
  "foo.return"() {uses = [@live_private_function]} : () -> ()
}

A known live function, public_function, contains a reference to an otherwise non-live function live_private_function. After running symbol-dce, only these two symbols should remain, as the final symbol dead_private_function is not visible outside of the current IR and there are no links to known-live operations. After running, we get the expected:

func.func private @live_private_function()

func.func public @public_function() {
  "foo.return"() {uses = [@live_private_function]} : () -> ()
}

See Symbols and SymbolTables for more information on Symbols.

Statistics 

num-dce'd : Number of symbols DCE'd

-symbol-privatize: Mark symbols private 

This pass marks all top-level symbols of the operation run as private except if listed in exclude pass option.

Options 

-exclude : Comma separated list of symbols that should not be marked private

-topological-sort: Sort regions without SSA dominance in topological order 

Recursively sorts all nested regions without SSA dominance in topological order. The main purpose is readability, as well as potentially processing of certain transformations and analyses. The function sorts the operations in all nested regions such that, as much as possible, all users appear after their producers.

This sort is stable. If the block is already topologically sorted, the IR is not changed. Operations that form a cycle are moved to the end of the regions in a stable order.

-view-op-graph: Print Graphviz visualization of an operation 

This pass prints a Graphviz graph of a module.

  • Operations are represented as nodes;
  • Uses (data flow) as edges;
  • Control flow as dashed edges;
  • Regions/blocks as subgraphs.

By default, only data flow edges are printed.

Note: See https://www.graphviz.org/doc/info/lang.html for more information about the Graphviz DOT language.

Options 

-max-label-len            : Limit attribute/type length to number of chars
-print-attrs              : Print attributes of operations
-print-control-flow-edges : Print control flow edges
-print-data-flow-edges    : Print data flow edges
-print-result-types       : Print result types of operations

Bufferization Passes 

-buffer-deallocation: Adds all required dealloc operations for all allocations in the input program 

This pass implements an algorithm to automatically introduce all required deallocation operations for all buffers in the input program. This ensures that the resulting program does not have any memory leaks.

Input

#map0 = affine_map<(d0) -> (d0)>
module {
  func.func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
    cf.cond_br %arg0, ^bb1, ^bb2
  ^bb1:
    cf.br ^bb3(%arg1 : memref<2xf32>)
  ^bb2:
    %0 = memref.alloc() : memref<2xf32>
    linalg.generic {
      args_in = 1 : i64,
      args_out = 1 : i64,
      indexing_maps = [#map0, #map0],
      iterator_types = ["parallel"]} %arg1, %0 {
    ^bb0(%gen1_arg0: f32, %gen1_arg1: f32):
      %tmp1 = exp %gen1_arg0 : f32
      linalg.yield %tmp1 : f32
    }: memref<2xf32>, memref<2xf32>
    cf.br ^bb3(%0 : memref<2xf32>)
  ^bb3(%1: memref<2xf32>):
    "memref.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
    return
  }
}

Output

#map0 = affine_map<(d0) -> (d0)>
module {
  func.func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
    cf.cond_br %arg0, ^bb1, ^bb2
  ^bb1:  // pred: ^bb0
    %0 = memref.alloc() : memref<2xf32>
    memref.copy(%arg1, %0) : memref<2xf32>, memref<2xf32>
    cf.br ^bb3(%0 : memref<2xf32>)
  ^bb2:  // pred: ^bb0
    %1 = memref.alloc() : memref<2xf32>
    linalg.generic {
      args_in = 1 : i64,
      args_out = 1 : i64,
      indexing_maps = [#map0, #map0],
      iterator_types = ["parallel"]} %arg1, %1 {
    ^bb0(%arg3: f32, %arg4: f32):
      %4 = exp %arg3 : f32
      linalg.yield %4 : f32
    }: memref<2xf32>, memref<2xf32>
    %2 = memref.alloc() : memref<2xf32>
    memref.copy(%1, %2) : memref<2xf32>, memref<2xf32>
    dealloc %1 : memref<2xf32>
    cf.br ^bb3(%2 : memref<2xf32>)
  ^bb3(%3: memref<2xf32>):  // 2 preds: ^bb1, ^bb2
    memref.copy(%3, %arg2) : memref<2xf32>, memref<2xf32>
    dealloc %3 : memref<2xf32>
    return
  }

}

-buffer-hoisting: Optimizes placement of allocation operations by moving them into common dominators and out of nested regions 

This pass implements an approach to aggressively move allocations upwards into common dominators and out of nested regions.

-buffer-loop-hoisting: Optimizes placement of allocation operations by moving them out of loop nests 

This pass implements an approach to aggressively move allocations upwards out of loop nests. It does not move allocations into common dominators.

-buffer-results-to-out-params: Converts memref-typed function results to out-params 

Some calling conventions prefer to pass output memrefs as “out params”. The conversion to this calling convention must be done as an atomic transformation of the entire program (hence this is a module pass).

For example, if a call is rewritten, the callee needs to be rewritten otherwise the IR will end up invalid. Thus, this transformation require an atomic change to the entire program (e.g. the whole module).

This pass is expected to run immediately after bufferization is finished. At that point, tensor-typed results will have been converted to memref-typed results, and can be consistently converted to out params.

All memref-typed results are appended to the function argument list.

The main issue with this pass (and the out-param calling convention) is that buffers for results need to be allocated in the caller. This currently only works for static shaped memrefs.

-bufferization-bufferize: Bufferize the bufferization dialect 

-drop-equivalent-buffer-results: Remove MemRef return values that are equivalent to a bbArg 

This pass removes MemRef return values from functions if they are equivalent to a function bbArg. In that case, the return value is redundant and the respective CallOp operand can be used at the call site.

Note: If a bbArg buffer is not returned directly but casted to beforehand, the buffer is still considered equivalent.

-eliminate-empty-tensors: Try to eliminate all tensor.empty ops. 

This pass tries to eliminate all insert_slice op-anchored tensor.empty ops. I.e., when a value that is equivalent to an tensor.empty op is inserted into another tensor, this pass tries to rewrite the IR in such a way that the destination tensor of the insert_slice op is used directly instead of the tensor.empty result.

-empty-tensor-to-alloc-tensor: Replace all empty ops by alloc_tensor ops. 

tensor.empty ops return a tensor of unspecified contents who’s only purpose is to carry the tensor shape. This pass converts such ops to bufferization.alloc_tensor ops, which bufferize to buffer allocations.

-finalizing-bufferize: Finalize a partial bufferization 

A bufferize pass that finalizes a partial bufferization by removing remaining bufferization.to_tensor and bufferization.to_buffer operations.

The removal of those operations is only possible if the operations only exist in pairs, i.e., all uses of bufferization.to_tensor operations are bufferization.to_buffer operations.

This pass will fail if not all operations can be removed or if any operation with tensor typed operands remains.

-one-shot-bufferize: One-Shot Bufferize 

This pass bufferizes all ops that implement BufferizableOpInterface. It first performs an inplacability analysis on SSA use-def chains of tensor values to determine which OpOperands may bufferize in-place, i.e., without inserting a buffer copy. It then rewrites the IR, inserting a buffer allocation and copy for each OpOperand that was decided to bufferize out-of-place.

One-Shot Bufferize (and BufferizableOpInterface) was designed for ops that are in destination-passing style. When bufferizing such ops, it is possible to reuse the buffer of a tensor OpOperand for a tensor OpResult. In essence, a possible destination of an operation is already passed as an SSA value.

tensor.insert is an example for an op in destination-passing style. E.g., when bufferizing %t0 = tensor.insert %f into %dest[%idx], buffer(%t0) is identical to buffer(%dest) in the absence of RaW conflicts. As a counter example, tensor.generate is not in destination-passing style and always results in a new buffer allocation.

One-Shot Bufferize deallocates all buffers that it allocates. Yielding newly allocated buffers from a block can lead to bad performance because additional buffer copies are often needed to make sure that every buffer allocation is also deallocated again. By default, such IR is rejected by One-Shot Bufferize. Such IR can be allowed with allow-return-allocs. In that case, the -buffer-deallocation pass should be run after One-Shot Bufferize. Note that new buffer allocations that are returned from a function can currently not be deallocated by -buffer-deallocation and leak.

One-Shot Bufferize will by default reject IR that contains non-bufferizable op, i.e., ops that do not implemement BufferizableOpInterface. Such IR can be allowed with allow-unknown-ops=1. In that case, to_memref and to_tensor ops will be generated at the bufferization boundary. This is useful for compatibility with existing partial bufferization passes: These can bufferize the remaining IR after running One-Shot Bufferize.

Note: Running One-Shot Bufferize after a partial bufferization pass is currently not supported. Running partial bufferization passes after running One-Shot Bufferize is supported and the recommended way to gradually migrate from partial bufferization to One-Shot Bufferize.

With dialect-filter, bufferization can be restricted to a set of dialects. If no filter is specified, all ops that implement BufferizableOpInterface are bufferized. Ops from the std dialect are an exception: These ops are always ignored, even if no filter is specified. When specifying a dialect filter and allow-unknown-ops is not turned on, bufferization would fail when encountering an op that is not included in the filter (even if it is bufferizable).

One-Shot Bufferize will by default assume memref types with fully dynamic layout maps when a precise layout cannot be inferred. E.g., this is the case when wrapping a non-bufferizable op in to_memref/to_tensor ops. This behavior can be overridden with unknown-type-conversion. Valid values are fully-dynamic-layout-map and identity-layout-map.

For testing/debugging purposes, test-analysis-only=1 print-conflicts=1 prints analysis results and explains why an OpOperand was decided to bufferize out-of-place. This is useful for understanding why One-Shot Bufferize chose to insert a certain buffer copy.

bufferize-function-boundaries is an experimental flag for bufferizing FuncOp, ReturnOp and CallOp. This feature is still under development and supports only simple cases at the moment. In particular:

  • Recursive or circular function call graphs are not supported.
  • External functions (without bodies) that return a tensor are not supported.
  • Function with multiple blocks or multiple ReturnOps are not supported.
  • Layout maps on function signatures can be controlled with a separate function-boundary-type-conversion option, which is similar to unknown-type-conversion but supports an additional infer-layout-map option. fully-dynamic-layout-map and identity-layout-map ensure that function signatures bufferize to easily predictable types, potentially at the cost of additional casts and copies, respectively. When layout maps are inferred, function return types may be more precise, but less predictable. Function argument types cannot be inferred and always have fully dynamic layout maps with infer-layout-map.

One-Shot Bufferize implements the following contract around function calls: The buffer of function arguments is always writable (unless annotated with bufferization.writable = false). A buffer copy may be inserted at the call site where necessary. Alias sets and equivalence info is propagated through function calls. Whenever a function is bufferized, all other functions that are being called were already analyzed and bufferized, so exact alias and equivalence information is available. This is why recursive function calls are not yet supported.

One-Shot Bufferize gathers additional information during the analysis phase when function boundary bufferization is activated. E.g., whether a function argument is read/written and which returned values are aliasing/equivalent. For debugging purposes, such information can be printed with test-analysis-only.

Options 

-allow-return-allocs               : Allows returning/yielding new allocations from a block.
-allow-unknown-ops                 : Allows unknown (not bufferizable) ops in the input IR.
-analysis-fuzzer-seed              : Test only: Analyze ops in random order with a given seed (fuzzer)
-analysis-heuristic                : Heuristic that control the IR traversal during analysis
-bufferize-function-boundaries     : Bufferize function boundaries (experimental).
-copy-before-write                 : Skip the analysis. Make a buffer copy on every write.
-create-deallocs                   : Specify if buffers should be deallocated. For compatibility with core bufferization passes.
-dialect-filter                    : Restrict bufferization to ops from these dialects.
-dump-alias-sets                   : Test only: Annotate tensor IR with alias sets
-no-analysis-func-filter           : Skip analysis of functions with these symbol names.Set copyBeforeWrite to true when bufferizing them.
-function-boundary-type-conversion : Controls layout maps when bufferizing function signatures.
-must-infer-memory-space           : The memory space of an memref types must always be inferred. If unset, a default memory space of 0 is used otherwise.
-test-analysis-only                : Test only: Only run inplaceability analysis and annotate IR
-print-conflicts                   : Test only: Annotate IR with RaW conflicts. Requires test-analysis-only.
-unknown-type-conversion           : Controls layout maps for non-inferrable memref types.

Statistics 

num-buffer-alloc        : Number of buffer allocations
num-buffer-dealloc      : Number of buffer deallocations
num-tensor-in-place     : Number of in-place tensor OpOperands
num-tensor-out-of-place : Number of out-of-place tensor OpOperands

-promote-buffers-to-stack: Promotes heap-based allocations to automatically managed stack-based allocations 

This pass implements a simple algorithm to convert heap-based memory allocations to stack-based ones. It uses a built-in heuristic to decide whether it makes sense to convert an allocation. Furthermore, dynamic shaped buffers that are limited by the rank of the tensor can be converted. They are only transformed if they are considered to be small.

Options 

-max-alloc-size-in-bytes      : Maximal size in bytes to promote allocations to stack.
-max-rank-of-allocated-memref : Maximal memref rank to promote dynamic buffers.

Conversion Passes 

-arm-neon-2d-to-intr: Convert Arm NEON structured ops to intrinsics 

-convert-affine-for-to-gpu: Convert top-level AffineFor Ops to GPU kernels 

Options 

-gpu-block-dims  : Number of GPU block dimensions for mapping
-gpu-thread-dims : Number of GPU thread dimensions for mapping

-convert-amdgpu-to-rocdl: Convert AMDGPU dialect to ROCDL dialect 

This pass converts supported AMDGPU ops to ROCDL dialect intrinsics.

Options 

-chipset : Chipset that these operations will run on

-convert-arith-to-llvm: Convert Arith dialect to LLVM dialect 

This pass converts supported Arith ops to LLVM dialect instructions.

Options 

-index-bitwidth : Bitwidth of the index type, 0 to use size of machine word

-convert-arith-to-spirv: Convert Arith dialect to SPIR-V dialect 

Options 

-emulate-lt-32-bit-scalar-types : Emulate narrower scalar types with 32-bit ones if not supported by the target
-enable-fast-math               : Enable fast math mode (assuming no NaN and infinity for floating point values) when performing conversion

-convert-async-to-llvm: Convert the operations from the async dialect into the LLVM dialect 

Convert async.execute operations to LLVM coroutines and use async runtime API to execute them.

Options 

-use-opaque-pointers : Generate LLVM IR using opaque pointers instead of typed pointers

-convert-bufferization-to-memref: Convert operations from the Bufferization dialect to the MemRef dialect 

This pass converts bufferization operations into memref operations.

In the current state, this pass only transforms a bufferization.clone operation into memref.alloc and memref.copy operations. This conversion is needed, since some clone operations could remain after applying several transformation processes. Currently, only canonicalize transforms clone operations or even eliminates them. This can lead to errors if any clone op survived after all conversion passes (starting from the bufferization dialect) are performed.

See: https://llvm.discourse.group/t/bufferization-error-related-to-memref-clone/4665

To avoid these errors, this pass can be performed as a last clean-up pass to transform remaining operations and to proceed in other dialects (memref e.g.).

Note that this pass only transforms the operation without any further analyses. This pass does not consider any memory analysis or optimization and hence does not resolve any memory leaks.

-convert-cf-to-llvm: Convert ControlFlow operations to the LLVM dialect 

Convert ControlFlow operations into LLVM IR dialect operations.

If other operations are present and their results are required by the LLVM IR dialect operations, the pass will fail. Any LLVM IR operations or types already present in the IR will be kept as is.

Options 

-index-bitwidth      : Bitwidth of the index type, 0 to use size of machine word
-use-opaque-pointers : Generate LLVM IR using opaque pointers instead of typed pointers

-convert-cf-to-spirv: Convert ControlFlow dialect to SPIR-V dialect 

Options 

-emulate-lt-32-bit-scalar-types : Emulate narrower scalar types with 32-bit ones if not supported by the target

-convert-complex-to-libm: Convert Complex dialect to libm calls 

This pass converts supported Complex ops to libm calls.

-convert-complex-to-llvm: Convert Complex dialect to LLVM dialect 

-convert-complex-to-spirv: Convert Complex dialect to SPIRV dialect 

-convert-complex-to-standard: Convert Complex dialect to standard dialect 

-convert-func-to-llvm: Convert from the Func dialect to the LLVM dialect 

Convert Func dialect operations into the LLVM IR dialect operations.

Input invariant 

  • no tensor types;
  • all vector are one-dimensional;
  • all blocks are reachable by following the successors of the first basic block;

If other operations are present and their results are required by the LLVM IR dialect operations, the pass will fail. Any LLVM IR operations or types already present in the IR will be kept as is.

Output IR 

Functions converted to LLVM IR. Function arguments types are converted one-to-one. Function results are converted one-to-one and, in case more than 1 value is returned, packed into an LLVM IR struct type. Function calls and returns are updated accordingly. Block argument types are updated to use LLVM IR types.

Options 

-use-bare-ptr-memref-call-conv : Replace FuncOp's MemRef arguments with bare pointers to the MemRef element types
-index-bitwidth                : Bitwidth of the index type, 0 to use size of machine word
-data-layout                   : String description (LLVM format) of the data layout that is expected on the produced module
-use-opaque-pointers           : Generate LLVM IR using opaque pointers instead of typed pointers

-convert-func-to-spirv: Convert Func dialect to SPIR-V dialect 

Options 

-emulate-lt-32-bit-scalar-types : Emulate narrower scalar types with 32-bit ones if not supported by the target

-convert-gpu-launch-to-vulkan-launch: Convert gpu.launch_func to vulkanLaunch external call 

This pass is only intended for the mlir-vulkan-runner.

-convert-gpu-to-nvvm: Generate NVVM operations for gpu operations 

Options 

-index-bitwidth      : Bitwidth of the index type, 0 to use size of machine word
-has-redux           : Target gpu supports redux
-use-opaque-pointers : Generate LLVM IR using opaque pointers instead of typed pointers

-convert-gpu-to-rocdl: Generate ROCDL operations for gpu operations 

Options 

-chipset                       : Chipset that these operations will run on
-index-bitwidth                : Bitwidth of the index type, 0 to use size of machine word
-use-bare-ptr-memref-call-conv : Replace memref arguments in GPU functions with bare pointers.All memrefs must have static shape
-runtime                       : Runtime code will be run on (default is Unknown, can also use HIP or OpenCl)
-use-opaque-pointers           : Generate LLVM IR using opaque pointers instead of typed pointers

-convert-gpu-to-spirv: Convert GPU dialect to SPIR-V dialect 

This pass converts supported GPU device ops to SPIR-V ops. It does not handle GPU host ops.

A gpu.func op can have parameters to pass in resources. But in SPIR-V entry functions cannot take parameters; they use descriptors to access resources. By default, parameters to a gpu.func op will be converted to global variables. These global variables will be assigned sequential binding numbers following their order in the original gpu.func op, starting from 0, in set 0. One can attach spirv.interface_var_abi to those parameters to control the set and binding if wanted.

Options 

-use-64bit-index : Use 64-bit integers to convert index types

-convert-index-to-llvm: Lower the index dialect to the llvm dialect. 

This pass lowers Index dialect operations to LLVM dialect operations. Operation conversions are 1-to-1 except for the exotic divides: ceildivs, ceildivu, and floordivs, which expand to series of LLVM operations. Importantly, the index bitwidth should be correctly set to the target pointer width via index-bitwidth.

Options 

-index-bitwidth : Bitwidth of the index type, 0 to use size of machine word

-convert-linalg-to-llvm: Convert the operations from the linalg dialect into the LLVM dialect 

Options 

-use-opaque-pointers : Generate LLVM IR using opaque pointers instead of typed pointers

-convert-linalg-to-std: Convert the operations from the linalg dialect into the Standard dialect 

-convert-math-to-funcs: Convert Math operations to calls of outlined implementations. 

This pass converts supported Math ops to calls of compiler generated functions implementing these operations in software. The LLVM dialect is used for LinkonceODR linkage of the generated functions.

Options 

-min-width-of-fpowi-exponent : Convert FPowI only if the width of its exponent's integer type is greater than or equal to this value
-convert-ctlz                : Convert math.ctlz to a software implementation. Enable for targets that do not natively support ctlz.

-convert-math-to-libm: Convert Math dialect to libm calls 

This pass converts supported Math ops to libm calls.

-convert-math-to-llvm: Convert Math dialect to LLVM dialect 

Options 

-approximate-log1p : Enable approximation of Log1p.

-convert-math-to-spirv: Convert Math dialect to SPIR-V dialect 

-convert-memref-to-spirv: Convert MemRef dialect to SPIR-V dialect 

Options 

-bool-num-bits : The number of bits to store a boolean value

-convert-nvgpu-to-nvvm: Convert NVGPU dialect to NVVM dialect 

This pass converts supported NVGPU ops to NVVM dialect intrinsics.

Options 

-use-opaque-pointers : Generate LLVM IR using opaque pointers instead of typed pointers

-convert-openacc-to-scf: Convert the OpenACC ops to OpenACC with SCF dialect 

-convert-openmp-to-llvm: Convert the OpenMP ops to OpenMP ops with LLVM dialect 

-convert-parallel-loops-to-gpu: Convert mapped scf.parallel ops to gpu launch operations 

-convert-pdl-to-pdl-interp: Convert PDL ops to PDL interpreter ops 

-convert-scf-to-cf: Convert SCF dialect to ControlFlow dialect, replacing structured control flow with a CFG 

-convert-scf-to-openmp: Convert SCF parallel loop to OpenMP parallel + workshare constructs. 

Options 

-use-opaque-pointers : Generate LLVM IR using opaque pointers instead of typed pointers

-convert-scf-to-spirv: Convert SCF dialect to SPIR-V dialect. 

Converts SCF ops into SPIR-V structured control flow ops. SPIR-V structured control flow ops do not support yielding values. So for SCF ops yielding values, SPIR-V variables are created for holding the values and load/store operations are emitted for updating them.

-convert-shape-constraints: Convert shape constraint operations to the standard dialect 

This pass eliminates shape constraints from the program, converting them to eager (side-effecting) error handling code.

This pass is separate from the regular convert-shape-to-standard, despite converting between the same dialects, because converting shape constraints can happen at a different part of the program than general shape computation lowering.

-convert-shape-to-std: Convert operations from the shape dialect into the standard dialect 

-convert-spirv-to-llvm: Convert SPIR-V dialect to LLVM dialect 

See https://mlir.llvm.org/docs/SPIRVToLLVMDialectConversion/ for more details.

Options 

-use-opaque-pointers : Generate LLVM IR using opaque pointers instead of typed pointers

-convert-tensor-to-linalg: Convert some Tensor dialect ops to Linalg dialect 

-convert-tensor-to-spirv: Convert Tensor dialect to SPIR-V dialect 

Options 

-emulate-lt-32-bit-scalar-types : Emulate narrower scalar types with 32-bit ones if not supported by the target

-convert-vector-to-gpu: Lower the operations from the vector dialect into the GPU dialect 

Options 

-use-nvgpu : convert to NvGPU ops instead of GPU dialect ops

-convert-vector-to-llvm: Lower the operations from the vector dialect into the LLVM dialect 

Convert operations from the vector dialect into the LLVM IR dialect operations. The lowering pass provides several options to control the kinds of optimizations that are allowed. It also provides options that enable the use of one or more architectural-specific dialects (AMX, X86Vector, ArmNeon, ArmSVE, etc.) in combination with the architectural-neutral vector dialect lowering.

Options 

-reassociate-fp-reductions  : Allows llvm to reassociate floating-point reductions for speed
-force-32bit-vector-indices : Allows compiler to assume vector indices fit in 32-bit if that yields faster code
-enable-amx                 : Enables the use of AMX dialect while lowering the vector dialect.
-enable-arm-neon            : Enables the use of ArmNeon dialect while lowering the vector dialect.
-enable-arm-sve             : Enables the use of ArmSVE dialect while lowering the vector dialect.
-enable-x86vector           : Enables the use of X86Vector dialect while lowering the vector dialect.
-use-opaque-pointers        : Generate LLVM IR using opaque pointers instead of typed pointers

-convert-vector-to-scf: Lower the operations from the vector dialect into the SCF dialect 

Options 

-full-unroll   : Perform full unrolling when converting vector transfers to SCF
-target-rank   : Target vector rank to which transfer ops should be lowered
-lower-tensors : Lower transfer ops that operate on tensors

-convert-vector-to-spirv: Convert Vector dialect to SPIR-V dialect 

-finalize-memref-to-llvm: Finalize MemRef dialect to LLVM dialect conversion 

Finalize the conversion of the operations from the MemRef dialect to the LLVM dialect. This conversion will not convert some complex MemRef operations. Make sure to run expand-strided-metadata beforehand for these.

Options 

-use-aligned-alloc     : Use aligned_alloc in place of malloc for heap allocations
-index-bitwidth        : Bitwidth of the index type, 0 to use size of machine word
-use-generic-functions : Use generic allocation and deallocation functions instead of the classic 'malloc', 'aligned_alloc' and 'free' functions
-use-opaque-pointers   : Generate LLVM IR using opaque pointers instead of typed pointers

-gpu-to-llvm: Convert GPU dialect to LLVM dialect with GPU runtime calls 

Creates a pass to convert a GPU operations into a sequence of GPU runtime calls.

This pass does not generate code to call GPU runtime APIs directly but instead uses a small wrapper library that exports a stable and conveniently typed ABI on top of GPU runtimes such as CUDA or ROCm (HIP).

Options 

-use-bare-pointers-for-kernels : Use bare pointers to pass memref arguments to kernels. The kernel must use the same setting for this option.
-gpu-binary-annotation         : Annotation attribute string for GPU binary
-use-opaque-pointers           : Generate LLVM IR using opaque pointers instead of typed pointers

-launch-func-to-vulkan: Convert vulkanLaunch external call to Vulkan runtime external calls 

This pass is only intended for the mlir-vulkan-runner.

Options 

-use-opaque-pointers : Generate LLVM IR using opaque pointers instead of typed pointers

-lower-affine: Lower Affine operations to a combination of Standard and SCF operations 

Convert operations from the affine dialect into operations from the SCF and standard dialects.

affine.for operations are converted to scf.for operations that are free of certain structural restrictions (on their bounds and step). affine.if is similarly converted to the scf.if operation. affine.apply operations are converted into sequences of primitive arithmetic operations from the standard dialect that have the same effect, using operands of the index type. Consequently, named maps and sets thare are no longer in use may be removed from the module.

For example, %r = affine.apply affine_map<(d0, d1)[s0] -> (d0 + 2*d1 + s0)>(%d0, %d1)[%s0] can be converted into:

%d0 = <...>
%d1 = <...>
%s0 = <...>
%0 = arith.constant 2 : index
%1 = arith.muli %0, %d1
%2 = arith.addi %d0, %1
%r = arith.addi %2, %s0

Input invariant 

  • no Tensor types;

These restrictions may be lifted in the future.

Output IR 

Functions with affine.for and affine.if operations eliminated. These functions may contain operations from the Standard dialect in addition to those already present before the pass.

Invariants 

  • Functions without a body are not modified.
  • The semantics of the other functions is preserved.
  • Individual operations other than those mentioned above are not modified if they do not depend on the loop iterator value or on the result of affine.apply.

-lower-host-to-llvm: Lowers the host module code and gpu.launch_func to LLVM 

Creates a pass to emulate gpu.launch_func call in LLVM dialect and lower the host module code to LLVM.

This transformation creates a sequence of global variables that are later linked to the variables in the kernel module, and a series of copies to/from them to emulate the memory transfer from the host or to the device sides. It also converts the remaining Arithmetic, Func, and MemRef dialects into LLVM dialect, emitting C wrappers.

Options 

-use-opaque-pointers : Generate LLVM IR using opaque pointers instead of typed pointers

-map-memref-spirv-storage-class: Map numeric MemRef memory spaces to SPIR-V storage classes 

Options 

-client-api : The client API to use for populating mappings

-reconcile-unrealized-casts: Simplify and eliminate unrealized conversion casts 

Eliminate unrealized_conversion_cast operations, commonly introduced by partial dialect conversions, that transitively convert a value to another value of the same type, that is:

%0 = "producer.op"() : () -> !type.A
%1 = unrealized_conversion_cast %0 : !type.A to !type.B
%2 = unrealized_conversion_cast %1 : !type.B to !type.C
%3 = unrealized_conversion_cast %2 : !type.C to !type.A
"consumer.op"(%3) : (!type.A) -> ()

Such situations appear when the consumer operation is converted by one pass and the producer operation is converted by another pass, each of which produces an unrealized cast. This pass can be used to clean up the IR.

-tosa-to-arith: Lower TOSA to the Arith dialect 

Pass that converts TOSA operations to the equivalent operations using the operations in the Arith dialect. The ApplyScale operator is optionally included as it is often preserved until the final invocation.

Options 

-include-apply-rescale : Whether to include the lowering for tosa.apply_rescale to arith
-use-32-bit            : Whether to prioritze lowering to 32-bit operations

-tosa-to-linalg: Lower TOSA to LinAlg on tensors 

Pass that converts TOSA operations to the equivalent operations using the tensor operations in LinAlg.

-tosa-to-linalg-named: Lower TOSA to LinAlg named operations 

Pass that converts TOSA operations to the equivalent operations using the Linalg named operations.

-tosa-to-scf: Lower TOSA to the SCF dialect 

Pass that converts TOSA’s control flow operations to the equivalent SCF operations.

-tosa-to-tensor: Lower TOSA to the Tensor dialect 

Pass that converts TOSA operations to the equivalent operations using the operations in the Tensor dialect.

‘affine’ Dialect Passes 

-affine-data-copy-generate: Generate explicit copying for affine memory operations 

Options 

-fast-mem-capacity          : Set fast memory space capacity in KiB (default: unlimited)
-fast-mem-space             : Fast memory space identifier for copy generation (default: 1)
-generate-dma               : Generate DMA instead of point-wise copy
-min-dma-transfer           : Minimum DMA transfer size supported by the target in bytes
-slow-mem-space             : Slow memory space identifier for copy generation (default: 0)
-skip-non-unit-stride-loops : Testing purposes: avoid non-unit stride loop choice depths for copy placement
-tag-mem-space              : Tag memory space identifier for copy generation (default: 0)

-affine-expand-index-ops: Lower affine operations operating on indices into more fundamental operations 

-affine-loop-coalescing: Coalesce nested loops with independent bounds into a single loop 

-affine-loop-fusion: Fuse affine loop nests 

This pass performs fusion of loop nests using a slicing-based approach. The transformation works on an MLIR Block granularity and applies to all blocks of the pass is run on. It combines two fusion strategies: producer-consumer fusion and sibling fusion. Producer-consumer fusion is aimed at fusing pairs of loops where the first one writes to a memref that the second reads. Sibling fusion targets pairs of loops that share no dependences between them but that load from the same memref. The fused loop nests, when possible, are rewritten to access significantly smaller local buffers instead of the original memref’s, and the latter are often either completely optimized away or contracted. This transformation leads to enhanced locality and lower memory footprint through the elimination or contraction of temporaries/intermediate memref’s. These benefits are sometimes achieved at the expense of redundant computation through a cost model that evaluates available choices such as the depth at which a source slice should be materialized in the designation slice.

Example 1: Producer-consumer fusion. Input:

func.func @producer_consumer_fusion(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
  %0 = memref.alloc() : memref<10xf32>
  %1 = memref.alloc() : memref<10xf32>
  %cst = arith.constant 0.000000e+00 : f32
  affine.for %arg2 = 0 to 10 {
    affine.store %cst, %0[%arg2] : memref<10xf32>
    affine.store %cst, %1[%arg2] : memref<10xf32>
  }
  affine.for %arg2 = 0 to 10 {
    %2 = affine.load %0[%arg2] : memref<10xf32>
    %3 = arith.addf %2, %2 : f32
    affine.store %3, %arg0[%arg2] : memref<10xf32>
  }
  affine.for %arg2 = 0 to 10 {
    %2 = affine.load %1[%arg2] : memref<10xf32>
    %3 = arith.mulf %2, %2 : f32
    affine.store %3, %arg1[%arg2] : memref<10xf32>
  }
  return
}

Output:

func.func @producer_consumer_fusion(%arg0: memref<10xf32>, %arg1: memref<10xf32>) {
  %0 = memref.alloc() : memref<1xf32>
  %1 = memref.alloc() : memref<1xf32>
  %cst = arith.constant 0.000000e+00 : f32
  affine.for %arg2 = 0 to 10 {
    affine.store %cst, %0[0] : memref<1xf32>
    affine.store %cst, %1[0] : memref<1xf32>
    %2 = affine.load %1[0] : memref<1xf32>
    %3 = arith.mulf %2, %2 : f32
    affine.store %3, %arg1[%arg2] : memref<10xf32>
    %4 = affine.load %0[0] : memref<1xf32>
    %5 = arith.addf %4, %4 : f32
    affine.store %5, %arg0[%arg2] : memref<10xf32>
  }
  return
}

Example 2: Sibling fusion. Input:

func.func @sibling_fusion(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>,
                     %arg2: memref<10x10xf32>, %arg3: memref<10x10xf32>,
                     %arg4: memref<10x10xf32>) {
  affine.for %arg5 = 0 to 3 {
    affine.for %arg6 = 0 to 3 {
      %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
      %1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32>
      %2 = arith.mulf %0, %1 : f32
      affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32>
    }
  }
  affine.for %arg5 = 0 to 3 {
    affine.for %arg6 = 0 to 3 {
      %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
      %1 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32>
      %2 = arith.addf %0, %1 : f32
      affine.store %2, %arg4[%arg5, %arg6] : memref<10x10xf32>
    }
  }
  return
}

Output:

func.func @sibling_fusion(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>,
                     %arg2: memref<10x10xf32>, %arg3: memref<10x10xf32>,
                     %arg4: memref<10x10xf32>) {
  affine.for %arg5 = 0 to 3 {
    affine.for %arg6 = 0 to 3 {
      %0 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
      %1 = affine.load %arg1[%arg5, %arg6] : memref<10x10xf32>
      %2 = arith.mulf %0, %1 : f32
      affine.store %2, %arg3[%arg5, %arg6] : memref<10x10xf32>
      %3 = affine.load %arg0[%arg5, %arg6] : memref<10x10xf32>
      %4 = affine.load %arg2[%arg5, %arg6] : memref<10x10xf32>
      %5 = arith.addf %3, %4 : f32
      affine.store %5, %arg4[%arg5, %arg6] : memref<10x10xf32>
    }
  }
  return
}

Options 

-fusion-compute-tolerance   : Fractional increase in additional computation tolerated while fusing
-fusion-fast-mem-space      : Faster memory space number to promote fusion buffers to
-fusion-local-buf-threshold : Threshold size (KiB) for promoting local buffers to fast memory space
-fusion-maximal             : Enables maximal loop fusion
-mode                       : fusion mode to attempt

-affine-loop-invariant-code-motion: Hoist loop invariant instructions outside of affine loops 

-affine-loop-normalize: Apply normalization transformations to affine loop-like ops 

Options 

-promote-single-iter : Promote single iteration loops

-affine-loop-tile: Tile affine loop nests 

Options 

-cache-size : Set size of cache to tile for in KiB (default: 512)
-separate   : Separate full and partial tiles (default: false)
-tile-size  : Use this tile size for all loops
-tile-sizes : List of tile sizes for each perfect nest (overridden by -tile-size)

-affine-loop-unroll: Unroll affine loops 

Options 

-unroll-factor         : Use this unroll factor for all loops being unrolled
-unroll-up-to-factor   : Allow unrolling up to the factor specified
-unroll-full           : Fully unroll loops
-unroll-num-reps       : Unroll innermost loops repeatedly this many times
-unroll-full-threshold : Unroll all loops with trip count less than or equal to this
-cleanup-unroll        : Fully unroll the cleanup loop when possible.

-affine-loop-unroll-jam: Unroll and jam affine loops 

Options 

-unroll-jam-factor : Use this unroll jam factor for all loops (default 4)

-affine-parallelize: Convert affine.for ops into 1-D affine.parallel 

Options 

-max-nested          : Maximum number of nested parallel loops to produce. Defaults to unlimited (UINT_MAX).
-parallel-reductions : Whether to parallelize reduction loops. Defaults to false.

-affine-pipeline-data-transfer: Pipeline non-blocking data transfers between explicitly managed levels of the memory hierarchy 

This pass performs a transformation to overlap non-blocking DMA operations in a loop with computations through double buffering. This is achieved by advancing dma_start operations with respect to other operations.

Input

func.func @pipelinedatatransfer() {
  %0 = memref.alloc() : memref<256xf32>
  %1 = memref.alloc() : memref<32xf32, 1>
  %2 = memref.alloc() : memref<1xf32>
  %c0 = arith.constant 0 : index
  %c128 = arith.constant 128 : index
  affine.for %i0 = 0 to 8 {
    affine.dma_start %0[%i0], %1[%i0], %2[%c0], %c128 : memref<256xf32>, memref<32xf32, 1>, memref<1xf32>
    affine.dma_wait %2[%c0], %c128 : memref<1xf32>
    %3 = affine.load %1[%i0] : memref<32xf32, 1>
    %4 = "compute"(%3) : (f32) -> f32
    affine.store %4, %1[%i0] : memref<32xf32, 1>
  }
  return
}

Output

module {
  func.func @pipelinedatatransfer() {
    %c8 = arith.constant 8 : index
    %c0 = arith.constant 0 : index
    %0 = memref.alloc() : memref<256xf32>
    %c0_0 = arith.constant 0 : index
    %c128 = arith.constant 128 : index
    %1 = memref.alloc() : memref<2x32xf32, 1>
    %2 = memref.alloc() : memref<2x1xf32>
    affine.dma_start %0[%c0], %1[%c0 mod 2, %c0], %2[%c0 mod 2, symbol(%c0_0)], %c128 : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
    affine.for %arg0 = 1 to 8 {
      affine.dma_start %0[%arg0], %1[%arg0 mod 2, %arg0], %2[%arg0 mod 2, symbol(%c0_0)], %c128 : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
      %8 = affine.apply #map3(%arg0)
      %9 = affine.apply #map4(%8)
      %10 = affine.apply #map4(%8)
      affine.dma_wait %2[%8 mod 2, symbol(%c0_0)], %c128 : memref<2x1xf32>
      %11 = affine.load %1[%8 mod 2, %8] : memref<2x32xf32, 1>
      %12 = "compute"(%11) : (f32) -> f32
      affine.store %12, %1[%8 mod 2, %8] : memref<2x32xf32, 1>
    }
    %3 = affine.apply #map3(%c8)
    %4 = affine.apply #map4(%3)
    %5 = affine.apply #map4(%3)
    affine.dma_wait %2[%3 mod 2, symbol(%c0_0)], %c128 : memref<2x1xf32>
    %6 = affine.load %1[%3 mod 2, %3] : memref<2x32xf32, 1>
    %7 = "compute"(%6) : (f32) -> f32
    affine.store %7, %1[%3 mod 2, %3] : memref<2x32xf32, 1>
    memref.dealloc %2 : memref<2x1xf32>
    memref.dealloc %1 : memref<2x32xf32, 1>
    return
  }
}

-affine-scalrep: Replace affine memref accesses by scalars by forwarding stores to loads and eliminating redundant loads 

This pass performs store to load forwarding and redundant load elimination for affine memref accesses and potentially eliminates the entire memref if all its accesses are forwarded.

Input

func.func @store_load_affine_apply() -> memref<10x10xf32> {
  %cf7 = arith.constant 7.0 : f32
  %m = memref.alloc() : memref<10x10xf32>
  affine.for %i0 = 0 to 10 {
    affine.for %i1 = 0 to 10 {
      affine.store %cf7, %m[%i0, %i1] : memref<10x10xf32>
      %v0 = affine.load %m[%i0, %i1] : memref<10x10xf32>
      %v1 = arith.addf %v0, %v0 : f32
    }
  }
  return %m : memref<10x10xf32>
}

Output

module {
  func.func @store_load_affine_apply() -> memref<10x10xf32> {
    %cst = arith.constant 7.000000e+00 : f32
    %0 = memref.alloc() : memref<10x10xf32>
    affine.for %arg0 = 0 to 10 {
      affine.for %arg1 = 0 to 10 {
        affine.store %cst, %0[%arg0, %arg1] : memref<10x10xf32>
        %1 = arith.addf %cst, %cst : f32
      }
    }
    return %0 : memref<10x10xf32>
  }
}

-affine-simplify-structures: Simplify affine expressions in maps/sets and normalize memrefs 

-affine-super-vectorize: Vectorize to a target independent n-D vector abstraction 

Options 

-virtual-vector-size  : Specify an n-D virtual vector size for vectorization
-test-fastest-varying : Specify a 1-D, 2-D or 3-D pattern of fastest varying memory dimensions to match. See defaultPatterns in Vectorize.cpp for a description and examples. This is used for testing purposes
-vectorize-reductions : Vectorize known reductions expressed via iter_args. Switched off by default.

‘amdgpu’ Dialect Passes 

-amdgpu-emulate-atomics: Emulate atomic operations on chipsets that do not support them 

This pass rewrites any AMDGPU-specific atomic operation that is not supported on the given chipset into a compare-and-swap loop.

Options 

-chipset : Chipset that these operations will run on

‘arith’ Dialect Passes 

-arith-bufferize: Bufferize Arith dialect ops. 

This pass bufferizes arith dialect ops.

This pass needs to be a module pass because it inserts memref.global ops into the module, which cannot be done safely from a function pass due to multi-threading. Most other bufferization passes can run in parallel at function granularity.

Options 

-alignment : Create global memrefs with a specified alignment

-arith-emulate-wide-int: Emulate 2*N-bit integer operations using N-bit operations 

Emulate arith integer operations that use too wide integer types with equivalent operations on supported narrow integer types. This is done by splitting original integer values into two halves.

This pass is intended preserve semantics but not necessarily provide the most efficient implementation. TODO: Optimize op emulation.

Currently, only power-of-two integer bitwidths are supported.

Options 

-widest-int-supported : Widest integer type supported by the target

-arith-expand: Legalize Arith ops to be convertible to LLVM. 

Options 

-include-bf16 : Enable the BF16 expansion patterns

-arith-int-narrowing: Reduce integer operation bitwidth 

Reduce bitwidths of integer types used in arith operations. This pass prefers the narrowest available integer bitwidths that are guaranteed to produce the same results.

Options 

-int-bitwidths-supported : Integer bitwidths supported

-arith-unsigned-when-equivalent: Replace signed ops with unsigned ones where they are proven equivalent 

Replace signed ops with their unsigned equivalents when integer range analysis determines that their arguments and results are all guaranteed to be non-negative when interpreted as signed integers. When this occurs, we know that the semantics of the signed and unsigned operations are the same, since they share the same behavior when their operands and results are in the range [0, signed_max(type)].

The affect ops include division, remainder, shifts, min, max, and integer comparisons.

-int-range-optimizations: Do optimizations based on integer range analysis 

This pass runs integer range analysis and apllies optimizations based on its results. e.g. replace arith.cmpi with const if it can be inferred from args ranges.

‘arm_sme’ Dialect Passes 

-enable-arm-streaming: Enable Armv9 Streaming SVE mode 

Enables the Armv9 Streaming SVE mode [1] for func.func ops by annotating them with attributes. See options for more details.

[1] https://developer.arm.com/documentation/ddi0616/aa

Options 

-mode : Select how streaming-mode is managed at the function-level.

‘async’ Dialect Passes 

-async-func-to-async-runtime: Lower async.func operations to the explicit async.runtime andasync.coro operations 

-async-parallel-for: Convert scf.parallel operations to multiple async compute ops executed concurrently for non-overlapping iteration ranges 

Options 

-async-dispatch : Dispatch async compute tasks using recursive work splitting. If `false` async compute tasks will be launched using simple for loop in the caller thread.
-num-workers    : The number of available workers to execute async operations. If `-1` the value will be retrieved from the runtime.
-min-task-size  : The minimum task size for sharding parallel operation.

-async-runtime-policy-based-ref-counting: Policy based reference counting for Async runtime operations 

This pass works at the async runtime abtraction level, after all async.execute and async.await operations are lowered to the async runtime API calls, and async coroutine operations.

This pass doesn’t rely on reference counted values liveness analysis, and instead uses simple policy to create reference counting operations. If the program violates any of the assumptions, then this pass might lead to memory leaks or runtime errors.

The default reference counting policy assumptions:

  1. Async token can be awaited or added to the group only once.
  2. Async value or group can be awaited only once.

Under these assumptions reference counting only needs to drop reference:

  1. After async.runtime.await operation for async tokens and groups (until error handling is not implemented for the sync await).
  2. After async.runtime.is_error operation for async tokens and groups (this is the last operation in the coroutine resume function).
  3. After async.runtime.load operation for async values.

This pass introduces significanly less runtime overhead compared to the automatic reference counting.

-async-runtime-ref-counting: Automatic reference counting for Async runtime operations 

This pass works at the async runtime abtraction level, after all async.execute and async.await operations are lowered to the async runtime API calls, and async coroutine operations.

It relies on the LLVM coroutines switched-resume lowering semantics for the correct placing of the reference counting operations.

See: https://llvm.org/docs/Coroutines.html#switched-resume-lowering

-async-runtime-ref-counting-opt: Optimize automatic reference counting operations for theAsync runtime by removing redundant operations 

-async-to-async-runtime: Lower all high level async operations (e.g. async.execute) tothe explicit async.runtime and async.coro operations 

‘func’ Dialect Passes 

-duplicate-function-elimination: Deduplicate functions 

Deduplicate functions that are equivalent in all aspects but their symbol name. The pass chooses one representative per equivalence class, erases the remainder, and updates function calls accordingly.

-func-bufferize: Bufferize func/call/return ops 

A bufferize pass that bufferizes func.func and func.call ops.

Because this pass updates func.func ops, it must be a module pass. It is useful to keep this pass separate from other bufferizations so that the other ones can be run at function-level in parallel.

This pass must be done atomically because it changes func op signatures, which requires atomically updating calls as well throughout the entire module.

This pass also changes the type of block arguments, which requires that all successor arguments of predecessors be converted. This is achieved by rewriting terminators based on the information provided by the BranchOpInterface. As this pass rewrites function operations, it also rewrites the corresponding return operations. Other return-like operations that implement the ReturnLike trait are not rewritten in general, as they require that the corresponding parent operation is also rewritten. Finally, this pass fails for unknown terminators, as we cannot decide whether they need rewriting.

‘gpu’ Dialect Passes 

-gpu-async-region: Make GPU ops async 

-gpu-kernel-outlining: Outline gpu.launch bodies to kernel functions 

-gpu-launch-sink-index-computations: Sink index computations into gpu.launch body 

-gpu-map-parallel-loops: Greedily maps loops to GPU hardware dimensions. 

Greedily maps loops to GPU hardware dimensions.

‘linalg’ Dialect Passes 

-convert-elementwise-to-linalg: Convert ElementwiseMappable ops to linalg 

Convert ops with the ElementwiseMappable trait to linalg parallel loops.

This pass only converts ops that operate on ranked tensors. It can be run on op which contains linalg ops (most commonly a FunctionOpInterface op).

-convert-linalg-to-affine-loops: Lower the operations from the linalg dialect into affine loops 

-convert-linalg-to-loops: Lower the operations from the linalg dialect into loops 

-convert-linalg-to-parallel-loops: Lower the operations from the linalg dialect into parallel loops 

-linalg-bufferize: Bufferize the linalg dialect 

-linalg-detensorize: Detensorize linalg ops 

Detensoring is the process through which a tensor value is converted to one or potentially more primitive value(s). During this process, operations with such detensored operands are also converted to an equivalent form that works on primitives.

The detensoring process is driven by linalg-on-tensor ops. In particular, a linalg-on-tensor op is checked to see whether all its operands can be detensored. If so, those operands are converted to their primitive counterparts and the linalg op is replaced by an equivalent op that takes those new primitive values as operands. Therefore, detensoring an op can be divided into 2 main logical phases:

  1. Detect/match an op that can be detensored.
  2. Detensor the operands of the op and replace it with a primitive equivalent.

In addition to detensoring individual ops, this pass detensors internal control flow inside a function. All blocks except for the entry block are detensored by converting their arguments whenever possible.

This can be run on any FunctionOpInterface op and must not be run on others. This is because it performs specific legalization of the blocks that make up the body, which it assumes has is a FunctionOpInterface.

Options 

-aggressive-mode : Detensorize all ops that qualify for detensoring along with branch operands and basic-block arguments.

-linalg-fold-unit-extent-dims: Remove unit-extent dimension in Linalg ops on tensors 

Options 

-fold-one-trip-loops-only : Only folds the one-trip loops from Linalg ops on tensors (for testing purposes only)
-use-rank-reducing-slices : Generate rank-reducing slices instead of reassociative reshapes

-linalg-fuse-elementwise-ops: Fuse elementwise operations on tensors 

-linalg-generalize-named-ops: Convert named ops into generic ops 

-linalg-inline-scalar-operands: Inline scalar operands into linalg generic ops 

-linalg-named-op-conversion: Convert from one named linalg op to another. 

‘llvm’ Dialect Passes 

-ensure-debug-info-scope-on-llvm-func: Materialize LLVM debug info subprogram attribute on every LLVMFuncOp 

Having a debug info subprogram attribute on a function is required for emitting line tables from MLIR FileLocCol locations.

This is not intended to be a proper replacement for frontends to emit complete debug informations, however it is a convenient way to get line tables for debugging purposes. This allow to step trough in a debugger line-by-line or get a backtrace with line numbers.

-llvm-legalize-for-export: Legalize LLVM dialect to be convertible to LLVM IR 

-llvm-optimize-for-nvvm-target: Optimize NVVM IR 

-llvm-request-c-wrappers: Request C wrapper emission for all functions 

Annotate every builtin function in the module with the LLVM dialect attribute that instructs the conversion to LLVM to emit the C wrapper for the function. This pass is expected to be applied immediately before the conversion of builtin functions to LLVM to avoid the attribute being dropped by other passes.

‘memref’ Dialect Passes 

-expand-strided-metadata: Expand memref operations into easier to analyze constructs 

The pass expands memref operations that modify the metadata of a memref (sizes, offset, strides) into a sequence of easier to analyze constructs. In particular, this pass transforms operations into explicit sequence of operations that model the effect of this operation on the different metadata. This pass uses affine constructs to materialize these effects.

-fold-memref-alias-ops: Fold memref alias ops into consumer load/store ops 

The pass folds loading/storing from/to memref aliasing ops to loading/storing from/to the original memref.

-memref-emulate-wide-int: Emulate 2*N-bit integer operations using N-bit operations 

Emulate memref integer operations that use too wide integer types with equivalent operations on supported narrow integer types. This is done by splitting original integer values into two halves.

Currently, only power-of-two integer bitwidths are supported.

Options 

-widest-int-supported : Widest integer type supported by the target

-memref-expand: Legalize memref operations to be convertible to LLVM. 

-normalize-memrefs: Normalize memrefs 

This pass transforms memref types with a non-trivial layout map into memref types with an identity layout map, e.g. (i, j) -> (i, j). This pass is inter-procedural, in the sense that it can modify function interfaces and call sites that pass memref types. In order to modify memref types while preserving the original behavior, users of those memref types are also modified to incorporate the resulting layout map. For instance, an AffineLoadOp will be updated to compose the layout map with with the affine expression contained in the op. Operations marked with the MemRefsNormalizable trait are expected to be normalizable. Supported operations include affine operations, memref.alloc, memref.dealloc, and func.return.

Given an appropriate layout map specified in the code, this transformation can express tiled or linearized access to multi-dimensional data structures, but will not modify memref types without an explicit layout map.

Currently this pass is limited to only modify functions where all memref types can be normalized. If a function contains any operations that are not MemRefNormalizable, then the function and any functions that call or call it will not be modified.

Input

#tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
func.func @matmul(%A: memref<16xf64, #tile>,
             %B: index, %C: memref<16xf64>) -> (memref<16xf64, #tile>) {
  affine.for %arg3 = 0 to 16 {
        %a = affine.load %A[%arg3] : memref<16xf64, #tile>
        %p = arith.mulf %a, %a : f64
        affine.store %p, %A[%arg3] : memref<16xf64, #tile>
  }
  %c = memref.alloc() : memref<16xf64, #tile>
  %d = affine.load %c[0] : memref<16xf64, #tile>
  return %A: memref<16xf64, #tile>
}

Output

func.func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
  -> memref<4x4xf64> {
  affine.for %arg3 = 0 to 16 {
    %3 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
    %4 = arith.mulf %3, %3 : f64
    affine.store %4, %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
  }
  %0 = memref.alloc() : memref<4x4xf64>
  %1 = affine.apply #map1()
  %2 = affine.load %0[0, 0] : memref<4x4xf64>
  return %arg0 : memref<4x4xf64>
}

Input

#linear8 = affine_map<(i, j) -> (i * 8 + j)>
func.func @linearize(%arg0: memref<8x8xi32, #linear8>,
                %arg1: memref<8x8xi32, #linear8>,
                %arg2: memref<8x8xi32, #linear8>) {
  %c8 = arith.constant 8 : index
  %c0 = arith.constant 0 : index
  %c1 = arith.constant 1 : index
  affine.for %arg3 = %c0 to %c8  {
  affine.for %arg4 = %c0 to %c8  {
    affine.for %arg5 = %c0 to %c8 {
      %0 = affine.load %arg0[%arg3, %arg5] : memref<8x8xi32, #linear8>
      %1 = affine.load %arg1[%arg5, %arg4] : memref<8x8xi32, #linear8>
      %2 = affine.load %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
      %3 = arith.muli %0, %1 : i32
      %4 = arith.addi %2, %3 : i32
      affine.store %4, %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
    }
  }
  }
  return
}

Output

func.func @linearize(%arg0: memref<64xi32>,
                %arg1: memref<64xi32>,
                %arg2: memref<64xi32>) {
%c8 = arith.constant 8 : index
%c0 = arith.constant 0 : index
affine.for %arg3 = %c0 to %c8 {
  affine.for %arg4 = %c0 to %c8 {
    affine.for %arg5 = %c0 to %c8 {
      %0 = affine.load %arg0[%arg3 * 8 + %arg5] : memref<64xi32>
      %1 = affine.load %arg1[%arg5 * 8 + %arg4] : memref<64xi32>
      %2 = affine.load %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
      %3 = arith.muli %0, %1 : i32
      %4 = arith.addi %2, %3 : i32
      affine.store %4, %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
    }
  }
}
return
}

-resolve-ranked-shaped-type-result-dims: Resolve memref.dim of result values of ranked shape type 

The pass resolves memref.dim of result of operations that implement the ReifyRankedShapedTypeOpInterface in terms of shapes of its operands.

-resolve-shaped-type-result-dims: Resolve memref.dim of result values 

The pass resolves memref.dim of result of operations that implement the InferShapedTypeOpInterface or ReifyRankedShapedTypeOpInterface in terms of shapes of its operands.

‘nvgpu’ Dialect Passes 

-nvgpu-optimize-shared-memory: Optimizes accesses to shard memory memrefs in order to reduce bank conflicts. 

‘quant’ Dialect Passes 

Reducer Passes 

-opt-reduction-pass: A wrapper pass that reduces the file with optimization passes 

Options 

-opt-pass : The optimization passes used for reduction, e.g., symbol-dce
-test     : The location of the tester which tests the file interestingness
-test-arg : arguments of the tester

-reduction-tree: Reduce the input with reduction-tree algorithm 

Options 

-traversal-mode : The graph traversal mode, the default is single-path mode
-test           : The location of the tester which tests the file interestingness
-test-arg       : arguments of the tester

‘scf’ Dialect Passes 

-scf-bufferize: Bufferize the scf dialect. 

-scf-for-loop-canonicalization: Canonicalize operations within scf.for loop bodies 

-scf-for-loop-peeling: Peel for loops at their upper bounds. 

Options 

-skip-partial : Do not peel loops inside of the last, partial iteration of another already peeled loop.

-scf-for-loop-range-folding: Fold add/mul ops into loop range 

-scf-for-loop-specialization: Specialize for loops for vectorization 

-scf-for-to-while: Convert SCF for loops to SCF while loops 

This pass transforms SCF.ForOp operations to SCF.WhileOp. The For loop condition is placed in the ‘before’ region of the while operation, and the induction variable incrementation and loop body in the ‘after’ region. The loop carried values of the while op are the induction variable (IV) of the for-loop + any iter_args specified for the for-loop. Any ‘yield’ ops in the for-loop are rewritten to additionally yield the (incremented) induction variable.

  scf.for %i = %c0 to %arg1 step %c1 {
    %0 = arith.addi %arg2, %arg2 : i32
    memref.store %0, %arg0[%i] : memref<?xi32>
  }

# After:
  %0 = scf.while (%i = %c0) : (index) -> index {
    %1 = arith.cmpi slt, %i, %arg1 : index
    scf.condition(%1) %i : index
  } do {
  ^bb0(%i: index):
    %1 = arith.addi %i, %c1 : index
    %2 = arith.addi %arg2, %arg2 : i32
    memref.store %2, %arg0[%i] : memref<?xi32>
    scf.yield %1 : index
  }

-scf-parallel-loop-fusion: Fuse adjacent parallel loops 

-scf-parallel-loop-specialization: Specialize parallel loops for vectorization 

-scf-parallel-loop-tiling: Tile parallel loops 

Options 

-parallel-loop-tile-sizes : Factors to tile parallel loops by
-no-min-max-bounds        : Perform tiling with fixed upper bound with inbound check inside the internal loops

-test-scf-parallel-loop-collapsing: Test parallel loops collapsing transformation 

This pass is purely for testing the scf::collapseParallelLoops transformation. The transformation does not have opinions on how a parallel loop should be collapsed, so this pass is structured for the common case on GPUs of collapsing to a 3d parallel loop. 3 lists can be provided to collapsed-indices-{0,1,2} to represent how the loop should be collapsed and must reference evrey iterator in the original parallel loop.

scf.parallel (%arg0, %arg1)
             = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
  "test.sink"(%5, %3) : (index, index) -> ()
  scf.yield
}

# After:
scf.parallel (%arg0) = (%c0) to (%c4) step (%c1) {
  %0 = arith.remsi %arg0, %c2 : index
  %1 = arith.divsi %arg0, %c2 : index
  %2 = arith.muli %0, %c7 : index
  %3 = arith.addi %2, %c3 : index
  %4 = arith.muli %1, %c7 : index
  %5 = arith.addi %4, %c3 : index
  "test.sink"(%5, %3) : (index, index) -> ()
}

Options 

-collapsed-indices-0 : Which loop indices to combine 0th loop index
-collapsed-indices-1 : Which loop indices to combine into the position 1 loop index
-collapsed-indices-2 : Which loop indices to combine into the position 2 loop index

‘shape’ Dialect Passes 

-outline-shape-computation: Using shape.func to preserve shape computation 

This pass outlines the shape computation part in high level IR by adding shape.func and populate corresponding mapping infoemation into ShapeMappingAnalysis. The shape computation part is usually introduced by shape reification, and each single dynamic shape is denoted by shape.with_shape.

There’re two main reasons this shape-outline pass is needed:

  1. Many passes don’t take shape reification part into consideration. Therefore we need to “remove” the shape reification part temporarily for these passes.
  2. Sometimes we cannot redo shape reification after converting from dialect A to dialect B. Because op-level shape reification is only implemented on A.

Input:

func.func @main(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) ->
  tensor<?x4x?xf32> {
  %c2 = arith.constant 2 : index
  %c0 = arith.constant 0 : index
  %c4 = arith.constant 4 : index
  %0 = shape.shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
  %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index
  %2 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
  %3 = shape.with_shape %2, %0 : tensor<?x4x?xf32>, tensor<3xindex>
  %4 = shape.value_of %3 : tensor<?x4x?xf32>
  %5 = "test.concat"(%4, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>,
        tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
  %6 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index
  %7 = arith.addi %6, %c2 : index
  %8 = shape.from_extents %7, %c4, %1 : index, index, index
  %9 = shape.with_shape %5, %8 : tensor<?x4x?xf32>, !shape.shape
  %10 = shape.value_of %9 : tensor<?x4x?xf32>
  return %10 : tensor<?x4x?xf32>
}

Output

func.func @main(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) ->
  tensor<?x4x?xf32> {
  %0 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
  %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>,
        tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
  return %1 : tensor<?x4x?xf32>
}
shape.func private @shape_cal_1(%arg0: tensor<?x4x?xf32>) -> !shape.shape {
  %c2 = arith.constant 2 : index
  %c0 = arith.constant 0 : index
  %c4 = arith.constant 4 : index
  %0 = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
  %1 = get_extent %0, %c2 : tensor<3xindex>, index -> index
  %2 = get_extent %0, %c0 : tensor<3xindex>, index -> index
  %3 = arith.addi %2, %c2 : index
  %4 = from_extents %3, %c4, %1 : index, index, index
  return %4 : !shape.shape
}
shape.func private @shape_cal_0(%arg0: tensor<?x4x?xf32>) -> tensor<3xindex> {
  %0 = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
  return %0 : tensor<3xindex>
}

For the above example, the shape computation is inlined in the input IR, which is used for two values' (test.abs and test.concat) shape. And the shape compuatation part is outlined in the output IR.

And the shape mapping infomation will be:

// ---- Shape Mapping Infomation -----
// - Shape for: %0 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32> :: @shape_cal_0(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)
// - Shape for: %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32> :: @shape_cal_1(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)

-remove-shape-constraints: Replace all cstr_ ops with a true witness 

-shape-bufferize: Bufferize the shape dialect. 

-shape-to-shape-lowering: Legalize Shape dialect to be convertible to Arith 

‘sparse_tensor’ Dialect Passes 

-post-sparsification-rewrite: Applies sparse tensor rewriting rules after sparsification 

A pass that applies rewriting rules to sparse tensor operations after running the actual sparsification pass.

Options 

-enable-runtime-library : Enable runtime library for manipulating sparse tensors
-enable-foreach         : Enable rewriting rules for the foreach operator
-enable-convert         : Enable rewriting rules for the convert operator

-pre-sparsification-rewrite: Applies sparse tensor rewriting rules prior to sparsification 

A pass that applies rewriting rules to sparse tensor operations prior to running the actual sparsification pass.

-sparse-buffer-rewrite: Rewrite sparse primitives on buffers to actual code 

A pass that rewrites sparse primitives on buffers to the MLIR implementation of the primitives. For example, sparse_tensor.sort operator is implemented in this pass.

Options 

-enable-buffer-initialization : Enable zero-initialization of the memory buffers

-sparse-gpu-codegen: Generates GPU code during sparsification 

Enables sparse compiler to use GPU acceleration.

Options 

-num_threads : Sets the number of GPU threads

-sparse-storage-specifier-to-llvm: Lower sparse storage specifer to llvm structure 

This pass rewrites sparse tensor storage specifier-related operations into LLVMDialect, and converts sparse tensor storage specifier into an llvm.struct.

Example of the conversion:

Before:
  %0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 0
  : !sparse_tensor.storage_specifier<#CSR> to i64

After:
  %0 = llvm.extractvalue %arg0[0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>

-sparse-tensor-codegen: Convert sparse tensors and primitives to actual code 

A pass that converts sparse tensor types and primitives to actual compiler visible buffers and compiler IR that implements these primitives on the selected sparse tensor storage schemes.

This pass provides an alternative to the SparseTensorConversion pass, eliminating the dependence on a runtime support library, and providing much more opportunities for subsequent compiler optimization of the generated code.

Example of the conversion:

  Before:
    func.func @foo(%arg0: tensor<8x8xf32, #CSR>) -> memref<?xindex> {
      %0 = sparse_tensor.pointers %arg0 {dimension = 1 : index}
         : tensor<8x8xf32, #CSR> to memref<?xindex>
      return %0 : memref<?xindex>
    }

  After:
    func.func @foo(%arg0: memref<2xindex>,
                   %arg1: memref<3xindex>,
                   %arg2: memref<?xindex>,
                   %arg3: memref<?xindex>,
                   %arg4: memref<?xf32>) -> memref<?xindex> {
      return %arg2 : memref<?xindex>
    }

Options 

-enable-buffer-initialization : Enable zero-initialization of the memory buffers
-create-sparse-deallocs       : Specify if the temporary buffers created by the sparse compiler should be deallocated. For compatibility with core bufferization passes. This option is only used when enable-runtime-library=false. See also create-deallocs for BufferizationOption.

-sparse-tensor-conversion: Convert sparse tensors and primitives to library calls 

A pass that converts sparse tensor primitives into calls into a runtime support library. Sparse tensor types are converted into opaque pointers to the underlying sparse storage schemes.

The use of opaque pointers together with runtime support library keeps the conversion relatively simple, but at the expense of IR opacity, which obscures opportunities for subsequent optimization of the IR. An alternative is provided by the SparseTensorCodegen pass.

Example of the conversion:

  Before:
    func.func @foo(%arg0: tensor<8x8xf32, #CSR>) -> memref<?xindex> {
      %0 = sparse_tensor.pointers %arg0 {dimension = 1 : index}
         : tensor<8x8xf32, #CSR> to memref<?xindex>
      return %0 : memref<?xindex>
    }

  After:
    func.func @foo(%arg0: !llvm.ptr<i8>) -> memref<?xindex> {
      %c1 = arith.constant 1 : index
      %0 = call @sparsePointers0(%arg0, %c1)
         : (!llvm.ptr<i8>, index) -> memref<?xindex>
      return %0 : memref<?xindex>
    }

Options 

-s2s-strategy : Set the strategy for sparse-to-sparse conversion

-sparse-vectorization: Vectorizes loops after sparsification 

A pass that converts loops after sparsification into vector loops. The vector dialect is used as target to provide an architectural neutral way of exploiting any platform that supports SIMD instructions.

The vector length (viz. vl) describes the number of packed data elements (e.g. both vector<16xf32> and vector<16xf64> have a vector length of 16 even though the actual bitwidths differ). A small multiple of the actual lengths supported in hardware typically results in efficient SIMD code, since the backend will map longer vectors to multiple vector registers, thereby effectively unrolling an addition level within the generated for-loop.

Example of the conversion:

  Before:
    %3 = memref.load %2[] : memref<f32>
    %4 = scf.for %arg3 = %c0 to %c1024 step %c1 iter_args(%arg4 = %3) -> (f32) {
      %6 = memref.load %0[%arg3] : memref<?xf32>
      %7 = memref.load %1[%arg3] : memref<1024xf32>
      %8 = arith.mulf %6, %7 : f32
      %9 = arith.addf %arg4, %8 : f32
      scf.yield %9 : f32
    }
    memref.store %4, %2[] : memref<f32>

  After:
    %3 = memref.load %2[] : memref<f32>
    %4 = vector.insertelement %3, %cst[%c0 : index] : vector<32xf32>
    %5 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4) -> (vector<32xf32>) {
      %8 = vector.load %0[%arg3] : memref<?xf32>, vector<32xf32>
      %9 = vector.load %1[%arg3] : memref<1024xf32>, vector<32xf32>
      %10 = arith.mulf %8, %9 : vector<32xf32>
      %11 = arith.addf %arg4, %10 : vector<32xf32>
      scf.yield %11 : vector<32xf32>
    }
    %6 = vector.reduction <add>, %5 : vector<32xf32> into f32
    memref.store %6, %2[] : memref<f32>

Options 

-vl                       : Set the vector length (use 0 to disable vectorization)
-enable-vla-vectorization : Enable vector length agnostic vectorization
-enable-simd-index32      : Enable i32 indexing into vectors (for efficient gather/scatter)

-sparsification: Automatically generate sparse tensor code from sparse tensor types 

A pass that implements the core functionality of a sparse compiler. Each Linalg operation (MLIR’s tensor index notation) that operates on sparse tensor types is converted into code in which the sparsity is explicit both in terms of co-iterating looping logic as well as selected sparse storage schemes.

See the SparseTensor dialect documentation for more background.

Example input:

#matvec = {
  indexing_maps = [
    affine_map<(i,j) -> (i,j)>, // A
    affine_map<(i,j) -> (j)>,   // b
    affine_map<(i,j) -> (i)>    // x (out)
  ],
  iterator_types = ["parallel", "reduction"],
  doc = "X(i) += A(i,j) * B(j)"
}

// Multiply a sparse matrix A with a dense vector b into a dense vector x.
func.func @kernel_matvec(%arga: tensor<?x?xf64, #SparseMatrix>,
                         %argb: tensor<?xf64>,
                         %argx: tensor<?xf64>) -> tensor<?xf64> {
  %0 = linalg.generic #matvec
    ins(%arga, %argb: tensor<?x?xf64, #SparseMatrix>, tensor<?xf64>)
    outs(%argx: tensor<?xf64>) {
    ^bb(%a: f64, %b: f64, %x: f64):
      %0 = arith.mulf %a, %b : f64
      %1 = arith.addf %x, %0 : f64
      linalg.yield %1 : f64
  } -> tensor<?xf64>
  return %0 : tensor<?xf64>
}

Options 

-enable-index-reduction   : Enable dependent index reduction based algorithm to handle non-trivial index expressions on sparse inputs (experimental features)
-parallelization-strategy : Set the parallelization strategy
-enable-gpu-libgen        : Enable GPU acceleration by means of direct library calls (like cuSPARSE)
-enable-runtime-library   : Enable runtime library for manipulating sparse tensors

‘spv’ Dialect Passes 

-decorate-spirv-composite-type-layout: Decorate SPIR-V composite type with layout info 

Module pass that converts composite types used by objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and PushConstant storage classes to attatch layout information. Right now this pass only supports Vulkan layout rules.

-spirv-canonicalize-gl: Canonicalize GLSL ops 

Pass to run canoncalization patterns that involve GL ops. These patterns cannot be run in default canonicalization because GL ops aren’t always available. So they should be involed specifically when needed.

-spirv-lower-abi-attrs: Decorate SPIR-V composite type with layout info 

Operation pass that lowers the ABI attributes specified during SPIR-V Lowering. Specifically:

  1. Creates the global variables for arguments of entry point function using the specification in the spirv.interface_var_abi attribute for each argument.
  2. Inserts the EntryPointOp and the ExecutionModeOp for entry point functions using the specification in the spirv.entry_point_abi attribute.

-spirv-rewrite-inserts: Rewrite sequential chains of spirv.CompositeInsert operations into spirv.CompositeConstruct operations 

-spirv-unify-aliased-resource: Unify access of multiple aliased resources into access of one single resource 

-spirv-update-vce: Deduce and attach minimal (version, capabilities, extensions) requirements to spirv.module ops 

Operation pass that deduces and attaches the minimal version/ capabilities/extensions requirements for spirv.module ops. For each spirv.module op, this pass requires a spirv.target_env attribute on it or an enclosing module-like op to drive the deduction. The reason is that an op can be enabled by multiple extensions/capabilities. So we need to know which one to pick. spirv.target_env gives the hard limit as for what the target environment can support; this pass deduces what are actually needed for a specific spirv.module op.

-spirv-webgpu-prepare: Prepare SPIR-V to target WebGPU by expanding unsupported ops and replacing with supported ones 

‘tensor’ Dialect Passes 

-fold-tensor-subset-ops: Fold tensor subset ops into producer/consumer ops 

The pass folds tensor subset ops into producer/consumer ops.

At the moment, the following foldings occur when possible:

  • tensor.extract_slice into vector.transfer_read
  • vector.transfer_write into tensor.insert_slice

-tensor-bufferize: Bufferize the tensor dialect 

‘transform’ Dialect Passes 

-transform-dialect-check-uses: warn about potential use-after-free in the transform dialect 

This pass analyzes operations from the transform dialect and its extensions and warns if a transform IR value may be used by an operation after it was “freed” by some other operation, as described by side effects on the TransformMappingResource. This statically detects situations that lead to errors when interpreting the Transform IR.

The pass is capable of handling branching control flow and reports all potential use-after-free situations, e.g., a may-use-after-free is reported if at least one of the control flow paths between the definition of a value and its use contains an operation with a “free” effect on the TransformMappingResource. It does not currently perform an SCCP-style data flow analysis to prove that some branches are not taken, however, SCCP and other control flow simplifications can be performed on the transform IR prior to this pass provided that transform ops implement the relevant control flow interfaces.

-transform-infer-effects: infer transform side effects for symbols 

This pass analyzes the definitions of transform dialect callable symbol operations, such as transform.named_sequence, and annotates the symbol arguments with attributes indicating the side effects that the nested operations have on them.

‘vector’ Dialect Passes 

-lower-vector-mask: Lower ‘vector.mask’ operations 

-vector-bufferize: Bufferize Vector dialect ops 

TOSA Dialect Passes 

-tosa-infer-shapes: Propagate shapes across TOSA operations 

Pass that uses operand types and propagates shapes to TOSA operations. This includes legalizing rankless and dynamic shapes towards static.

-tosa-layerwise-constant-fold: Fold layerwise operations on constant tensors 

Pass that enables folding of full-layer operations on constant tensors.

-tosa-make-broadcastable: TOSA rank Reshape to enable Broadcasting 

Pass that enables broadcast by making all input arrays have the same number of dimensions. Insert RESHAPE operations to prepend dimensions of size one until the number of dimensions is equal. Implements approach similar to step 1 of Numpy 4-step broadcasting: https://numpy.org/doc/stable/reference/ufuncs.html#broadcasting

-tosa-optional-decompositions: Applies Tosa operations optional decompositions 

Pass to apply the Tosa operations decompositions exposed as populate functions in include/mlir/Dialect/Tosa/Transforms/Passes.h

-tosa-validate: Validates TOSA dialect 

This pass validates if input TOSA operations match the specification for given criteria, e.g. TOSA profile.

Options 

-profile                  : Validate if operations match for the given profile
-strict-op-spec-alignment : Verify if the properties of certain operations align the spec requirement