17 #include "llvm/ADT/STLExtras.h"
26 spirv::MemoryAccessAttr memoryOperand) {
27 auto pointerType = cast<PointerType>(pointer);
28 Type pointeeType = pointerType.getPointeeType();
29 if (!isa<ScalarType, VectorType>(pointeeType)) {
31 "Pointer must point to a scalar or vector type but provided ")
36 spirv::MemoryAccess operandSet = memoryOperand.getValue();
38 if (isa<spirv::KHRCooperativeMatrixLoadOp>(op) &&
39 spirv::bitEnumContainsAll(operandSet,
40 spirv::MemoryAccess::MakePointerAvailable)) {
42 "not compatible with memory operand 'MakePointerAvailable'");
45 if (isa<spirv::KHRCooperativeMatrixStoreOp>(op) &&
46 spirv::bitEnumContainsAll(operandSet,
47 spirv::MemoryAccess::MakePointerVisible)) {
49 "not compatible with memory operand 'MakePointerVisible'");
56 if (spirv::bitEnumContainsAll(memoryOperand.getValue(),
57 spirv::MemoryAccess::Aligned)) {
58 return op->
emitOpError(
"has unhandled memory operand 'Aligned'");
75 getResult().
getType(), getMemoryOperandAttr());
84 getObject().
getType(), getMemoryOperandAttr());
92 auto typeA = cast<spirv::CooperativeMatrixType>(getA().
getType());
93 auto typeB = cast<spirv::CooperativeMatrixType>(getB().
getType());
94 auto typeC = cast<spirv::CooperativeMatrixType>(getC().
getType());
100 if (typeA.getUse() != CooperativeMatrixUseKHR::MatrixA)
101 return emitOpError(
"operand #0 must be of use 'MatrixA'");
102 if (typeB.getUse() != CooperativeMatrixUseKHR::MatrixB)
103 return emitOpError(
"operand #1 must be of use 'MatrixB'");
104 if (typeC.getUse() != CooperativeMatrixUseKHR::MatrixAcc)
105 return emitOpError(
"operand #2 must be of use 'MatrixAcc'");
108 if (!llvm::all_equal({typeA.getScope(), typeB.getScope(), typeC.getScope()}))
109 return emitOpError(
"matrix scope mismatch");
112 if (typeA.getRows() != typeC.getRows())
113 return emitOpError(
"matrix size mismatch on dimension 'M'");
114 if (typeB.getColumns() != typeC.getColumns())
115 return emitOpError(
"matrix size mismatch on dimension 'N'");
116 if (typeA.getColumns() != typeB.getRows())
117 return emitOpError(
"matrix size mismatch on dimension 'K'");
125 if (getMatrixOperands()) {
126 Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(),
127 typeC.getElementType()};
128 if (!llvm::all_of(elementTypes, llvm::IsaPred<IntegerType>)) {
129 return emitOpError(
"Matrix Operands require all matrix element types to "
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
@ Type
An inlay hint that for a type annotation.
static LogicalResult verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, spirv::MemoryAccessAttr memoryOperand)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...