MLIR 23.0.0git
LowerVectorContract.cpp File Reference

Go to the source code of this file.

Classes

struct  ContractOpToElementwise
 Lower vector.contract with all size one reduction dimensions to elementwise ops when possible. More...
class  OuterProductOpLowering
 Progressive lowering of OuterProductOp. More...

Macros

#define DEBUG_TYPE   "vector-contract-lowering"

Functions

static std::optional< int64_tgetResultIndex (AffineMap map, int64_t index)
static SmallVector< AttributeadjustIter (ArrayAttr iteratorTypes, int64_t index)
static AffineMap adjustMap (AffineMap map, int64_t index, PatternRewriter &rewriter)
static Value reshapeLoad (Location loc, Value val, int64_t index, int64_t pos, PatternRewriter &rewriter)
 Returns val with the dimension at position index dropped by indexing that dimension with pos.
static Value reshapeStore (Location loc, Value val, Value result, int64_t index, int64_t pos, PatternRewriter &rewriter)
 Inserts val into result at position pos along dimension index.
static std::optional< ValuecreateContractArithOp (Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask=Value(), arith::FastMathFlagsAttr fmf={})
 Helper to create arithmetic operation associated with a kind of contraction.
static SmallVector< int64_tgetReductionIndex (AffineMap map, ArrayAttr iteratorTypes)
 Return the positions of the reductions in the given map.
static std::optional< unsignedgetDimPosition (AffineMap map, unsigned dim)
 Look for a given dimension in an affine map and return its position.
static Value createAdd (Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter, arith::FastMathFlagsAttr fmf={})
 Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.

Macro Definition Documentation

◆ DEBUG_TYPE

#define DEBUG_TYPE   "vector-contract-lowering"

Definition at line 26 of file LowerVectorContract.cpp.

Function Documentation

◆ adjustIter()

SmallVector< Attribute > adjustIter ( ArrayAttr iteratorTypes,
int64_t index )
static

Definition at line 45 of file LowerVectorContract.cpp.

References ArrayAttr().

◆ adjustMap()

◆ createAdd()

Value createAdd ( Location loc,
Value x,
Value y,
bool isInt,
PatternRewriter & rewriter,
arith::FastMathFlagsAttr fmf = {} )
static

Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.

Definition at line 204 of file LowerVectorContract.cpp.

◆ createContractArithOp()

std::optional< Value > createContractArithOp ( Location loc,
Value x,
Value y,
Value acc,
vector::CombiningKind kind,
PatternRewriter & rewriter,
bool isInt,
Value mask = Value(),
arith::FastMathFlagsAttr fmf = {} )
static

Helper to create arithmetic operation associated with a kind of contraction.

Definition at line 142 of file LowerVectorContract.cpp.

Referenced by OuterProductOpLowering::matchAndRewrite(), and ContractOpToElementwise::matchAndRewriteMaskableOp().

◆ getDimPosition()

std::optional< unsigned > getDimPosition ( AffineMap map,
unsigned dim )
static

Look for a given dimension in an affine map and return its position.

Return std::nullopt if the dimension is not in the map results.

Definition at line 194 of file LowerVectorContract.cpp.

References mlir::AffineMap::getDimPosition(), and mlir::AffineMap::getNumResults().

Referenced by ContractOpToElementwise::matchAndRewriteMaskableOp().

◆ getReductionIndex()

SmallVector< int64_t > getReductionIndex ( AffineMap map,
ArrayAttr iteratorTypes )
static

Return the positions of the reductions in the given map.

Definition at line 182 of file LowerVectorContract.cpp.

References ArrayAttr(), mlir::AffineMap::getDimPosition(), mlir::AffineMap::getNumResults(), and mlir::vector::isReductionIterator().

Referenced by ContractOpToElementwise::matchAndRewriteMaskableOp().

◆ getResultIndex()

std::optional< int64_t > getResultIndex ( AffineMap map,
int64_t index )
static

◆ reshapeLoad()

Value reshapeLoad ( Location loc,
Value val,
int64_t index,
int64_t pos,
PatternRewriter & rewriter )
static

Returns val with the dimension at position index dropped by indexing that dimension with pos.

If index == -1, returns val unchanged. If index == 0, the result is a single vector.extract val[pos].

Example (index == 0): extract the sub-vector at pos along the leading dimension. // val : vector<4x8xf32>, pos = 2 res = vector.extract val[2] : vector<8xf32> from vector<4x8xf32>

For index > 0, recursively applies the same drop to each sub-vector of the leading dimension and reassembles the result.

Definition at line 87 of file LowerVectorContract.cpp.

References mlir::VectorType::Builder::dropDim(), mlir::Value::getType(), mlir::Builder::getZeroAttr(), load, reshapeLoad(), and result.

Referenced by reshapeLoad().

◆ reshapeStore()

Value reshapeStore ( Location loc,
Value val,
Value result,
int64_t index,
int64_t pos,
PatternRewriter & rewriter )
static

Inserts val into result at position pos along dimension index.

This is the inverse of reshapeLoad. If index == -1, returns val. If index == 0, the result is a single vector.insert val, result [pos].

Example (index == 0): insert val at pos along the leading dimension. // val : vector<4xf32>, acc : vector<2x4xf32>, pos = 1 res = vector.insert val, acc [1] : vector<4xf32> into vector<2x4xf32>

For index > 0, recursively applies the same insertion to each sub-vector of the leading dimension and reassembles the result.

Definition at line 120 of file LowerVectorContract.cpp.

References reshapeStore(), and result.

Referenced by reshapeStore().