17 #include "llvm/ADT/STLExtras.h"
25 spirv::MemoryAccessAttr memoryOperand,
26 IntegerAttr alignment) {
27 auto pointerType = cast<PointerType>(pointer);
28 Type pointeeType = pointerType.getPointeeType();
29 if (!isa<ScalarType, VectorType>(pointeeType)) {
31 "Pointer must point to a scalar or vector type but provided ")
36 spirv::MemoryAccess operandSet = memoryOperand.getValue();
38 if (isa<spirv::KHRCooperativeMatrixLoadOp>(op) &&
39 spirv::bitEnumContainsAll(operandSet,
40 spirv::MemoryAccess::MakePointerAvailable)) {
42 "not compatible with memory operand 'MakePointerAvailable'");
45 if (isa<spirv::KHRCooperativeMatrixStoreOp>(op) &&
46 spirv::bitEnumContainsAll(operandSet,
47 spirv::MemoryAccess::MakePointerVisible)) {
49 "not compatible with memory operand 'MakePointerVisible'");
55 if (spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) &&
57 return op->
emitOpError(
"missing value for the 'Aligned' memory operand");
60 if (!spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) &&
63 "found alignment attribute for non-'Aligned' memory operand");
80 getResult().
getType(), getMemoryOperandAttr(),
90 getObject().
getType(), getMemoryOperandAttr(),
99 auto typeA = cast<spirv::CooperativeMatrixType>(getA().
getType());
100 auto typeB = cast<spirv::CooperativeMatrixType>(getB().
getType());
101 auto typeC = cast<spirv::CooperativeMatrixType>(getC().
getType());
107 if (typeA.getUse() != CooperativeMatrixUseKHR::MatrixA)
108 return emitOpError(
"operand #0 must be of use 'MatrixA'");
109 if (typeB.getUse() != CooperativeMatrixUseKHR::MatrixB)
110 return emitOpError(
"operand #1 must be of use 'MatrixB'");
111 if (typeC.getUse() != CooperativeMatrixUseKHR::MatrixAcc)
112 return emitOpError(
"operand #2 must be of use 'MatrixAcc'");
115 if (!llvm::all_equal({typeA.getScope(), typeB.getScope(), typeC.getScope()}))
116 return emitOpError(
"matrix scope mismatch");
119 if (typeA.getRows() != typeC.getRows())
120 return emitOpError(
"matrix size mismatch on dimension 'M'");
121 if (typeB.getColumns() != typeC.getColumns())
122 return emitOpError(
"matrix size mismatch on dimension 'N'");
123 if (typeA.getColumns() != typeB.getRows())
124 return emitOpError(
"matrix size mismatch on dimension 'K'");
132 if (getMatrixOperands()) {
133 Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(),
134 typeC.getElementType()};
135 if (!llvm::all_of(elementTypes, llvm::IsaPred<IntegerType>)) {
136 return emitOpError(
"Matrix Operands require all matrix element types to "
Operation is the basic unit of execution within MLIR.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
@ Type
An inlay hint that for a type annotation.
static LogicalResult verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, spirv::MemoryAccessAttr memoryOperand, IntegerAttr alignment)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...