MLIR  16.0.0git
Classes | Namespaces | Macros | Enumerations | Functions
VectorOps.cpp File Reference
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include <numeric>
#include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc"
#include "mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc"
#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc"
#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
+ Include dependency graph for VectorOps.cpp:

Go to the source code of this file.

Classes

struct  mlir::vector::detail::BitmaskEnumStorage
 
struct  CanonicalizeContractAdd< AddOpType >
 Return a fused vector::ContractionOp which represents a patterns such as: More...
 

Namespaces

 mlir
 Include the generated interface declarations.
 
 mlir::vector
 
 mlir::vector::detail
 

Macros

#define GET_ATTRDEF_LIST
 
#define GET_OP_LIST
 
#define GET_ATTRDEF_CLASSES
 
#define GET_OP_CLASSES
 

Enumerations

enum  MaskFormat { MaskFormat::AllTrue = 0, MaskFormat::AllFalse = 1, MaskFormat::Unknown = 2 }
 Helper enum to classify mask value. More...
 

Functions

static MaskFormat getMaskFormat (Value mask)
 Helper method to classify a mask value. More...
 
static bool isSupportedCombiningKind (CombiningKind combiningKind, Type elementType)
 
static bool verifyDimMap (VectorType lhsType, VectorType rhsType, const std::vector< std::pair< int64_t, int64_t >> &map)
 
static LogicalResult verifyOutputShape (ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, Type resType, const std::vector< std::pair< int64_t, int64_t >> &contractingDimMap, const std::vector< std::pair< int64_t, int64_t >> &batchDimMap)
 
static int64_t getResultIndex (AffineMap map, AffineExpr targetExpr)
 
static std::vector< std::pair< int64_t, int64_t > > getDimMap (ArrayRef< AffineMap > indexingMaps, ArrayAttr iteratorTypes, IteratorType targetIteratorType, MLIRContext *context)
 
template<typename IntType >
static SmallVector< IntType > extractVector (ArrayAttr arrayAttr)
 
static LogicalResult foldExtractOpFromExtractChain (ExtractOp extractOp)
 Fold the result of chains of ExtractOp in place by simply concatenating the positions. More...
 
static Value foldExtractFromBroadcast (ExtractOp extractOp)
 Fold extractOp with scalar result coming from BroadcastOp or SplatOp. More...
 
static Value foldExtractFromShapeCast (ExtractOp extractOp)
 
static Value foldExtractFromExtractStrided (ExtractOp extractOp)
 Fold an ExtractOp from ExtractStridedSliceOp. More...
 
static Value foldExtractStridedOpFromInsertChain (ExtractOp op)
 Fold extract_op fed from a chain of insertStridedSlice ops. More...
 
static void populateFromInt64AttrArray (ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
 
static bool isStepIndexArray (ArrayAttr idxArr, uint64_t begin, size_t width)
 
template<typename OpType >
static LogicalResult isIntegerArrayAttrSmallerThanShape (OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName)
 
template<typename OpType >
static LogicalResult isIntegerArrayAttrConfinedToRange (OpType op, ArrayAttr arrayAttr, int64_t min, int64_t max, StringRef attrName, bool halfOpen=true)
 
template<typename OpType >
static LogicalResult isIntegerArrayAttrConfinedToShape (OpType op, ArrayAttr arrayAttr, ArrayRef< int64_t > shape, StringRef attrName, bool halfOpen=true, int64_t min=0)
 
template<typename OpType >
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape (OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, ArrayRef< int64_t > shape, StringRef attrName1, StringRef attrName2, bool halfOpen=true, int64_t min=1)
 
static ArrayAttr makeI64ArrayAttr (ArrayRef< int64_t > values, MLIRContext *context)
 
static Type inferStridedSliceOpResultType (VectorType vectorType, ArrayAttr offsets, ArrayAttr sizes, ArrayAttr strides)
 
static LogicalResult foldExtractStridedOpFromInsertChain (ExtractStridedSliceOp op)
 
template<typename EmitFun >
static LogicalResult verifyPermutationMap (AffineMap permutationMap, EmitFun emitOpError)
 
static LogicalResult verifyTransferOp (VectorTransferOpInterface op, ShapedType shapedType, VectorType vectorType, VectorType maskType, AffineMap permutationMap, ArrayAttr inBounds)
 
static void printTransferAttrs (OpAsmPrinter &p, VectorTransferOpInterface op)
 
static LogicalResult foldMemRefCast (Operation *op)
 This is a common class used for patterns of the form

someop(memrefcast) -> someop

It folds the source of the memref.cast into the root operation directly. More...

 
static LogicalResult foldTensorCast (Operation *op)
 
template<typename TransferOp >
static bool isInBounds (TransferOp op, int64_t resultIdx, int64_t indicesIdx)
 
template<typename TransferOp >
static LogicalResult foldTransferInBoundsAttribute (TransferOp op)
 

Macro Definition Documentation

◆ GET_ATTRDEF_CLASSES

#define GET_ATTRDEF_CLASSES

◆ GET_ATTRDEF_LIST

#define GET_ATTRDEF_LIST

◆ GET_OP_CLASSES

#define GET_OP_CLASSES

◆ GET_OP_LIST

#define GET_OP_LIST

Enumeration Type Documentation

◆ MaskFormat

enum MaskFormat
strong

Helper enum to classify mask value.

Enumerator
AllTrue 
AllFalse 
Unknown 

Definition at line 46 of file VectorOps.cpp.

Function Documentation

◆ extractVector()

template<typename IntType >
static SmallVector<IntType> extractVector ( ArrayAttr  arrayAttr)
static

Definition at line 1044 of file VectorOps.cpp.

◆ foldExtractFromBroadcast()

static Value foldExtractFromBroadcast ( ExtractOp  extractOp)
static

Fold extractOp with scalar result coming from BroadcastOp or SplatOp.

Definition at line 1290 of file VectorOps.cpp.

References mlir::Type::cast(), mlir::Type::dyn_cast(), mlir::Operation::getOperand(), mlir::Value::getType(), and mlir::Type::isa().

Referenced by foldExtractStridedOpFromInsertChain().

◆ foldExtractFromExtractStrided()

static Value foldExtractFromExtractStrided ( ExtractOp  extractOp)
static

Fold an ExtractOp from ExtractStridedSliceOp.

Definition at line 1383 of file VectorOps.cpp.

References mlir::Value::getDefiningOp().

Referenced by foldExtractStridedOpFromInsertChain().

◆ foldExtractFromShapeCast()

static Value foldExtractFromShapeCast ( ExtractOp  extractOp)
static

◆ foldExtractOpFromExtractChain()

static LogicalResult foldExtractOpFromExtractChain ( ExtractOp  extractOp)
static

Fold the result of chains of ExtractOp in place by simply concatenating the positions.

Definition at line 1052 of file VectorOps.cpp.

References extractPosition(), mlir::failure(), mlir::Value::getDefiningOp(), and mlir::success().

Referenced by foldExtractStridedOpFromInsertChain().

◆ foldExtractStridedOpFromInsertChain() [1/2]

static Value foldExtractStridedOpFromInsertChain ( ExtractOp  op)
static

◆ foldExtractStridedOpFromInsertChain() [2/2]

static LogicalResult foldExtractStridedOpFromInsertChain ( ExtractStridedSliceOp  op)
static

◆ foldMemRefCast()

static LogicalResult foldMemRefCast ( Operation op)
static

This is a common class used for patterns of the form

someop(memrefcast) -> someop

It folds the source of the memref.cast into the root operation directly.

Definition at line 3043 of file VectorOps.cpp.

References mlir::tensor::canFoldIntoConsumerOp(), mlir::Operation::getOpOperands(), and mlir::success().

Referenced by foldTransferInBoundsAttribute().

◆ foldTensorCast()

static LogicalResult foldTensorCast ( Operation op)
static

◆ foldTransferInBoundsAttribute()

template<typename TransferOp >
static LogicalResult foldTransferInBoundsAttribute ( TransferOp  op)
static

Definition at line 3085 of file VectorOps.cpp.

References mlir::RewritePatternSet::add(), mlir::OperationState::addAttribute(), mlir::AsmParser::addTypeToList(), AllFalse, AllTrue, mlir::applyPermutationMap(), mlir::OperationState::attributes, mlir::Type::cast(), mlir::vector::checkSameValueRAW(), mlir::vector::checkSameValueWAW(), mlir::DataLayout::closest(), mlir::AffineMap::compose(), mlir::compressUnusedDims(), mlir::OpBuilder::create(), mlir::Attribute::dyn_cast(), mlir::AffineExpr::dyn_cast(), mlir::Type::dyn_cast(), mlir::AsmParser::emitError(), mlir::Operation::emitOpError(), mlir::detail::enumerate(), mlir::RewriterBase::eraseOp(), mlir::failed(), mlir::failure(), foldMemRefCast(), foldTensorCast(), mlir::SideEffects::Effect::Base< DerivedEffect, BaseEffect >::get(), mlir::SideEffects::Resource::Base< DefaultResource >::get(), mlir::DenseElementsAttr::get(), mlir::NamedAttrList::get(), mlir::Builder::getBoolArrayAttr(), mlir::AsmParser::getBuilder(), mlir::getConstantIntValue(), mlir::AsmParser::getCurrentLocation(), mlir::Value::getDefiningOp(), mlir::Builder::getDenseI32ArrayAttr(), mlir::Builder::getI1Type(), mlir::Builder::getIndexType(), getIndices(), getMaskFormat(), mlir::bufferization::getMemRefType(), mlir::AffineMap::getPermutationMap(), mlir::AffineMap::getResult(), mlir::vector::getTransferMinorIdentityMap(), mlir::Value::getType(), mlir::DataLayout::getTypeSizeInBits(), mlir::getValueOrCreateConstantIndexOp(), getVectorType(), mlir::vector::isDisjointTransferIndices(), mlir::isEqualConstantIntOrValue(), mlir::Type::isF16(), mlir::Type::isF32(), isInBounds(), mlir::vector::isLastMemrefDimUnitStride(), mlir::OpAsmParser::UnresolvedOperand::location, mlir::RewriterBase::notifyMatchFailure(), mlir::OperationState::operands, mlir::OpRewritePattern< SourceOp >::OpRewritePattern(), mlir::AsmParser::parseColonTypeList(), mlir::AsmParser::parseComma(), mlir::OpAsmParser::parseOperand(), mlir::OpAsmParser::parseOperandList(), mlir::AsmParser::parseOptionalAttrDict(), mlir::AsmParser::parseOptionalComma(), print(), printTransferAttrs(), mlir::RewriterBase::replaceOp(), mlir::RewriterBase::replaceOpWithNewOp(), mlir::OpAsmParser::resolveOperand(), mlir::OpAsmParser::resolveOperands(), mlir::NamedAttrList::set(), mlir::AsmParser::Square, mlir::LogicalResult::succeeded(), mlir::succeeded(), mlir::success(), mlir::OperationState::types, Unknown, mlir::RewriterBase::updateRootInPlace(), mlir::arith::ConstantIndexOp::value(), value, vectorShape(), vectorType(), mlir::verify(), verifyPermutationMap(), and verifyTransferOp().

◆ getDimMap()

static std::vector<std::pair<int64_t, int64_t> > getDimMap ( ArrayRef< AffineMap indexingMaps,
ArrayAttr  iteratorTypes,
IteratorType  targetIteratorType,
MLIRContext context 
)
static

◆ getMaskFormat()

static MaskFormat getMaskFormat ( Value  mask)
static

Helper method to classify a mask value.

Currently, the method looks "under the hood" of a constant value with dense attributes and a constant mask operation (since the client may be called at various stages during progressive lowering).

Definition at line 56 of file VectorOps.cpp.

References AllFalse, AllTrue, mlir::Value::getDefiningOp(), and Unknown.

Referenced by foldTransferInBoundsAttribute().

◆ getResultIndex()

static int64_t getResultIndex ( AffineMap  map,
AffineExpr  targetExpr 
)
static

Definition at line 782 of file VectorOps.cpp.

References mlir::AffineMap::getNumResults(), and mlir::AffineMap::getResult().

Referenced by getDimMap().

◆ inferStridedSliceOpResultType()

static Type inferStridedSliceOpResultType ( VectorType  vectorType,
ArrayAttr  offsets,
ArrayAttr  sizes,
ArrayAttr  strides 
)
static

◆ isInBounds()

template<typename TransferOp >
static bool isInBounds ( TransferOp  op,
int64_t  resultIdx,
int64_t  indicesIdx 
)
static

◆ isIntegerArrayAttrConfinedToRange()

template<typename OpType >
static LogicalResult isIntegerArrayAttrConfinedToRange ( OpType  op,
ArrayAttr  arrayAttr,
int64_t  min,
int64_t  max,
StringRef  attrName,
bool  halfOpen = true 
)
static

Definition at line 2094 of file VectorOps.cpp.

References max(), and mlir::success().

Referenced by inferStridedSliceOpResultType(), and makeI64ArrayAttr().

◆ isIntegerArrayAttrConfinedToShape()

template<typename OpType >
static LogicalResult isIntegerArrayAttrConfinedToShape ( OpType  op,
ArrayAttr  arrayAttr,
ArrayRef< int64_t >  shape,
StringRef  attrName,
bool  halfOpen = true,
int64_t  min = 0 
)
static

Definition at line 2114 of file VectorOps.cpp.

References max(), min(), and mlir::success().

Referenced by inferStridedSliceOpResultType(), and makeI64ArrayAttr().

◆ isIntegerArrayAttrSmallerThanShape()

template<typename OpType >
static LogicalResult isIntegerArrayAttrSmallerThanShape ( OpType  op,
ArrayAttr  arrayAttr,
ArrayRef< int64_t >  shape,
StringRef  attrName 
)
static

Definition at line 2079 of file VectorOps.cpp.

References mlir::success().

Referenced by inferStridedSliceOpResultType().

◆ isStepIndexArray()

static bool isStepIndexArray ( ArrayAttr  idxArr,
uint64_t  begin,
size_t  width 
)
static

Definition at line 1802 of file VectorOps.cpp.

◆ isSumOfIntegerArrayAttrConfinedToShape()

template<typename OpType >
static LogicalResult isSumOfIntegerArrayAttrConfinedToShape ( OpType  op,
ArrayAttr  arrayAttr1,
ArrayAttr  arrayAttr2,
ArrayRef< int64_t >  shape,
StringRef  attrName1,
StringRef  attrName2,
bool  halfOpen = true,
int64_t  min = 1 
)
static

Definition at line 2137 of file VectorOps.cpp.

References max(), min(), and mlir::success().

Referenced by inferStridedSliceOpResultType(), and makeI64ArrayAttr().

◆ isSupportedCombiningKind()

static bool isSupportedCombiningKind ( CombiningKind  combiningKind,
Type  elementType 
)
static

◆ makeI64ArrayAttr()

static ArrayAttr makeI64ArrayAttr ( ArrayRef< int64_t >  values,
MLIRContext context 
)
static

◆ populateFromInt64AttrArray()

static void populateFromInt64AttrArray ( ArrayAttr  arrayAttr,
SmallVectorImpl< int64_t > &  results 
)
static

Definition at line 1627 of file VectorOps.cpp.

References getVectorType().

Referenced by foldExtractStridedOpFromInsertChain().

◆ printTransferAttrs()

static void printTransferAttrs ( OpAsmPrinter p,
VectorTransferOpInterface  op 
)
static

◆ verifyDimMap()

static bool verifyDimMap ( VectorType  lhsType,
VectorType  rhsType,
const std::vector< std::pair< int64_t, int64_t >> &  map 
)
static

Definition at line 598 of file VectorOps.cpp.

◆ verifyOutputShape()

static LogicalResult verifyOutputShape ( ContractionOp  op,
VectorType  lhsType,
VectorType  rhsType,
Type  accType,
Type  resType,
const std::vector< std::pair< int64_t, int64_t >> &  contractingDimMap,
const std::vector< std::pair< int64_t, int64_t >> &  batchDimMap 
)
static

Definition at line 609 of file VectorOps.cpp.

◆ verifyPermutationMap()

template<typename EmitFun >
static LogicalResult verifyPermutationMap ( AffineMap  permutationMap,
EmitFun  emitOpError 
)
static

◆ verifyTransferOp()

static LogicalResult verifyTransferOp ( VectorTransferOpInterface  op,
ShapedType  shapedType,
VectorType  vectorType,
VectorType  maskType,
AffineMap  permutationMap,
ArrayAttr  inBounds 
)
static