MLIR  20.0.0git
CastOps.cpp
Go to the documentation of this file.
1 //===- CastOps.cpp - MLIR SPIR-V Cast Ops --------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the cast and conversion operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include "SPIRVOpUtils.h"
16 #include "SPIRVParsingUtils.h"
17 
18 #include "llvm/ADT/TypeSwitch.h"
19 
20 using namespace mlir::spirv::AttrNames;
21 
22 namespace mlir::spirv {
23 
24 static LogicalResult verifyCastOp(Operation *op,
25  bool requireSameBitWidth = true,
26  bool skipBitWidthCheck = false) {
27  // Some CastOps have no limit on bit widths for result and operand type.
28  if (skipBitWidthCheck)
29  return success();
30 
31  Type operandType = op->getOperand(0).getType();
32  Type resultType = op->getResult(0).getType();
33 
34  // ODS checks that result type and operand type have the same shape. Check
35  // that composite types match and extract the element types, if any.
36  using TypePair = std::pair<Type, Type>;
37  auto [operandElemTy, resultElemTy] =
38  TypeSwitch<Type, TypePair>(operandType)
39  .Case<VectorType, spirv::CooperativeMatrixType>(
40  [resultType](auto concreteOperandTy) -> TypePair {
41  if (auto concreteResultTy =
42  dyn_cast<decltype(concreteOperandTy)>(resultType)) {
43  return {concreteOperandTy.getElementType(),
44  concreteResultTy.getElementType()};
45  }
46  return {};
47  })
48  .Default([resultType](Type operandType) -> TypePair {
49  return {operandType, resultType};
50  });
51 
52  if (!operandElemTy || !resultElemTy)
53  return op->emitOpError("incompatible operand and result types");
54 
55  unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth();
56  unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth();
57  bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
58 
59  if (requireSameBitWidth) {
60  if (!isSameBitWidth) {
61  return op->emitOpError(
62  "expected the same bit widths for operand type and result "
63  "type, but provided ")
64  << operandElemTy << " and " << resultElemTy;
65  }
66  return success();
67  }
68 
69  if (isSameBitWidth) {
70  return op->emitOpError(
71  "expected the different bit widths for operand type and result "
72  "type, but provided ")
73  << operandElemTy << " and " << resultElemTy;
74  }
75  return success();
76 }
77 
78 //===----------------------------------------------------------------------===//
79 // spirv.BitcastOp
80 //===----------------------------------------------------------------------===//
81 
82 LogicalResult BitcastOp::verify() {
83  // TODO: The SPIR-V spec validation rules are different for different
84  // versions.
85  auto operandType = getOperand().getType();
86  auto resultType = getResult().getType();
87  if (operandType == resultType) {
88  return emitError("result type must be different from operand type");
89  }
90  if (llvm::isa<spirv::PointerType>(operandType) &&
91  !llvm::isa<spirv::PointerType>(resultType)) {
92  return emitError(
93  "unhandled bit cast conversion from pointer type to non-pointer type");
94  }
95  if (!llvm::isa<spirv::PointerType>(operandType) &&
96  llvm::isa<spirv::PointerType>(resultType)) {
97  return emitError(
98  "unhandled bit cast conversion from non-pointer type to pointer type");
99  }
100  auto operandBitWidth = getBitWidth(operandType);
101  auto resultBitWidth = getBitWidth(resultType);
102  if (operandBitWidth != resultBitWidth) {
103  return emitOpError("mismatch in result type bitwidth ")
104  << resultBitWidth << " and operand type bitwidth "
105  << operandBitWidth;
106  }
107  return success();
108 }
109 
110 //===----------------------------------------------------------------------===//
111 // spirv.ConvertPtrToUOp
112 //===----------------------------------------------------------------------===//
113 
114 LogicalResult ConvertPtrToUOp::verify() {
115  auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
116  auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
117  if (!resultType || !resultType.isSignlessInteger())
118  return emitError("result must be a scalar type of unsigned integer");
119  auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
120  if (!spirvModule)
121  return success();
122  auto addressingModel = spirvModule.getAddressingModel();
123  if ((addressingModel == spirv::AddressingModel::Logical) ||
124  (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
125  operandType.getStorageClass() !=
126  spirv::StorageClass::PhysicalStorageBuffer))
127  return emitError("operand must be a physical pointer");
128  return success();
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // spirv.ConvertUToPtrOp
133 //===----------------------------------------------------------------------===//
134 
135 LogicalResult ConvertUToPtrOp::verify() {
136  auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
137  auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
138  if (!operandType || !operandType.isSignlessInteger())
139  return emitError("result must be a scalar type of unsigned integer");
140  auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
141  if (!spirvModule)
142  return success();
143  auto addressingModel = spirvModule.getAddressingModel();
144  if ((addressingModel == spirv::AddressingModel::Logical) ||
145  (addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
146  resultType.getStorageClass() !=
147  spirv::StorageClass::PhysicalStorageBuffer))
148  return emitError("result must be a physical pointer");
149  return success();
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // spirv.PtrCastToGenericOp
154 //===----------------------------------------------------------------------===//
155 
156 LogicalResult PtrCastToGenericOp::verify() {
157  auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
158  auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
159 
160  spirv::StorageClass operandStorage = operandType.getStorageClass();
161  if (operandStorage != spirv::StorageClass::Workgroup &&
162  operandStorage != spirv::StorageClass::CrossWorkgroup &&
163  operandStorage != spirv::StorageClass::Function)
164  return emitError("pointer must point to the Workgroup, CrossWorkgroup"
165  ", or Function Storage Class");
166 
167  spirv::StorageClass resultStorage = resultType.getStorageClass();
168  if (resultStorage != spirv::StorageClass::Generic)
169  return emitError("result type must be of storage class Generic");
170 
171  Type operandPointeeType = operandType.getPointeeType();
172  Type resultPointeeType = resultType.getPointeeType();
173  if (operandPointeeType != resultPointeeType)
174  return emitOpError("pointer operand's pointee type must have the same "
175  "as the op result type, but found ")
176  << operandPointeeType << " vs " << resultPointeeType;
177  return success();
178 }
179 
180 //===----------------------------------------------------------------------===//
181 // spirv.GenericCastToPtrOp
182 //===----------------------------------------------------------------------===//
183 
184 LogicalResult GenericCastToPtrOp::verify() {
185  auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
186  auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
187 
188  spirv::StorageClass operandStorage = operandType.getStorageClass();
189  if (operandStorage != spirv::StorageClass::Generic)
190  return emitError("pointer type must be of storage class Generic");
191 
192  spirv::StorageClass resultStorage = resultType.getStorageClass();
193  if (resultStorage != spirv::StorageClass::Workgroup &&
194  resultStorage != spirv::StorageClass::CrossWorkgroup &&
195  resultStorage != spirv::StorageClass::Function)
196  return emitError("result must point to the Workgroup, CrossWorkgroup, "
197  "or Function Storage Class");
198 
199  Type operandPointeeType = operandType.getPointeeType();
200  Type resultPointeeType = resultType.getPointeeType();
201  if (operandPointeeType != resultPointeeType)
202  return emitOpError("pointer operand's pointee type must have the same "
203  "as the op result type, but found ")
204  << operandPointeeType << " vs " << resultPointeeType;
205  return success();
206 }
207 
208 //===----------------------------------------------------------------------===//
209 // spirv.GenericCastToPtrExplicitOp
210 //===----------------------------------------------------------------------===//
211 
212 LogicalResult GenericCastToPtrExplicitOp::verify() {
213  auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
214  auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
215 
216  spirv::StorageClass operandStorage = operandType.getStorageClass();
217  if (operandStorage != spirv::StorageClass::Generic)
218  return emitError("pointer type must be of storage class Generic");
219 
220  spirv::StorageClass resultStorage = resultType.getStorageClass();
221  if (resultStorage != spirv::StorageClass::Workgroup &&
222  resultStorage != spirv::StorageClass::CrossWorkgroup &&
223  resultStorage != spirv::StorageClass::Function)
224  return emitError("result must point to the Workgroup, CrossWorkgroup, "
225  "or Function Storage Class");
226 
227  Type operandPointeeType = operandType.getPointeeType();
228  Type resultPointeeType = resultType.getPointeeType();
229  if (operandPointeeType != resultPointeeType)
230  return emitOpError("pointer operand's pointee type must have the same "
231  "as the op result type, but found ")
232  << operandPointeeType << " vs " << resultPointeeType;
233  return success();
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // spirv.ConvertFToSOp
238 //===----------------------------------------------------------------------===//
239 
240 LogicalResult ConvertFToSOp::verify() {
241  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
242  /*skipBitWidthCheck=*/true);
243 }
244 
245 //===----------------------------------------------------------------------===//
246 // spirv.ConvertFToUOp
247 //===----------------------------------------------------------------------===//
248 
249 LogicalResult ConvertFToUOp::verify() {
250  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
251  /*skipBitWidthCheck=*/true);
252 }
253 
254 //===----------------------------------------------------------------------===//
255 // spirv.ConvertSToFOp
256 //===----------------------------------------------------------------------===//
257 
258 LogicalResult ConvertSToFOp::verify() {
259  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
260  /*skipBitWidthCheck=*/true);
261 }
262 
263 //===----------------------------------------------------------------------===//
264 // spirv.ConvertUToFOp
265 //===----------------------------------------------------------------------===//
266 
267 LogicalResult ConvertUToFOp::verify() {
268  return verifyCastOp(*this, /*requireSameBitWidth=*/false,
269  /*skipBitWidthCheck=*/true);
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // spirv.INTELConvertBF16ToFOp
274 //===----------------------------------------------------------------------===//
275 
276 LogicalResult INTELConvertBF16ToFOp::verify() {
277  auto operandType = getOperand().getType();
278  auto resultType = getResult().getType();
279  // ODS checks that vector result type and vector operand type have the same
280  // shape.
281  if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
282  unsigned operandNumElements = vectorType.getNumElements();
283  unsigned resultNumElements =
284  llvm::cast<VectorType>(resultType).getNumElements();
285  if (operandNumElements != resultNumElements) {
286  return emitOpError(
287  "operand and result must have same number of elements");
288  }
289  }
290  return success();
291 }
292 
293 //===----------------------------------------------------------------------===//
294 // spirv.INTELConvertFToBF16Op
295 //===----------------------------------------------------------------------===//
296 
297 LogicalResult INTELConvertFToBF16Op::verify() {
298  auto operandType = getOperand().getType();
299  auto resultType = getResult().getType();
300  // ODS checks that vector result type and vector operand type have the same
301  // shape.
302  if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
303  unsigned operandNumElements = vectorType.getNumElements();
304  unsigned resultNumElements =
305  llvm::cast<VectorType>(resultType).getNumElements();
306  if (operandNumElements != resultNumElements) {
307  return emitOpError(
308  "operand and result must have same number of elements");
309  }
310  }
311  return success();
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // spirv.FConvertOp
316 //===----------------------------------------------------------------------===//
317 
318 LogicalResult spirv::FConvertOp::verify() {
319  return verifyCastOp(*this, /*requireSameBitWidth=*/false);
320 }
321 
322 //===----------------------------------------------------------------------===//
323 // spirv.SConvertOp
324 //===----------------------------------------------------------------------===//
325 
326 LogicalResult spirv::SConvertOp::verify() {
327  return verifyCastOp(*this, /*requireSameBitWidth=*/false);
328 }
329 
330 //===----------------------------------------------------------------------===//
331 // spirv.UConvertOp
332 //===----------------------------------------------------------------------===//
333 
334 LogicalResult spirv::UConvertOp::verify() {
335  return verifyCastOp(*this, /*requireSameBitWidth=*/false);
336 }
337 
338 } // namespace mlir::spirv
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Type getType() const
Return the type of this value.
Definition: Value.h:129
@ Type
An inlay hint that for a type annotation.
static LogicalResult verifyCastOp(Operation *op, bool requireSameBitWidth=true, bool skipBitWidthCheck=false)
Definition: CastOps.cpp:24
unsigned getBitWidth(Type type)
Returns the bit width of the type.
Definition: SPIRVOpUtils.h:14
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426