MLIR 22.0.0git
CastOps.cpp
Go to the documentation of this file.
1//===- CastOps.cpp - MLIR SPIR-V Cast Ops --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the cast and conversion operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "SPIRVOpUtils.h"
16#include "SPIRVParsingUtils.h"
17
18#include "llvm/ADT/TypeSwitch.h"
19
20using namespace mlir::spirv::AttrNames;
21
22namespace mlir::spirv {
23
24static LogicalResult verifyCastOp(Operation *op,
25 bool requireSameBitWidth = true,
26 bool skipBitWidthCheck = false) {
27 // Some CastOps have no limit on bit widths for result and operand type.
28 if (skipBitWidthCheck)
29 return success();
30
31 Type operandType = op->getOperand(0).getType();
32 Type resultType = op->getResult(0).getType();
33
34 // ODS checks that result type and operand type have the same shape. Check
35 // that composite types match and extract the element types, if any.
36 using TypePair = std::pair<Type, Type>;
37 auto [operandElemTy, resultElemTy] =
39 .Case<VectorType, spirv::CooperativeMatrixType>(
40 [resultType](auto concreteOperandTy) -> TypePair {
41 if (auto concreteResultTy =
42 dyn_cast<decltype(concreteOperandTy)>(resultType)) {
43 return {concreteOperandTy.getElementType(),
44 concreteResultTy.getElementType()};
45 }
46 return {};
47 })
48 .Default([resultType](Type operandType) -> TypePair {
49 return {operandType, resultType};
50 });
51
52 if (!operandElemTy || !resultElemTy)
53 return op->emitOpError("incompatible operand and result types");
54
55 unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth();
56 unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth();
57 bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
58
59 if (requireSameBitWidth) {
60 if (!isSameBitWidth) {
61 return op->emitOpError(
62 "expected the same bit widths for operand type and result "
63 "type, but provided ")
64 << operandElemTy << " and " << resultElemTy;
65 }
66 return success();
67 }
68
69 if (isSameBitWidth) {
70 return op->emitOpError(
71 "expected the different bit widths for operand type and result "
72 "type, but provided ")
73 << operandElemTy << " and " << resultElemTy;
74 }
75 return success();
76}
77
78//===----------------------------------------------------------------------===//
79// spirv.BitcastOp
80//===----------------------------------------------------------------------===//
81
82LogicalResult BitcastOp::verify() {
83 // TODO: The SPIR-V spec validation rules are different for different
84 // versions.
85 auto operandType = getOperand().getType();
86 auto resultType = getResult().getType();
87 if (operandType == resultType) {
88 return emitError("result type must be different from operand type");
89 }
90 if (llvm::isa<spirv::PointerType>(operandType) &&
91 !llvm::isa<spirv::PointerType>(resultType)) {
92 return emitError(
93 "unhandled bit cast conversion from pointer type to non-pointer type");
94 }
95 if (!llvm::isa<spirv::PointerType>(operandType) &&
96 llvm::isa<spirv::PointerType>(resultType)) {
97 return emitError(
98 "unhandled bit cast conversion from non-pointer type to pointer type");
99 }
100 auto operandBitWidth = getBitWidth(operandType);
101 auto resultBitWidth = getBitWidth(resultType);
102 if (operandBitWidth != resultBitWidth) {
103 return emitOpError("mismatch in result type bitwidth ")
104 << resultBitWidth << " and operand type bitwidth "
105 << operandBitWidth;
106 }
107 return success();
108}
109
110//===----------------------------------------------------------------------===//
111// spirv.ConvertPtrToUOp
112//===----------------------------------------------------------------------===//
113
114LogicalResult ConvertPtrToUOp::verify() {
115 auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
116 auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
117 if (!resultType || !resultType.isSignlessInteger())
118 return emitError("result must be a scalar type of unsigned integer");
119 auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
120 if (!spirvModule)
121 return success();
122 auto addressingModel = spirvModule.getAddressingModel();
123 if ((addressingModel == spirv::AddressingModel::Logical) ||
124 (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
125 operandType.getStorageClass() !=
126 spirv::StorageClass::PhysicalStorageBuffer))
127 return emitError("operand must be a physical pointer");
128 return success();
129}
130
131//===----------------------------------------------------------------------===//
132// spirv.ConvertUToPtrOp
133//===----------------------------------------------------------------------===//
134
135LogicalResult ConvertUToPtrOp::verify() {
136 auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
137 auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
138 if (!operandType || !operandType.isSignlessInteger())
139 return emitError("result must be a scalar type of unsigned integer");
140 auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
141 if (!spirvModule)
142 return success();
143 auto addressingModel = spirvModule.getAddressingModel();
144 if ((addressingModel == spirv::AddressingModel::Logical) ||
145 (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
146 resultType.getStorageClass() !=
147 spirv::StorageClass::PhysicalStorageBuffer))
148 return emitError("result must be a physical pointer");
149 return success();
150}
151
152//===----------------------------------------------------------------------===//
153// spirv.PtrCastToGenericOp
154//===----------------------------------------------------------------------===//
155
156LogicalResult PtrCastToGenericOp::verify() {
157 auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
158 auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
159
160 spirv::StorageClass operandStorage = operandType.getStorageClass();
161 if (operandStorage != spirv::StorageClass::Workgroup &&
162 operandStorage != spirv::StorageClass::CrossWorkgroup &&
163 operandStorage != spirv::StorageClass::Function)
164 return emitError("pointer must point to the Workgroup, CrossWorkgroup"
165 ", or Function Storage Class");
166
167 spirv::StorageClass resultStorage = resultType.getStorageClass();
168 if (resultStorage != spirv::StorageClass::Generic)
169 return emitError("result type must be of storage class Generic");
170
171 Type operandPointeeType = operandType.getPointeeType();
172 Type resultPointeeType = resultType.getPointeeType();
173 if (operandPointeeType != resultPointeeType)
174 return emitOpError("pointer operand's pointee type must have the same "
175 "as the op result type, but found ")
176 << operandPointeeType << " vs " << resultPointeeType;
177 return success();
178}
179
180//===----------------------------------------------------------------------===//
181// spirv.GenericCastToPtrOp
182//===----------------------------------------------------------------------===//
183
184LogicalResult GenericCastToPtrOp::verify() {
185 auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
186 auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
187
188 spirv::StorageClass operandStorage = operandType.getStorageClass();
189 if (operandStorage != spirv::StorageClass::Generic)
190 return emitError("pointer type must be of storage class Generic");
191
192 spirv::StorageClass resultStorage = resultType.getStorageClass();
193 if (resultStorage != spirv::StorageClass::Workgroup &&
194 resultStorage != spirv::StorageClass::CrossWorkgroup &&
195 resultStorage != spirv::StorageClass::Function)
196 return emitError("result must point to the Workgroup, CrossWorkgroup, "
197 "or Function Storage Class");
198
199 Type operandPointeeType = operandType.getPointeeType();
200 Type resultPointeeType = resultType.getPointeeType();
201 if (operandPointeeType != resultPointeeType)
202 return emitOpError("pointer operand's pointee type must have the same "
203 "as the op result type, but found ")
204 << operandPointeeType << " vs " << resultPointeeType;
205 return success();
206}
207
208//===----------------------------------------------------------------------===//
209// spirv.GenericCastToPtrExplicitOp
210//===----------------------------------------------------------------------===//
211
212LogicalResult GenericCastToPtrExplicitOp::verify() {
213 auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
214 auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
215
216 spirv::StorageClass operandStorage = operandType.getStorageClass();
217 if (operandStorage != spirv::StorageClass::Generic)
218 return emitError("pointer type must be of storage class Generic");
219
220 spirv::StorageClass resultStorage = resultType.getStorageClass();
221 if (resultStorage != spirv::StorageClass::Workgroup &&
222 resultStorage != spirv::StorageClass::CrossWorkgroup &&
223 resultStorage != spirv::StorageClass::Function)
224 return emitError("result must point to the Workgroup, CrossWorkgroup, "
225 "or Function Storage Class");
226
227 Type operandPointeeType = operandType.getPointeeType();
228 Type resultPointeeType = resultType.getPointeeType();
229 if (operandPointeeType != resultPointeeType)
230 return emitOpError("pointer operand's pointee type must have the same "
231 "as the op result type, but found ")
232 << operandPointeeType << " vs " << resultPointeeType;
233 return success();
234}
235
236//===----------------------------------------------------------------------===//
237// spirv.ConvertFToSOp
238//===----------------------------------------------------------------------===//
239
240LogicalResult ConvertFToSOp::verify() {
241 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
242 /*skipBitWidthCheck=*/true);
243}
244
245//===----------------------------------------------------------------------===//
246// spirv.ConvertFToUOp
247//===----------------------------------------------------------------------===//
248
249LogicalResult ConvertFToUOp::verify() {
250 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
251 /*skipBitWidthCheck=*/true);
252}
253
254//===----------------------------------------------------------------------===//
255// spirv.ConvertSToFOp
256//===----------------------------------------------------------------------===//
257
258LogicalResult ConvertSToFOp::verify() {
259 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
260 /*skipBitWidthCheck=*/true);
261}
262
263//===----------------------------------------------------------------------===//
264// spirv.ConvertUToFOp
265//===----------------------------------------------------------------------===//
266
267LogicalResult ConvertUToFOp::verify() {
268 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
269 /*skipBitWidthCheck=*/true);
270}
271
272//===----------------------------------------------------------------------===//
273// spirv.FConvertOp
274//===----------------------------------------------------------------------===//
275
276LogicalResult spirv::FConvertOp::verify() {
277 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
278}
279
280//===----------------------------------------------------------------------===//
281// spirv.SConvertOp
282//===----------------------------------------------------------------------===//
283
284LogicalResult spirv::SConvertOp::verify() {
285 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
286}
287
288//===----------------------------------------------------------------------===//
289// spirv.UConvertOp
290//===----------------------------------------------------------------------===//
291
292LogicalResult spirv::UConvertOp::verify() {
293 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
294}
295
296} // namespace mlir::spirv
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Type getType() const
Return the type of this value.
Definition Value.h:105
static LogicalResult verifyCastOp(Operation *op, bool requireSameBitWidth=true, bool skipBitWidthCheck=false)
Definition CastOps.cpp:24
unsigned getBitWidth(Type type)
Returns the bit width of the type.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144