MLIR  19.0.0git
CastOps.cpp
Go to the documentation of this file.
1 //===- CastOps.cpp - MLIR SPIR-V Cast Ops --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the cast and conversion operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 using namespace mlir::spirv::AttrNames;
21 
22 namespace mlir::spirv {
23 
25  bool requireSameBitWidth = true,
26  bool skipBitWidthCheck = false) {
27  // Some CastOps have no limit on bit widths for result and operand type.
28  if (skipBitWidthCheck)
29  return success();
30 
31  Type operandType = op->getOperand(0).getType();
32  Type resultType = op->getResult(0).getType();
33 
34  // ODS checks that result type and operand type have the same shape. Check
35  // that composite types match and extract the element types, if any.
36  using TypePair = std::pair<Type, Type>;
37  auto [operandElemTy, resultElemTy] =
38  TypeSwitch<Type, TypePair>(operandType)
39  .Case<VectorType, spirv::CooperativeMatrixType,
41  [resultType](auto concreteOperandTy) -> TypePair {
42  if (auto concreteResultTy =
43  dyn_cast<decltype(concreteOperandTy)>(resultType)) {
44  return {concreteOperandTy.getElementType(),
45  concreteResultTy.getElementType()};
46  }
47  return {};
48  })
49  .Default([resultType](Type operandType) -> TypePair {
50  return {operandType, resultType};
51  });
52 
53  if (!operandElemTy || !resultElemTy)
54  return op->emitOpError("incompatible operand and result types");
55 
56  unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth();
57  unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth();
58  bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
59 
60  if (requireSameBitWidth) {
61  if (!isSameBitWidth) {
62  return op->emitOpError(
63  "expected the same bit widths for operand type and result "
64  "type, but provided ")
65  << operandElemTy << " and " << resultElemTy;
66  }
67  return success();
68  }
69 
70  if (isSameBitWidth) {
71  return op->emitOpError(
72  "expected the different bit widths for operand type and result "
73  "type, but provided ")
74  << operandElemTy << " and " << resultElemTy;
75  }
76  return success();
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // spirv.BitcastOp
81 //===----------------------------------------------------------------------===//
82 
84  // TODO: The SPIR-V spec validation rules are different for different
85  // versions.
86  auto operandType = getOperand().getType();
87  auto resultType = getResult().getType();
88  if (operandType == resultType) {
89  return emitError("result type must be different from operand type");
90  }
91  if (llvm::isa<spirv::PointerType>(operandType) &&
92  !llvm::isa<spirv::PointerType>(resultType)) {
93  return emitError(
94  "unhandled bit cast conversion from pointer type to non-pointer type");
95  }
96  if (!llvm::isa<spirv::PointerType>(operandType) &&
97  llvm::isa<spirv::PointerType>(resultType)) {
98  return emitError(
99  "unhandled bit cast conversion from non-pointer type to pointer type");
100  }
101  auto operandBitWidth = getBitWidth(operandType);
102  auto resultBitWidth = getBitWidth(resultType);
103  if (operandBitWidth != resultBitWidth) {
104  return emitOpError("mismatch in result type bitwidth ")
105  << resultBitWidth << " and operand type bitwidth "
106  << operandBitWidth;
107  }
108  return success();
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // spirv.ConvertPtrToUOp
113 //===----------------------------------------------------------------------===//
114 
115 LogicalResult ConvertPtrToUOp::verify() {
116  auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
117  auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
118  if (!resultType || !resultType.isSignlessInteger())
119  return emitError("result must be a scalar type of unsigned integer");
120  auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
121  if (!spirvModule)
122  return success();
123  auto addressingModel = spirvModule.getAddressingModel();
124  if ((addressingModel == spirv::AddressingModel::Logical) ||
125  (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
126  operandType.getStorageClass() !=
127  spirv::StorageClass::PhysicalStorageBuffer))
128  return emitError("operand must be a physical pointer");
129  return success();
130 }
131 
132 //===----------------------------------------------------------------------===//
133 // spirv.ConvertUToPtrOp
134 //===----------------------------------------------------------------------===//
135 
136 LogicalResult ConvertUToPtrOp::verify() {
137  auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
138  auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
139  if (!operandType || !operandType.isSignlessInteger())
140  return emitError("result must be a scalar type of unsigned integer");
141  auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
142  if (!spirvModule)
143  return success();
144  auto addressingModel = spirvModule.getAddressingModel();
145  if ((addressingModel == spirv::AddressingModel::Logical) ||
146  (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
147  resultType.getStorageClass() !=
148  spirv::StorageClass::PhysicalStorageBuffer))
149  return emitError("result must be a physical pointer");
150  return success();
151 }
152 
153 //===----------------------------------------------------------------------===//
154 // spirv.PtrCastToGenericOp
155 //===----------------------------------------------------------------------===//
156 
157 LogicalResult PtrCastToGenericOp::verify() {
158  auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
159  auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
160 
161  spirv::StorageClass operandStorage = operandType.getStorageClass();
162  if (operandStorage != spirv::StorageClass::Workgroup &&
163  operandStorage != spirv::StorageClass::CrossWorkgroup &&
164  operandStorage != spirv::StorageClass::Function)
165  return emitError("pointer must point to the Workgroup, CrossWorkgroup"
166  ", or Function Storage Class");
167 
168  spirv::StorageClass resultStorage = resultType.getStorageClass();
169  if (resultStorage != spirv::StorageClass::Generic)
170  return emitError("result type must be of storage class Generic");
171 
172  Type operandPointeeType = operandType.getPointeeType();
173  Type resultPointeeType = resultType.getPointeeType();
174  if (operandPointeeType != resultPointeeType)
175  return emitOpError("pointer operand's pointee type must have the same "
176  "as the op result type, but found ")
177  << operandPointeeType << " vs " << resultPointeeType;
178  return success();
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // spirv.GenericCastToPtrOp
183 //===----------------------------------------------------------------------===//
184 
185 LogicalResult GenericCastToPtrOp::verify() {
186  auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
187  auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
188 
189  spirv::StorageClass operandStorage = operandType.getStorageClass();
190  if (operandStorage != spirv::StorageClass::Generic)
191  return emitError("pointer type must be of storage class Generic");
192 
193  spirv::StorageClass resultStorage = resultType.getStorageClass();
194  if (resultStorage != spirv::StorageClass::Workgroup &&
195  resultStorage != spirv::StorageClass::CrossWorkgroup &&
196  resultStorage != spirv::StorageClass::Function)
197  return emitError("result must point to the Workgroup, CrossWorkgroup, "
198  "or Function Storage Class");
199 
200  Type operandPointeeType = operandType.getPointeeType();
201  Type resultPointeeType = resultType.getPointeeType();
202  if (operandPointeeType != resultPointeeType)
203  return emitOpError("pointer operand's pointee type must have the same "
204  "as the op result type, but found ")
205  << operandPointeeType << " vs " << resultPointeeType;
206  return success();
207 }
208 
209 //===----------------------------------------------------------------------===//
210 // spirv.GenericCastToPtrExplicitOp
211 //===----------------------------------------------------------------------===//
212 
213 LogicalResult GenericCastToPtrExplicitOp::verify() {
214  auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
215  auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
216 
217  spirv::StorageClass operandStorage = operandType.getStorageClass();
218  if (operandStorage != spirv::StorageClass::Generic)
219  return emitError("pointer type must be of storage class Generic");
220 
221  spirv::StorageClass resultStorage = resultType.getStorageClass();
222  if (resultStorage != spirv::StorageClass::Workgroup &&
223  resultStorage != spirv::StorageClass::CrossWorkgroup &&
224  resultStorage != spirv::StorageClass::Function)
225  return emitError("result must point to the Workgroup, CrossWorkgroup, "
226  "or Function Storage Class");
227 
228  Type operandPointeeType = operandType.getPointeeType();
229  Type resultPointeeType = resultType.getPointeeType();
230  if (operandPointeeType != resultPointeeType)
231  return emitOpError("pointer operand's pointee type must have the same "
232  "as the op result type, but found ")
233  << operandPointeeType << " vs " << resultPointeeType;
234  return success();
235 }
236 
237 //===----------------------------------------------------------------------===//
238 // spirv.ConvertFToSOp
239 //===----------------------------------------------------------------------===//
240 
241 LogicalResult ConvertFToSOp::verify() {
242  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
243  /*skipBitWidthCheck=*/true);
244 }
245 
246 //===----------------------------------------------------------------------===//
247 // spirv.ConvertFToUOp
248 //===----------------------------------------------------------------------===//
249 
250 LogicalResult ConvertFToUOp::verify() {
251  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
252  /*skipBitWidthCheck=*/true);
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // spirv.ConvertSToFOp
257 //===----------------------------------------------------------------------===//
258 
259 LogicalResult ConvertSToFOp::verify() {
260  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
261  /*skipBitWidthCheck=*/true);
262 }
263 
264 //===----------------------------------------------------------------------===//
265 // spirv.ConvertUToFOp
266 //===----------------------------------------------------------------------===//
267 
268 LogicalResult ConvertUToFOp::verify() {
269  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
270  /*skipBitWidthCheck=*/true);
271 }
272 
273 //===----------------------------------------------------------------------===//
274 // spirv.INTELConvertBF16ToFOp
275 //===----------------------------------------------------------------------===//
276 
277 LogicalResult INTELConvertBF16ToFOp::verify() {
278  auto operandType = getOperand().getType();
279  auto resultType = getResult().getType();
280  // ODS checks that vector result type and vector operand type have the same
281  // shape.
282  if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
283  unsigned operandNumElements = vectorType.getNumElements();
284  unsigned resultNumElements =
285  llvm::cast<VectorType>(resultType).getNumElements();
286  if (operandNumElements != resultNumElements) {
287  return emitOpError(
288  "operand and result must have same number of elements");
289  }
290  }
291  return success();
292 }
293 
294 //===----------------------------------------------------------------------===//
295 // spirv.INTELConvertFToBF16Op
296 //===----------------------------------------------------------------------===//
297 
298 LogicalResult INTELConvertFToBF16Op::verify() {
299  auto operandType = getOperand().getType();
300  auto resultType = getResult().getType();
301  // ODS checks that vector result type and vector operand type have the same
302  // shape.
303  if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
304  unsigned operandNumElements = vectorType.getNumElements();
305  unsigned resultNumElements =
306  llvm::cast<VectorType>(resultType).getNumElements();
307  if (operandNumElements != resultNumElements) {
308  return emitOpError(
309  "operand and result must have same number of elements");
310  }
311  }
312  return success();
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // spirv.FConvertOp
317 //===----------------------------------------------------------------------===//
318 
319 LogicalResult spirv::FConvertOp::verify() {
320  return verifyCastOp(*this, /*requireSameBitWidth=*/false);
321 }
322 
323 //===----------------------------------------------------------------------===//
324 // spirv.SConvertOp
325 //===----------------------------------------------------------------------===//
326 
327 LogicalResult spirv::SConvertOp::verify() {
328  return verifyCastOp(*this, /*requireSameBitWidth=*/false);
329 }
330 
331 //===----------------------------------------------------------------------===//
332 // spirv.UConvertOp
333 //===----------------------------------------------------------------------===//
334 
335 LogicalResult spirv::UConvertOp::verify() {
336  return verifyCastOp(*this, /*requireSameBitWidth=*/false);
337 }
338 
339 } // namespace mlir::spirv
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
@ Type
An inlay hint that for a type annotation.
static LogicalResult verifyCastOp(Operation *op, bool requireSameBitWidth=true, bool skipBitWidthCheck=false)
Definition: CastOps.cpp:24
unsigned getBitWidth(Type type)
Returns the bit width of the type.
Definition: SPIRVOpUtils.h:14
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26