MLIR  19.0.0git
SPIRVConversion.cpp
Go to the documentation of this file.
1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/Operation.h"
26 #include "mlir/IR/PatternMatch.h"
27 #include "mlir/Support/LLVM.h"
30 #include "llvm/ADT/STLExtras.h"
31 #include "llvm/ADT/SmallVector.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/MathExtras.h"
35 
36 #include <functional>
37 #include <optional>
38 
39 #define DEBUG_TYPE "mlir-spirv-conversion"
40 
41 using namespace mlir;
42 
43 namespace {
44 
45 //===----------------------------------------------------------------------===//
46 // Utility functions
47 //===----------------------------------------------------------------------===//
48 
49 static int getComputeVectorSize(int64_t size) {
50  for (int i : {4, 3, 2}) {
51  if (size % i == 0)
52  return i;
53  }
54  return 1;
55 }
56 
57 static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
58  LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
59  if (vecType.isScalable()) {
60  LLVM_DEBUG(llvm::dbgs()
61  << "--scalable vectors are not supported -> BAIL\n");
62  return std::nullopt;
63  }
64  SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
65  std::optional<SmallVector<int64_t>> targetShape =
66  SmallVector<int64_t>(1, getComputeVectorSize(vecType.getShape().back()));
67  if (!targetShape) {
68  LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
69  return std::nullopt;
70  }
71  auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
72  if (!maybeShapeRatio) {
73  LLVM_DEBUG(llvm::dbgs()
74  << "--could not compute integral shape ratio -> BAIL\n");
75  return std::nullopt;
76  }
77  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
78  LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
79  return std::nullopt;
80  }
81  LLVM_DEBUG(llvm::dbgs()
82  << "--found an integral shape ratio to unroll to -> SUCCESS\n");
83  return targetShape;
84 }
85 
86 /// Checks that `candidates` extension requirements are possible to be satisfied
87 /// with the given `targetEnv`.
88 ///
89 /// `candidates` is a vector of vector for extension requirements following
90 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
91 /// convention.
92 template <typename LabelT>
93 static LogicalResult checkExtensionRequirements(
94  LabelT label, const spirv::TargetEnv &targetEnv,
96  for (const auto &ors : candidates) {
97  if (targetEnv.allows(ors))
98  continue;
99 
100  LLVM_DEBUG({
101  SmallVector<StringRef> extStrings;
102  for (spirv::Extension ext : ors)
103  extStrings.push_back(spirv::stringifyExtension(ext));
104 
105  llvm::dbgs() << label << " illegal: requires at least one extension in ["
106  << llvm::join(extStrings, ", ")
107  << "] but none allowed in target environment\n";
108  });
109  return failure();
110  }
111  return success();
112 }
113 
114 /// Checks that `candidates`capability requirements are possible to be satisfied
115 /// with the given `isAllowedFn`.
116 ///
117 /// `candidates` is a vector of vector for capability requirements following
118 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
119 /// convention.
120 template <typename LabelT>
121 static LogicalResult checkCapabilityRequirements(
122  LabelT label, const spirv::TargetEnv &targetEnv,
123  const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
124  for (const auto &ors : candidates) {
125  if (targetEnv.allows(ors))
126  continue;
127 
128  LLVM_DEBUG({
129  SmallVector<StringRef> capStrings;
130  for (spirv::Capability cap : ors)
131  capStrings.push_back(spirv::stringifyCapability(cap));
132 
133  llvm::dbgs() << label << " illegal: requires at least one capability in ["
134  << llvm::join(capStrings, ", ")
135  << "] but none allowed in target environment\n";
136  });
137  return failure();
138  }
139  return success();
140 }
141 
142 /// Returns true if the given `storageClass` needs explicit layout when used in
143 /// Shader environments.
144 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
145  switch (storageClass) {
146  case spirv::StorageClass::PhysicalStorageBuffer:
147  case spirv::StorageClass::PushConstant:
148  case spirv::StorageClass::StorageBuffer:
149  case spirv::StorageClass::Uniform:
150  return true;
151  default:
152  return false;
153  }
154 }
155 
156 /// Wraps the given `elementType` in a struct and gets the pointer to the
157 /// struct. This is used to satisfy Vulkan interface requirements.
158 static spirv::PointerType
159 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
160  auto structType = needsExplicitLayout(storageClass)
161  ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
162  : spirv::StructType::get(elementType);
163  return spirv::PointerType::get(structType, storageClass);
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // Type Conversion
168 //===----------------------------------------------------------------------===//
169 
170 static spirv::ScalarType getIndexType(MLIRContext *ctx,
172  return cast<spirv::ScalarType>(
173  IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
174 }
175 
176 // TODO: This is a utility function that should probably be exposed by the
177 // SPIR-V dialect. Keeping it local till the use case arises.
178 static std::optional<int64_t>
179 getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
180  if (isa<spirv::ScalarType>(type)) {
181  auto bitWidth = type.getIntOrFloatBitWidth();
182  // According to the SPIR-V spec:
183  // "There is no physical size or bit pattern defined for values with boolean
184  // type. If they are stored (in conjunction with OpVariable), they can only
185  // be used with logical addressing operations, not physical, and only with
186  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
187  // Private, Function, Input, and Output."
188  if (bitWidth == 1)
189  return std::nullopt;
190  return bitWidth / 8;
191  }
192 
193  if (auto complexType = dyn_cast<ComplexType>(type)) {
194  auto elementSize = getTypeNumBytes(options, complexType.getElementType());
195  if (!elementSize)
196  return std::nullopt;
197  return 2 * *elementSize;
198  }
199 
200  if (auto vecType = dyn_cast<VectorType>(type)) {
201  auto elementSize = getTypeNumBytes(options, vecType.getElementType());
202  if (!elementSize)
203  return std::nullopt;
204  return vecType.getNumElements() * *elementSize;
205  }
206 
207  if (auto memRefType = dyn_cast<MemRefType>(type)) {
208  // TODO: Layout should also be controlled by the ABI attributes. For now
209  // using the layout from MemRef.
210  int64_t offset;
211  SmallVector<int64_t, 4> strides;
212  if (!memRefType.hasStaticShape() ||
213  failed(getStridesAndOffset(memRefType, strides, offset)))
214  return std::nullopt;
215 
216  // To get the size of the memref object in memory, the total size is the
217  // max(stride * dimension-size) computed for all dimensions times the size
218  // of the element.
219  auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
220  if (!elementSize)
221  return std::nullopt;
222 
223  if (memRefType.getRank() == 0)
224  return elementSize;
225 
226  auto dims = memRefType.getShape();
227  if (llvm::is_contained(dims, ShapedType::kDynamic) ||
228  ShapedType::isDynamic(offset) ||
229  llvm::is_contained(strides, ShapedType::kDynamic))
230  return std::nullopt;
231 
232  int64_t memrefSize = -1;
233  for (const auto &shape : enumerate(dims))
234  memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
235 
236  return (offset + memrefSize) * *elementSize;
237  }
238 
239  if (auto tensorType = dyn_cast<TensorType>(type)) {
240  if (!tensorType.hasStaticShape())
241  return std::nullopt;
242 
243  auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
244  if (!elementSize)
245  return std::nullopt;
246 
247  int64_t size = *elementSize;
248  for (auto shape : tensorType.getShape())
249  size *= shape;
250 
251  return size;
252  }
253 
254  // TODO: Add size computation for other types.
255  return std::nullopt;
256 }
257 
258 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
259 static Type
260 convertScalarType(const spirv::TargetEnv &targetEnv,
262  std::optional<spirv::StorageClass> storageClass = {}) {
263  // Get extension and capability requirements for the given type.
266  type.getExtensions(extensions, storageClass);
267  type.getCapabilities(capabilities, storageClass);
268 
269  // If all requirements are met, then we can accept this type as-is.
270  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
271  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
272  return type;
273 
274  // Otherwise we need to adjust the type, which really means adjusting the
275  // bitwidth given this is a scalar type.
276  if (!options.emulateLT32BitScalarTypes)
277  return nullptr;
278 
279  // We only emulate narrower scalar types here and do not truncate results.
280  if (type.getIntOrFloatBitWidth() > 32) {
281  LLVM_DEBUG(llvm::dbgs()
282  << type
283  << " not converted to 32-bit for SPIR-V to avoid truncation\n");
284  return nullptr;
285  }
286 
287  if (auto floatType = dyn_cast<FloatType>(type)) {
288  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
289  return Builder(targetEnv.getContext()).getF32Type();
290  }
291 
292  auto intType = cast<IntegerType>(type);
293  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
294  return IntegerType::get(targetEnv.getContext(), /*width=*/32,
295  intType.getSignedness());
296 }
297 
298 /// Converts a sub-byte integer `type` to i32 regardless of target environment.
299 ///
300 /// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
301 /// the above given that these sub-byte types are not supported at all in
302 /// SPIR-V; there are no compute/storage capability for them like other
303 /// supported integer types.
304 static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
305  IntegerType type) {
306  if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
307  LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
308  return nullptr;
309  }
310 
311  if (!llvm::isPowerOf2_32(type.getWidth())) {
312  LLVM_DEBUG(llvm::dbgs()
313  << "unsupported non-power-of-two bitwidth in sub-byte" << type
314  << "\n");
315  return nullptr;
316  }
317 
318  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
319  return IntegerType::get(type.getContext(), /*width=*/32,
320  type.getSignedness());
321 }
322 
323 /// Returns a type with the same shape but with any index element type converted
324 /// to the matching integer type. This is a noop when the element type is not
325 /// the index type.
326 static ShapedType
327 convertIndexElementType(ShapedType type,
329  Type indexType = dyn_cast<IndexType>(type.getElementType());
330  if (!indexType)
331  return type;
332 
333  return type.clone(getIndexType(type.getContext(), options));
334 }
335 
336 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
337 static Type
338 convertVectorType(const spirv::TargetEnv &targetEnv,
339  const SPIRVConversionOptions &options, VectorType type,
340  std::optional<spirv::StorageClass> storageClass = {}) {
341  type = cast<VectorType>(convertIndexElementType(type, options));
342  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
343  if (!scalarType) {
344  // If this is not a spec allowed scalar type, try to handle sub-byte integer
345  // types.
346  auto intType = dyn_cast<IntegerType>(type.getElementType());
347  if (!intType) {
348  LLVM_DEBUG(llvm::dbgs()
349  << type
350  << " illegal: cannot convert non-scalar element type\n");
351  return nullptr;
352  }
353 
354  Type elementType = convertSubByteIntegerType(options, intType);
355  if (type.getRank() <= 1 && type.getNumElements() == 1)
356  return elementType;
357 
358  if (type.getNumElements() > 4) {
359  LLVM_DEBUG(llvm::dbgs()
360  << type << " illegal: > 4-element unimplemented\n");
361  return nullptr;
362  }
363 
364  return VectorType::get(type.getShape(), elementType);
365  }
366 
367  if (type.getRank() <= 1 && type.getNumElements() == 1)
368  return convertScalarType(targetEnv, options, scalarType, storageClass);
369 
370  if (!spirv::CompositeType::isValid(type)) {
371  LLVM_DEBUG(llvm::dbgs()
372  << type << " illegal: not a valid composite type\n");
373  return nullptr;
374  }
375 
376  // Get extension and capability requirements for the given type.
379  cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
380  cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
381 
382  // If all requirements are met, then we can accept this type as-is.
383  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
384  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
385  return type;
386 
387  auto elementType =
388  convertScalarType(targetEnv, options, scalarType, storageClass);
389  if (elementType)
390  return VectorType::get(type.getShape(), elementType);
391  return nullptr;
392 }
393 
394 static Type
395 convertComplexType(const spirv::TargetEnv &targetEnv,
396  const SPIRVConversionOptions &options, ComplexType type,
397  std::optional<spirv::StorageClass> storageClass = {}) {
398  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
399  if (!scalarType) {
400  LLVM_DEBUG(llvm::dbgs()
401  << type << " illegal: cannot convert non-scalar element type\n");
402  return nullptr;
403  }
404 
405  auto elementType =
406  convertScalarType(targetEnv, options, scalarType, storageClass);
407  if (!elementType)
408  return nullptr;
409  if (elementType != type.getElementType()) {
410  LLVM_DEBUG(llvm::dbgs()
411  << type << " illegal: complex type emulation unsupported\n");
412  return nullptr;
413  }
414 
415  return VectorType::get(2, elementType);
416 }
417 
418 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
419 ///
420 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
421 /// create composite constants with OpConstantComposite to embed relative large
422 /// constant values and use OpCompositeExtract and OpCompositeInsert to
423 /// manipulate, like what we do for vectors.
424 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
426  TensorType type) {
427  // TODO: Handle dynamic shapes.
428  if (!type.hasStaticShape()) {
429  LLVM_DEBUG(llvm::dbgs()
430  << type << " illegal: dynamic shape unimplemented\n");
431  return nullptr;
432  }
433 
434  type = cast<TensorType>(convertIndexElementType(type, options));
435  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
436  if (!scalarType) {
437  LLVM_DEBUG(llvm::dbgs()
438  << type << " illegal: cannot convert non-scalar element type\n");
439  return nullptr;
440  }
441 
442  std::optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
443  std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);
444  if (!scalarSize || !tensorSize) {
445  LLVM_DEBUG(llvm::dbgs()
446  << type << " illegal: cannot deduce element count\n");
447  return nullptr;
448  }
449 
450  int64_t arrayElemCount = *tensorSize / *scalarSize;
451  if (arrayElemCount == 0) {
452  LLVM_DEBUG(llvm::dbgs()
453  << type << " illegal: cannot handle zero-element tensors\n");
454  return nullptr;
455  }
456 
457  Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
458  if (!arrayElemType)
459  return nullptr;
460  std::optional<int64_t> arrayElemSize =
461  getTypeNumBytes(options, arrayElemType);
462  if (!arrayElemSize) {
463  LLVM_DEBUG(llvm::dbgs()
464  << type << " illegal: cannot deduce converted element size\n");
465  return nullptr;
466  }
467 
468  return spirv::ArrayType::get(arrayElemType, arrayElemCount);
469 }
470 
471 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
473  MemRefType type,
474  spirv::StorageClass storageClass) {
475  unsigned numBoolBits = options.boolNumBits;
476  if (numBoolBits != 8) {
477  LLVM_DEBUG(llvm::dbgs()
478  << "using non-8-bit storage for bool types unimplemented");
479  return nullptr;
480  }
481  auto elementType = dyn_cast<spirv::ScalarType>(
482  IntegerType::get(type.getContext(), numBoolBits));
483  if (!elementType)
484  return nullptr;
485  Type arrayElemType =
486  convertScalarType(targetEnv, options, elementType, storageClass);
487  if (!arrayElemType)
488  return nullptr;
489  std::optional<int64_t> arrayElemSize =
490  getTypeNumBytes(options, arrayElemType);
491  if (!arrayElemSize) {
492  LLVM_DEBUG(llvm::dbgs()
493  << type << " illegal: cannot deduce converted element size\n");
494  return nullptr;
495  }
496 
497  if (!type.hasStaticShape()) {
498  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
499  // to the element.
500  if (targetEnv.allows(spirv::Capability::Kernel))
501  return spirv::PointerType::get(arrayElemType, storageClass);
502  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
503  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
504  // For Vulkan we need extra wrapping struct and array to satisfy interface
505  // needs.
506  return wrapInStructAndGetPointer(arrayType, storageClass);
507  }
508 
509  if (type.getNumElements() == 0) {
510  LLVM_DEBUG(llvm::dbgs()
511  << type << " illegal: zero-element memrefs are not supported\n");
512  return nullptr;
513  }
514 
515  int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
516  int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
517  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
518  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
519  if (targetEnv.allows(spirv::Capability::Kernel))
520  return spirv::PointerType::get(arrayType, storageClass);
521  return wrapInStructAndGetPointer(arrayType, storageClass);
522 }
523 
524 static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
526  MemRefType type,
527  spirv::StorageClass storageClass) {
528  IntegerType elementType = cast<IntegerType>(type.getElementType());
529  Type arrayElemType = convertSubByteIntegerType(options, elementType);
530  if (!arrayElemType)
531  return nullptr;
532  int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType);
533 
534  if (!type.hasStaticShape()) {
535  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
536  // to the element.
537  if (targetEnv.allows(spirv::Capability::Kernel))
538  return spirv::PointerType::get(arrayElemType, storageClass);
539  int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
540  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
541  // For Vulkan we need extra wrapping struct and array to satisfy interface
542  // needs.
543  return wrapInStructAndGetPointer(arrayType, storageClass);
544  }
545 
546  if (type.getNumElements() == 0) {
547  LLVM_DEBUG(llvm::dbgs()
548  << type << " illegal: zero-element memrefs are not supported\n");
549  return nullptr;
550  }
551 
552  int64_t memrefSize =
553  llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
554  int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
555  int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
556  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
557  if (targetEnv.allows(spirv::Capability::Kernel))
558  return spirv::PointerType::get(arrayType, storageClass);
559  return wrapInStructAndGetPointer(arrayType, storageClass);
560 }
561 
562 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
564  MemRefType type) {
565  auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
566  if (!attr) {
567  LLVM_DEBUG(
568  llvm::dbgs()
569  << type
570  << " illegal: expected memory space to be a SPIR-V storage class "
571  "attribute; please use MemorySpaceToStorageClassConverter to map "
572  "numeric memory spaces beforehand\n");
573  return nullptr;
574  }
575  spirv::StorageClass storageClass = attr.getValue();
576 
577  if (isa<IntegerType>(type.getElementType())) {
578  if (type.getElementTypeBitWidth() == 1)
579  return convertBoolMemrefType(targetEnv, options, type, storageClass);
580  if (type.getElementTypeBitWidth() < 8)
581  return convertSubByteMemrefType(targetEnv, options, type, storageClass);
582  }
583 
584  Type arrayElemType;
585  Type elementType = type.getElementType();
586  if (auto vecType = dyn_cast<VectorType>(elementType)) {
587  arrayElemType =
588  convertVectorType(targetEnv, options, vecType, storageClass);
589  } else if (auto complexType = dyn_cast<ComplexType>(elementType)) {
590  arrayElemType =
591  convertComplexType(targetEnv, options, complexType, storageClass);
592  } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
593  arrayElemType =
594  convertScalarType(targetEnv, options, scalarType, storageClass);
595  } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
596  type = cast<MemRefType>(convertIndexElementType(type, options));
597  arrayElemType = type.getElementType();
598  } else {
599  LLVM_DEBUG(
600  llvm::dbgs()
601  << type
602  << " unhandled: can only convert scalar or vector element type\n");
603  return nullptr;
604  }
605  if (!arrayElemType)
606  return nullptr;
607 
608  std::optional<int64_t> arrayElemSize =
609  getTypeNumBytes(options, arrayElemType);
610  if (!arrayElemSize) {
611  LLVM_DEBUG(llvm::dbgs()
612  << type << " illegal: cannot deduce converted element size\n");
613  return nullptr;
614  }
615 
616  if (!type.hasStaticShape()) {
617  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
618  // to the element.
619  if (targetEnv.allows(spirv::Capability::Kernel))
620  return spirv::PointerType::get(arrayElemType, storageClass);
621  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
622  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
623  // For Vulkan we need extra wrapping struct and array to satisfy interface
624  // needs.
625  return wrapInStructAndGetPointer(arrayType, storageClass);
626  }
627 
628  std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);
629  if (!memrefSize) {
630  LLVM_DEBUG(llvm::dbgs()
631  << type << " illegal: cannot deduce element count\n");
632  return nullptr;
633  }
634 
635  if (*memrefSize == 0) {
636  LLVM_DEBUG(llvm::dbgs()
637  << type << " illegal: zero-element memrefs are not supported\n");
638  return nullptr;
639  }
640 
641  int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
642  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
643  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
644  if (targetEnv.allows(spirv::Capability::Kernel))
645  return spirv::PointerType::get(arrayType, storageClass);
646  return wrapInStructAndGetPointer(arrayType, storageClass);
647 }
648 
649 //===----------------------------------------------------------------------===//
650 // Type casting materialization
651 //===----------------------------------------------------------------------===//
652 
653 /// Converts the given `inputs` to the original source `type` considering the
654 /// `targetEnv`'s capabilities.
655 ///
656 /// This function is meant to be used for source materialization in type
657 /// converters. When the type converter needs to materialize a cast op back
658 /// to some original source type, we need to check whether the original source
659 /// type is supported in the target environment. If so, we can insert legal
660 /// SPIR-V cast ops accordingly.
661 ///
662 /// Note that in SPIR-V the capabilities for storage and compute are separate.
663 /// This function is meant to handle the **compute** side; so it does not
664 /// involve storage classes in its logic. The storage side is expected to be
665 /// handled by MemRef conversion logic.
666 static std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
667  OpBuilder &builder, Type type,
668  ValueRange inputs, Location loc) {
669  // We can only cast one value in SPIR-V.
670  if (inputs.size() != 1) {
671  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
672  return castOp.getResult(0);
673  }
674  Value input = inputs.front();
675 
676  // Only support integer types for now. Floating point types to be implemented.
677  if (!isa<IntegerType>(type)) {
678  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
679  return castOp.getResult(0);
680  }
681  auto inputType = cast<IntegerType>(input.getType());
682 
683  auto scalarType = dyn_cast<spirv::ScalarType>(type);
684  if (!scalarType) {
685  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
686  return castOp.getResult(0);
687  }
688 
689  // Only support source type with a smaller bitwidth. This would mean we are
690  // truncating to go back so we don't need to worry about the signedness.
691  // For extension, we cannot have enough signal here to decide which op to use.
692  if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
693  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
694  return castOp.getResult(0);
695  }
696 
697  // Boolean values would need to use different ops than normal integer values.
698  if (type.isInteger(1)) {
699  Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
700  return builder.create<spirv::IEqualOp>(loc, input, one);
701  }
702 
703  // Check that the source integer type is supported by the environment.
706  scalarType.getExtensions(exts);
707  scalarType.getCapabilities(caps);
708  if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
709  failed(checkExtensionRequirements(type, targetEnv, exts))) {
710  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
711  return castOp.getResult(0);
712  }
713 
714  // We've already made sure this is truncating previously, so we don't need to
715  // care about signedness here. Still try to use a corresponding op for better
716  // consistency though.
717  if (type.isSignedInteger()) {
718  return builder.create<spirv::SConvertOp>(loc, type, input);
719  }
720  return builder.create<spirv::UConvertOp>(loc, type, input);
721 }
722 
723 //===----------------------------------------------------------------------===//
724 // Builtin Variables
725 //===----------------------------------------------------------------------===//
726 
727 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
728  spirv::BuiltIn builtin) {
729  // Look through all global variables in the given `body` block and check if
730  // there is a spirv.GlobalVariable that has the same `builtin` attribute.
731  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
732  if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
733  spirv::SPIRVDialect::getAttributeName(
734  spirv::Decoration::BuiltIn))) {
735  auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
736  if (varBuiltIn && *varBuiltIn == builtin) {
737  return varOp;
738  }
739  }
740  }
741  return nullptr;
742 }
743 
744 /// Gets name of global variable for a builtin.
745 std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
746  StringRef suffix) {
747  return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
748 }
749 
750 /// Gets or inserts a global variable for a builtin within `body` block.
751 static spirv::GlobalVariableOp
752 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
753  Type integerType, OpBuilder &builder,
754  StringRef prefix, StringRef suffix) {
755  if (auto varOp = getBuiltinVariable(body, builtin))
756  return varOp;
757 
758  OpBuilder::InsertionGuard guard(builder);
759  builder.setInsertionPointToStart(&body);
760 
761  spirv::GlobalVariableOp newVarOp;
762  switch (builtin) {
763  case spirv::BuiltIn::NumWorkgroups:
764  case spirv::BuiltIn::WorkgroupSize:
765  case spirv::BuiltIn::WorkgroupId:
766  case spirv::BuiltIn::LocalInvocationId:
767  case spirv::BuiltIn::GlobalInvocationId: {
768  auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
769  spirv::StorageClass::Input);
770  std::string name = getBuiltinVarName(builtin, prefix, suffix);
771  newVarOp =
772  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
773  break;
774  }
775  case spirv::BuiltIn::SubgroupId:
776  case spirv::BuiltIn::NumSubgroups:
777  case spirv::BuiltIn::SubgroupSize: {
778  auto ptrType =
779  spirv::PointerType::get(integerType, spirv::StorageClass::Input);
780  std::string name = getBuiltinVarName(builtin, prefix, suffix);
781  newVarOp =
782  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
783  break;
784  }
785  default:
786  emitError(loc, "unimplemented builtin variable generation for ")
787  << stringifyBuiltIn(builtin);
788  }
789  return newVarOp;
790 }
791 
792 //===----------------------------------------------------------------------===//
793 // Push constant storage
794 //===----------------------------------------------------------------------===//
795 
796 /// Returns the pointer type for the push constant storage containing
797 /// `elementCount` 32-bit integer values.
798 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
799  Builder &builder,
800  Type indexType) {
801  auto arrayType = spirv::ArrayType::get(indexType, elementCount,
802  /*stride=*/4);
803  auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
804  return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
805 }
806 
807 /// Returns the push constant varible containing `elementCount` 32-bit integer
808 /// values in `body`. Returns null op if such an op does not exit.
809 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
810  unsigned elementCount) {
811  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
812  auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
813  if (!ptrType)
814  continue;
815 
816  // Note that Vulkan requires "There must be no more than one push constant
817  // block statically used per shader entry point." So we should always reuse
818  // the existing one.
819  if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
820  auto numElements = cast<spirv::ArrayType>(
821  cast<spirv::StructType>(ptrType.getPointeeType())
822  .getElementType(0))
823  .getNumElements();
824  if (numElements == elementCount)
825  return varOp;
826  }
827  }
828  return nullptr;
829 }
830 
831 /// Gets or inserts a global variable for push constant storage containing
832 /// `elementCount` 32-bit integer values in `block`.
833 static spirv::GlobalVariableOp
834 getOrInsertPushConstantVariable(Location loc, Block &block,
835  unsigned elementCount, OpBuilder &b,
836  Type indexType) {
837  if (auto varOp = getPushConstantVariable(block, elementCount))
838  return varOp;
839 
840  auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
841  auto type = getPushConstantStorageType(elementCount, builder, indexType);
842  const char *name = "__push_constant_var__";
843  return builder.create<spirv::GlobalVariableOp>(loc, type, name,
844  /*initializer=*/nullptr);
845 }
846 
847 //===----------------------------------------------------------------------===//
848 // func::FuncOp Conversion Patterns
849 //===----------------------------------------------------------------------===//
850 
851 /// A pattern for rewriting function signature to convert arguments of functions
852 /// to be of valid SPIR-V types.
853 struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
855 
856  LogicalResult
857  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
858  ConversionPatternRewriter &rewriter) const override {
859  FunctionType fnType = funcOp.getFunctionType();
860  if (fnType.getNumResults() > 1)
861  return failure();
862 
863  TypeConverter::SignatureConversion signatureConverter(
864  fnType.getNumInputs());
865  for (const auto &argType : enumerate(fnType.getInputs())) {
866  auto convertedType = getTypeConverter()->convertType(argType.value());
867  if (!convertedType)
868  return failure();
869  signatureConverter.addInputs(argType.index(), convertedType);
870  }
871 
872  Type resultType;
873  if (fnType.getNumResults() == 1) {
874  resultType = getTypeConverter()->convertType(fnType.getResult(0));
875  if (!resultType)
876  return failure();
877  }
878 
879  // Create the converted spirv.func op.
880  auto newFuncOp = rewriter.create<spirv::FuncOp>(
881  funcOp.getLoc(), funcOp.getName(),
882  rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
883  resultType ? TypeRange(resultType)
884  : TypeRange()));
885 
886  // Copy over all attributes other than the function name and type.
887  for (const auto &namedAttr : funcOp->getAttrs()) {
888  if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
889  namedAttr.getName() != SymbolTable::getSymbolAttrName())
890  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
891  }
892 
893  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
894  newFuncOp.end());
895  if (failed(rewriter.convertRegionTypes(
896  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
897  return failure();
898  rewriter.eraseOp(funcOp);
899  return success();
900  }
901 };
902 
903 /// A pattern for rewriting function signature to convert vector arguments of
904 /// functions to be of valid types
905 struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
907 
908  LogicalResult matchAndRewrite(func::FuncOp funcOp,
909  PatternRewriter &rewriter) const override {
910  FunctionType fnType = funcOp.getFunctionType();
911 
912  // TODO: Handle declarations.
913  if (funcOp.isDeclaration()) {
914  LLVM_DEBUG(llvm::dbgs()
915  << fnType << " illegal: declarations are unsupported\n");
916  return failure();
917  }
918 
919  // Create a new func op with the original type and copy the function body.
920  auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(),
921  funcOp.getName(), fnType);
922  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
923  newFuncOp.end());
924 
925  Location loc = newFuncOp.getBody().getLoc();
926 
927  Block &entryBlock = newFuncOp.getBlocks().front();
928  OpBuilder::InsertionGuard guard(rewriter);
929  rewriter.setInsertionPointToStart(&entryBlock);
930 
931  OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
932 
933  // For arguments that are of illegal types and require unrolling.
934  // `unrolledInputNums` stores the indices of arguments that result from
935  // unrolling in the new function signature. `newInputNo` is a counter.
936  SmallVector<size_t> unrolledInputNums;
937  size_t newInputNo = 0;
938 
939  // For arguments that are of legal types and do not require unrolling.
940  // `tmpOps` stores a mapping from temporary operations that serve as
941  // placeholders for new arguments that will be added later. These operations
942  // will be erased once the entry block's argument list is updated.
943  llvm::SmallDenseMap<Operation *, size_t> tmpOps;
944 
945  // This counts the number of new operations created.
946  size_t newOpCount = 0;
947 
948  // Enumerate through the arguments.
949  for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
950  // Check whether the argument is of vector type.
951  auto origVecType = dyn_cast<VectorType>(origType);
952  if (!origVecType) {
953  // We need a placeholder for the old argument that will be erased later.
954  Value result = rewriter.create<arith::ConstantOp>(
955  loc, origType, rewriter.getZeroAttr(origType));
956  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
957  tmpOps.insert({result.getDefiningOp(), newInputNo});
958  oneToNTypeMapping.addInputs(origInputNo, origType);
959  ++newInputNo;
960  ++newOpCount;
961  continue;
962  }
963  // Check whether the vector needs unrolling.
964  auto targetShape = getTargetShape(origVecType);
965  if (!targetShape) {
966  // We need a placeholder for the old argument that will be erased later.
967  Value result = rewriter.create<arith::ConstantOp>(
968  loc, origType, rewriter.getZeroAttr(origType));
969  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
970  tmpOps.insert({result.getDefiningOp(), newInputNo});
971  oneToNTypeMapping.addInputs(origInputNo, origType);
972  ++newInputNo;
973  ++newOpCount;
974  continue;
975  }
976  VectorType unrolledType =
977  VectorType::get(*targetShape, origVecType.getElementType());
978  auto originalShape =
979  llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
980 
981  // Prepare the result vector.
982  Value result = rewriter.create<arith::ConstantOp>(
983  loc, origVecType, rewriter.getZeroAttr(origVecType));
984  ++newOpCount;
985  // Prepare the placeholder for the new arguments that will be added later.
986  Value dummy = rewriter.create<arith::ConstantOp>(
987  loc, unrolledType, rewriter.getZeroAttr(unrolledType));
988  ++newOpCount;
989 
990  // Create the `vector.insert_strided_slice` ops.
991  SmallVector<int64_t> strides(targetShape->size(), 1);
992  SmallVector<Type> newTypes;
993  for (SmallVector<int64_t> offsets :
994  StaticTileOffsetRange(originalShape, *targetShape)) {
995  result = rewriter.create<vector::InsertStridedSliceOp>(
996  loc, dummy, result, offsets, strides);
997  newTypes.push_back(unrolledType);
998  unrolledInputNums.push_back(newInputNo);
999  ++newInputNo;
1000  ++newOpCount;
1001  }
1002  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1003  oneToNTypeMapping.addInputs(origInputNo, newTypes);
1004  }
1005 
1006  // Change the function signature.
1007  auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
1008  auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1009  rewriter.modifyOpInPlace(newFuncOp,
1010  [&] { newFuncOp.setFunctionType(newFnType); });
1011 
1012  // Update the arguments in the entry block.
1013  entryBlock.eraseArguments(0, fnType.getNumInputs());
1014  SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
1015  entryBlock.addArguments(convertedTypes, locs);
1016 
1017  // Replace the placeholder values with the new arguments. We assume there is
1018  // only one block for now.
1019  size_t unrolledInputIdx = 0;
1020  for (auto [count, op] : enumerate(entryBlock.getOperations())) {
1021  // We first look for operands that are placeholders for initially legal
1022  // arguments.
1023  Operation &curOp = op;
1024  for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
1025  Operation *operandOp = operandVal.getDefiningOp();
1026  if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
1027  size_t idx = operandIdx;
1028  rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
1029  curOp.setOperand(idx, newFuncOp.getArgument(it->second));
1030  });
1031  }
1032  }
1033  // Since all newly created operations are in the beginning, reaching the
1034  // end of them means that any later `vector.insert_strided_slice` should
1035  // not be touched.
1036  if (count >= newOpCount)
1037  continue;
1038  if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1039  size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1040  rewriter.modifyOpInPlace(&curOp, [&] {
1041  curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1042  });
1043  ++unrolledInputIdx;
1044  }
1045  }
1046 
1047  // Erase the original funcOp. The `tmpOps` do not need to be erased since
1048  // they have no uses and will be handled by dead-code elimination.
1049  rewriter.eraseOp(funcOp);
1050  return success();
1051  }
1052 };
1053 
1054 //===----------------------------------------------------------------------===//
1055 // func::ReturnOp Conversion Patterns
1056 //===----------------------------------------------------------------------===//
1057 
1058 /// A pattern for rewriting function signature and the return op to convert
1059 /// vectors to be of valid types.
1060 struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
1062 
1063  LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1064  PatternRewriter &rewriter) const override {
1065  // Check whether the parent funcOp is valid.
1066  auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
1067  if (!funcOp)
1068  return failure();
1069 
1070  FunctionType fnType = funcOp.getFunctionType();
1071  OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
1072  Location loc = returnOp.getLoc();
1073 
1074  // For the new return op.
1075  SmallVector<Value> newOperands;
1076 
1077  // Enumerate through the results.
1078  for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
1079  // Check whether the argument is of vector type.
1080  auto origVecType = dyn_cast<VectorType>(origType);
1081  if (!origVecType) {
1082  oneToNTypeMapping.addInputs(origResultNo, origType);
1083  newOperands.push_back(returnOp.getOperand(origResultNo));
1084  continue;
1085  }
1086  // Check whether the vector needs unrolling.
1087  auto targetShape = getTargetShape(origVecType);
1088  if (!targetShape) {
1089  // The original argument can be used.
1090  oneToNTypeMapping.addInputs(origResultNo, origType);
1091  newOperands.push_back(returnOp.getOperand(origResultNo));
1092  continue;
1093  }
1094  VectorType unrolledType =
1095  VectorType::get(*targetShape, origVecType.getElementType());
1096 
1097  // Create `vector.extract_strided_slice` ops to form legal vectors from
1098  // the original operand of illegal type.
1099  auto originalShape =
1100  llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1101  SmallVector<int64_t> strides(targetShape->size(), 1);
1102  SmallVector<Type> newTypes;
1103  Value returnValue = returnOp.getOperand(origResultNo);
1104  for (SmallVector<int64_t> offsets :
1105  StaticTileOffsetRange(originalShape, *targetShape)) {
1106  Value result = rewriter.create<vector::ExtractStridedSliceOp>(
1107  loc, returnValue, offsets, *targetShape, strides);
1108  newOperands.push_back(result);
1109  newTypes.push_back(unrolledType);
1110  }
1111  oneToNTypeMapping.addInputs(origResultNo, newTypes);
1112  }
1113 
1114  // Change the function signature.
1115  auto newFnType =
1116  FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
1117  TypeRange(oneToNTypeMapping.getConvertedTypes()));
1118  rewriter.modifyOpInPlace(funcOp,
1119  [&] { funcOp.setFunctionType(newFnType); });
1120 
1121  // Replace the return op using the new operands. This will automatically
1122  // update the entry block as well.
1123  rewriter.replaceOp(returnOp,
1124  rewriter.create<func::ReturnOp>(loc, newOperands));
1125 
1126  return success();
1127  }
1128 };
1129 
1130 } // namespace
1131 
1132 //===----------------------------------------------------------------------===//
1133 // Public function for builtin variables
1134 //===----------------------------------------------------------------------===//
1135 
1137  spirv::BuiltIn builtin,
1138  Type integerType, OpBuilder &builder,
1139  StringRef prefix, StringRef suffix) {
1141  if (!parent) {
1142  op->emitError("expected operation to be within a module-like op");
1143  return nullptr;
1144  }
1145 
1146  spirv::GlobalVariableOp varOp =
1147  getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
1148  builtin, integerType, builder, prefix, suffix);
1149  Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
1150  return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
1151 }
1152 
1153 //===----------------------------------------------------------------------===//
1154 // Public function for pushing constant storage
1155 //===----------------------------------------------------------------------===//
1156 
1157 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
1158  unsigned offset, Type integerType,
1159  OpBuilder &builder) {
1160  Location loc = op->getLoc();
1162  if (!parent) {
1163  op->emitError("expected operation to be within a module-like op");
1164  return nullptr;
1165  }
1166 
1167  spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1168  loc, parent->getRegion(0).front(), elementCount, builder, integerType);
1169 
1170  Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
1171  Value offsetOp = builder.create<spirv::ConstantOp>(
1172  loc, integerType, builder.getI32IntegerAttr(offset));
1173  auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
1174  auto acOp = builder.create<spirv::AccessChainOp>(
1175  loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp}));
1176  return builder.create<spirv::LoadOp>(loc, acOp);
1177 }
1178 
1179 //===----------------------------------------------------------------------===//
1180 // Public functions for index calculation
1181 //===----------------------------------------------------------------------===//
1182 
1184  int64_t offset, Type integerType,
1185  Location loc, OpBuilder &builder) {
1186  assert(indices.size() == strides.size() &&
1187  "must provide indices for all dimensions");
1188 
1189  // TODO: Consider moving to use affine.apply and patterns converting
1190  // affine.apply to standard ops. This needs converting to SPIR-V passes to be
1191  // broken down into progressive small steps so we can have intermediate steps
1192  // using other dialects. At the moment SPIR-V is the final sink.
1193 
1194  Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
1195  loc, integerType, IntegerAttr::get(integerType, offset));
1196  for (const auto &index : llvm::enumerate(indices)) {
1197  Value strideVal = builder.createOrFold<spirv::ConstantOp>(
1198  loc, integerType,
1199  IntegerAttr::get(integerType, strides[index.index()]));
1200  Value update =
1201  builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1202  linearizedIndex =
1203  builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1204  }
1205  return linearizedIndex;
1206 }
1207 
1209  MemRefType baseType, Value basePtr,
1210  ValueRange indices, Location loc,
1211  OpBuilder &builder) {
1212  // Get base and offset of the MemRefType and verify they are static.
1213 
1214  int64_t offset;
1215  SmallVector<int64_t, 4> strides;
1216  if (failed(getStridesAndOffset(baseType, strides, offset)) ||
1217  llvm::is_contained(strides, ShapedType::kDynamic) ||
1218  ShapedType::isDynamic(offset)) {
1219  return nullptr;
1220  }
1221 
1222  auto indexType = typeConverter.getIndexType();
1223 
1224  SmallVector<Value, 2> linearizedIndices;
1225  auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
1226 
1227  // Add a '0' at the start to index into the struct.
1228  linearizedIndices.push_back(zero);
1229 
1230  if (baseType.getRank() == 0) {
1231  linearizedIndices.push_back(zero);
1232  } else {
1233  linearizedIndices.push_back(
1234  linearizeIndex(indices, strides, offset, indexType, loc, builder));
1235  }
1236  return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
1237 }
1238 
1240  MemRefType baseType, Value basePtr,
1241  ValueRange indices, Location loc,
1242  OpBuilder &builder) {
1243  // Get base and offset of the MemRefType and verify they are static.
1244 
1245  int64_t offset;
1246  SmallVector<int64_t, 4> strides;
1247  if (failed(getStridesAndOffset(baseType, strides, offset)) ||
1248  llvm::is_contained(strides, ShapedType::kDynamic) ||
1249  ShapedType::isDynamic(offset)) {
1250  return nullptr;
1251  }
1252 
1253  auto indexType = typeConverter.getIndexType();
1254 
1255  SmallVector<Value, 2> linearizedIndices;
1256  Value linearIndex;
1257  if (baseType.getRank() == 0) {
1258  linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
1259  } else {
1260  linearIndex =
1261  linearizeIndex(indices, strides, offset, indexType, loc, builder);
1262  }
1263  Type pointeeType =
1264  cast<spirv::PointerType>(basePtr.getType()).getPointeeType();
1265  if (isa<spirv::ArrayType>(pointeeType)) {
1266  linearizedIndices.push_back(linearIndex);
1267  return builder.create<spirv::AccessChainOp>(loc, basePtr,
1268  linearizedIndices);
1269  }
1270  return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
1271  linearizedIndices);
1272 }
1273 
1275  MemRefType baseType, Value basePtr,
1276  ValueRange indices, Location loc,
1277  OpBuilder &builder) {
1278 
1279  if (typeConverter.allows(spirv::Capability::Kernel)) {
1280  return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
1281  builder);
1282  }
1283 
1284  return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
1285  builder);
1286 }
1287 
1288 //===----------------------------------------------------------------------===//
1289 // SPIR-V TypeConverter
1290 //===----------------------------------------------------------------------===//
1291 
1294  : targetEnv(targetAttr), options(options) {
1295  // Add conversions. The order matters here: later ones will be tried earlier.
1296 
1297  // Allow all SPIR-V dialect specific types. This assumes all builtin types
1298  // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
1299  // were tried before.
1300  //
1301  // TODO: This assumes that the SPIR-V types are valid to use in the given
1302  // target environment, which should be the case if the whole pipeline is
1303  // driven by the same target environment. Still, we probably still want to
1304  // validate and convert to be safe.
1305  addConversion([](spirv::SPIRVType type) { return type; });
1306 
1307  addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
1308 
1309  addConversion([this](IntegerType intType) -> std::optional<Type> {
1310  if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
1311  return convertScalarType(this->targetEnv, this->options, scalarType);
1312  if (intType.getWidth() < 8)
1313  return convertSubByteIntegerType(this->options, intType);
1314  return Type();
1315  });
1316 
1317  addConversion([this](FloatType floatType) -> std::optional<Type> {
1318  if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
1319  return convertScalarType(this->targetEnv, this->options, scalarType);
1320  return Type();
1321  });
1322 
1323  addConversion([this](ComplexType complexType) {
1324  return convertComplexType(this->targetEnv, this->options, complexType);
1325  });
1326 
1327  addConversion([this](VectorType vectorType) {
1328  return convertVectorType(this->targetEnv, this->options, vectorType);
1329  });
1330 
1331  addConversion([this](TensorType tensorType) {
1332  return convertTensorType(this->targetEnv, this->options, tensorType);
1333  });
1334 
1335  addConversion([this](MemRefType memRefType) {
1336  return convertMemrefType(this->targetEnv, this->options, memRefType);
1337  });
1338 
1339  // Register some last line of defense casting logic.
1341  [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1342  return castToSourceType(this->targetEnv, builder, type, inputs, loc);
1343  });
1344  addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
1345  Location loc) {
1346  auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
1347  return std::optional<Value>(cast.getResult(0));
1348  });
1349 }
1350 
1352  return ::getIndexType(getContext(), options);
1353 }
1354 
1355 MLIRContext *SPIRVTypeConverter::getContext() const {
1356  return targetEnv.getAttr().getContext();
1357 }
1358 
1359 bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
1360  return targetEnv.allows(capability);
1361 }
1362 
1363 //===----------------------------------------------------------------------===//
1364 // SPIR-V ConversionTarget
1365 //===----------------------------------------------------------------------===//
1366 
1367 std::unique_ptr<SPIRVConversionTarget>
1369  std::unique_ptr<SPIRVConversionTarget> target(
1370  // std::make_unique does not work here because the constructor is private.
1371  new SPIRVConversionTarget(targetAttr));
1372  SPIRVConversionTarget *targetPtr = target.get();
1373  target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1374  // We need to capture the raw pointer here because it is stable:
1375  // target will be destroyed once this function is returned.
1376  [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
1377  return target;
1378 }
1379 
1380 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
1381  : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
1382 
1383 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
1384  // Make sure this op is available at the given version. Ops not implementing
1385  // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
1386  // SPIR-V versions.
1387  if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1388  std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1389  if (minVersion && *minVersion > this->targetEnv.getVersion()) {
1390  LLVM_DEBUG(llvm::dbgs()
1391  << op->getName() << " illegal: requiring min version "
1392  << spirv::stringifyVersion(*minVersion) << "\n");
1393  return false;
1394  }
1395  }
1396  if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1397  std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1398  if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
1399  LLVM_DEBUG(llvm::dbgs()
1400  << op->getName() << " illegal: requiring max version "
1401  << spirv::stringifyVersion(*maxVersion) << "\n");
1402  return false;
1403  }
1404  }
1405 
1406  // Make sure this op's required extensions are allowed to use. Ops not
1407  // implementing QueryExtensionInterface do not require extensions to be
1408  // available.
1409  if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1410  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1411  extensions.getExtensions())))
1412  return false;
1413 
1414  // Make sure this op's required extensions are allowed to use. Ops not
1415  // implementing QueryCapabilityInterface do not require capabilities to be
1416  // available.
1417  if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1418  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1419  capabilities.getCapabilities())))
1420  return false;
1421 
1422  SmallVector<Type, 4> valueTypes;
1423  valueTypes.append(op->operand_type_begin(), op->operand_type_end());
1424  valueTypes.append(op->result_type_begin(), op->result_type_end());
1425 
1426  // Ensure that all types have been converted to SPIRV types.
1427  if (llvm::any_of(valueTypes,
1428  [](Type t) { return !isa<spirv::SPIRVType>(t); }))
1429  return false;
1430 
1431  // Special treatment for global variables, whose type requirements are
1432  // conveyed by type attributes.
1433  if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1434  valueTypes.push_back(globalVar.getType());
1435 
1436  // Make sure the op's operands/results use types that are allowed by the
1437  // target environment.
1438  SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
1439  SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
1440  for (Type valueType : valueTypes) {
1441  typeExtensions.clear();
1442  cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1443  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1444  typeExtensions)))
1445  return false;
1446 
1447  typeCapabilities.clear();
1448  cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1449  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1450  typeCapabilities)))
1451  return false;
1452  }
1453 
1454  return true;
1455 }
1456 
1457 //===----------------------------------------------------------------------===//
1458 // Public functions for populating patterns
1459 //===----------------------------------------------------------------------===//
1460 
1462  RewritePatternSet &patterns) {
1463  patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
1464 }
1465 
1467  patterns.add<FuncOpVectorUnroll>(patterns.getContext());
1468 }
1469 
1471  patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
1472 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
Block represents an ordered list of Operations.
Definition: Block.h:31
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:159
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:200
OpListType & getOperations()
Definition: Block.h:135
Operation & front()
Definition: Block.h:151
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:191
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
FloatType getF32Type()
Definition: Builders.cpp:63
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:96
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Stores a 1:N mapping of types and provides several useful accessors.
TypeRange getConvertedTypes(unsigned originalTypeNo) const
Returns the list of types that corresponds to the original type at the given index.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:242
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:322
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
operand_type_iterator operand_type_end()
Definition: Operation.h:391
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_iterator result_type_end()
Definition: Operation.h:422
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
result_type_iterator result_type_begin()
Definition: Operation.h:421
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
operand_type_iterator operand_type_begin()
Definition: Operation.h:390
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:96
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal replacement value...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting an illegal (source) value...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:79
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:102
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:481
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:538
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
Definition: TargetAndABI.h:62
MLIRContext * getContext() const
Returns the MLIRContext.
OpFoldResult linearizeIndex(ArrayRef< OpFoldResult > multiIndex, ArrayRef< OpFoldResult > basis, ImplicitLocOpBuilder &builder)
Definition: Utils.cpp:1878
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362