MLIR  19.0.0git
Classes | Namespaces | Macros | Enumerations | Functions
VectorOps.cpp File Reference
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include <cassert>
#include <cstdint>
#include <numeric>
#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc"
#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc"
#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"

Go to the source code of this file.

Classes

struct  mlir::vector::detail::BitmaskEnumStorage
 
struct  CanonicalizeContractAdd< AddOpType >
 Return a fused vector::ContractionOp which represents a patterns such as: More...
 

Namespaces

 mlir
 Include the generated interface declarations.
 
 mlir::vector
 
 mlir::vector::detail
 

Macros

#define GET_ATTRDEF_LIST
 
#define GET_OP_LIST
 
#define GET_ATTRDEF_CLASSES
 
#define GET_OP_CLASSES
 

Enumerations

enum class  MaskFormat { AllTrue = 0 , AllFalse = 1 , Unknown = 2 }
 Helper enum to classify mask value. More...
 

Functions

static MaskFormat getMaskFormat (Value mask)
 Helper method to classify a mask value. More...
 
static bool isSupportedCombiningKind (CombiningKind combiningKind, Type elementType)
 
static bool isSplatWriteConsistentWithMaskedRead (vector::TransferWriteOp write, vector::TransferReadOp read)
 Check if write is of a constant splat and the masked read is padded with the same splat value – meaning it could be the same value as the initial constant splat. More...
 
static LogicalResult incSlicePosition (MutableArrayRef< int64_t > position, ArrayRef< int64_t > shape, ArrayRef< int64_t > offsets)
 
static bool verifyDimMap (VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
 
static LogicalResult verifyOutputShape (ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
 
static int64_t getResultIndex (AffineMap map, AffineExpr targetExpr)
 
static std::vector< std::pair< int64_t, int64_t > > getDimMap (ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
 
template<typename IntType >
static SmallVector< IntType > extractVector (ArrayAttr arrayAttr)
 
static LogicalResult foldExtractOpFromExtractChain (ExtractOp extractOp)
 Fold the result of chains of ExtractOp in place by simply concatenating the positions. More...
 
static bool hasZeroDimVectors (Operation *op)
 Returns true if the operation has a 0-D vector type operand or result. More...
 
static Value foldExtractFromBroadcast (ExtractOp extractOp)
 Fold extractOp with scalar result coming from BroadcastOp or SplatOp. More...
 
static Value foldExtractFromShapeCast (ExtractOp extractOp)
 
static Value foldExtractFromExtractStrided (ExtractOp extractOp)
 Fold an ExtractOp from ExtractStridedSliceOp. More...
 
static Value foldExtractStridedOpFromInsertChain (ExtractOp extractOp)
 Fold extract_op fed from a chain of insertStridedSlice ops. More...
 
static Value foldScalarExtractFromFromElements (ExtractOp extractOp)
 Try to fold the extraction of a scalar from a vector defined by vector.from_elements. More...
 
static void populateFromInt64AttrArray (ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
 
static LogicalResult rewriteFromElementsAsSplat (FromElementsOp fromElementsOp, PatternRewriter &rewriter)
 Rewrite a vector.from_elements into a vector.splat if all elements are the same SSA value. More...
 
static llvm::SetVector< int64_t > computeBroadcastedUnitDims (ArrayRef< int64_t > srcShape, ArrayRef< int64_t > dstShape)
 Return the dimensions of the result vector that were formerly ones in the source tensor and thus correspond to "dim-1" broadcasting. More...
 
static bool isStepIndexArray (ArrayAttr idxArr, uint64_t begin, size_t width)
 
template<typename OpType >
static LogicalResult isIntegerArrayAttrSmallerThanShape (OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
 
template<typename OpType >
static LogicalResult isIntegerArrayAttrConfinedToRange (OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
 
template<typename OpType >
static LogicalResult isIntegerArrayAttrConfinedToShape (OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
 
template<typename OpType >
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape (OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
 
static ArrayAttr makeI64ArrayAttr (ArrayRef< int64_t > values, MLIRContext *context)
 
static Type inferStridedSliceOpResultType (VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
 
static LogicalResult foldExtractStridedOpFromInsertChain (ExtractStridedSliceOp op)
 
template<typename EmitFun >
static LogicalResult verifyPermutationMap (AffineMap permutationMap, EmitFun emitOpError)
 
static LogicalResult verifyTransferOp (VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, VectorType inferredMaskType, AffineMap permutationMap, ArrayAttr inBounds)
 
static void printTransferAttrs (OpAsmPrinter &p, VectorTransferOpInterface op)
 
template<typename TransferOp >
static bool isInBounds (TransferOp op, int64_t resultIdx, int64_t indicesIdx)
 
template<typename TransferOp >
static LogicalResult foldTransferInBoundsAttribute (TransferOp op)
 
template<typename TransferOp >
static LogicalResult foldTransferFullMask (TransferOp op)
 

Macro Definition Documentation

◆ GET_ATTRDEF_CLASSES

#define GET_ATTRDEF_CLASSES

◆ GET_ATTRDEF_LIST

#define GET_ATTRDEF_LIST

◆ GET_OP_CLASSES

#define GET_OP_CLASSES

◆ GET_OP_LIST

#define GET_OP_LIST

Enumeration Type Documentation

◆ MaskFormat

enum MaskFormat
strong

Helper enum to classify mask value.

Enumerator
AllTrue 
AllFalse 
Unknown 

Definition at line 58 of file VectorOps.cpp.

Function Documentation

◆ computeBroadcastedUnitDims()

static llvm::SetVector<int64_t> computeBroadcastedUnitDims ( ArrayRef< int64_t >  srcShape,
ArrayRef< int64_t >  dstShape 
)
static

Return the dimensions of the result vector that were formerly ones in the source tensor and thus correspond to "dim-1" broadcasting.

Definition at line 2243 of file VectorOps.cpp.

◆ extractVector()

template<typename IntType >
static SmallVector<IntType> extractVector ( ArrayAttr  arrayAttr)
static

Definition at line 1345 of file VectorOps.cpp.

◆ foldExtractFromBroadcast()

static Value foldExtractFromBroadcast ( ExtractOp  extractOp)
static

Fold extractOp with scalar result coming from BroadcastOp or SplatOp.

Definition at line 1625 of file VectorOps.cpp.

References mlir::Operation::getOperand(), and mlir::Value::getType().

◆ foldExtractFromExtractStrided()

static Value foldExtractFromExtractStrided ( ExtractOp  extractOp)
static

Fold an ExtractOp from ExtractStridedSliceOp.

Definition at line 1754 of file VectorOps.cpp.

References hasZeroDimVectors().

◆ foldExtractFromShapeCast()

static Value foldExtractFromShapeCast ( ExtractOp  extractOp)
static

Definition at line 1684 of file VectorOps.cpp.

References delinearize(), hasZeroDimVectors(), and mlir::linearize().

◆ foldExtractOpFromExtractChain()

static LogicalResult foldExtractOpFromExtractChain ( ExtractOp  extractOp)
static

Fold the result of chains of ExtractOp in place by simply concatenating the positions.

Definition at line 1353 of file VectorOps.cpp.

References mlir::failure(), and mlir::success().

◆ foldExtractStridedOpFromInsertChain() [1/2]

static Value foldExtractStridedOpFromInsertChain ( ExtractOp  extractOp)
static

Fold extract_op fed from a chain of insertStridedSlice ops.

Definition at line 1806 of file VectorOps.cpp.

References hasZeroDimVectors().

◆ foldExtractStridedOpFromInsertChain() [2/2]

static LogicalResult foldExtractStridedOpFromInsertChain ( ExtractStridedSliceOp  op)
static

Definition at line 3500 of file VectorOps.cpp.

◆ foldScalarExtractFromFromElements()

static Value foldScalarExtractFromFromElements ( ExtractOp  extractOp)
static

Try to fold the extraction of a scalar from a vector defined by vector.from_elements.

E.g.:

%0 = vector.from_elements a, b : vector<2xf32> %1 = vector.extract %0[0] : f32 from vector<2xf32> ==> fold to a

Definition at line 1886 of file VectorOps.cpp.

References mlir::Value::getDefiningOp().

◆ foldTransferFullMask()

template<typename TransferOp >
static LogicalResult foldTransferFullMask ( TransferOp  op)
static

Definition at line 4170 of file VectorOps.cpp.

◆ foldTransferInBoundsAttribute()

template<typename TransferOp >
static LogicalResult foldTransferInBoundsAttribute ( TransferOp  op)
static

Definition at line 4136 of file VectorOps.cpp.

◆ getDimMap()

static std::vector<std::pair<int64_t, int64_t> > getDimMap ( ArrayRef< AffineMap indexingMaps,
ArrayAttr  iteratorTypes,
IteratorType  targetIteratorType,
MLIRContext context 
)
static

◆ getMaskFormat()

static MaskFormat getMaskFormat ( Value  mask)
static

Helper method to classify a mask value.

Currently, the method looks "under the hood" of a constant value with dense attributes and a constant mask operation (since the client may be called at various stages during progressive lowering).

Definition at line 68 of file VectorOps.cpp.

References AllFalse, AllTrue, mlir::Value::getDefiningOp(), and Unknown.

◆ getResultIndex()

static int64_t getResultIndex ( AffineMap  map,
AffineExpr  targetExpr 
)
static

Definition at line 1064 of file VectorOps.cpp.

References mlir::AffineMap::getNumResults(), and mlir::AffineMap::getResult().

Referenced by getDimMap().

◆ hasZeroDimVectors()

static bool hasZeroDimVectors ( Operation op)
static

Returns true if the operation has a 0-D vector type operand or result.

Definition at line 1614 of file VectorOps.cpp.

Referenced by foldExtractFromExtractStrided(), foldExtractFromShapeCast(), and foldExtractStridedOpFromInsertChain().

◆ incSlicePosition()

static LogicalResult incSlicePosition ( MutableArrayRef< int64_t >  position,
ArrayRef< int64_t >  shape,
ArrayRef< int64_t >  offsets 
)
static

Definition at line 297 of file VectorOps.cpp.

References mlir::failure(), and mlir::success().

◆ inferStridedSliceOpResultType()

static Type inferStridedSliceOpResultType ( VectorType  vectorType,
ArrayAttr  offsets,
ArrayAttr  sizes,
ArrayAttr  strides 
)
static

Definition at line 3408 of file VectorOps.cpp.

References mlir::get().

◆ isInBounds()

template<typename TransferOp >
static bool isInBounds ( TransferOp  op,
int64_t  resultIdx,
int64_t  indicesIdx 
)
static

Definition at line 4119 of file VectorOps.cpp.

◆ isIntegerArrayAttrConfinedToRange()

template<typename OpType >
static LogicalResult isIntegerArrayAttrConfinedToRange ( OpType  op,
ArrayAttr  arrayAttr,
int64_t  min,
int64_t  max,
StringRef  attrName,
bool  halfOpen = true 
)
static

Definition at line 2945 of file VectorOps.cpp.

◆ isIntegerArrayAttrConfinedToShape()

template<typename OpType >
static LogicalResult isIntegerArrayAttrConfinedToShape ( OpType  op,
ArrayAttr  arrayAttr,
ArrayRef< int64_t >  shape,
StringRef  attrName,
bool  halfOpen = true,
int64_t  min = 0 
)
static

Definition at line 2965 of file VectorOps.cpp.

◆ isIntegerArrayAttrSmallerThanShape()

template<typename OpType >
static LogicalResult isIntegerArrayAttrSmallerThanShape ( OpType  op,
ArrayAttr  arrayAttr,
ArrayRef< int64_t >  shape,
StringRef  attrName 
)
static

Definition at line 2930 of file VectorOps.cpp.

◆ isSplatWriteConsistentWithMaskedRead()

static bool isSplatWriteConsistentWithMaskedRead ( vector::TransferWriteOp  write,
vector::TransferReadOp  read 
)
static

Check if write is of a constant splat and the masked read is padded with the same splat value – meaning it could be the same value as the initial constant splat.

Definition at line 176 of file VectorOps.cpp.

References mlir::DenseElementsAttr::getSplatValue(), mlir::DenseElementsAttr::isSplat(), mlir::m_Constant(), and mlir::matchPattern().

◆ isStepIndexArray()

static bool isStepIndexArray ( ArrayAttr  idxArr,
uint64_t  begin,
size_t  width 
)
static

Definition at line 2531 of file VectorOps.cpp.

◆ isSumOfIntegerArrayAttrConfinedToShape()

template<typename OpType >
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape ( OpType  op,
ArrayAttr  arrayAttr1,
ArrayAttr  arrayAttr2,
ArrayRef< int64_t >  shape,
StringRef  attrName1,
StringRef  attrName2,
bool  halfOpen = true,
int64_t  min = 1 
)
static

Definition at line 2988 of file VectorOps.cpp.

◆ isSupportedCombiningKind()

static bool isSupportedCombiningKind ( CombiningKind  combiningKind,
Type  elementType 
)
static

◆ makeI64ArrayAttr()

static ArrayAttr makeI64ArrayAttr ( ArrayRef< int64_t >  values,
MLIRContext context 
)
static

Definition at line 3009 of file VectorOps.cpp.

References mlir::get().

◆ populateFromInt64AttrArray()

static void populateFromInt64AttrArray ( ArrayAttr  arrayAttr,
SmallVectorImpl< int64_t > &  results 
)
static

Definition at line 2199 of file VectorOps.cpp.

◆ printTransferAttrs()

static void printTransferAttrs ( OpAsmPrinter p,
VectorTransferOpInterface  op 
)
static

Definition at line 3968 of file VectorOps.cpp.

◆ rewriteFromElementsAsSplat()

static LogicalResult rewriteFromElementsAsSplat ( FromElementsOp  fromElementsOp,
PatternRewriter rewriter 
)
static

Rewrite a vector.from_elements into a vector.splat if all elements are the same SSA value.

E.g.:

%0 = vector.from_elements a, a, a : vector<3xf32> ==> rewrite to vector.splat a : vector<3xf32>

Definition at line 2222 of file VectorOps.cpp.

References mlir::failure(), mlir::RewriterBase::replaceOpWithNewOp(), and mlir::success().

◆ verifyDimMap()

static bool verifyDimMap ( VectorType  lhsType,
VectorType  rhsType,
const std::vector< std::pair< int64_t, int64_t >> &  map 
)
static

Definition at line 853 of file VectorOps.cpp.

◆ verifyOutputShape()

static LogicalResult verifyOutputShape ( ContractionOp  op,
VectorType  lhsType,
VectorType  rhsType,
Type  accType,
Type  resType,
const std::vector< std::pair< int64_t, int64_t >> &  contractingDimMap,
const std::vector< std::pair< int64_t, int64_t >> &  batchDimMap 
)
static

Definition at line 864 of file VectorOps.cpp.

◆ verifyPermutationMap()

template<typename EmitFun >
static LogicalResult verifyPermutationMap ( AffineMap  permutationMap,
EmitFun  emitOpError 
)
static

◆ verifyTransferOp()

static LogicalResult verifyTransferOp ( VectorTransferOpInterface  op,
ShapedType  shapedType,
VectorType  vectorType,
VectorType  maskType,
VectorType  inferredMaskType,
AffineMap  permutationMap,
ArrayAttr  inBounds 
)
static

Definition at line 3883 of file VectorOps.cpp.