MLIR  19.0.0git
SPIRVConversion.cpp
Go to the documentation of this file.
1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
20 #include "mlir/IR/BuiltinTypes.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/MathExtras.h"
25 
26 #include <functional>
27 #include <optional>
28 
29 #define DEBUG_TYPE "mlir-spirv-conversion"
30 
31 using namespace mlir;
32 
33 //===----------------------------------------------------------------------===//
34 // Utility functions
35 //===----------------------------------------------------------------------===//
36 
37 /// Checks that `candidates` extension requirements are possible to be satisfied
38 /// with the given `targetEnv`.
39 ///
40 /// `candidates` is a vector of vector for extension requirements following
41 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
42 /// convention.
43 template <typename LabelT>
45  LabelT label, const spirv::TargetEnv &targetEnv,
47  for (const auto &ors : candidates) {
48  if (targetEnv.allows(ors))
49  continue;
50 
51  LLVM_DEBUG({
52  SmallVector<StringRef> extStrings;
53  for (spirv::Extension ext : ors)
54  extStrings.push_back(spirv::stringifyExtension(ext));
55 
56  llvm::dbgs() << label << " illegal: requires at least one extension in ["
57  << llvm::join(extStrings, ", ")
58  << "] but none allowed in target environment\n";
59  });
60  return failure();
61  }
62  return success();
63 }
64 
65 /// Checks that `candidates`capability requirements are possible to be satisfied
66 /// with the given `isAllowedFn`.
67 ///
68 /// `candidates` is a vector of vector for capability requirements following
69 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
70 /// convention.
71 template <typename LabelT>
73  LabelT label, const spirv::TargetEnv &targetEnv,
75  for (const auto &ors : candidates) {
76  if (targetEnv.allows(ors))
77  continue;
78 
79  LLVM_DEBUG({
80  SmallVector<StringRef> capStrings;
81  for (spirv::Capability cap : ors)
82  capStrings.push_back(spirv::stringifyCapability(cap));
83 
84  llvm::dbgs() << label << " illegal: requires at least one capability in ["
85  << llvm::join(capStrings, ", ")
86  << "] but none allowed in target environment\n";
87  });
88  return failure();
89  }
90  return success();
91 }
92 
93 /// Returns true if the given `storageClass` needs explicit layout when used in
94 /// Shader environments.
95 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
96  switch (storageClass) {
97  case spirv::StorageClass::PhysicalStorageBuffer:
98  case spirv::StorageClass::PushConstant:
99  case spirv::StorageClass::StorageBuffer:
100  case spirv::StorageClass::Uniform:
101  return true;
102  default:
103  return false;
104  }
105 }
106 
107 /// Wraps the given `elementType` in a struct and gets the pointer to the
108 /// struct. This is used to satisfy Vulkan interface requirements.
109 static spirv::PointerType
110 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
111  auto structType = needsExplicitLayout(storageClass)
112  ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
113  : spirv::StructType::get(elementType);
114  return spirv::PointerType::get(structType, storageClass);
115 }
116 
117 //===----------------------------------------------------------------------===//
118 // Type Conversion
119 //===----------------------------------------------------------------------===//
120 
123  return cast<spirv::ScalarType>(
124  IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
125 }
126 
128  return ::getIndexType(getContext(), options);
129 }
130 
131 MLIRContext *SPIRVTypeConverter::getContext() const {
132  return targetEnv.getAttr().getContext();
133 }
134 
135 bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
136  return targetEnv.allows(capability);
137 }
138 
139 // TODO: This is a utility function that should probably be exposed by the
140 // SPIR-V dialect. Keeping it local till the use case arises.
141 static std::optional<int64_t>
143  if (isa<spirv::ScalarType>(type)) {
144  auto bitWidth = type.getIntOrFloatBitWidth();
145  // According to the SPIR-V spec:
146  // "There is no physical size or bit pattern defined for values with boolean
147  // type. If they are stored (in conjunction with OpVariable), they can only
148  // be used with logical addressing operations, not physical, and only with
149  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
150  // Private, Function, Input, and Output."
151  if (bitWidth == 1)
152  return std::nullopt;
153  return bitWidth / 8;
154  }
155 
156  if (auto complexType = dyn_cast<ComplexType>(type)) {
157  auto elementSize = getTypeNumBytes(options, complexType.getElementType());
158  if (!elementSize)
159  return std::nullopt;
160  return 2 * *elementSize;
161  }
162 
163  if (auto vecType = dyn_cast<VectorType>(type)) {
164  auto elementSize = getTypeNumBytes(options, vecType.getElementType());
165  if (!elementSize)
166  return std::nullopt;
167  return vecType.getNumElements() * *elementSize;
168  }
169 
170  if (auto memRefType = dyn_cast<MemRefType>(type)) {
171  // TODO: Layout should also be controlled by the ABI attributes. For now
172  // using the layout from MemRef.
173  int64_t offset;
174  SmallVector<int64_t, 4> strides;
175  if (!memRefType.hasStaticShape() ||
176  failed(getStridesAndOffset(memRefType, strides, offset)))
177  return std::nullopt;
178 
179  // To get the size of the memref object in memory, the total size is the
180  // max(stride * dimension-size) computed for all dimensions times the size
181  // of the element.
182  auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
183  if (!elementSize)
184  return std::nullopt;
185 
186  if (memRefType.getRank() == 0)
187  return elementSize;
188 
189  auto dims = memRefType.getShape();
190  if (llvm::is_contained(dims, ShapedType::kDynamic) ||
191  ShapedType::isDynamic(offset) ||
192  llvm::is_contained(strides, ShapedType::kDynamic))
193  return std::nullopt;
194 
195  int64_t memrefSize = -1;
196  for (const auto &shape : enumerate(dims))
197  memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
198 
199  return (offset + memrefSize) * *elementSize;
200  }
201 
202  if (auto tensorType = dyn_cast<TensorType>(type)) {
203  if (!tensorType.hasStaticShape())
204  return std::nullopt;
205 
206  auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
207  if (!elementSize)
208  return std::nullopt;
209 
210  int64_t size = *elementSize;
211  for (auto shape : tensorType.getShape())
212  size *= shape;
213 
214  return size;
215  }
216 
217  // TODO: Add size computation for other types.
218  return std::nullopt;
219 }
220 
221 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
222 static Type
225  std::optional<spirv::StorageClass> storageClass = {}) {
226  // Get extension and capability requirements for the given type.
229  type.getExtensions(extensions, storageClass);
230  type.getCapabilities(capabilities, storageClass);
231 
232  // If all requirements are met, then we can accept this type as-is.
233  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
234  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
235  return type;
236 
237  // Otherwise we need to adjust the type, which really means adjusting the
238  // bitwidth given this is a scalar type.
239  if (!options.emulateLT32BitScalarTypes)
240  return nullptr;
241 
242  // We only emulate narrower scalar types here and do not truncate results.
243  if (type.getIntOrFloatBitWidth() > 32) {
244  LLVM_DEBUG(llvm::dbgs()
245  << type
246  << " not converted to 32-bit for SPIR-V to avoid truncation\n");
247  return nullptr;
248  }
249 
250  if (auto floatType = dyn_cast<FloatType>(type)) {
251  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
252  return Builder(targetEnv.getContext()).getF32Type();
253  }
254 
255  auto intType = cast<IntegerType>(type);
256  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
257  return IntegerType::get(targetEnv.getContext(), /*width=*/32,
258  intType.getSignedness());
259 }
260 
261 /// Converts a sub-byte integer `type` to i32 regardless of target environment.
262 ///
263 /// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
264 /// the above given that these sub-byte types are not supported at all in
265 /// SPIR-V; there are no compute/storage capability for them like other
266 /// supported integer types.
268  IntegerType type) {
269  if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
270  LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
271  return nullptr;
272  }
273 
274  if (!llvm::isPowerOf2_32(type.getWidth())) {
275  LLVM_DEBUG(llvm::dbgs()
276  << "unsupported non-power-of-two bitwidth in sub-byte" << type
277  << "\n");
278  return nullptr;
279  }
280 
281  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
282  return IntegerType::get(type.getContext(), /*width=*/32,
283  type.getSignedness());
284 }
285 
286 /// Returns a type with the same shape but with any index element type converted
287 /// to the matching integer type. This is a noop when the element type is not
288 /// the index type.
289 static ShapedType
290 convertIndexElementType(ShapedType type,
292  Type indexType = dyn_cast<IndexType>(type.getElementType());
293  if (!indexType)
294  return type;
295 
296  return type.clone(getIndexType(type.getContext(), options));
297 }
298 
299 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
300 static Type
302  const SPIRVConversionOptions &options, VectorType type,
303  std::optional<spirv::StorageClass> storageClass = {}) {
304  type = cast<VectorType>(convertIndexElementType(type, options));
305  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
306  if (!scalarType) {
307  // If this is not a spec allowed scalar type, try to handle sub-byte integer
308  // types.
309  auto intType = dyn_cast<IntegerType>(type.getElementType());
310  if (!intType) {
311  LLVM_DEBUG(llvm::dbgs()
312  << type
313  << " illegal: cannot convert non-scalar element type\n");
314  return nullptr;
315  }
316 
317  Type elementType = convertSubByteIntegerType(options, intType);
318  if (type.getRank() <= 1 && type.getNumElements() == 1)
319  return elementType;
320 
321  if (type.getNumElements() > 4) {
322  LLVM_DEBUG(llvm::dbgs()
323  << type << " illegal: > 4-element unimplemented\n");
324  return nullptr;
325  }
326 
327  return VectorType::get(type.getShape(), elementType);
328  }
329 
330  if (type.getRank() <= 1 && type.getNumElements() == 1)
331  return convertScalarType(targetEnv, options, scalarType, storageClass);
332 
333  if (!spirv::CompositeType::isValid(type)) {
334  LLVM_DEBUG(llvm::dbgs()
335  << type << " illegal: not a valid composite type\n");
336  return nullptr;
337  }
338 
339  // Get extension and capability requirements for the given type.
342  cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
343  cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
344 
345  // If all requirements are met, then we can accept this type as-is.
346  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
347  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
348  return type;
349 
350  auto elementType =
351  convertScalarType(targetEnv, options, scalarType, storageClass);
352  if (elementType)
353  return VectorType::get(type.getShape(), elementType);
354  return nullptr;
355 }
356 
357 static Type
359  const SPIRVConversionOptions &options, ComplexType type,
360  std::optional<spirv::StorageClass> storageClass = {}) {
361  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
362  if (!scalarType) {
363  LLVM_DEBUG(llvm::dbgs()
364  << type << " illegal: cannot convert non-scalar element type\n");
365  return nullptr;
366  }
367 
368  auto elementType =
369  convertScalarType(targetEnv, options, scalarType, storageClass);
370  if (!elementType)
371  return nullptr;
372  if (elementType != type.getElementType()) {
373  LLVM_DEBUG(llvm::dbgs()
374  << type << " illegal: complex type emulation unsupported\n");
375  return nullptr;
376  }
377 
378  return VectorType::get(2, elementType);
379 }
380 
381 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
382 ///
383 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
384 /// create composite constants with OpConstantComposite to embed relative large
385 /// constant values and use OpCompositeExtract and OpCompositeInsert to
386 /// manipulate, like what we do for vectors.
387 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
389  TensorType type) {
390  // TODO: Handle dynamic shapes.
391  if (!type.hasStaticShape()) {
392  LLVM_DEBUG(llvm::dbgs()
393  << type << " illegal: dynamic shape unimplemented\n");
394  return nullptr;
395  }
396 
397  type = cast<TensorType>(convertIndexElementType(type, options));
398  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
399  if (!scalarType) {
400  LLVM_DEBUG(llvm::dbgs()
401  << type << " illegal: cannot convert non-scalar element type\n");
402  return nullptr;
403  }
404 
405  std::optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
406  std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);
407  if (!scalarSize || !tensorSize) {
408  LLVM_DEBUG(llvm::dbgs()
409  << type << " illegal: cannot deduce element count\n");
410  return nullptr;
411  }
412 
413  int64_t arrayElemCount = *tensorSize / *scalarSize;
414  if (arrayElemCount == 0) {
415  LLVM_DEBUG(llvm::dbgs()
416  << type << " illegal: cannot handle zero-element tensors\n");
417  return nullptr;
418  }
419 
420  Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
421  if (!arrayElemType)
422  return nullptr;
423  std::optional<int64_t> arrayElemSize =
424  getTypeNumBytes(options, arrayElemType);
425  if (!arrayElemSize) {
426  LLVM_DEBUG(llvm::dbgs()
427  << type << " illegal: cannot deduce converted element size\n");
428  return nullptr;
429  }
430 
431  return spirv::ArrayType::get(arrayElemType, arrayElemCount);
432 }
433 
436  MemRefType type,
437  spirv::StorageClass storageClass) {
438  unsigned numBoolBits = options.boolNumBits;
439  if (numBoolBits != 8) {
440  LLVM_DEBUG(llvm::dbgs()
441  << "using non-8-bit storage for bool types unimplemented");
442  return nullptr;
443  }
444  auto elementType = dyn_cast<spirv::ScalarType>(
445  IntegerType::get(type.getContext(), numBoolBits));
446  if (!elementType)
447  return nullptr;
448  Type arrayElemType =
449  convertScalarType(targetEnv, options, elementType, storageClass);
450  if (!arrayElemType)
451  return nullptr;
452  std::optional<int64_t> arrayElemSize =
453  getTypeNumBytes(options, arrayElemType);
454  if (!arrayElemSize) {
455  LLVM_DEBUG(llvm::dbgs()
456  << type << " illegal: cannot deduce converted element size\n");
457  return nullptr;
458  }
459 
460  if (!type.hasStaticShape()) {
461  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
462  // to the element.
463  if (targetEnv.allows(spirv::Capability::Kernel))
464  return spirv::PointerType::get(arrayElemType, storageClass);
465  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
466  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
467  // For Vulkan we need extra wrapping struct and array to satisfy interface
468  // needs.
469  return wrapInStructAndGetPointer(arrayType, storageClass);
470  }
471 
472  if (type.getNumElements() == 0) {
473  LLVM_DEBUG(llvm::dbgs()
474  << type << " illegal: zero-element memrefs are not supported\n");
475  return nullptr;
476  }
477 
478  int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
479  int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
480  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
481  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
482  if (targetEnv.allows(spirv::Capability::Kernel))
483  return spirv::PointerType::get(arrayType, storageClass);
484  return wrapInStructAndGetPointer(arrayType, storageClass);
485 }
486 
489  MemRefType type,
490  spirv::StorageClass storageClass) {
491  IntegerType elementType = cast<IntegerType>(type.getElementType());
492  Type arrayElemType = convertSubByteIntegerType(options, elementType);
493  if (!arrayElemType)
494  return nullptr;
495  int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType);
496 
497  if (!type.hasStaticShape()) {
498  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
499  // to the element.
500  if (targetEnv.allows(spirv::Capability::Kernel))
501  return spirv::PointerType::get(arrayElemType, storageClass);
502  int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
503  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
504  // For Vulkan we need extra wrapping struct and array to satisfy interface
505  // needs.
506  return wrapInStructAndGetPointer(arrayType, storageClass);
507  }
508 
509  if (type.getNumElements() == 0) {
510  LLVM_DEBUG(llvm::dbgs()
511  << type << " illegal: zero-element memrefs are not supported\n");
512  return nullptr;
513  }
514 
515  int64_t memrefSize =
516  llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
517  int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
518  int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
519  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
520  if (targetEnv.allows(spirv::Capability::Kernel))
521  return spirv::PointerType::get(arrayType, storageClass);
522  return wrapInStructAndGetPointer(arrayType, storageClass);
523 }
524 
525 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
527  MemRefType type) {
528  auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
529  if (!attr) {
530  LLVM_DEBUG(
531  llvm::dbgs()
532  << type
533  << " illegal: expected memory space to be a SPIR-V storage class "
534  "attribute; please use MemorySpaceToStorageClassConverter to map "
535  "numeric memory spaces beforehand\n");
536  return nullptr;
537  }
538  spirv::StorageClass storageClass = attr.getValue();
539 
540  if (isa<IntegerType>(type.getElementType())) {
541  if (type.getElementTypeBitWidth() == 1)
542  return convertBoolMemrefType(targetEnv, options, type, storageClass);
543  if (type.getElementTypeBitWidth() < 8)
544  return convertSubByteMemrefType(targetEnv, options, type, storageClass);
545  }
546 
547  Type arrayElemType;
548  Type elementType = type.getElementType();
549  if (auto vecType = dyn_cast<VectorType>(elementType)) {
550  arrayElemType =
551  convertVectorType(targetEnv, options, vecType, storageClass);
552  } else if (auto complexType = dyn_cast<ComplexType>(elementType)) {
553  arrayElemType =
554  convertComplexType(targetEnv, options, complexType, storageClass);
555  } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
556  arrayElemType =
557  convertScalarType(targetEnv, options, scalarType, storageClass);
558  } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
559  type = cast<MemRefType>(convertIndexElementType(type, options));
560  arrayElemType = type.getElementType();
561  } else {
562  LLVM_DEBUG(
563  llvm::dbgs()
564  << type
565  << " unhandled: can only convert scalar or vector element type\n");
566  return nullptr;
567  }
568  if (!arrayElemType)
569  return nullptr;
570 
571  std::optional<int64_t> arrayElemSize =
572  getTypeNumBytes(options, arrayElemType);
573  if (!arrayElemSize) {
574  LLVM_DEBUG(llvm::dbgs()
575  << type << " illegal: cannot deduce converted element size\n");
576  return nullptr;
577  }
578 
579  if (!type.hasStaticShape()) {
580  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
581  // to the element.
582  if (targetEnv.allows(spirv::Capability::Kernel))
583  return spirv::PointerType::get(arrayElemType, storageClass);
584  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
585  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
586  // For Vulkan we need extra wrapping struct and array to satisfy interface
587  // needs.
588  return wrapInStructAndGetPointer(arrayType, storageClass);
589  }
590 
591  std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);
592  if (!memrefSize) {
593  LLVM_DEBUG(llvm::dbgs()
594  << type << " illegal: cannot deduce element count\n");
595  return nullptr;
596  }
597 
598  if (*memrefSize == 0) {
599  LLVM_DEBUG(llvm::dbgs()
600  << type << " illegal: zero-element memrefs are not supported\n");
601  return nullptr;
602  }
603 
604  int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
605  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
606  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
607  if (targetEnv.allows(spirv::Capability::Kernel))
608  return spirv::PointerType::get(arrayType, storageClass);
609  return wrapInStructAndGetPointer(arrayType, storageClass);
610 }
611 
612 //===----------------------------------------------------------------------===//
613 // Type casting materialization
614 //===----------------------------------------------------------------------===//
615 
616 /// Converts the given `inputs` to the original source `type` considering the
617 /// `targetEnv`'s capabilities.
618 ///
619 /// This function is meant to be used for source materialization in type
620 /// converters. When the type converter needs to materialize a cast op back
621 /// to some original source type, we need to check whether the original source
622 /// type is supported in the target environment. If so, we can insert legal
623 /// SPIR-V cast ops accordingly.
624 ///
625 /// Note that in SPIR-V the capabilities for storage and compute are separate.
626 /// This function is meant to handle the **compute** side; so it does not
627 /// involve storage classes in its logic. The storage side is expected to be
628 /// handled by MemRef conversion logic.
629 std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
630  OpBuilder &builder, Type type,
631  ValueRange inputs, Location loc) {
632  // We can only cast one value in SPIR-V.
633  if (inputs.size() != 1) {
634  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
635  return castOp.getResult(0);
636  }
637  Value input = inputs.front();
638 
639  // Only support integer types for now. Floating point types to be implemented.
640  if (!isa<IntegerType>(type)) {
641  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
642  return castOp.getResult(0);
643  }
644  auto inputType = cast<IntegerType>(input.getType());
645 
646  auto scalarType = dyn_cast<spirv::ScalarType>(type);
647  if (!scalarType) {
648  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
649  return castOp.getResult(0);
650  }
651 
652  // Only support source type with a smaller bitwidth. This would mean we are
653  // truncating to go back so we don't need to worry about the signedness.
654  // For extension, we cannot have enough signal here to decide which op to use.
655  if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
656  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
657  return castOp.getResult(0);
658  }
659 
660  // Boolean values would need to use different ops than normal integer values.
661  if (type.isInteger(1)) {
662  Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
663  return builder.create<spirv::IEqualOp>(loc, input, one);
664  }
665 
666  // Check that the source integer type is supported by the environment.
669  scalarType.getExtensions(exts);
670  scalarType.getCapabilities(caps);
671  if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
672  failed(checkExtensionRequirements(type, targetEnv, exts))) {
673  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
674  return castOp.getResult(0);
675  }
676 
677  // We've already made sure this is truncating previously, so we don't need to
678  // care about signedness here. Still try to use a corresponding op for better
679  // consistency though.
680  if (type.isSignedInteger()) {
681  return builder.create<spirv::SConvertOp>(loc, type, input);
682  }
683  return builder.create<spirv::UConvertOp>(loc, type, input);
684 }
685 
686 //===----------------------------------------------------------------------===//
687 // SPIRVTypeConverter
688 //===----------------------------------------------------------------------===//
689 
692  : targetEnv(targetAttr), options(options) {
693  // Add conversions. The order matters here: later ones will be tried earlier.
694 
695  // Allow all SPIR-V dialect specific types. This assumes all builtin types
696  // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
697  // were tried before.
698  //
699  // TODO: This assumes that the SPIR-V types are valid to use in the given
700  // target environment, which should be the case if the whole pipeline is
701  // driven by the same target environment. Still, we probably still want to
702  // validate and convert to be safe.
703  addConversion([](spirv::SPIRVType type) { return type; });
704 
705  addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
706 
707  addConversion([this](IntegerType intType) -> std::optional<Type> {
708  if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
709  return convertScalarType(this->targetEnv, this->options, scalarType);
710  if (intType.getWidth() < 8)
711  return convertSubByteIntegerType(this->options, intType);
712  return Type();
713  });
714 
715  addConversion([this](FloatType floatType) -> std::optional<Type> {
716  if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
717  return convertScalarType(this->targetEnv, this->options, scalarType);
718  return Type();
719  });
720 
721  addConversion([this](ComplexType complexType) {
722  return convertComplexType(this->targetEnv, this->options, complexType);
723  });
724 
725  addConversion([this](VectorType vectorType) {
726  return convertVectorType(this->targetEnv, this->options, vectorType);
727  });
728 
729  addConversion([this](TensorType tensorType) {
730  return convertTensorType(this->targetEnv, this->options, tensorType);
731  });
732 
733  addConversion([this](MemRefType memRefType) {
734  return convertMemrefType(this->targetEnv, this->options, memRefType);
735  });
736 
737  // Register some last line of defense casting logic.
739  [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
740  return castToSourceType(this->targetEnv, builder, type, inputs, loc);
741  });
742  addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
743  Location loc) {
744  auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
745  return std::optional<Value>(cast.getResult(0));
746  });
747 }
748 
749 //===----------------------------------------------------------------------===//
750 // func::FuncOp Conversion Patterns
751 //===----------------------------------------------------------------------===//
752 
753 namespace {
754 /// A pattern for rewriting function signature to convert arguments of functions
755 /// to be of valid SPIR-V types.
756 class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
757 public:
759 
761  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
762  ConversionPatternRewriter &rewriter) const override;
763 };
764 } // namespace
765 
767 FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
768  ConversionPatternRewriter &rewriter) const {
769  auto fnType = funcOp.getFunctionType();
770  if (fnType.getNumResults() > 1)
771  return failure();
772 
773  TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
774  for (const auto &argType : enumerate(fnType.getInputs())) {
775  auto convertedType = getTypeConverter()->convertType(argType.value());
776  if (!convertedType)
777  return failure();
778  signatureConverter.addInputs(argType.index(), convertedType);
779  }
780 
781  Type resultType;
782  if (fnType.getNumResults() == 1) {
783  resultType = getTypeConverter()->convertType(fnType.getResult(0));
784  if (!resultType)
785  return failure();
786  }
787 
788  // Create the converted spirv.func op.
789  auto newFuncOp = rewriter.create<spirv::FuncOp>(
790  funcOp.getLoc(), funcOp.getName(),
791  rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
792  resultType ? TypeRange(resultType)
793  : TypeRange()));
794 
795  // Copy over all attributes other than the function name and type.
796  for (const auto &namedAttr : funcOp->getAttrs()) {
797  if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
798  namedAttr.getName() != SymbolTable::getSymbolAttrName())
799  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
800  }
801 
802  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
803  newFuncOp.end());
804  if (failed(rewriter.convertRegionTypes(
805  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
806  return failure();
807  rewriter.eraseOp(funcOp);
808  return success();
809 }
810 
812  RewritePatternSet &patterns) {
813  patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
814 }
815 
816 //===----------------------------------------------------------------------===//
817 // Builtin Variables
818 //===----------------------------------------------------------------------===//
819 
820 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
821  spirv::BuiltIn builtin) {
822  // Look through all global variables in the given `body` block and check if
823  // there is a spirv.GlobalVariable that has the same `builtin` attribute.
824  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
825  if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
826  spirv::SPIRVDialect::getAttributeName(
827  spirv::Decoration::BuiltIn))) {
828  auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
829  if (varBuiltIn && *varBuiltIn == builtin) {
830  return varOp;
831  }
832  }
833  }
834  return nullptr;
835 }
836 
837 /// Gets name of global variable for a builtin.
838 static std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
839  StringRef suffix) {
840  return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
841 }
842 
843 /// Gets or inserts a global variable for a builtin within `body` block.
844 static spirv::GlobalVariableOp
845 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
846  Type integerType, OpBuilder &builder,
847  StringRef prefix, StringRef suffix) {
848  if (auto varOp = getBuiltinVariable(body, builtin))
849  return varOp;
850 
851  OpBuilder::InsertionGuard guard(builder);
852  builder.setInsertionPointToStart(&body);
853 
854  spirv::GlobalVariableOp newVarOp;
855  switch (builtin) {
856  case spirv::BuiltIn::NumWorkgroups:
857  case spirv::BuiltIn::WorkgroupSize:
858  case spirv::BuiltIn::WorkgroupId:
859  case spirv::BuiltIn::LocalInvocationId:
860  case spirv::BuiltIn::GlobalInvocationId: {
861  auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
862  spirv::StorageClass::Input);
863  std::string name = getBuiltinVarName(builtin, prefix, suffix);
864  newVarOp =
865  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
866  break;
867  }
868  case spirv::BuiltIn::SubgroupId:
869  case spirv::BuiltIn::NumSubgroups:
870  case spirv::BuiltIn::SubgroupSize: {
871  auto ptrType =
872  spirv::PointerType::get(integerType, spirv::StorageClass::Input);
873  std::string name = getBuiltinVarName(builtin, prefix, suffix);
874  newVarOp =
875  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
876  break;
877  }
878  default:
879  emitError(loc, "unimplemented builtin variable generation for ")
880  << stringifyBuiltIn(builtin);
881  }
882  return newVarOp;
883 }
884 
886  spirv::BuiltIn builtin,
887  Type integerType, OpBuilder &builder,
888  StringRef prefix, StringRef suffix) {
890  if (!parent) {
891  op->emitError("expected operation to be within a module-like op");
892  return nullptr;
893  }
894 
895  spirv::GlobalVariableOp varOp =
896  getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
897  builtin, integerType, builder, prefix, suffix);
898  Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
899  return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
900 }
901 
902 //===----------------------------------------------------------------------===//
903 // Push constant storage
904 //===----------------------------------------------------------------------===//
905 
906 /// Returns the pointer type for the push constant storage containing
907 /// `elementCount` 32-bit integer values.
908 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
909  Builder &builder,
910  Type indexType) {
911  auto arrayType = spirv::ArrayType::get(indexType, elementCount,
912  /*stride=*/4);
913  auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
914  return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
915 }
916 
917 /// Returns the push constant varible containing `elementCount` 32-bit integer
918 /// values in `body`. Returns null op if such an op does not exit.
919 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
920  unsigned elementCount) {
921  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
922  auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
923  if (!ptrType)
924  continue;
925 
926  // Note that Vulkan requires "There must be no more than one push constant
927  // block statically used per shader entry point." So we should always reuse
928  // the existing one.
929  if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
930  auto numElements = cast<spirv::ArrayType>(
931  cast<spirv::StructType>(ptrType.getPointeeType())
932  .getElementType(0))
933  .getNumElements();
934  if (numElements == elementCount)
935  return varOp;
936  }
937  }
938  return nullptr;
939 }
940 
941 /// Gets or inserts a global variable for push constant storage containing
942 /// `elementCount` 32-bit integer values in `block`.
943 static spirv::GlobalVariableOp
945  unsigned elementCount, OpBuilder &b,
946  Type indexType) {
947  if (auto varOp = getPushConstantVariable(block, elementCount))
948  return varOp;
949 
950  auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
951  auto type = getPushConstantStorageType(elementCount, builder, indexType);
952  const char *name = "__push_constant_var__";
953  return builder.create<spirv::GlobalVariableOp>(loc, type, name,
954  /*initializer=*/nullptr);
955 }
956 
957 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
958  unsigned offset, Type integerType,
959  OpBuilder &builder) {
960  Location loc = op->getLoc();
962  if (!parent) {
963  op->emitError("expected operation to be within a module-like op");
964  return nullptr;
965  }
966 
967  spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
968  loc, parent->getRegion(0).front(), elementCount, builder, integerType);
969 
970  Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
971  Value offsetOp = builder.create<spirv::ConstantOp>(
972  loc, integerType, builder.getI32IntegerAttr(offset));
973  auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
974  auto acOp = builder.create<spirv::AccessChainOp>(
975  loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp}));
976  return builder.create<spirv::LoadOp>(loc, acOp);
977 }
978 
979 //===----------------------------------------------------------------------===//
980 // Index calculation
981 //===----------------------------------------------------------------------===//
982 
984  int64_t offset, Type integerType,
985  Location loc, OpBuilder &builder) {
986  assert(indices.size() == strides.size() &&
987  "must provide indices for all dimensions");
988 
989  // TODO: Consider moving to use affine.apply and patterns converting
990  // affine.apply to standard ops. This needs converting to SPIR-V passes to be
991  // broken down into progressive small steps so we can have intermediate steps
992  // using other dialects. At the moment SPIR-V is the final sink.
993 
994  Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
995  loc, integerType, IntegerAttr::get(integerType, offset));
996  for (const auto &index : llvm::enumerate(indices)) {
997  Value strideVal = builder.createOrFold<spirv::ConstantOp>(
998  loc, integerType,
999  IntegerAttr::get(integerType, strides[index.index()]));
1000  Value update =
1001  builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1002  linearizedIndex =
1003  builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1004  }
1005  return linearizedIndex;
1006 }
1007 
1009  MemRefType baseType, Value basePtr,
1010  ValueRange indices, Location loc,
1011  OpBuilder &builder) {
1012  // Get base and offset of the MemRefType and verify they are static.
1013 
1014  int64_t offset;
1015  SmallVector<int64_t, 4> strides;
1016  if (failed(getStridesAndOffset(baseType, strides, offset)) ||
1017  llvm::is_contained(strides, ShapedType::kDynamic) ||
1018  ShapedType::isDynamic(offset)) {
1019  return nullptr;
1020  }
1021 
1022  auto indexType = typeConverter.getIndexType();
1023 
1024  SmallVector<Value, 2> linearizedIndices;
1025  auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
1026 
1027  // Add a '0' at the start to index into the struct.
1028  linearizedIndices.push_back(zero);
1029 
1030  if (baseType.getRank() == 0) {
1031  linearizedIndices.push_back(zero);
1032  } else {
1033  linearizedIndices.push_back(
1034  linearizeIndex(indices, strides, offset, indexType, loc, builder));
1035  }
1036  return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
1037 }
1038 
1040  MemRefType baseType, Value basePtr,
1041  ValueRange indices, Location loc,
1042  OpBuilder &builder) {
1043  // Get base and offset of the MemRefType and verify they are static.
1044 
1045  int64_t offset;
1046  SmallVector<int64_t, 4> strides;
1047  if (failed(getStridesAndOffset(baseType, strides, offset)) ||
1048  llvm::is_contained(strides, ShapedType::kDynamic) ||
1049  ShapedType::isDynamic(offset)) {
1050  return nullptr;
1051  }
1052 
1053  auto indexType = typeConverter.getIndexType();
1054 
1055  SmallVector<Value, 2> linearizedIndices;
1056  Value linearIndex;
1057  if (baseType.getRank() == 0) {
1058  linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
1059  } else {
1060  linearIndex =
1061  linearizeIndex(indices, strides, offset, indexType, loc, builder);
1062  }
1063  Type pointeeType =
1064  cast<spirv::PointerType>(basePtr.getType()).getPointeeType();
1065  if (isa<spirv::ArrayType>(pointeeType)) {
1066  linearizedIndices.push_back(linearIndex);
1067  return builder.create<spirv::AccessChainOp>(loc, basePtr,
1068  linearizedIndices);
1069  }
1070  return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
1071  linearizedIndices);
1072 }
1073 
1075  MemRefType baseType, Value basePtr,
1076  ValueRange indices, Location loc,
1077  OpBuilder &builder) {
1078 
1079  if (typeConverter.allows(spirv::Capability::Kernel)) {
1080  return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
1081  builder);
1082  }
1083 
1084  return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
1085  builder);
1086 }
1087 
1088 //===----------------------------------------------------------------------===//
1089 // SPIR-V ConversionTarget
1090 //===----------------------------------------------------------------------===//
1091 
1092 std::unique_ptr<SPIRVConversionTarget>
1094  std::unique_ptr<SPIRVConversionTarget> target(
1095  // std::make_unique does not work here because the constructor is private.
1096  new SPIRVConversionTarget(targetAttr));
1097  SPIRVConversionTarget *targetPtr = target.get();
1098  target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1099  // We need to capture the raw pointer here because it is stable:
1100  // target will be destroyed once this function is returned.
1101  [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
1102  return target;
1103 }
1104 
1105 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
1106  : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
1107 
1108 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
1109  // Make sure this op is available at the given version. Ops not implementing
1110  // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
1111  // SPIR-V versions.
1112  if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1113  std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1114  if (minVersion && *minVersion > this->targetEnv.getVersion()) {
1115  LLVM_DEBUG(llvm::dbgs()
1116  << op->getName() << " illegal: requiring min version "
1117  << spirv::stringifyVersion(*minVersion) << "\n");
1118  return false;
1119  }
1120  }
1121  if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1122  std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1123  if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
1124  LLVM_DEBUG(llvm::dbgs()
1125  << op->getName() << " illegal: requiring max version "
1126  << spirv::stringifyVersion(*maxVersion) << "\n");
1127  return false;
1128  }
1129  }
1130 
1131  // Make sure this op's required extensions are allowed to use. Ops not
1132  // implementing QueryExtensionInterface do not require extensions to be
1133  // available.
1134  if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1135  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1136  extensions.getExtensions())))
1137  return false;
1138 
1139  // Make sure this op's required extensions are allowed to use. Ops not
1140  // implementing QueryCapabilityInterface do not require capabilities to be
1141  // available.
1142  if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1143  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1144  capabilities.getCapabilities())))
1145  return false;
1146 
1147  SmallVector<Type, 4> valueTypes;
1148  valueTypes.append(op->operand_type_begin(), op->operand_type_end());
1149  valueTypes.append(op->result_type_begin(), op->result_type_end());
1150 
1151  // Ensure that all types have been converted to SPIRV types.
1152  if (llvm::any_of(valueTypes,
1153  [](Type t) { return !isa<spirv::SPIRVType>(t); }))
1154  return false;
1155 
1156  // Special treatment for global variables, whose type requirements are
1157  // conveyed by type attributes.
1158  if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1159  valueTypes.push_back(globalVar.getType());
1160 
1161  // Make sure the op's operands/results use types that are allowed by the
1162  // target environment.
1163  SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
1164  SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
1165  for (Type valueType : valueTypes) {
1166  typeExtensions.clear();
1167  cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1168  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1169  typeExtensions)))
1170  return false;
1171 
1172  typeCapabilities.clear();
1173  cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1174  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1175  typeCapabilities)))
1176  return false;
1177  }
1178 
1179  return true;
1180 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static bool needsExplicitLayout(spirv::StorageClass storageClass)
Returns true if the given storageClass needs explicit layout when used in Shader environments.
static spirv::GlobalVariableOp getPushConstantVariable(Block &body, unsigned elementCount)
Returns the push constant varible containing elementCount 32-bit integer values in body.
static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type, spirv::StorageClass storageClass)
static Type convertTensorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, TensorType type)
Converts a tensor type to a suitable type under the given targetEnv.
static LogicalResult checkCapabilityRequirements(LabelT label, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::CapabilityArrayRefVector &candidates)
Checks that candidatescapability requirements are possible to be satisfied with the given isAllowedFn...
static std::optional< int64_t > getTypeNumBytes(const SPIRVConversionOptions &options, Type type)
static Type convertSubByteIntegerType(const SPIRVConversionOptions &options, IntegerType type)
Converts a sub-byte integer type to i32 regardless of target environment.
static spirv::GlobalVariableOp getBuiltinVariable(Block &body, spirv::BuiltIn builtin)
static ShapedType convertIndexElementType(ShapedType type, const SPIRVConversionOptions &options)
Returns a type with the same shape but with any index element type converted to the matching integer ...
static spirv::GlobalVariableOp getOrInsertPushConstantVariable(Location loc, Block &block, unsigned elementCount, OpBuilder &b, Type indexType)
Gets or inserts a global variable for push constant storage containing elementCount 32-bit integer va...
static Type convertComplexType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, ComplexType type, std::optional< spirv::StorageClass > storageClass={})
static LogicalResult checkExtensionRequirements(LabelT label, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::ExtensionArrayRefVector &candidates)
Checks that candidates extension requirements are possible to be satisfied with the given targetEnv.
std::optional< Value > castToSourceType(const spirv::TargetEnv &targetEnv, OpBuilder &builder, Type type, ValueRange inputs, Location loc)
Converts the given inputs to the original source type considering the targetEnv's capabilities.
static spirv::ScalarType getIndexType(MLIRContext *ctx, const SPIRVConversionOptions &options)
static spirv::GlobalVariableOp getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix, StringRef suffix)
Gets or inserts a global variable for a builtin within body block.
static std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix, StringRef suffix)
Gets name of global variable for a builtin.
static Type convertScalarType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, spirv::ScalarType type, std::optional< spirv::StorageClass > storageClass={})
Converts a scalar type to a suitable type under the given targetEnv.
static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type, spirv::StorageClass storageClass)
static Type convertVectorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, VectorType type, std::optional< spirv::StorageClass > storageClass={})
Converts a vector type to a suitable type under the given targetEnv.
static spirv::PointerType wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass)
Wraps the given elementType in a struct and gets the pointer to the struct.
static spirv::PointerType getPushConstantStorageType(unsigned elementCount, Builder &builder, Type indexType)
Returns the pointer type for the push constant storage containing elementCount 32-bit integer values.
static Type convertMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type)
Block represents an ordered list of Operations.
Definition: Block.h:30
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:190
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
FloatType getF32Type()
Definition: Builders.cpp:63
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:96
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Convert the types of block arguments within the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:242
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:322
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
operand_type_iterator operand_type_end()
Definition: Operation.h:391
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_iterator result_type_end()
Definition: Operation.h:422
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
result_type_iterator result_type_begin()
Definition: Operation.h:421
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_iterator operand_type_begin()
Definition: Operation.h:390
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:812
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:836
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:91
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:79
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:102
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:481
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:538
void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:590
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional< StorageClass > storage=std::nullopt)
Definition: SPIRVTypes.cpp:621
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
Definition: TargetAndABI.h:62
MLIRContext * getContext() const
Returns the MLIRContext.
OpFoldResult linearizeIndex(ArrayRef< OpFoldResult > multiIndex, ArrayRef< OpFoldResult > basis, ImplicitLocOpBuilder &builder)
Definition: Utils.cpp:1874
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26