MLIR 22.0.0git
CooperativeMatrixOps.cpp
Go to the documentation of this file.
1//===- CooperativeMatrixOps.cpp - MLIR SPIR-V Cooperative Matrix Ops -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the Cooperative Matrix operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "SPIRVParsingUtils.h"
17#include "llvm/ADT/STLExtras.h"
18
19using namespace mlir::spirv::AttrNames;
20
21namespace mlir::spirv {
22
23static LogicalResult
24verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
25 spirv::MemoryAccessAttr memoryOperand,
26 IntegerAttr alignment) {
27 auto pointerType = cast<PointerType>(pointer);
28 Type pointeeType = pointerType.getPointeeType();
29 if (!isa<ScalarType, VectorType>(pointeeType)) {
30 return op->emitOpError(
31 "Pointer must point to a scalar or vector type but provided ")
32 << pointeeType;
33 }
34
35 if (memoryOperand) {
36 spirv::MemoryAccess operandSet = memoryOperand.getValue();
37
38 if (isa<spirv::KHRCooperativeMatrixLoadOp>(op) &&
39 spirv::bitEnumContainsAll(operandSet,
40 spirv::MemoryAccess::MakePointerAvailable)) {
41 return op->emitOpError(
42 "not compatible with memory operand 'MakePointerAvailable'");
43 }
44
45 if (isa<spirv::KHRCooperativeMatrixStoreOp>(op) &&
46 spirv::bitEnumContainsAll(operandSet,
47 spirv::MemoryAccess::MakePointerVisible)) {
48 return op->emitOpError(
49 "not compatible with memory operand 'MakePointerVisible'");
50 }
51
52 // TODO: Need to check that NonPrivatePointer is set for MakePointer*. See
53 // #145485.
54
55 if (spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) &&
56 !alignment) {
57 return op->emitOpError("missing value for the 'Aligned' memory operand");
58 }
59
60 if (!spirv::bitEnumContainsAll(operandSet, spirv::MemoryAccess::Aligned) &&
61 alignment) {
62 return op->emitOpError(
63 "found alignment attribute for non-'Aligned' memory operand");
64 }
65 }
66
67 // TODO: Verify the memory object behind the pointer:
68 // > If the Shader capability was declared, Pointer must point into an array
69 // > and any ArrayStride decoration on Pointer is ignored.
70
71 return success();
72}
73
74//===----------------------------------------------------------------------===//
75// spirv.KHR.CooperativeMatrixLoad
76//===----------------------------------------------------------------------===//
77
78LogicalResult KHRCooperativeMatrixLoadOp::verify() {
79 return verifyCoopMatrixAccess(*this, getPointer().getType(),
80 getResult().getType(), getMemoryOperandAttr(),
81 getAlignmentAttr());
82}
83
84//===----------------------------------------------------------------------===//
85// spirv.KHR.CooperativeMatrixStore
86//===----------------------------------------------------------------------===//
87
88LogicalResult KHRCooperativeMatrixStoreOp::verify() {
89 return verifyCoopMatrixAccess(*this, getPointer().getType(),
90 getObject().getType(), getMemoryOperandAttr(),
91 getAlignmentAttr());
92}
93
94//===----------------------------------------------------------------------===//
95// spirv.KHR.CooperativeMatrixMulAdd
96//===----------------------------------------------------------------------===//
97
98LogicalResult KHRCooperativeMatrixMulAddOp::verify() {
99 auto typeA = cast<spirv::CooperativeMatrixType>(getA().getType());
100 auto typeB = cast<spirv::CooperativeMatrixType>(getB().getType());
101 auto typeC = cast<spirv::CooperativeMatrixType>(getC().getType());
102
103 // Check element types. ODS enforces that `type(c) == type(result)`, so no
104 // need to check it here.
105
106 // Check the 'use' part of the type against the operands and the result.
107 if (typeA.getUse() != CooperativeMatrixUseKHR::MatrixA)
108 return emitOpError("operand #0 must be of use 'MatrixA'");
109 if (typeB.getUse() != CooperativeMatrixUseKHR::MatrixB)
110 return emitOpError("operand #1 must be of use 'MatrixB'");
111 if (typeC.getUse() != CooperativeMatrixUseKHR::MatrixAcc)
112 return emitOpError("operand #2 must be of use 'MatrixAcc'");
113
114 // Check the 'scope' part of the type.
115 if (!llvm::all_equal({typeA.getScope(), typeB.getScope(), typeC.getScope()}))
116 return emitOpError("matrix scope mismatch");
117
118 // Check dimension sizes. We expect 'MxK * KxN + MxN -> MxN'.
119 if (typeA.getRows() != typeC.getRows())
120 return emitOpError("matrix size mismatch on dimension 'M'");
121 if (typeB.getColumns() != typeC.getColumns())
122 return emitOpError("matrix size mismatch on dimension 'N'");
123 if (typeA.getColumns() != typeB.getRows())
124 return emitOpError("matrix size mismatch on dimension 'K'");
125
126 // The spec does not restrict the element types:
127 // > A, B, C, and Result Type need not necessarily have the same component
128 // > type, this is defined by the client API.
129
130 // Check that if Cooperative Matrix Operands are provided, the element type
131 // is integer.
132 if (getMatrixOperands()) {
133 Type elementTypes[] = {typeA.getElementType(), typeB.getElementType(),
134 typeC.getElementType()};
135 if (!llvm::all_of(elementTypes, llvm::IsaPred<IntegerType>)) {
136 return emitOpError("Matrix Operands require all matrix element types to "
137 "be Integer Types");
138 }
139 }
140
141 // Any further requirements need to be checked against VCE.
142 return success();
143}
144
145} // namespace mlir::spirv
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
static LogicalResult verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, spirv::MemoryAccessAttr memoryOperand, IntegerAttr alignment)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304