MLIR  21.0.0git
SPIRVConversion.cpp
Go to the documentation of this file.
1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Support/LLVM.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/LogicalResult.h"
37 #include "llvm/Support/MathExtras.h"
38 
39 #include <functional>
40 #include <optional>
41 
42 #define DEBUG_TYPE "mlir-spirv-conversion"
43 
44 using namespace mlir;
45 
46 namespace {
47 
48 //===----------------------------------------------------------------------===//
49 // Utility functions
50 //===----------------------------------------------------------------------===//
51 
52 static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
53  LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
54  if (vecType.isScalable()) {
55  LLVM_DEBUG(llvm::dbgs()
56  << "--scalable vectors are not supported -> BAIL\n");
57  return std::nullopt;
58  }
59  SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
60  std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(
61  1, mlir::spirv::getComputeVectorSize(vecType.getShape().back()));
62  if (!targetShape) {
63  LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
64  return std::nullopt;
65  }
66  auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
67  if (!maybeShapeRatio) {
68  LLVM_DEBUG(llvm::dbgs()
69  << "--could not compute integral shape ratio -> BAIL\n");
70  return std::nullopt;
71  }
72  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
73  LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
74  return std::nullopt;
75  }
76  LLVM_DEBUG(llvm::dbgs()
77  << "--found an integral shape ratio to unroll to -> SUCCESS\n");
78  return targetShape;
79 }
80 
81 /// Checks that `candidates` extension requirements are possible to be satisfied
82 /// with the given `targetEnv`.
83 ///
84 /// `candidates` is a vector of vector for extension requirements following
85 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
86 /// convention.
87 template <typename LabelT>
88 static LogicalResult checkExtensionRequirements(
89  LabelT label, const spirv::TargetEnv &targetEnv,
91  for (const auto &ors : candidates) {
92  if (targetEnv.allows(ors))
93  continue;
94 
95  LLVM_DEBUG({
96  SmallVector<StringRef> extStrings;
97  for (spirv::Extension ext : ors)
98  extStrings.push_back(spirv::stringifyExtension(ext));
99 
100  llvm::dbgs() << label << " illegal: requires at least one extension in ["
101  << llvm::join(extStrings, ", ")
102  << "] but none allowed in target environment\n";
103  });
104  return failure();
105  }
106  return success();
107 }
108 
109 /// Checks that `candidates`capability requirements are possible to be satisfied
110 /// with the given `isAllowedFn`.
111 ///
112 /// `candidates` is a vector of vector for capability requirements following
113 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
114 /// convention.
115 template <typename LabelT>
116 static LogicalResult checkCapabilityRequirements(
117  LabelT label, const spirv::TargetEnv &targetEnv,
118  const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
119  for (const auto &ors : candidates) {
120  if (targetEnv.allows(ors))
121  continue;
122 
123  LLVM_DEBUG({
124  SmallVector<StringRef> capStrings;
125  for (spirv::Capability cap : ors)
126  capStrings.push_back(spirv::stringifyCapability(cap));
127 
128  llvm::dbgs() << label << " illegal: requires at least one capability in ["
129  << llvm::join(capStrings, ", ")
130  << "] but none allowed in target environment\n";
131  });
132  return failure();
133  }
134  return success();
135 }
136 
137 /// Returns true if the given `storageClass` needs explicit layout when used in
138 /// Shader environments.
139 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
140  switch (storageClass) {
141  case spirv::StorageClass::PhysicalStorageBuffer:
142  case spirv::StorageClass::PushConstant:
143  case spirv::StorageClass::StorageBuffer:
144  case spirv::StorageClass::Uniform:
145  return true;
146  default:
147  return false;
148  }
149 }
150 
151 /// Wraps the given `elementType` in a struct and gets the pointer to the
152 /// struct. This is used to satisfy Vulkan interface requirements.
153 static spirv::PointerType
154 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
155  auto structType = needsExplicitLayout(storageClass)
156  ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
157  : spirv::StructType::get(elementType);
158  return spirv::PointerType::get(structType, storageClass);
159 }
160 
161 //===----------------------------------------------------------------------===//
162 // Type Conversion
163 //===----------------------------------------------------------------------===//
164 
165 static spirv::ScalarType getIndexType(MLIRContext *ctx,
167  return cast<spirv::ScalarType>(
168  IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
169 }
170 
171 // TODO: This is a utility function that should probably be exposed by the
172 // SPIR-V dialect. Keeping it local till the use case arises.
173 static std::optional<int64_t>
174 getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
175  if (isa<spirv::ScalarType>(type)) {
176  auto bitWidth = type.getIntOrFloatBitWidth();
177  // According to the SPIR-V spec:
178  // "There is no physical size or bit pattern defined for values with boolean
179  // type. If they are stored (in conjunction with OpVariable), they can only
180  // be used with logical addressing operations, not physical, and only with
181  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
182  // Private, Function, Input, and Output."
183  if (bitWidth == 1)
184  return std::nullopt;
185  return bitWidth / 8;
186  }
187 
188  if (auto complexType = dyn_cast<ComplexType>(type)) {
189  auto elementSize = getTypeNumBytes(options, complexType.getElementType());
190  if (!elementSize)
191  return std::nullopt;
192  return 2 * *elementSize;
193  }
194 
195  if (auto vecType = dyn_cast<VectorType>(type)) {
196  auto elementSize = getTypeNumBytes(options, vecType.getElementType());
197  if (!elementSize)
198  return std::nullopt;
199  return vecType.getNumElements() * *elementSize;
200  }
201 
202  if (auto memRefType = dyn_cast<MemRefType>(type)) {
203  // TODO: Layout should also be controlled by the ABI attributes. For now
204  // using the layout from MemRef.
205  int64_t offset;
206  SmallVector<int64_t, 4> strides;
207  if (!memRefType.hasStaticShape() ||
208  failed(memRefType.getStridesAndOffset(strides, offset)))
209  return std::nullopt;
210 
211  // To get the size of the memref object in memory, the total size is the
212  // max(stride * dimension-size) computed for all dimensions times the size
213  // of the element.
214  auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
215  if (!elementSize)
216  return std::nullopt;
217 
218  if (memRefType.getRank() == 0)
219  return elementSize;
220 
221  auto dims = memRefType.getShape();
222  if (llvm::is_contained(dims, ShapedType::kDynamic) ||
223  ShapedType::isDynamic(offset) ||
224  llvm::is_contained(strides, ShapedType::kDynamic))
225  return std::nullopt;
226 
227  int64_t memrefSize = -1;
228  for (const auto &shape : enumerate(dims))
229  memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
230 
231  return (offset + memrefSize) * *elementSize;
232  }
233 
234  if (auto tensorType = dyn_cast<TensorType>(type)) {
235  if (!tensorType.hasStaticShape())
236  return std::nullopt;
237 
238  auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
239  if (!elementSize)
240  return std::nullopt;
241 
242  int64_t size = *elementSize;
243  for (auto shape : tensorType.getShape())
244  size *= shape;
245 
246  return size;
247  }
248 
249  // TODO: Add size computation for other types.
250  return std::nullopt;
251 }
252 
253 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
254 static Type
255 convertScalarType(const spirv::TargetEnv &targetEnv,
257  std::optional<spirv::StorageClass> storageClass = {}) {
258  // Get extension and capability requirements for the given type.
261  type.getExtensions(extensions, storageClass);
262  type.getCapabilities(capabilities, storageClass);
263 
264  // If all requirements are met, then we can accept this type as-is.
265  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
266  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
267  return type;
268 
269  // Otherwise we need to adjust the type, which really means adjusting the
270  // bitwidth given this is a scalar type.
271  if (!options.emulateLT32BitScalarTypes)
272  return nullptr;
273 
274  // We only emulate narrower scalar types here and do not truncate results.
275  if (type.getIntOrFloatBitWidth() > 32) {
276  LLVM_DEBUG(llvm::dbgs()
277  << type
278  << " not converted to 32-bit for SPIR-V to avoid truncation\n");
279  return nullptr;
280  }
281 
282  if (auto floatType = dyn_cast<FloatType>(type)) {
283  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
284  return Builder(targetEnv.getContext()).getF32Type();
285  }
286 
287  auto intType = cast<IntegerType>(type);
288  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
289  return IntegerType::get(targetEnv.getContext(), /*width=*/32,
290  intType.getSignedness());
291 }
292 
293 /// Converts a sub-byte integer `type` to i32 regardless of target environment.
294 /// Returns a nullptr for unsupported integer types, including non sub-byte
295 /// types.
296 ///
297 /// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
298 /// the above given that these sub-byte types are not supported at all in
299 /// SPIR-V; there are no compute/storage capability for them like other
300 /// supported integer types.
301 static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
302  IntegerType type) {
303  if (type.getWidth() > 8) {
304  LLVM_DEBUG(llvm::dbgs() << "not a subbyte type\n");
305  return nullptr;
306  }
307  if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
308  LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
309  return nullptr;
310  }
311 
312  if (!llvm::isPowerOf2_32(type.getWidth())) {
313  LLVM_DEBUG(llvm::dbgs()
314  << "unsupported non-power-of-two bitwidth in sub-byte" << type
315  << "\n");
316  return nullptr;
317  }
318 
319  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
320  return IntegerType::get(type.getContext(), /*width=*/32,
321  type.getSignedness());
322 }
323 
324 /// Returns a type with the same shape but with any index element type converted
325 /// to the matching integer type. This is a noop when the element type is not
326 /// the index type.
327 static ShapedType
328 convertIndexElementType(ShapedType type,
330  Type indexType = dyn_cast<IndexType>(type.getElementType());
331  if (!indexType)
332  return type;
333 
334  return type.clone(getIndexType(type.getContext(), options));
335 }
336 
337 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
338 static Type
339 convertVectorType(const spirv::TargetEnv &targetEnv,
340  const SPIRVConversionOptions &options, VectorType type,
341  std::optional<spirv::StorageClass> storageClass = {}) {
342  type = cast<VectorType>(convertIndexElementType(type, options));
343  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
344  if (!scalarType) {
345  // If this is not a spec allowed scalar type, try to handle sub-byte integer
346  // types.
347  auto intType = dyn_cast<IntegerType>(type.getElementType());
348  if (!intType) {
349  LLVM_DEBUG(llvm::dbgs()
350  << type
351  << " illegal: cannot convert non-scalar element type\n");
352  return nullptr;
353  }
354 
355  Type elementType = convertSubByteIntegerType(options, intType);
356  if (!elementType)
357  return nullptr;
358 
359  if (type.getRank() <= 1 && type.getNumElements() == 1)
360  return elementType;
361 
362  if (type.getNumElements() > 4) {
363  LLVM_DEBUG(llvm::dbgs()
364  << type << " illegal: > 4-element unimplemented\n");
365  return nullptr;
366  }
367 
368  return VectorType::get(type.getShape(), elementType);
369  }
370 
371  if (type.getRank() <= 1 && type.getNumElements() == 1)
372  return convertScalarType(targetEnv, options, scalarType, storageClass);
373 
374  if (!spirv::CompositeType::isValid(type)) {
375  LLVM_DEBUG(llvm::dbgs()
376  << type << " illegal: not a valid composite type\n");
377  return nullptr;
378  }
379 
380  // Get extension and capability requirements for the given type.
383  cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
384  cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
385 
386  // If all requirements are met, then we can accept this type as-is.
387  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
388  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
389  return type;
390 
391  auto elementType =
392  convertScalarType(targetEnv, options, scalarType, storageClass);
393  if (elementType)
394  return VectorType::get(type.getShape(), elementType);
395  return nullptr;
396 }
397 
398 static Type
399 convertComplexType(const spirv::TargetEnv &targetEnv,
400  const SPIRVConversionOptions &options, ComplexType type,
401  std::optional<spirv::StorageClass> storageClass = {}) {
402  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
403  if (!scalarType) {
404  LLVM_DEBUG(llvm::dbgs()
405  << type << " illegal: cannot convert non-scalar element type\n");
406  return nullptr;
407  }
408 
409  auto elementType =
410  convertScalarType(targetEnv, options, scalarType, storageClass);
411  if (!elementType)
412  return nullptr;
413  if (elementType != type.getElementType()) {
414  LLVM_DEBUG(llvm::dbgs()
415  << type << " illegal: complex type emulation unsupported\n");
416  return nullptr;
417  }
418 
419  return VectorType::get(2, elementType);
420 }
421 
422 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
423 ///
424 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
425 /// create composite constants with OpConstantComposite to embed relative large
426 /// constant values and use OpCompositeExtract and OpCompositeInsert to
427 /// manipulate, like what we do for vectors.
428 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
430  TensorType type) {
431  // TODO: Handle dynamic shapes.
432  if (!type.hasStaticShape()) {
433  LLVM_DEBUG(llvm::dbgs()
434  << type << " illegal: dynamic shape unimplemented\n");
435  return nullptr;
436  }
437 
438  type = cast<TensorType>(convertIndexElementType(type, options));
439  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
440  if (!scalarType) {
441  LLVM_DEBUG(llvm::dbgs()
442  << type << " illegal: cannot convert non-scalar element type\n");
443  return nullptr;
444  }
445 
446  std::optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
447  std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);
448  if (!scalarSize || !tensorSize) {
449  LLVM_DEBUG(llvm::dbgs()
450  << type << " illegal: cannot deduce element count\n");
451  return nullptr;
452  }
453 
454  int64_t arrayElemCount = *tensorSize / *scalarSize;
455  if (arrayElemCount == 0) {
456  LLVM_DEBUG(llvm::dbgs()
457  << type << " illegal: cannot handle zero-element tensors\n");
458  return nullptr;
459  }
460 
461  Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
462  if (!arrayElemType)
463  return nullptr;
464  std::optional<int64_t> arrayElemSize =
465  getTypeNumBytes(options, arrayElemType);
466  if (!arrayElemSize) {
467  LLVM_DEBUG(llvm::dbgs()
468  << type << " illegal: cannot deduce converted element size\n");
469  return nullptr;
470  }
471 
472  return spirv::ArrayType::get(arrayElemType, arrayElemCount);
473 }
474 
475 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
477  MemRefType type,
478  spirv::StorageClass storageClass) {
479  unsigned numBoolBits = options.boolNumBits;
480  if (numBoolBits != 8) {
481  LLVM_DEBUG(llvm::dbgs()
482  << "using non-8-bit storage for bool types unimplemented");
483  return nullptr;
484  }
485  auto elementType = dyn_cast<spirv::ScalarType>(
486  IntegerType::get(type.getContext(), numBoolBits));
487  if (!elementType)
488  return nullptr;
489  Type arrayElemType =
490  convertScalarType(targetEnv, options, elementType, storageClass);
491  if (!arrayElemType)
492  return nullptr;
493  std::optional<int64_t> arrayElemSize =
494  getTypeNumBytes(options, arrayElemType);
495  if (!arrayElemSize) {
496  LLVM_DEBUG(llvm::dbgs()
497  << type << " illegal: cannot deduce converted element size\n");
498  return nullptr;
499  }
500 
501  if (!type.hasStaticShape()) {
502  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
503  // to the element.
504  if (targetEnv.allows(spirv::Capability::Kernel))
505  return spirv::PointerType::get(arrayElemType, storageClass);
506  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
507  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
508  // For Vulkan we need extra wrapping struct and array to satisfy interface
509  // needs.
510  return wrapInStructAndGetPointer(arrayType, storageClass);
511  }
512 
513  if (type.getNumElements() == 0) {
514  LLVM_DEBUG(llvm::dbgs()
515  << type << " illegal: zero-element memrefs are not supported\n");
516  return nullptr;
517  }
518 
519  int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
520  int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
521  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
522  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
523  if (targetEnv.allows(spirv::Capability::Kernel))
524  return spirv::PointerType::get(arrayType, storageClass);
525  return wrapInStructAndGetPointer(arrayType, storageClass);
526 }
527 
528 static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
530  MemRefType type,
531  spirv::StorageClass storageClass) {
532  IntegerType elementType = cast<IntegerType>(type.getElementType());
533  Type arrayElemType = convertSubByteIntegerType(options, elementType);
534  if (!arrayElemType)
535  return nullptr;
536  int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType);
537 
538  if (!type.hasStaticShape()) {
539  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
540  // to the element.
541  if (targetEnv.allows(spirv::Capability::Kernel))
542  return spirv::PointerType::get(arrayElemType, storageClass);
543  int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
544  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
545  // For Vulkan we need extra wrapping struct and array to satisfy interface
546  // needs.
547  return wrapInStructAndGetPointer(arrayType, storageClass);
548  }
549 
550  if (type.getNumElements() == 0) {
551  LLVM_DEBUG(llvm::dbgs()
552  << type << " illegal: zero-element memrefs are not supported\n");
553  return nullptr;
554  }
555 
556  int64_t memrefSize =
557  llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
558  int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
559  int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
560  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
561  if (targetEnv.allows(spirv::Capability::Kernel))
562  return spirv::PointerType::get(arrayType, storageClass);
563  return wrapInStructAndGetPointer(arrayType, storageClass);
564 }
565 
566 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
568  MemRefType type) {
569  auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
570  if (!attr) {
571  LLVM_DEBUG(
572  llvm::dbgs()
573  << type
574  << " illegal: expected memory space to be a SPIR-V storage class "
575  "attribute; please use MemorySpaceToStorageClassConverter to map "
576  "numeric memory spaces beforehand\n");
577  return nullptr;
578  }
579  spirv::StorageClass storageClass = attr.getValue();
580 
581  if (isa<IntegerType>(type.getElementType())) {
582  if (type.getElementTypeBitWidth() == 1)
583  return convertBoolMemrefType(targetEnv, options, type, storageClass);
584  if (type.getElementTypeBitWidth() < 8)
585  return convertSubByteMemrefType(targetEnv, options, type, storageClass);
586  }
587 
588  Type arrayElemType;
589  Type elementType = type.getElementType();
590  if (auto vecType = dyn_cast<VectorType>(elementType)) {
591  arrayElemType =
592  convertVectorType(targetEnv, options, vecType, storageClass);
593  } else if (auto complexType = dyn_cast<ComplexType>(elementType)) {
594  arrayElemType =
595  convertComplexType(targetEnv, options, complexType, storageClass);
596  } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
597  arrayElemType =
598  convertScalarType(targetEnv, options, scalarType, storageClass);
599  } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
600  type = cast<MemRefType>(convertIndexElementType(type, options));
601  arrayElemType = type.getElementType();
602  } else {
603  LLVM_DEBUG(
604  llvm::dbgs()
605  << type
606  << " unhandled: can only convert scalar or vector element type\n");
607  return nullptr;
608  }
609  if (!arrayElemType)
610  return nullptr;
611 
612  std::optional<int64_t> arrayElemSize =
613  getTypeNumBytes(options, arrayElemType);
614  if (!arrayElemSize) {
615  LLVM_DEBUG(llvm::dbgs()
616  << type << " illegal: cannot deduce converted element size\n");
617  return nullptr;
618  }
619 
620  if (!type.hasStaticShape()) {
621  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
622  // to the element.
623  if (targetEnv.allows(spirv::Capability::Kernel))
624  return spirv::PointerType::get(arrayElemType, storageClass);
625  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
626  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
627  // For Vulkan we need extra wrapping struct and array to satisfy interface
628  // needs.
629  return wrapInStructAndGetPointer(arrayType, storageClass);
630  }
631 
632  std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);
633  if (!memrefSize) {
634  LLVM_DEBUG(llvm::dbgs()
635  << type << " illegal: cannot deduce element count\n");
636  return nullptr;
637  }
638 
639  if (*memrefSize == 0) {
640  LLVM_DEBUG(llvm::dbgs()
641  << type << " illegal: zero-element memrefs are not supported\n");
642  return nullptr;
643  }
644 
645  int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
646  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
647  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
648  if (targetEnv.allows(spirv::Capability::Kernel))
649  return spirv::PointerType::get(arrayType, storageClass);
650  return wrapInStructAndGetPointer(arrayType, storageClass);
651 }
652 
653 //===----------------------------------------------------------------------===//
654 // Type casting materialization
655 //===----------------------------------------------------------------------===//
656 
657 /// Converts the given `inputs` to the original source `type` considering the
658 /// `targetEnv`'s capabilities.
659 ///
660 /// This function is meant to be used for source materialization in type
661 /// converters. When the type converter needs to materialize a cast op back
662 /// to some original source type, we need to check whether the original source
663 /// type is supported in the target environment. If so, we can insert legal
664 /// SPIR-V cast ops accordingly.
665 ///
666 /// Note that in SPIR-V the capabilities for storage and compute are separate.
667 /// This function is meant to handle the **compute** side; so it does not
668 /// involve storage classes in its logic. The storage side is expected to be
669 /// handled by MemRef conversion logic.
670 static Value castToSourceType(const spirv::TargetEnv &targetEnv,
671  OpBuilder &builder, Type type, ValueRange inputs,
672  Location loc) {
673  // We can only cast one value in SPIR-V.
674  if (inputs.size() != 1) {
675  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
676  return castOp.getResult(0);
677  }
678  Value input = inputs.front();
679 
680  // Only support integer types for now. Floating point types to be implemented.
681  if (!isa<IntegerType>(type)) {
682  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
683  return castOp.getResult(0);
684  }
685  auto inputType = cast<IntegerType>(input.getType());
686 
687  auto scalarType = dyn_cast<spirv::ScalarType>(type);
688  if (!scalarType) {
689  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
690  return castOp.getResult(0);
691  }
692 
693  // Only support source type with a smaller bitwidth. This would mean we are
694  // truncating to go back so we don't need to worry about the signedness.
695  // For extension, we cannot have enough signal here to decide which op to use.
696  if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
697  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
698  return castOp.getResult(0);
699  }
700 
701  // Boolean values would need to use different ops than normal integer values.
702  if (type.isInteger(1)) {
703  Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
704  return builder.create<spirv::IEqualOp>(loc, input, one);
705  }
706 
707  // Check that the source integer type is supported by the environment.
710  scalarType.getExtensions(exts);
711  scalarType.getCapabilities(caps);
712  if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
713  failed(checkExtensionRequirements(type, targetEnv, exts))) {
714  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
715  return castOp.getResult(0);
716  }
717 
718  // We've already made sure this is truncating previously, so we don't need to
719  // care about signedness here. Still try to use a corresponding op for better
720  // consistency though.
721  if (type.isSignedInteger()) {
722  return builder.create<spirv::SConvertOp>(loc, type, input);
723  }
724  return builder.create<spirv::UConvertOp>(loc, type, input);
725 }
726 
727 //===----------------------------------------------------------------------===//
728 // Builtin Variables
729 //===----------------------------------------------------------------------===//
730 
731 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
732  spirv::BuiltIn builtin) {
733  // Look through all global variables in the given `body` block and check if
734  // there is a spirv.GlobalVariable that has the same `builtin` attribute.
735  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
736  if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
737  spirv::SPIRVDialect::getAttributeName(
738  spirv::Decoration::BuiltIn))) {
739  auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
740  if (varBuiltIn == builtin) {
741  return varOp;
742  }
743  }
744  }
745  return nullptr;
746 }
747 
748 /// Gets name of global variable for a builtin.
749 std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
750  StringRef suffix) {
751  return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
752 }
753 
754 /// Gets or inserts a global variable for a builtin within `body` block.
755 static spirv::GlobalVariableOp
756 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
757  Type integerType, OpBuilder &builder,
758  StringRef prefix, StringRef suffix) {
759  if (auto varOp = getBuiltinVariable(body, builtin))
760  return varOp;
761 
762  OpBuilder::InsertionGuard guard(builder);
763  builder.setInsertionPointToStart(&body);
764 
765  spirv::GlobalVariableOp newVarOp;
766  switch (builtin) {
767  case spirv::BuiltIn::NumWorkgroups:
768  case spirv::BuiltIn::WorkgroupSize:
769  case spirv::BuiltIn::WorkgroupId:
770  case spirv::BuiltIn::LocalInvocationId:
771  case spirv::BuiltIn::GlobalInvocationId: {
772  auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
773  spirv::StorageClass::Input);
774  std::string name = getBuiltinVarName(builtin, prefix, suffix);
775  newVarOp =
776  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
777  break;
778  }
779  case spirv::BuiltIn::SubgroupId:
780  case spirv::BuiltIn::NumSubgroups:
781  case spirv::BuiltIn::SubgroupSize:
782  case spirv::BuiltIn::SubgroupLocalInvocationId: {
783  auto ptrType =
784  spirv::PointerType::get(integerType, spirv::StorageClass::Input);
785  std::string name = getBuiltinVarName(builtin, prefix, suffix);
786  newVarOp =
787  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
788  break;
789  }
790  default:
791  emitError(loc, "unimplemented builtin variable generation for ")
792  << stringifyBuiltIn(builtin);
793  }
794  return newVarOp;
795 }
796 
797 //===----------------------------------------------------------------------===//
798 // Push constant storage
799 //===----------------------------------------------------------------------===//
800 
801 /// Returns the pointer type for the push constant storage containing
802 /// `elementCount` 32-bit integer values.
803 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
804  Builder &builder,
805  Type indexType) {
806  auto arrayType = spirv::ArrayType::get(indexType, elementCount,
807  /*stride=*/4);
808  auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
809  return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
810 }
811 
812 /// Returns the push constant varible containing `elementCount` 32-bit integer
813 /// values in `body`. Returns null op if such an op does not exit.
814 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
815  unsigned elementCount) {
816  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
817  auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
818  if (!ptrType)
819  continue;
820 
821  // Note that Vulkan requires "There must be no more than one push constant
822  // block statically used per shader entry point." So we should always reuse
823  // the existing one.
824  if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
825  auto numElements = cast<spirv::ArrayType>(
826  cast<spirv::StructType>(ptrType.getPointeeType())
827  .getElementType(0))
828  .getNumElements();
829  if (numElements == elementCount)
830  return varOp;
831  }
832  }
833  return nullptr;
834 }
835 
836 /// Gets or inserts a global variable for push constant storage containing
837 /// `elementCount` 32-bit integer values in `block`.
838 static spirv::GlobalVariableOp
839 getOrInsertPushConstantVariable(Location loc, Block &block,
840  unsigned elementCount, OpBuilder &b,
841  Type indexType) {
842  if (auto varOp = getPushConstantVariable(block, elementCount))
843  return varOp;
844 
845  auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
846  auto type = getPushConstantStorageType(elementCount, builder, indexType);
847  const char *name = "__push_constant_var__";
848  return builder.create<spirv::GlobalVariableOp>(loc, type, name,
849  /*initializer=*/nullptr);
850 }
851 
852 //===----------------------------------------------------------------------===//
853 // func::FuncOp Conversion Patterns
854 //===----------------------------------------------------------------------===//
855 
856 /// A pattern for rewriting function signature to convert arguments of functions
857 /// to be of valid SPIR-V types.
858 struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
860 
861  LogicalResult
862  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
863  ConversionPatternRewriter &rewriter) const override {
864  FunctionType fnType = funcOp.getFunctionType();
865  if (fnType.getNumResults() > 1)
866  return failure();
867 
868  TypeConverter::SignatureConversion signatureConverter(
869  fnType.getNumInputs());
870  for (const auto &argType : enumerate(fnType.getInputs())) {
871  auto convertedType = getTypeConverter()->convertType(argType.value());
872  if (!convertedType)
873  return failure();
874  signatureConverter.addInputs(argType.index(), convertedType);
875  }
876 
877  Type resultType;
878  if (fnType.getNumResults() == 1) {
879  resultType = getTypeConverter()->convertType(fnType.getResult(0));
880  if (!resultType)
881  return failure();
882  }
883 
884  // Create the converted spirv.func op.
885  auto newFuncOp = rewriter.create<spirv::FuncOp>(
886  funcOp.getLoc(), funcOp.getName(),
887  rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
888  resultType ? TypeRange(resultType)
889  : TypeRange()));
890 
891  // Copy over all attributes other than the function name and type.
892  for (const auto &namedAttr : funcOp->getAttrs()) {
893  if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
894  namedAttr.getName() != SymbolTable::getSymbolAttrName())
895  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
896  }
897 
898  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
899  newFuncOp.end());
900  if (failed(rewriter.convertRegionTypes(
901  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
902  return failure();
903  rewriter.eraseOp(funcOp);
904  return success();
905  }
906 };
907 
908 /// A pattern for rewriting function signature to convert vector arguments of
909 /// functions to be of valid types
910 struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
912 
913  LogicalResult matchAndRewrite(func::FuncOp funcOp,
914  PatternRewriter &rewriter) const override {
915  FunctionType fnType = funcOp.getFunctionType();
916 
917  // TODO: Handle declarations.
918  if (funcOp.isDeclaration()) {
919  LLVM_DEBUG(llvm::dbgs()
920  << fnType << " illegal: declarations are unsupported\n");
921  return failure();
922  }
923 
924  // Create a new func op with the original type and copy the function body.
925  auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(),
926  funcOp.getName(), fnType);
927  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
928  newFuncOp.end());
929 
930  Location loc = newFuncOp.getBody().getLoc();
931 
932  Block &entryBlock = newFuncOp.getBlocks().front();
933  OpBuilder::InsertionGuard guard(rewriter);
934  rewriter.setInsertionPointToStart(&entryBlock);
935 
936  TypeConverter::SignatureConversion oneToNTypeMapping(
937  fnType.getInputs().size());
938 
939  // For arguments that are of illegal types and require unrolling.
940  // `unrolledInputNums` stores the indices of arguments that result from
941  // unrolling in the new function signature. `newInputNo` is a counter.
942  SmallVector<size_t> unrolledInputNums;
943  size_t newInputNo = 0;
944 
945  // For arguments that are of legal types and do not require unrolling.
946  // `tmpOps` stores a mapping from temporary operations that serve as
947  // placeholders for new arguments that will be added later. These operations
948  // will be erased once the entry block's argument list is updated.
949  llvm::SmallDenseMap<Operation *, size_t> tmpOps;
950 
951  // This counts the number of new operations created.
952  size_t newOpCount = 0;
953 
954  // Enumerate through the arguments.
955  for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
956  // Check whether the argument is of vector type.
957  auto origVecType = dyn_cast<VectorType>(origType);
958  if (!origVecType) {
959  // We need a placeholder for the old argument that will be erased later.
960  Value result = rewriter.create<arith::ConstantOp>(
961  loc, origType, rewriter.getZeroAttr(origType));
962  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
963  tmpOps.insert({result.getDefiningOp(), newInputNo});
964  oneToNTypeMapping.addInputs(origInputNo, origType);
965  ++newInputNo;
966  ++newOpCount;
967  continue;
968  }
969  // Check whether the vector needs unrolling.
970  auto targetShape = getTargetShape(origVecType);
971  if (!targetShape) {
972  // We need a placeholder for the old argument that will be erased later.
973  Value result = rewriter.create<arith::ConstantOp>(
974  loc, origType, rewriter.getZeroAttr(origType));
975  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
976  tmpOps.insert({result.getDefiningOp(), newInputNo});
977  oneToNTypeMapping.addInputs(origInputNo, origType);
978  ++newInputNo;
979  ++newOpCount;
980  continue;
981  }
982  VectorType unrolledType =
983  VectorType::get(*targetShape, origVecType.getElementType());
984  auto originalShape =
985  llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
986 
987  // Prepare the result vector.
988  Value result = rewriter.create<arith::ConstantOp>(
989  loc, origVecType, rewriter.getZeroAttr(origVecType));
990  ++newOpCount;
991  // Prepare the placeholder for the new arguments that will be added later.
992  Value dummy = rewriter.create<arith::ConstantOp>(
993  loc, unrolledType, rewriter.getZeroAttr(unrolledType));
994  ++newOpCount;
995 
996  // Create the `vector.insert_strided_slice` ops.
997  SmallVector<int64_t> strides(targetShape->size(), 1);
998  SmallVector<Type> newTypes;
999  for (SmallVector<int64_t> offsets :
1000  StaticTileOffsetRange(originalShape, *targetShape)) {
1001  result = rewriter.create<vector::InsertStridedSliceOp>(
1002  loc, dummy, result, offsets, strides);
1003  newTypes.push_back(unrolledType);
1004  unrolledInputNums.push_back(newInputNo);
1005  ++newInputNo;
1006  ++newOpCount;
1007  }
1008  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1009  oneToNTypeMapping.addInputs(origInputNo, newTypes);
1010  }
1011 
1012  // Change the function signature.
1013  auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
1014  auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1015  rewriter.modifyOpInPlace(newFuncOp,
1016  [&] { newFuncOp.setFunctionType(newFnType); });
1017 
1018  // Update the arguments in the entry block.
1019  entryBlock.eraseArguments(0, fnType.getNumInputs());
1020  SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
1021  entryBlock.addArguments(convertedTypes, locs);
1022 
1023  // Replace all uses of placeholders for initially legal arguments with their
1024  // original function arguments (that were added to `newFuncOp`).
1025  for (auto &[placeholderOp, argIdx] : tmpOps) {
1026  if (!placeholderOp)
1027  continue;
1028  Value replacement = newFuncOp.getArgument(argIdx);
1029  rewriter.replaceAllUsesWith(placeholderOp->getResult(0), replacement);
1030  }
1031 
1032  // Replace dummy operands of new `vector.insert_strided_slice` ops with
1033  // their corresponding new function arguments. The new
1034  // `vector.insert_strided_slice` ops are inserted only into the entry block,
1035  // so iterating over that block is sufficient.
1036  size_t unrolledInputIdx = 0;
1037  for (auto [count, op] : enumerate(entryBlock.getOperations())) {
1038  Operation &curOp = op;
1039  // Since all newly created operations are in the beginning, reaching the
1040  // end of them means that any later `vector.insert_strided_slice` should
1041  // not be touched.
1042  if (count >= newOpCount)
1043  continue;
1044  if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1045  size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1046  rewriter.modifyOpInPlace(&curOp, [&] {
1047  curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1048  });
1049  ++unrolledInputIdx;
1050  }
1051  }
1052 
1053  // Erase the original funcOp. The `tmpOps` do not need to be erased since
1054  // they have no uses and will be handled by dead-code elimination.
1055  rewriter.eraseOp(funcOp);
1056  return success();
1057  }
1058 };
1059 
1060 //===----------------------------------------------------------------------===//
1061 // func::ReturnOp Conversion Patterns
1062 //===----------------------------------------------------------------------===//
1063 
1064 /// A pattern for rewriting function signature and the return op to convert
1065 /// vectors to be of valid types.
1066 struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
1068 
1069  LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1070  PatternRewriter &rewriter) const override {
1071  // Check whether the parent funcOp is valid.
1072  auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
1073  if (!funcOp)
1074  return failure();
1075 
1076  FunctionType fnType = funcOp.getFunctionType();
1077  TypeConverter::SignatureConversion oneToNTypeMapping(
1078  fnType.getResults().size());
1079  Location loc = returnOp.getLoc();
1080 
1081  // For the new return op.
1082  SmallVector<Value> newOperands;
1083 
1084  // Enumerate through the results.
1085  for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
1086  // Check whether the argument is of vector type.
1087  auto origVecType = dyn_cast<VectorType>(origType);
1088  if (!origVecType) {
1089  oneToNTypeMapping.addInputs(origResultNo, origType);
1090  newOperands.push_back(returnOp.getOperand(origResultNo));
1091  continue;
1092  }
1093  // Check whether the vector needs unrolling.
1094  auto targetShape = getTargetShape(origVecType);
1095  if (!targetShape) {
1096  // The original argument can be used.
1097  oneToNTypeMapping.addInputs(origResultNo, origType);
1098  newOperands.push_back(returnOp.getOperand(origResultNo));
1099  continue;
1100  }
1101  VectorType unrolledType =
1102  VectorType::get(*targetShape, origVecType.getElementType());
1103 
1104  // Create `vector.extract_strided_slice` ops to form legal vectors from
1105  // the original operand of illegal type.
1106  auto originalShape =
1107  llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1108  SmallVector<int64_t> strides(originalShape.size(), 1);
1109  SmallVector<int64_t> extractShape(originalShape.size(), 1);
1110  extractShape.back() = targetShape->back();
1111  SmallVector<Type> newTypes;
1112  Value returnValue = returnOp.getOperand(origResultNo);
1113  for (SmallVector<int64_t> offsets :
1114  StaticTileOffsetRange(originalShape, *targetShape)) {
1115  Value result = rewriter.create<vector::ExtractStridedSliceOp>(
1116  loc, returnValue, offsets, extractShape, strides);
1117  if (originalShape.size() > 1) {
1118  SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
1119  result =
1120  rewriter.create<vector::ExtractOp>(loc, result, extractIndices);
1121  }
1122  newOperands.push_back(result);
1123  newTypes.push_back(unrolledType);
1124  }
1125  oneToNTypeMapping.addInputs(origResultNo, newTypes);
1126  }
1127 
1128  // Change the function signature.
1129  auto newFnType =
1130  FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
1131  TypeRange(oneToNTypeMapping.getConvertedTypes()));
1132  rewriter.modifyOpInPlace(funcOp,
1133  [&] { funcOp.setFunctionType(newFnType); });
1134 
1135  // Replace the return op using the new operands. This will automatically
1136  // update the entry block as well.
1137  rewriter.replaceOp(returnOp,
1138  rewriter.create<func::ReturnOp>(loc, newOperands));
1139 
1140  return success();
1141  }
1142 };
1143 
1144 } // namespace
1145 
1146 //===----------------------------------------------------------------------===//
1147 // Public function for builtin variables
1148 //===----------------------------------------------------------------------===//
1149 
1151  spirv::BuiltIn builtin,
1152  Type integerType, OpBuilder &builder,
1153  StringRef prefix, StringRef suffix) {
1155  if (!parent) {
1156  op->emitError("expected operation to be within a module-like op");
1157  return nullptr;
1158  }
1159 
1160  spirv::GlobalVariableOp varOp =
1161  getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
1162  builtin, integerType, builder, prefix, suffix);
1163  Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
1164  return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
1165 }
1166 
1167 //===----------------------------------------------------------------------===//
1168 // Public function for pushing constant storage
1169 //===----------------------------------------------------------------------===//
1170 
1171 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
1172  unsigned offset, Type integerType,
1173  OpBuilder &builder) {
1174  Location loc = op->getLoc();
1176  if (!parent) {
1177  op->emitError("expected operation to be within a module-like op");
1178  return nullptr;
1179  }
1180 
1181  spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1182  loc, parent->getRegion(0).front(), elementCount, builder, integerType);
1183 
1184  Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
1185  Value offsetOp = builder.create<spirv::ConstantOp>(
1186  loc, integerType, builder.getI32IntegerAttr(offset));
1187  auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
1188  auto acOp = builder.create<spirv::AccessChainOp>(
1189  loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp}));
1190  return builder.create<spirv::LoadOp>(loc, acOp);
1191 }
1192 
1193 //===----------------------------------------------------------------------===//
1194 // Public functions for index calculation
1195 //===----------------------------------------------------------------------===//
1196 
1198  int64_t offset, Type integerType,
1199  Location loc, OpBuilder &builder) {
1200  assert(indices.size() == strides.size() &&
1201  "must provide indices for all dimensions");
1202 
1203  // TODO: Consider moving to use affine.apply and patterns converting
1204  // affine.apply to standard ops. This needs converting to SPIR-V passes to be
1205  // broken down into progressive small steps so we can have intermediate steps
1206  // using other dialects. At the moment SPIR-V is the final sink.
1207 
1208  Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
1209  loc, integerType, IntegerAttr::get(integerType, offset));
1210  for (const auto &index : llvm::enumerate(indices)) {
1211  Value strideVal = builder.createOrFold<spirv::ConstantOp>(
1212  loc, integerType,
1213  IntegerAttr::get(integerType, strides[index.index()]));
1214  Value update =
1215  builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1216  linearizedIndex =
1217  builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1218  }
1219  return linearizedIndex;
1220 }
1221 
1223  MemRefType baseType, Value basePtr,
1224  ValueRange indices, Location loc,
1225  OpBuilder &builder) {
1226  // Get base and offset of the MemRefType and verify they are static.
1227 
1228  int64_t offset;
1229  SmallVector<int64_t, 4> strides;
1230  if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1231  llvm::is_contained(strides, ShapedType::kDynamic) ||
1232  ShapedType::isDynamic(offset)) {
1233  return nullptr;
1234  }
1235 
1236  auto indexType = typeConverter.getIndexType();
1237 
1238  SmallVector<Value, 2> linearizedIndices;
1239  auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
1240 
1241  // Add a '0' at the start to index into the struct.
1242  linearizedIndices.push_back(zero);
1243 
1244  if (baseType.getRank() == 0) {
1245  linearizedIndices.push_back(zero);
1246  } else {
1247  linearizedIndices.push_back(
1248  linearizeIndex(indices, strides, offset, indexType, loc, builder));
1249  }
1250  return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
1251 }
1252 
1254  MemRefType baseType, Value basePtr,
1255  ValueRange indices, Location loc,
1256  OpBuilder &builder) {
1257  // Get base and offset of the MemRefType and verify they are static.
1258 
1259  int64_t offset;
1260  SmallVector<int64_t, 4> strides;
1261  if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1262  llvm::is_contained(strides, ShapedType::kDynamic) ||
1263  ShapedType::isDynamic(offset)) {
1264  return nullptr;
1265  }
1266 
1267  auto indexType = typeConverter.getIndexType();
1268 
1269  SmallVector<Value, 2> linearizedIndices;
1270  Value linearIndex;
1271  if (baseType.getRank() == 0) {
1272  linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
1273  } else {
1274  linearIndex =
1275  linearizeIndex(indices, strides, offset, indexType, loc, builder);
1276  }
1277  Type pointeeType =
1278  cast<spirv::PointerType>(basePtr.getType()).getPointeeType();
1279  if (isa<spirv::ArrayType>(pointeeType)) {
1280  linearizedIndices.push_back(linearIndex);
1281  return builder.create<spirv::AccessChainOp>(loc, basePtr,
1282  linearizedIndices);
1283  }
1284  return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
1285  linearizedIndices);
1286 }
1287 
1289  MemRefType baseType, Value basePtr,
1290  ValueRange indices, Location loc,
1291  OpBuilder &builder) {
1292 
1293  if (typeConverter.allows(spirv::Capability::Kernel)) {
1294  return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
1295  builder);
1296  }
1297 
1298  return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
1299  builder);
1300 }
1301 
1302 //===----------------------------------------------------------------------===//
1303 // Public functions for vector unrolling
1304 //===----------------------------------------------------------------------===//
1305 
1307  for (int i : {4, 3, 2}) {
1308  if (size % i == 0)
1309  return i;
1310  }
1311  return 1;
1312 }
1313 
1315 mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
1316  VectorType srcVectorType = op.getSourceVectorType();
1317  assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
1318  int64_t vectorSize =
1319  mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
1320  return {vectorSize};
1321 }
1322 
1324 mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) {
1325  VectorType vectorType = op.getResultVectorType();
1326  SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
1327  nativeSize.back() =
1328  mlir::spirv::getComputeVectorSize(vectorType.getShape().back());
1329  return nativeSize;
1330 }
1331 
1332 std::optional<SmallVector<int64_t>>
1334  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
1335  if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
1336  SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
1337  nativeSize.back() =
1338  mlir::spirv::getComputeVectorSize(vecType.getShape().back());
1339  return nativeSize;
1340  }
1341  }
1342 
1344  .Case<vector::ReductionOp, vector::TransposeOp>(
1345  [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
1346  .Default([](Operation *) { return std::nullopt; });
1347 }
1348 
1350  MLIRContext *context = op->getContext();
1351  RewritePatternSet patterns(context);
1354  // We only want to apply signature conversion once to the existing func ops.
1355  // Without specifying strictMode, the greedy pattern rewriter will keep
1356  // looking for newly created func ops.
1357  return applyPatternsGreedily(op, std::move(patterns),
1358  GreedyRewriteConfig().setStrictness(
1360 }
1361 
1363  MLIRContext *context = op->getContext();
1364 
1365  // Unroll vectors in function bodies to native vector size.
1366  {
1367  RewritePatternSet patterns(context);
1369  [](auto op) { return mlir::spirv::getNativeVectorShape(op); });
1370  populateVectorUnrollPatterns(patterns, options);
1371  if (failed(applyPatternsGreedily(op, std::move(patterns))))
1372  return failure();
1373  }
1374 
1375  // Convert transpose ops into extract and insert pairs, in preparation of
1376  // further transformations to canonicalize/cancel.
1377  {
1378  RewritePatternSet patterns(context);
1380  patterns, vector::VectorTransposeLowering::EltWise);
1382  if (failed(applyPatternsGreedily(op, std::move(patterns))))
1383  return failure();
1384  }
1385 
1386  // Run canonicalization to cast away leading size-1 dimensions.
1387  {
1388  RewritePatternSet patterns(context);
1389 
1390  // We need to pull in casting way leading one dims.
1391  vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
1392  vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
1393  vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
1394 
1395  // Decompose different rank insert_strided_slice and n-D
1396  // extract_slided_slice.
1397  vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
1398  patterns);
1399  vector::InsertOp::getCanonicalizationPatterns(patterns, context);
1400  vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
1401 
1402  // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
1403  // them up.
1404  vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
1405  vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
1406 
1407  if (failed(applyPatternsGreedily(op, std::move(patterns))))
1408  return failure();
1409  }
1410  return success();
1411 }
1412 
1413 //===----------------------------------------------------------------------===//
1414 // SPIR-V TypeConverter
1415 //===----------------------------------------------------------------------===//
1416 
1419  : targetEnv(targetAttr), options(options) {
1420  // Add conversions. The order matters here: later ones will be tried earlier.
1421 
1422  // Allow all SPIR-V dialect specific types. This assumes all builtin types
1423  // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
1424  // were tried before.
1425  //
1426  // TODO: This assumes that the SPIR-V types are valid to use in the given
1427  // target environment, which should be the case if the whole pipeline is
1428  // driven by the same target environment. Still, we probably still want to
1429  // validate and convert to be safe.
1430  addConversion([](spirv::SPIRVType type) { return type; });
1431 
1432  addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
1433 
1434  addConversion([this](IntegerType intType) -> std::optional<Type> {
1435  if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
1436  return convertScalarType(this->targetEnv, this->options, scalarType);
1437  if (intType.getWidth() < 8)
1438  return convertSubByteIntegerType(this->options, intType);
1439  return Type();
1440  });
1441 
1442  addConversion([this](FloatType floatType) -> std::optional<Type> {
1443  if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
1444  return convertScalarType(this->targetEnv, this->options, scalarType);
1445  return Type();
1446  });
1447 
1448  addConversion([this](ComplexType complexType) {
1449  return convertComplexType(this->targetEnv, this->options, complexType);
1450  });
1451 
1452  addConversion([this](VectorType vectorType) {
1453  return convertVectorType(this->targetEnv, this->options, vectorType);
1454  });
1455 
1456  addConversion([this](TensorType tensorType) {
1457  return convertTensorType(this->targetEnv, this->options, tensorType);
1458  });
1459 
1460  addConversion([this](MemRefType memRefType) {
1461  return convertMemrefType(this->targetEnv, this->options, memRefType);
1462  });
1463 
1464  // Register some last line of defense casting logic.
1466  [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1467  return castToSourceType(this->targetEnv, builder, type, inputs, loc);
1468  });
1469  addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
1470  Location loc) {
1471  auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
1472  return cast.getResult(0);
1473  });
1474 }
1475 
1477  return ::getIndexType(getContext(), options);
1478 }
1479 
1480 MLIRContext *SPIRVTypeConverter::getContext() const {
1481  return targetEnv.getAttr().getContext();
1482 }
1483 
1484 bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
1485  return targetEnv.allows(capability);
1486 }
1487 
1488 //===----------------------------------------------------------------------===//
1489 // SPIR-V ConversionTarget
1490 //===----------------------------------------------------------------------===//
1491 
1492 std::unique_ptr<SPIRVConversionTarget>
1494  std::unique_ptr<SPIRVConversionTarget> target(
1495  // std::make_unique does not work here because the constructor is private.
1496  new SPIRVConversionTarget(targetAttr));
1497  SPIRVConversionTarget *targetPtr = target.get();
1498  target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1499  // We need to capture the raw pointer here because it is stable:
1500  // target will be destroyed once this function is returned.
1501  [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
1502  return target;
1503 }
1504 
1505 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
1506  : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
1507 
1508 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
1509  // Make sure this op is available at the given version. Ops not implementing
1510  // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
1511  // SPIR-V versions.
1512  if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1513  std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1514  if (minVersion && *minVersion > this->targetEnv.getVersion()) {
1515  LLVM_DEBUG(llvm::dbgs()
1516  << op->getName() << " illegal: requiring min version "
1517  << spirv::stringifyVersion(*minVersion) << "\n");
1518  return false;
1519  }
1520  }
1521  if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1522  std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1523  if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
1524  LLVM_DEBUG(llvm::dbgs()
1525  << op->getName() << " illegal: requiring max version "
1526  << spirv::stringifyVersion(*maxVersion) << "\n");
1527  return false;
1528  }
1529  }
1530 
1531  // Make sure this op's required extensions are allowed to use. Ops not
1532  // implementing QueryExtensionInterface do not require extensions to be
1533  // available.
1534  if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1535  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1536  extensions.getExtensions())))
1537  return false;
1538 
1539  // Make sure this op's required extensions are allowed to use. Ops not
1540  // implementing QueryCapabilityInterface do not require capabilities to be
1541  // available.
1542  if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1543  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1544  capabilities.getCapabilities())))
1545  return false;
1546 
1547  SmallVector<Type, 4> valueTypes;
1548  valueTypes.append(op->operand_type_begin(), op->operand_type_end());
1549  valueTypes.append(op->result_type_begin(), op->result_type_end());
1550 
1551  // Ensure that all types have been converted to SPIRV types.
1552  if (llvm::any_of(valueTypes,
1553  [](Type t) { return !isa<spirv::SPIRVType>(t); }))
1554  return false;
1555 
1556  // Special treatment for global variables, whose type requirements are
1557  // conveyed by type attributes.
1558  if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1559  valueTypes.push_back(globalVar.getType());
1560 
1561  // Make sure the op's operands/results use types that are allowed by the
1562  // target environment.
1563  SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
1564  SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
1565  for (Type valueType : valueTypes) {
1566  typeExtensions.clear();
1567  cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1568  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1569  typeExtensions)))
1570  return false;
1571 
1572  typeCapabilities.clear();
1573  cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1574  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1575  typeCapabilities)))
1576  return false;
1577  }
1578 
1579  return true;
1580 }
1581 
1582 //===----------------------------------------------------------------------===//
1583 // Public functions for populating patterns
1584 //===----------------------------------------------------------------------===//
1585 
1587  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1588  patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
1589 }
1590 
1592  patterns.add<FuncOpVectorUnroll>(patterns.getContext());
1593 }
1594 
1596  patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
1597 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
Block represents an ordered list of Operations.
Definition: Block.h:33
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:162
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:203
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:193
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:198
FloatType getF32Type()
Definition: Builders.cpp:45
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:78
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:322
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:238
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
operand_type_iterator operand_type_end()
Definition: Operation.h:396
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_iterator result_type_end()
Definition: Operation.h:427
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
result_type_iterator result_type_begin()
Definition: Operation.h:426
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
operand_type_iterator operand_type_begin()
Definition: Operation.h:395
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:602
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:594
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:102
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:427
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:484
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:996
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
Definition: TargetAndABI.h:62
MLIRContext * getContext() const
Returns the MLIRContext.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1395
OpFoldResult linearizeIndex(ArrayRef< OpFoldResult > multiIndex, ArrayRef< OpFoldResult > basis, ImplicitLocOpBuilder &builder)
Definition: Utils.cpp:2038
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
int getComputeVectorSize(int64_t size)
LogicalResult unrollVectorsInSignatures(Operation *op)
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
@ ExistingOps
Only pre-existing ops are processed.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)