MLIR 23.0.0git
CastOps.cpp
Go to the documentation of this file.
1//===- CastOps.cpp - MLIR SPIR-V Cast Ops --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Defines the cast and conversion operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
15#include "SPIRVOpUtils.h"
16#include "SPIRVParsingUtils.h"
17
18#include "llvm/ADT/TypeSwitch.h"
19
20using namespace mlir::spirv::AttrNames;
21
22namespace mlir::spirv {
23
24static LogicalResult verifyCastOp(Operation *op,
25 bool requireSameBitWidth = true,
26 bool skipBitWidthCheck = false) {
27 // Some CastOps have no limit on bit widths for result and operand type.
28 if (skipBitWidthCheck)
29 return success();
30
31 Type operandType = op->getOperand(0).getType();
32 Type resultType = op->getResult(0).getType();
33
34 // ODS checks that result type and operand type have the same shape. Check
35 // that composite types match and extract the element types, if any.
36 using TypePair = std::pair<Type, Type>;
37 auto [operandElemTy, resultElemTy] =
39 .Case<VectorType, spirv::CooperativeMatrixType>(
40 [resultType](auto concreteOperandTy) -> TypePair {
41 if (auto concreteResultTy =
42 dyn_cast<decltype(concreteOperandTy)>(resultType)) {
43 return {concreteOperandTy.getElementType(),
44 concreteResultTy.getElementType()};
45 }
46 return {};
47 })
48 .Default([resultType](Type operandType) -> TypePair {
49 return {operandType, resultType};
50 });
51
52 if (!operandElemTy || !resultElemTy)
53 return op->emitOpError("incompatible operand and result types");
54
55 unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth();
56 unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth();
57 bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
58
59 if (requireSameBitWidth) {
60 if (!isSameBitWidth) {
61 return op->emitOpError(
62 "expected the same bit widths for operand type and result "
63 "type, but provided ")
64 << operandElemTy << " and " << resultElemTy;
65 }
66 return success();
67 }
68
69 if (isSameBitWidth) {
70 return op->emitOpError(
71 "expected the different bit widths for operand type and result "
72 "type, but provided ")
73 << operandElemTy << " and " << resultElemTy;
74 }
75 return success();
76}
77
78//===----------------------------------------------------------------------===//
79// spirv.BitcastOp
80//===----------------------------------------------------------------------===//
81
82LogicalResult BitcastOp::verify() {
83 // TODO: The SPIR-V spec validation rules are different for different
84 // versions.
85 auto operandType = getOperand().getType();
86 auto resultType = getResult().getType();
87 if (operandType == resultType) {
88 return emitError("result type must be different from operand type");
89 }
90
91 auto operandCoopMatrixType =
92 dyn_cast<spirv::CooperativeMatrixType>(operandType);
93 auto resultCoopMatrixType =
94 dyn_cast<spirv::CooperativeMatrixType>(resultType);
95 if (operandCoopMatrixType || resultCoopMatrixType) {
96 if (!operandCoopMatrixType || !resultCoopMatrixType)
97 return emitError("unhandled bit cast conversion from cooperative matrix "
98 "type to non-cooperative matrix type");
99
100 if (operandCoopMatrixType.getRows() != resultCoopMatrixType.getRows() ||
101 operandCoopMatrixType.getColumns() != resultCoopMatrixType.getColumns())
102 return emitError("cooperative matrix dimensions must match");
103
104 if (operandCoopMatrixType.getScope() != resultCoopMatrixType.getScope())
105 return emitError("cooperative matrix scope must match");
106
107 if (operandCoopMatrixType.getUse() != resultCoopMatrixType.getUse())
108 return emitError("cooperative matrix use must match");
109
110 unsigned operandBitWidth =
111 getBitWidth(operandCoopMatrixType.getElementType());
112 unsigned resultBitWidth =
113 getBitWidth(resultCoopMatrixType.getElementType());
114 if (operandBitWidth != resultBitWidth)
115 return emitOpError("mismatch in result and operand type bitwidth");
116
117 return success();
118 }
119
120 if (isa<spirv::PointerType>(operandType) &&
121 !isa<spirv::PointerType>(resultType)) {
122 return emitError(
123 "unhandled bit cast conversion from pointer type to non-pointer type");
124 }
125 if (!isa<spirv::PointerType>(operandType) &&
126 isa<spirv::PointerType>(resultType)) {
127 return emitError(
128 "unhandled bit cast conversion from non-pointer type to pointer type");
129 }
130 auto operandBitWidth = getBitWidth(operandType);
131 auto resultBitWidth = getBitWidth(resultType);
132 if (operandBitWidth != resultBitWidth) {
133 return emitOpError("mismatch in result type bitwidth ")
134 << resultBitWidth << " and operand type bitwidth "
135 << operandBitWidth;
136 }
137 return success();
138}
139
140//===----------------------------------------------------------------------===//
141// spirv.ConvertPtrToUOp
142//===----------------------------------------------------------------------===//
143
144LogicalResult ConvertPtrToUOp::verify() {
145 auto operandType = cast<spirv::PointerType>(getPointer().getType());
146 auto resultType = cast<spirv::ScalarType>(getResult().getType());
147 if (!resultType || !resultType.isSignlessInteger())
148 return emitError("result must be a scalar type of unsigned integer");
149 auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
150 if (!spirvModule)
151 return success();
152 auto addressingModel = spirvModule.getAddressingModel();
153 if ((addressingModel == spirv::AddressingModel::Logical) ||
154 (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
155 operandType.getStorageClass() !=
156 spirv::StorageClass::PhysicalStorageBuffer))
157 return emitError("operand must be a physical pointer");
158 return success();
159}
160
161//===----------------------------------------------------------------------===//
162// spirv.ConvertUToPtrOp
163//===----------------------------------------------------------------------===//
164
165LogicalResult ConvertUToPtrOp::verify() {
166 auto operandType = cast<spirv::ScalarType>(getOperand().getType());
167 auto resultType = cast<spirv::PointerType>(getResult().getType());
168 if (!operandType || !operandType.isSignlessInteger())
169 return emitError("result must be a scalar type of unsigned integer");
170 auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
171 if (!spirvModule)
172 return success();
173 auto addressingModel = spirvModule.getAddressingModel();
174 if ((addressingModel == spirv::AddressingModel::Logical) ||
175 (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
176 resultType.getStorageClass() !=
177 spirv::StorageClass::PhysicalStorageBuffer))
178 return emitError("result must be a physical pointer");
179 return success();
180}
181
182//===----------------------------------------------------------------------===//
183// spirv.PtrCastToGenericOp
184//===----------------------------------------------------------------------===//
185
186LogicalResult PtrCastToGenericOp::verify() {
187 auto operandType = cast<spirv::PointerType>(getPointer().getType());
188 auto resultType = cast<spirv::PointerType>(getResult().getType());
189
190 spirv::StorageClass operandStorage = operandType.getStorageClass();
191 if (operandStorage != spirv::StorageClass::Workgroup &&
192 operandStorage != spirv::StorageClass::CrossWorkgroup &&
193 operandStorage != spirv::StorageClass::Function)
194 return emitError("pointer must point to the Workgroup, CrossWorkgroup"
195 ", or Function Storage Class");
196
197 spirv::StorageClass resultStorage = resultType.getStorageClass();
198 if (resultStorage != spirv::StorageClass::Generic)
199 return emitError("result type must be of storage class Generic");
200
201 Type operandPointeeType = operandType.getPointeeType();
202 Type resultPointeeType = resultType.getPointeeType();
203 if (operandPointeeType != resultPointeeType)
204 return emitOpError("pointer operand's pointee type must have the same "
205 "as the op result type, but found ")
206 << operandPointeeType << " vs " << resultPointeeType;
207 return success();
208}
209
210//===----------------------------------------------------------------------===//
211// spirv.GenericCastToPtrOp
212//===----------------------------------------------------------------------===//
213
214LogicalResult GenericCastToPtrOp::verify() {
215 auto operandType = cast<spirv::PointerType>(getPointer().getType());
216 auto resultType = cast<spirv::PointerType>(getResult().getType());
217
218 spirv::StorageClass operandStorage = operandType.getStorageClass();
219 if (operandStorage != spirv::StorageClass::Generic)
220 return emitError("pointer type must be of storage class Generic");
221
222 spirv::StorageClass resultStorage = resultType.getStorageClass();
223 if (resultStorage != spirv::StorageClass::Workgroup &&
224 resultStorage != spirv::StorageClass::CrossWorkgroup &&
225 resultStorage != spirv::StorageClass::Function)
226 return emitError("result must point to the Workgroup, CrossWorkgroup, "
227 "or Function Storage Class");
228
229 Type operandPointeeType = operandType.getPointeeType();
230 Type resultPointeeType = resultType.getPointeeType();
231 if (operandPointeeType != resultPointeeType)
232 return emitOpError("pointer operand's pointee type must have the same "
233 "as the op result type, but found ")
234 << operandPointeeType << " vs " << resultPointeeType;
235 return success();
236}
237
238//===----------------------------------------------------------------------===//
239// spirv.GenericCastToPtrExplicitOp
240//===----------------------------------------------------------------------===//
241
242LogicalResult GenericCastToPtrExplicitOp::verify() {
243 auto operandType = cast<spirv::PointerType>(getPointer().getType());
244 auto resultType = cast<spirv::PointerType>(getResult().getType());
245
246 spirv::StorageClass operandStorage = operandType.getStorageClass();
247 if (operandStorage != spirv::StorageClass::Generic)
248 return emitError("pointer type must be of storage class Generic");
249
250 spirv::StorageClass resultStorage = resultType.getStorageClass();
251 if (resultStorage != spirv::StorageClass::Workgroup &&
252 resultStorage != spirv::StorageClass::CrossWorkgroup &&
253 resultStorage != spirv::StorageClass::Function)
254 return emitError("result must point to the Workgroup, CrossWorkgroup, "
255 "or Function Storage Class");
256
257 Type operandPointeeType = operandType.getPointeeType();
258 Type resultPointeeType = resultType.getPointeeType();
259 if (operandPointeeType != resultPointeeType)
260 return emitOpError("pointer operand's pointee type must have the same "
261 "as the op result type, but found ")
262 << operandPointeeType << " vs " << resultPointeeType;
263 return success();
264}
265
266//===----------------------------------------------------------------------===//
267// spirv.ConvertFToSOp
268//===----------------------------------------------------------------------===//
269
270LogicalResult ConvertFToSOp::verify() {
271 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
272 /*skipBitWidthCheck=*/true);
273}
274
275//===----------------------------------------------------------------------===//
276// spirv.ConvertFToUOp
277//===----------------------------------------------------------------------===//
278
279LogicalResult ConvertFToUOp::verify() {
280 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
281 /*skipBitWidthCheck=*/true);
282}
283
284//===----------------------------------------------------------------------===//
285// spirv.ConvertSToFOp
286//===----------------------------------------------------------------------===//
287
288LogicalResult ConvertSToFOp::verify() {
289 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
290 /*skipBitWidthCheck=*/true);
291}
292
293//===----------------------------------------------------------------------===//
294// spirv.ConvertUToFOp
295//===----------------------------------------------------------------------===//
296
297LogicalResult ConvertUToFOp::verify() {
298 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
299 /*skipBitWidthCheck=*/true);
300}
301
302//===----------------------------------------------------------------------===//
303// spirv.FConvertOp
304//===----------------------------------------------------------------------===//
305
306LogicalResult spirv::FConvertOp::verify() {
307 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
308}
309
310//===----------------------------------------------------------------------===//
311// spirv.SConvertOp
312//===----------------------------------------------------------------------===//
313
314LogicalResult spirv::SConvertOp::verify() {
315 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
316}
317
318//===----------------------------------------------------------------------===//
319// spirv.UConvertOp
320//===----------------------------------------------------------------------===//
321
322LogicalResult spirv::UConvertOp::verify() {
323 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
324}
325
326} // namespace mlir::spirv
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:376
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Type getType() const
Return the type of this value.
Definition Value.h:105
static LogicalResult verifyCastOp(Operation *op, bool requireSameBitWidth=true, bool skipBitWidthCheck=false)
Definition CastOps.cpp:24
unsigned getBitWidth(Type type)
Returns the bit width of the type.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139