17 #include "llvm/ADT/STLExtras.h"
26 spirv::MemoryAccessAttr memoryOperand) {
27 auto pointerType = cast<PointerType>(pointer);
28 Type pointeeType = pointerType.getPointeeType();
29 if (!isa<ScalarType, VectorType>(pointeeType)) {
30 return op->emitOpError(
31 "Pointer must point to a scalar or vector type but provided ")
36 spirv::MemoryAccess operandSet = memoryOperand.getValue();
38 if (isa<spirv::KHRCooperativeMatrixLoadOp>(op) &&
39 spirv::bitEnumContainsAll(operandSet,
40 spirv::MemoryAccess::MakePointerAvailable)) {
41 return op->emitOpError(
42 "not compatible with memory operand 'MakePointerAvailable'");
45 if (isa<spirv::KHRCooperativeMatrixStoreOp>(op) &&
46 spirv::bitEnumContainsAll(operandSet,
47 spirv::MemoryAccess::MakePointerVisible)) {
48 return op->emitOpError(
49 "not compatible with memory operand 'MakePointerVisible'");
56 if (spirv::bitEnumContainsAll(memoryOperand.getValue(),
57 spirv::MemoryAccess::Aligned)) {
58 return op->emitOpError(
"has unhandled memory operand 'Aligned'");
75 getResult().getType(), getMemoryOperandAttr());
84 getObject().getType(), getMemoryOperandAttr());
92 auto typeA = cast<spirv::CooperativeMatrixType>(getA().getType());
93 auto typeB = cast<spirv::CooperativeMatrixType>(getB().getType());
94 auto typeC = cast<spirv::CooperativeMatrixType>(getC().getType());
100 if (typeA.getUse() != CooperativeMatrixUseKHR::MatrixA)
101 return emitOpError(
"operand #0 must be of use 'MatrixA'");
102 if (typeB.getUse() != CooperativeMatrixUseKHR::MatrixB)
103 return emitOpError(
"operand #1 must be of use 'MatrixB'");
104 if (typeC.getUse() != CooperativeMatrixUseKHR::MatrixAcc)
105 return emitOpError(
"operand #2 must be of use 'MatrixAcc'");
108 if (!llvm::all_equal({typeA.getScope(), typeB.getScope(), typeC.getScope()}))
109 return emitOpError(
"matrix scope mismatch");
112 if (typeA.getRows() != typeC.getRows())
113 return emitOpError(
"matrix size mismatch on dimension 'M'");
114 if (typeB.getColumns() != typeC.getColumns())
115 return emitOpError(
"matrix size mismatch on dimension 'N'");
116 if (typeA.getColumns() != typeB.getRows())
117 return emitOpError(
"matrix size mismatch on dimension 'K'");
125 if (getMatrixOperands()) {
126 Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(),
127 typeC.getElementType()};
128 if (!llvm::all_of(elementTypes,
129 [](Type ty) {
return isa<IntegerType>(ty); })) {
130 return emitOpError(
"Matrix Operands require all matrix element types to "
144 if (!isa<CooperativeMatrixNVType>(getCooperativeMatrixType())) {
146 "type attribute must be a '!spirv.NV.coopmatrix' type, found ")
147 << getCooperativeMatrixType() <<
" instead";
158 OperationState &result) {
159 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
160 Type strideType = parser.getBuilder().getIntegerType(32);
161 Type columnMajorType = parser.getBuilder().getIntegerType(1);
164 if (parser.parseOperandList(operandInfo, 3) ||
166 parser.parseType(ptrType) || parser.parseKeywordType(
"as", elementType)) {
169 if (parser.resolveOperands(operandInfo,
170 {ptrType, strideType, columnMajorType},
171 parser.getNameLoc(), result.operands)) {
175 result.addTypes(elementType);
180 printer <<
" " << getPointer() <<
", " << getStride() <<
", "
183 if (
auto memAccess = getMemoryAccess())
184 printer <<
" [\"" << stringifyMemoryAccess(*memAccess) <<
"\"]";
185 printer <<
" : " << getPointer().getType() <<
" as " << getType();
190 Type pointeeType = llvm::cast<PointerType>(pointer).getPointeeType();
191 if (!llvm::isa<ScalarType>(pointeeType) &&
192 !llvm::isa<VectorType>(pointeeType))
193 return op->emitError(
194 "Pointer must point to a scalar or vector type but provided ")
196 StorageClass storage = llvm::cast<PointerType>(pointer).getStorageClass();
197 if (storage != StorageClass::Workgroup &&
198 storage != StorageClass::StorageBuffer &&
199 storage != StorageClass::PhysicalStorageBuffer)
200 return op->emitError(
201 "Pointer storage class must be Workgroup, StorageBuffer or "
202 "PhysicalStorageBufferEXT but provided ")
203 << stringifyStorageClass(storage);
209 getResult().getType());
217 OperationState &result) {
218 SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
219 Type strideType = parser.getBuilder().getIntegerType(32);
220 Type columnMajorType = parser.getBuilder().getIntegerType(1);
223 if (parser.parseOperandList(operandInfo, 4) ||
225 parser.parseType(ptrType) || parser.parseComma() ||
226 parser.parseType(elementType)) {
229 if (parser.resolveOperands(
230 operandInfo, {ptrType, elementType, strideType, columnMajorType},
231 parser.getNameLoc(), result.operands)) {
239 printer <<
" " << getPointer() <<
", " << getObject() <<
", " << getStride()
240 <<
", " << getColumnmajor();
242 if (
auto memAccess = getMemoryAccess())
243 printer <<
" [\"" << stringifyMemoryAccess(*memAccess) <<
"\"]";
244 printer <<
" : " << getPointer().getType() <<
", " << getOperand(1).getType();
249 getObject().getType());
257 if (op.getC().getType() != op.getResult().getType())
258 return op.emitOpError(
"result and third operand must have the same type");
259 auto typeA = llvm::cast<CooperativeMatrixNVType>(op.getA().getType());
260 auto typeB = llvm::cast<CooperativeMatrixNVType>(op.getB().getType());
261 auto typeC = llvm::cast<CooperativeMatrixNVType>(op.getC().getType());
262 auto typeR = llvm::cast<CooperativeMatrixNVType>(op.getResult().getType());
263 if (typeA.getRows() != typeR.getRows() ||
264 typeA.getColumns() != typeB.getRows() ||
265 typeB.getColumns() != typeR.getColumns())
266 return op.emitOpError(
"matrix size must match");
267 if (typeR.getScope() != typeA.getScope() ||
268 typeR.getScope() != typeB.getScope() ||
269 typeR.getScope() != typeC.getScope())
270 return op.emitOpError(
"matrix scope must match");
271 auto elementTypeA = typeA.getElementType();
272 auto elementTypeB = typeB.getElementType();
273 if (isa<IntegerType>(elementTypeA) && isa<IntegerType>(elementTypeB)) {
274 if (llvm::cast<IntegerType>(elementTypeA).getWidth() !=
275 llvm::cast<IntegerType>(elementTypeB).getWidth())
276 return op.emitOpError(
277 "matrix A and B integer element types must be the same bit width");
278 }
else if (elementTypeA != elementTypeB) {
279 return op.emitOpError(
280 "matrix A and B non-integer element types must match");
282 if (typeR.getElementType() != typeC.getElementType())
283 return op.emitOpError(
"matrix accumulator element type must match");
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Operation is the basic unit of execution within MLIR.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
@ Type
An inlay hint that for a type annotation.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state, StringRef attrName)
Parses optional memory access (a.k.a.
static LogicalResult verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, spirv::MemoryAccessAttr memoryOperand)
static LogicalResult verifyCoopMatrixMulAddNV(NVCooperativeMatrixMulAddOp op)
static LogicalResult verifyPointerAndCoopMatrixNVType(Operation *op, Type pointer, Type coopMatrix)
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This class represents an efficient way to signal success or failure.