18 #include "llvm/ADT/TypeSwitch.h"
25 bool requireSameBitWidth =
true,
26 bool skipBitWidthCheck =
false) {
28 if (skipBitWidthCheck)
36 using TypePair = std::pair<Type, Type>;
37 auto [operandElemTy, resultElemTy] =
40 [resultType](
auto concreteOperandTy) -> TypePair {
41 if (
auto concreteResultTy =
42 dyn_cast<decltype(concreteOperandTy)>(resultType)) {
43 return {concreteOperandTy.getElementType(),
44 concreteResultTy.getElementType()};
48 .Default([resultType](
Type operandType) -> TypePair {
49 return {operandType, resultType};
52 if (!operandElemTy || !resultElemTy)
53 return op->
emitOpError(
"incompatible operand and result types");
55 unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth();
56 unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth();
57 bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
59 if (requireSameBitWidth) {
60 if (!isSameBitWidth) {
62 "expected the same bit widths for operand type and result "
63 "type, but provided ")
64 << operandElemTy <<
" and " << resultElemTy;
71 "expected the different bit widths for operand type and result "
72 "type, but provided ")
73 << operandElemTy <<
" and " << resultElemTy;
85 auto operandType = getOperand().getType();
86 auto resultType = getResult().getType();
87 if (operandType == resultType) {
88 return emitError(
"result type must be different from operand type");
90 if (llvm::isa<spirv::PointerType>(operandType) &&
91 !llvm::isa<spirv::PointerType>(resultType)) {
93 "unhandled bit cast conversion from pointer type to non-pointer type");
95 if (!llvm::isa<spirv::PointerType>(operandType) &&
96 llvm::isa<spirv::PointerType>(resultType)) {
98 "unhandled bit cast conversion from non-pointer type to pointer type");
102 if (operandBitWidth != resultBitWidth) {
103 return emitOpError(
"mismatch in result type bitwidth ")
104 << resultBitWidth <<
" and operand type bitwidth "
115 auto operandType = llvm::cast<spirv::PointerType>(getPointer().
getType());
116 auto resultType = llvm::cast<spirv::ScalarType>(getResult().
getType());
117 if (!resultType || !resultType.isSignlessInteger())
118 return emitError(
"result must be a scalar type of unsigned integer");
119 auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
122 auto addressingModel = spirvModule.getAddressingModel();
123 if ((addressingModel == spirv::AddressingModel::Logical) ||
124 (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
125 operandType.getStorageClass() !=
126 spirv::StorageClass::PhysicalStorageBuffer))
127 return emitError(
"operand must be a physical pointer");
136 auto operandType = llvm::cast<spirv::ScalarType>(getOperand().
getType());
137 auto resultType = llvm::cast<spirv::PointerType>(getResult().
getType());
138 if (!operandType || !operandType.isSignlessInteger())
139 return emitError(
"result must be a scalar type of unsigned integer");
140 auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
143 auto addressingModel = spirvModule.getAddressingModel();
144 if ((addressingModel == spirv::AddressingModel::Logical) ||
145 (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
146 resultType.getStorageClass() !=
147 spirv::StorageClass::PhysicalStorageBuffer))
148 return emitError(
"result must be a physical pointer");
157 auto operandType = llvm::cast<spirv::PointerType>(getPointer().
getType());
158 auto resultType = llvm::cast<spirv::PointerType>(getResult().
getType());
160 spirv::StorageClass operandStorage = operandType.getStorageClass();
161 if (operandStorage != spirv::StorageClass::Workgroup &&
162 operandStorage != spirv::StorageClass::CrossWorkgroup &&
163 operandStorage != spirv::StorageClass::Function)
164 return emitError(
"pointer must point to the Workgroup, CrossWorkgroup"
165 ", or Function Storage Class");
167 spirv::StorageClass resultStorage = resultType.getStorageClass();
168 if (resultStorage != spirv::StorageClass::Generic)
169 return emitError(
"result type must be of storage class Generic");
171 Type operandPointeeType = operandType.getPointeeType();
172 Type resultPointeeType = resultType.getPointeeType();
173 if (operandPointeeType != resultPointeeType)
174 return emitOpError(
"pointer operand's pointee type must have the same "
175 "as the op result type, but found ")
176 << operandPointeeType <<
" vs " << resultPointeeType;
185 auto operandType = llvm::cast<spirv::PointerType>(getPointer().
getType());
186 auto resultType = llvm::cast<spirv::PointerType>(getResult().
getType());
188 spirv::StorageClass operandStorage = operandType.getStorageClass();
189 if (operandStorage != spirv::StorageClass::Generic)
190 return emitError(
"pointer type must be of storage class Generic");
192 spirv::StorageClass resultStorage = resultType.getStorageClass();
193 if (resultStorage != spirv::StorageClass::Workgroup &&
194 resultStorage != spirv::StorageClass::CrossWorkgroup &&
195 resultStorage != spirv::StorageClass::Function)
196 return emitError(
"result must point to the Workgroup, CrossWorkgroup, "
197 "or Function Storage Class");
199 Type operandPointeeType = operandType.getPointeeType();
200 Type resultPointeeType = resultType.getPointeeType();
201 if (operandPointeeType != resultPointeeType)
202 return emitOpError(
"pointer operand's pointee type must have the same "
203 "as the op result type, but found ")
204 << operandPointeeType <<
" vs " << resultPointeeType;
213 auto operandType = llvm::cast<spirv::PointerType>(getPointer().
getType());
214 auto resultType = llvm::cast<spirv::PointerType>(getResult().
getType());
216 spirv::StorageClass operandStorage = operandType.getStorageClass();
217 if (operandStorage != spirv::StorageClass::Generic)
218 return emitError(
"pointer type must be of storage class Generic");
220 spirv::StorageClass resultStorage = resultType.getStorageClass();
221 if (resultStorage != spirv::StorageClass::Workgroup &&
222 resultStorage != spirv::StorageClass::CrossWorkgroup &&
223 resultStorage != spirv::StorageClass::Function)
224 return emitError(
"result must point to the Workgroup, CrossWorkgroup, "
225 "or Function Storage Class");
227 Type operandPointeeType = operandType.getPointeeType();
228 Type resultPointeeType = resultType.getPointeeType();
229 if (operandPointeeType != resultPointeeType)
230 return emitOpError(
"pointer operand's pointee type must have the same "
231 "as the op result type, but found ")
232 << operandPointeeType <<
" vs " << resultPointeeType;
277 auto operandType = getOperand().getType();
278 auto resultType = getResult().getType();
281 if (
auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
282 unsigned operandNumElements = vectorType.getNumElements();
283 unsigned resultNumElements =
284 llvm::cast<VectorType>(resultType).getNumElements();
285 if (operandNumElements != resultNumElements) {
287 "operand and result must have same number of elements");
298 auto operandType = getOperand().getType();
299 auto resultType = getResult().getType();
302 if (
auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
303 unsigned operandNumElements = vectorType.getNumElements();
304 unsigned resultNumElements =
305 llvm::cast<VectorType>(resultType).getNumElements();
306 if (operandNumElements != resultNumElements) {
308 "operand and result must have same number of elements");
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Type getType() const
Return the type of this value.
@ Type
An inlay hint that for a type annotation.
static LogicalResult verifyCastOp(Operation *op, bool requireSameBitWidth=true, bool skipBitWidthCheck=false)
unsigned getBitWidth(Type type)
Returns the bit width of the type.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...