MLIR  20.0.0git
SPIRVConversion.cpp
Go to the documentation of this file.
1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Support/LLVM.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/LogicalResult.h"
38 #include "llvm/Support/MathExtras.h"
39 
40 #include <functional>
41 #include <optional>
42 
43 #define DEBUG_TYPE "mlir-spirv-conversion"
44 
45 using namespace mlir;
46 
47 namespace {
48 
49 //===----------------------------------------------------------------------===//
50 // Utility functions
51 //===----------------------------------------------------------------------===//
52 
53 static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
54  LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
55  if (vecType.isScalable()) {
56  LLVM_DEBUG(llvm::dbgs()
57  << "--scalable vectors are not supported -> BAIL\n");
58  return std::nullopt;
59  }
60  SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
61  std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(
62  1, mlir::spirv::getComputeVectorSize(vecType.getShape().back()));
63  if (!targetShape) {
64  LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
65  return std::nullopt;
66  }
67  auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
68  if (!maybeShapeRatio) {
69  LLVM_DEBUG(llvm::dbgs()
70  << "--could not compute integral shape ratio -> BAIL\n");
71  return std::nullopt;
72  }
73  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
74  LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
75  return std::nullopt;
76  }
77  LLVM_DEBUG(llvm::dbgs()
78  << "--found an integral shape ratio to unroll to -> SUCCESS\n");
79  return targetShape;
80 }
81 
82 /// Checks that `candidates` extension requirements are possible to be satisfied
83 /// with the given `targetEnv`.
84 ///
85 /// `candidates` is a vector of vector for extension requirements following
86 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
87 /// convention.
88 template <typename LabelT>
89 static LogicalResult checkExtensionRequirements(
90  LabelT label, const spirv::TargetEnv &targetEnv,
92  for (const auto &ors : candidates) {
93  if (targetEnv.allows(ors))
94  continue;
95 
96  LLVM_DEBUG({
97  SmallVector<StringRef> extStrings;
98  for (spirv::Extension ext : ors)
99  extStrings.push_back(spirv::stringifyExtension(ext));
100 
101  llvm::dbgs() << label << " illegal: requires at least one extension in ["
102  << llvm::join(extStrings, ", ")
103  << "] but none allowed in target environment\n";
104  });
105  return failure();
106  }
107  return success();
108 }
109 
110 /// Checks that `candidates`capability requirements are possible to be satisfied
111 /// with the given `isAllowedFn`.
112 ///
113 /// `candidates` is a vector of vector for capability requirements following
114 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
115 /// convention.
116 template <typename LabelT>
117 static LogicalResult checkCapabilityRequirements(
118  LabelT label, const spirv::TargetEnv &targetEnv,
119  const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
120  for (const auto &ors : candidates) {
121  if (targetEnv.allows(ors))
122  continue;
123 
124  LLVM_DEBUG({
125  SmallVector<StringRef> capStrings;
126  for (spirv::Capability cap : ors)
127  capStrings.push_back(spirv::stringifyCapability(cap));
128 
129  llvm::dbgs() << label << " illegal: requires at least one capability in ["
130  << llvm::join(capStrings, ", ")
131  << "] but none allowed in target environment\n";
132  });
133  return failure();
134  }
135  return success();
136 }
137 
138 /// Returns true if the given `storageClass` needs explicit layout when used in
139 /// Shader environments.
140 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
141  switch (storageClass) {
142  case spirv::StorageClass::PhysicalStorageBuffer:
143  case spirv::StorageClass::PushConstant:
144  case spirv::StorageClass::StorageBuffer:
145  case spirv::StorageClass::Uniform:
146  return true;
147  default:
148  return false;
149  }
150 }
151 
152 /// Wraps the given `elementType` in a struct and gets the pointer to the
153 /// struct. This is used to satisfy Vulkan interface requirements.
154 static spirv::PointerType
155 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
156  auto structType = needsExplicitLayout(storageClass)
157  ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
158  : spirv::StructType::get(elementType);
159  return spirv::PointerType::get(structType, storageClass);
160 }
161 
162 //===----------------------------------------------------------------------===//
163 // Type Conversion
164 //===----------------------------------------------------------------------===//
165 
166 static spirv::ScalarType getIndexType(MLIRContext *ctx,
168  return cast<spirv::ScalarType>(
169  IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
170 }
171 
172 // TODO: This is a utility function that should probably be exposed by the
173 // SPIR-V dialect. Keeping it local till the use case arises.
174 static std::optional<int64_t>
175 getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
176  if (isa<spirv::ScalarType>(type)) {
177  auto bitWidth = type.getIntOrFloatBitWidth();
178  // According to the SPIR-V spec:
179  // "There is no physical size or bit pattern defined for values with boolean
180  // type. If they are stored (in conjunction with OpVariable), they can only
181  // be used with logical addressing operations, not physical, and only with
182  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
183  // Private, Function, Input, and Output."
184  if (bitWidth == 1)
185  return std::nullopt;
186  return bitWidth / 8;
187  }
188 
189  if (auto complexType = dyn_cast<ComplexType>(type)) {
190  auto elementSize = getTypeNumBytes(options, complexType.getElementType());
191  if (!elementSize)
192  return std::nullopt;
193  return 2 * *elementSize;
194  }
195 
196  if (auto vecType = dyn_cast<VectorType>(type)) {
197  auto elementSize = getTypeNumBytes(options, vecType.getElementType());
198  if (!elementSize)
199  return std::nullopt;
200  return vecType.getNumElements() * *elementSize;
201  }
202 
203  if (auto memRefType = dyn_cast<MemRefType>(type)) {
204  // TODO: Layout should also be controlled by the ABI attributes. For now
205  // using the layout from MemRef.
206  int64_t offset;
207  SmallVector<int64_t, 4> strides;
208  if (!memRefType.hasStaticShape() ||
209  failed(getStridesAndOffset(memRefType, strides, offset)))
210  return std::nullopt;
211 
212  // To get the size of the memref object in memory, the total size is the
213  // max(stride * dimension-size) computed for all dimensions times the size
214  // of the element.
215  auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
216  if (!elementSize)
217  return std::nullopt;
218 
219  if (memRefType.getRank() == 0)
220  return elementSize;
221 
222  auto dims = memRefType.getShape();
223  if (llvm::is_contained(dims, ShapedType::kDynamic) ||
224  ShapedType::isDynamic(offset) ||
225  llvm::is_contained(strides, ShapedType::kDynamic))
226  return std::nullopt;
227 
228  int64_t memrefSize = -1;
229  for (const auto &shape : enumerate(dims))
230  memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
231 
232  return (offset + memrefSize) * *elementSize;
233  }
234 
235  if (auto tensorType = dyn_cast<TensorType>(type)) {
236  if (!tensorType.hasStaticShape())
237  return std::nullopt;
238 
239  auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
240  if (!elementSize)
241  return std::nullopt;
242 
243  int64_t size = *elementSize;
244  for (auto shape : tensorType.getShape())
245  size *= shape;
246 
247  return size;
248  }
249 
250  // TODO: Add size computation for other types.
251  return std::nullopt;
252 }
253 
254 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
255 static Type
256 convertScalarType(const spirv::TargetEnv &targetEnv,
258  std::optional<spirv::StorageClass> storageClass = {}) {
259  // Get extension and capability requirements for the given type.
262  type.getExtensions(extensions, storageClass);
263  type.getCapabilities(capabilities, storageClass);
264 
265  // If all requirements are met, then we can accept this type as-is.
266  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
267  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
268  return type;
269 
270  // Otherwise we need to adjust the type, which really means adjusting the
271  // bitwidth given this is a scalar type.
272  if (!options.emulateLT32BitScalarTypes)
273  return nullptr;
274 
275  // We only emulate narrower scalar types here and do not truncate results.
276  if (type.getIntOrFloatBitWidth() > 32) {
277  LLVM_DEBUG(llvm::dbgs()
278  << type
279  << " not converted to 32-bit for SPIR-V to avoid truncation\n");
280  return nullptr;
281  }
282 
283  if (auto floatType = dyn_cast<FloatType>(type)) {
284  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
285  return Builder(targetEnv.getContext()).getF32Type();
286  }
287 
288  auto intType = cast<IntegerType>(type);
289  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
290  return IntegerType::get(targetEnv.getContext(), /*width=*/32,
291  intType.getSignedness());
292 }
293 
294 /// Converts a sub-byte integer `type` to i32 regardless of target environment.
295 ///
296 /// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
297 /// the above given that these sub-byte types are not supported at all in
298 /// SPIR-V; there are no compute/storage capability for them like other
299 /// supported integer types.
300 static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
301  IntegerType type) {
302  if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
303  LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
304  return nullptr;
305  }
306 
307  if (!llvm::isPowerOf2_32(type.getWidth())) {
308  LLVM_DEBUG(llvm::dbgs()
309  << "unsupported non-power-of-two bitwidth in sub-byte" << type
310  << "\n");
311  return nullptr;
312  }
313 
314  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
315  return IntegerType::get(type.getContext(), /*width=*/32,
316  type.getSignedness());
317 }
318 
319 /// Returns a type with the same shape but with any index element type converted
320 /// to the matching integer type. This is a noop when the element type is not
321 /// the index type.
322 static ShapedType
323 convertIndexElementType(ShapedType type,
325  Type indexType = dyn_cast<IndexType>(type.getElementType());
326  if (!indexType)
327  return type;
328 
329  return type.clone(getIndexType(type.getContext(), options));
330 }
331 
332 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
333 static Type
334 convertVectorType(const spirv::TargetEnv &targetEnv,
335  const SPIRVConversionOptions &options, VectorType type,
336  std::optional<spirv::StorageClass> storageClass = {}) {
337  type = cast<VectorType>(convertIndexElementType(type, options));
338  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
339  if (!scalarType) {
340  // If this is not a spec allowed scalar type, try to handle sub-byte integer
341  // types.
342  auto intType = dyn_cast<IntegerType>(type.getElementType());
343  if (!intType) {
344  LLVM_DEBUG(llvm::dbgs()
345  << type
346  << " illegal: cannot convert non-scalar element type\n");
347  return nullptr;
348  }
349 
350  Type elementType = convertSubByteIntegerType(options, intType);
351  if (type.getRank() <= 1 && type.getNumElements() == 1)
352  return elementType;
353 
354  if (type.getNumElements() > 4) {
355  LLVM_DEBUG(llvm::dbgs()
356  << type << " illegal: > 4-element unimplemented\n");
357  return nullptr;
358  }
359 
360  return VectorType::get(type.getShape(), elementType);
361  }
362 
363  if (type.getRank() <= 1 && type.getNumElements() == 1)
364  return convertScalarType(targetEnv, options, scalarType, storageClass);
365 
366  if (!spirv::CompositeType::isValid(type)) {
367  LLVM_DEBUG(llvm::dbgs()
368  << type << " illegal: not a valid composite type\n");
369  return nullptr;
370  }
371 
372  // Get extension and capability requirements for the given type.
375  cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
376  cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
377 
378  // If all requirements are met, then we can accept this type as-is.
379  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
380  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
381  return type;
382 
383  auto elementType =
384  convertScalarType(targetEnv, options, scalarType, storageClass);
385  if (elementType)
386  return VectorType::get(type.getShape(), elementType);
387  return nullptr;
388 }
389 
390 static Type
391 convertComplexType(const spirv::TargetEnv &targetEnv,
392  const SPIRVConversionOptions &options, ComplexType type,
393  std::optional<spirv::StorageClass> storageClass = {}) {
394  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
395  if (!scalarType) {
396  LLVM_DEBUG(llvm::dbgs()
397  << type << " illegal: cannot convert non-scalar element type\n");
398  return nullptr;
399  }
400 
401  auto elementType =
402  convertScalarType(targetEnv, options, scalarType, storageClass);
403  if (!elementType)
404  return nullptr;
405  if (elementType != type.getElementType()) {
406  LLVM_DEBUG(llvm::dbgs()
407  << type << " illegal: complex type emulation unsupported\n");
408  return nullptr;
409  }
410 
411  return VectorType::get(2, elementType);
412 }
413 
414 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
415 ///
416 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
417 /// create composite constants with OpConstantComposite to embed relative large
418 /// constant values and use OpCompositeExtract and OpCompositeInsert to
419 /// manipulate, like what we do for vectors.
420 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
422  TensorType type) {
423  // TODO: Handle dynamic shapes.
424  if (!type.hasStaticShape()) {
425  LLVM_DEBUG(llvm::dbgs()
426  << type << " illegal: dynamic shape unimplemented\n");
427  return nullptr;
428  }
429 
430  type = cast<TensorType>(convertIndexElementType(type, options));
431  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
432  if (!scalarType) {
433  LLVM_DEBUG(llvm::dbgs()
434  << type << " illegal: cannot convert non-scalar element type\n");
435  return nullptr;
436  }
437 
438  std::optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
439  std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);
440  if (!scalarSize || !tensorSize) {
441  LLVM_DEBUG(llvm::dbgs()
442  << type << " illegal: cannot deduce element count\n");
443  return nullptr;
444  }
445 
446  int64_t arrayElemCount = *tensorSize / *scalarSize;
447  if (arrayElemCount == 0) {
448  LLVM_DEBUG(llvm::dbgs()
449  << type << " illegal: cannot handle zero-element tensors\n");
450  return nullptr;
451  }
452 
453  Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
454  if (!arrayElemType)
455  return nullptr;
456  std::optional<int64_t> arrayElemSize =
457  getTypeNumBytes(options, arrayElemType);
458  if (!arrayElemSize) {
459  LLVM_DEBUG(llvm::dbgs()
460  << type << " illegal: cannot deduce converted element size\n");
461  return nullptr;
462  }
463 
464  return spirv::ArrayType::get(arrayElemType, arrayElemCount);
465 }
466 
467 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
469  MemRefType type,
470  spirv::StorageClass storageClass) {
471  unsigned numBoolBits = options.boolNumBits;
472  if (numBoolBits != 8) {
473  LLVM_DEBUG(llvm::dbgs()
474  << "using non-8-bit storage for bool types unimplemented");
475  return nullptr;
476  }
477  auto elementType = dyn_cast<spirv::ScalarType>(
478  IntegerType::get(type.getContext(), numBoolBits));
479  if (!elementType)
480  return nullptr;
481  Type arrayElemType =
482  convertScalarType(targetEnv, options, elementType, storageClass);
483  if (!arrayElemType)
484  return nullptr;
485  std::optional<int64_t> arrayElemSize =
486  getTypeNumBytes(options, arrayElemType);
487  if (!arrayElemSize) {
488  LLVM_DEBUG(llvm::dbgs()
489  << type << " illegal: cannot deduce converted element size\n");
490  return nullptr;
491  }
492 
493  if (!type.hasStaticShape()) {
494  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
495  // to the element.
496  if (targetEnv.allows(spirv::Capability::Kernel))
497  return spirv::PointerType::get(arrayElemType, storageClass);
498  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
499  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
500  // For Vulkan we need extra wrapping struct and array to satisfy interface
501  // needs.
502  return wrapInStructAndGetPointer(arrayType, storageClass);
503  }
504 
505  if (type.getNumElements() == 0) {
506  LLVM_DEBUG(llvm::dbgs()
507  << type << " illegal: zero-element memrefs are not supported\n");
508  return nullptr;
509  }
510 
511  int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
512  int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
513  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
514  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
515  if (targetEnv.allows(spirv::Capability::Kernel))
516  return spirv::PointerType::get(arrayType, storageClass);
517  return wrapInStructAndGetPointer(arrayType, storageClass);
518 }
519 
520 static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
522  MemRefType type,
523  spirv::StorageClass storageClass) {
524  IntegerType elementType = cast<IntegerType>(type.getElementType());
525  Type arrayElemType = convertSubByteIntegerType(options, elementType);
526  if (!arrayElemType)
527  return nullptr;
528  int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType);
529 
530  if (!type.hasStaticShape()) {
531  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
532  // to the element.
533  if (targetEnv.allows(spirv::Capability::Kernel))
534  return spirv::PointerType::get(arrayElemType, storageClass);
535  int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
536  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
537  // For Vulkan we need extra wrapping struct and array to satisfy interface
538  // needs.
539  return wrapInStructAndGetPointer(arrayType, storageClass);
540  }
541 
542  if (type.getNumElements() == 0) {
543  LLVM_DEBUG(llvm::dbgs()
544  << type << " illegal: zero-element memrefs are not supported\n");
545  return nullptr;
546  }
547 
548  int64_t memrefSize =
549  llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
550  int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
551  int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
552  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
553  if (targetEnv.allows(spirv::Capability::Kernel))
554  return spirv::PointerType::get(arrayType, storageClass);
555  return wrapInStructAndGetPointer(arrayType, storageClass);
556 }
557 
558 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
560  MemRefType type) {
561  auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
562  if (!attr) {
563  LLVM_DEBUG(
564  llvm::dbgs()
565  << type
566  << " illegal: expected memory space to be a SPIR-V storage class "
567  "attribute; please use MemorySpaceToStorageClassConverter to map "
568  "numeric memory spaces beforehand\n");
569  return nullptr;
570  }
571  spirv::StorageClass storageClass = attr.getValue();
572 
573  if (isa<IntegerType>(type.getElementType())) {
574  if (type.getElementTypeBitWidth() == 1)
575  return convertBoolMemrefType(targetEnv, options, type, storageClass);
576  if (type.getElementTypeBitWidth() < 8)
577  return convertSubByteMemrefType(targetEnv, options, type, storageClass);
578  }
579 
580  Type arrayElemType;
581  Type elementType = type.getElementType();
582  if (auto vecType = dyn_cast<VectorType>(elementType)) {
583  arrayElemType =
584  convertVectorType(targetEnv, options, vecType, storageClass);
585  } else if (auto complexType = dyn_cast<ComplexType>(elementType)) {
586  arrayElemType =
587  convertComplexType(targetEnv, options, complexType, storageClass);
588  } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
589  arrayElemType =
590  convertScalarType(targetEnv, options, scalarType, storageClass);
591  } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
592  type = cast<MemRefType>(convertIndexElementType(type, options));
593  arrayElemType = type.getElementType();
594  } else {
595  LLVM_DEBUG(
596  llvm::dbgs()
597  << type
598  << " unhandled: can only convert scalar or vector element type\n");
599  return nullptr;
600  }
601  if (!arrayElemType)
602  return nullptr;
603 
604  std::optional<int64_t> arrayElemSize =
605  getTypeNumBytes(options, arrayElemType);
606  if (!arrayElemSize) {
607  LLVM_DEBUG(llvm::dbgs()
608  << type << " illegal: cannot deduce converted element size\n");
609  return nullptr;
610  }
611 
612  if (!type.hasStaticShape()) {
613  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
614  // to the element.
615  if (targetEnv.allows(spirv::Capability::Kernel))
616  return spirv::PointerType::get(arrayElemType, storageClass);
617  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
618  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
619  // For Vulkan we need extra wrapping struct and array to satisfy interface
620  // needs.
621  return wrapInStructAndGetPointer(arrayType, storageClass);
622  }
623 
624  std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);
625  if (!memrefSize) {
626  LLVM_DEBUG(llvm::dbgs()
627  << type << " illegal: cannot deduce element count\n");
628  return nullptr;
629  }
630 
631  if (*memrefSize == 0) {
632  LLVM_DEBUG(llvm::dbgs()
633  << type << " illegal: zero-element memrefs are not supported\n");
634  return nullptr;
635  }
636 
637  int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
638  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
639  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
640  if (targetEnv.allows(spirv::Capability::Kernel))
641  return spirv::PointerType::get(arrayType, storageClass);
642  return wrapInStructAndGetPointer(arrayType, storageClass);
643 }
644 
645 //===----------------------------------------------------------------------===//
646 // Type casting materialization
647 //===----------------------------------------------------------------------===//
648 
649 /// Converts the given `inputs` to the original source `type` considering the
650 /// `targetEnv`'s capabilities.
651 ///
652 /// This function is meant to be used for source materialization in type
653 /// converters. When the type converter needs to materialize a cast op back
654 /// to some original source type, we need to check whether the original source
655 /// type is supported in the target environment. If so, we can insert legal
656 /// SPIR-V cast ops accordingly.
657 ///
658 /// Note that in SPIR-V the capabilities for storage and compute are separate.
659 /// This function is meant to handle the **compute** side; so it does not
660 /// involve storage classes in its logic. The storage side is expected to be
661 /// handled by MemRef conversion logic.
662 static Value castToSourceType(const spirv::TargetEnv &targetEnv,
663  OpBuilder &builder, Type type, ValueRange inputs,
664  Location loc) {
665  // We can only cast one value in SPIR-V.
666  if (inputs.size() != 1) {
667  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
668  return castOp.getResult(0);
669  }
670  Value input = inputs.front();
671 
672  // Only support integer types for now. Floating point types to be implemented.
673  if (!isa<IntegerType>(type)) {
674  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
675  return castOp.getResult(0);
676  }
677  auto inputType = cast<IntegerType>(input.getType());
678 
679  auto scalarType = dyn_cast<spirv::ScalarType>(type);
680  if (!scalarType) {
681  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
682  return castOp.getResult(0);
683  }
684 
685  // Only support source type with a smaller bitwidth. This would mean we are
686  // truncating to go back so we don't need to worry about the signedness.
687  // For extension, we cannot have enough signal here to decide which op to use.
688  if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
689  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
690  return castOp.getResult(0);
691  }
692 
693  // Boolean values would need to use different ops than normal integer values.
694  if (type.isInteger(1)) {
695  Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
696  return builder.create<spirv::IEqualOp>(loc, input, one);
697  }
698 
699  // Check that the source integer type is supported by the environment.
702  scalarType.getExtensions(exts);
703  scalarType.getCapabilities(caps);
704  if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
705  failed(checkExtensionRequirements(type, targetEnv, exts))) {
706  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
707  return castOp.getResult(0);
708  }
709 
710  // We've already made sure this is truncating previously, so we don't need to
711  // care about signedness here. Still try to use a corresponding op for better
712  // consistency though.
713  if (type.isSignedInteger()) {
714  return builder.create<spirv::SConvertOp>(loc, type, input);
715  }
716  return builder.create<spirv::UConvertOp>(loc, type, input);
717 }
718 
719 //===----------------------------------------------------------------------===//
720 // Builtin Variables
721 //===----------------------------------------------------------------------===//
722 
723 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
724  spirv::BuiltIn builtin) {
725  // Look through all global variables in the given `body` block and check if
726  // there is a spirv.GlobalVariable that has the same `builtin` attribute.
727  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
728  if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
729  spirv::SPIRVDialect::getAttributeName(
730  spirv::Decoration::BuiltIn))) {
731  auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
732  if (varBuiltIn && *varBuiltIn == builtin) {
733  return varOp;
734  }
735  }
736  }
737  return nullptr;
738 }
739 
740 /// Gets name of global variable for a builtin.
741 std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
742  StringRef suffix) {
743  return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
744 }
745 
746 /// Gets or inserts a global variable for a builtin within `body` block.
747 static spirv::GlobalVariableOp
748 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
749  Type integerType, OpBuilder &builder,
750  StringRef prefix, StringRef suffix) {
751  if (auto varOp = getBuiltinVariable(body, builtin))
752  return varOp;
753 
754  OpBuilder::InsertionGuard guard(builder);
755  builder.setInsertionPointToStart(&body);
756 
757  spirv::GlobalVariableOp newVarOp;
758  switch (builtin) {
759  case spirv::BuiltIn::NumWorkgroups:
760  case spirv::BuiltIn::WorkgroupSize:
761  case spirv::BuiltIn::WorkgroupId:
762  case spirv::BuiltIn::LocalInvocationId:
763  case spirv::BuiltIn::GlobalInvocationId: {
764  auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
765  spirv::StorageClass::Input);
766  std::string name = getBuiltinVarName(builtin, prefix, suffix);
767  newVarOp =
768  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
769  break;
770  }
771  case spirv::BuiltIn::SubgroupId:
772  case spirv::BuiltIn::NumSubgroups:
773  case spirv::BuiltIn::SubgroupSize: {
774  auto ptrType =
775  spirv::PointerType::get(integerType, spirv::StorageClass::Input);
776  std::string name = getBuiltinVarName(builtin, prefix, suffix);
777  newVarOp =
778  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
779  break;
780  }
781  default:
782  emitError(loc, "unimplemented builtin variable generation for ")
783  << stringifyBuiltIn(builtin);
784  }
785  return newVarOp;
786 }
787 
788 //===----------------------------------------------------------------------===//
789 // Push constant storage
790 //===----------------------------------------------------------------------===//
791 
792 /// Returns the pointer type for the push constant storage containing
793 /// `elementCount` 32-bit integer values.
794 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
795  Builder &builder,
796  Type indexType) {
797  auto arrayType = spirv::ArrayType::get(indexType, elementCount,
798  /*stride=*/4);
799  auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
800  return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
801 }
802 
803 /// Returns the push constant varible containing `elementCount` 32-bit integer
804 /// values in `body`. Returns null op if such an op does not exit.
805 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
806  unsigned elementCount) {
807  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
808  auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
809  if (!ptrType)
810  continue;
811 
812  // Note that Vulkan requires "There must be no more than one push constant
813  // block statically used per shader entry point." So we should always reuse
814  // the existing one.
815  if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
816  auto numElements = cast<spirv::ArrayType>(
817  cast<spirv::StructType>(ptrType.getPointeeType())
818  .getElementType(0))
819  .getNumElements();
820  if (numElements == elementCount)
821  return varOp;
822  }
823  }
824  return nullptr;
825 }
826 
827 /// Gets or inserts a global variable for push constant storage containing
828 /// `elementCount` 32-bit integer values in `block`.
829 static spirv::GlobalVariableOp
830 getOrInsertPushConstantVariable(Location loc, Block &block,
831  unsigned elementCount, OpBuilder &b,
832  Type indexType) {
833  if (auto varOp = getPushConstantVariable(block, elementCount))
834  return varOp;
835 
836  auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
837  auto type = getPushConstantStorageType(elementCount, builder, indexType);
838  const char *name = "__push_constant_var__";
839  return builder.create<spirv::GlobalVariableOp>(loc, type, name,
840  /*initializer=*/nullptr);
841 }
842 
843 //===----------------------------------------------------------------------===//
844 // func::FuncOp Conversion Patterns
845 //===----------------------------------------------------------------------===//
846 
847 /// A pattern for rewriting function signature to convert arguments of functions
848 /// to be of valid SPIR-V types.
849 struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
851 
852  LogicalResult
853  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
854  ConversionPatternRewriter &rewriter) const override {
855  FunctionType fnType = funcOp.getFunctionType();
856  if (fnType.getNumResults() > 1)
857  return failure();
858 
859  TypeConverter::SignatureConversion signatureConverter(
860  fnType.getNumInputs());
861  for (const auto &argType : enumerate(fnType.getInputs())) {
862  auto convertedType = getTypeConverter()->convertType(argType.value());
863  if (!convertedType)
864  return failure();
865  signatureConverter.addInputs(argType.index(), convertedType);
866  }
867 
868  Type resultType;
869  if (fnType.getNumResults() == 1) {
870  resultType = getTypeConverter()->convertType(fnType.getResult(0));
871  if (!resultType)
872  return failure();
873  }
874 
875  // Create the converted spirv.func op.
876  auto newFuncOp = rewriter.create<spirv::FuncOp>(
877  funcOp.getLoc(), funcOp.getName(),
878  rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
879  resultType ? TypeRange(resultType)
880  : TypeRange()));
881 
882  // Copy over all attributes other than the function name and type.
883  for (const auto &namedAttr : funcOp->getAttrs()) {
884  if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
885  namedAttr.getName() != SymbolTable::getSymbolAttrName())
886  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
887  }
888 
889  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
890  newFuncOp.end());
891  if (failed(rewriter.convertRegionTypes(
892  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
893  return failure();
894  rewriter.eraseOp(funcOp);
895  return success();
896  }
897 };
898 
899 /// A pattern for rewriting function signature to convert vector arguments of
900 /// functions to be of valid types
901 struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
903 
904  LogicalResult matchAndRewrite(func::FuncOp funcOp,
905  PatternRewriter &rewriter) const override {
906  FunctionType fnType = funcOp.getFunctionType();
907 
908  // TODO: Handle declarations.
909  if (funcOp.isDeclaration()) {
910  LLVM_DEBUG(llvm::dbgs()
911  << fnType << " illegal: declarations are unsupported\n");
912  return failure();
913  }
914 
915  // Create a new func op with the original type and copy the function body.
916  auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(),
917  funcOp.getName(), fnType);
918  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
919  newFuncOp.end());
920 
921  Location loc = newFuncOp.getBody().getLoc();
922 
923  Block &entryBlock = newFuncOp.getBlocks().front();
924  OpBuilder::InsertionGuard guard(rewriter);
925  rewriter.setInsertionPointToStart(&entryBlock);
926 
927  OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
928 
929  // For arguments that are of illegal types and require unrolling.
930  // `unrolledInputNums` stores the indices of arguments that result from
931  // unrolling in the new function signature. `newInputNo` is a counter.
932  SmallVector<size_t> unrolledInputNums;
933  size_t newInputNo = 0;
934 
935  // For arguments that are of legal types and do not require unrolling.
936  // `tmpOps` stores a mapping from temporary operations that serve as
937  // placeholders for new arguments that will be added later. These operations
938  // will be erased once the entry block's argument list is updated.
939  llvm::SmallDenseMap<Operation *, size_t> tmpOps;
940 
941  // This counts the number of new operations created.
942  size_t newOpCount = 0;
943 
944  // Enumerate through the arguments.
945  for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
946  // Check whether the argument is of vector type.
947  auto origVecType = dyn_cast<VectorType>(origType);
948  if (!origVecType) {
949  // We need a placeholder for the old argument that will be erased later.
950  Value result = rewriter.create<arith::ConstantOp>(
951  loc, origType, rewriter.getZeroAttr(origType));
952  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
953  tmpOps.insert({result.getDefiningOp(), newInputNo});
954  oneToNTypeMapping.addInputs(origInputNo, origType);
955  ++newInputNo;
956  ++newOpCount;
957  continue;
958  }
959  // Check whether the vector needs unrolling.
960  auto targetShape = getTargetShape(origVecType);
961  if (!targetShape) {
962  // We need a placeholder for the old argument that will be erased later.
963  Value result = rewriter.create<arith::ConstantOp>(
964  loc, origType, rewriter.getZeroAttr(origType));
965  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
966  tmpOps.insert({result.getDefiningOp(), newInputNo});
967  oneToNTypeMapping.addInputs(origInputNo, origType);
968  ++newInputNo;
969  ++newOpCount;
970  continue;
971  }
972  VectorType unrolledType =
973  VectorType::get(*targetShape, origVecType.getElementType());
974  auto originalShape =
975  llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
976 
977  // Prepare the result vector.
978  Value result = rewriter.create<arith::ConstantOp>(
979  loc, origVecType, rewriter.getZeroAttr(origVecType));
980  ++newOpCount;
981  // Prepare the placeholder for the new arguments that will be added later.
982  Value dummy = rewriter.create<arith::ConstantOp>(
983  loc, unrolledType, rewriter.getZeroAttr(unrolledType));
984  ++newOpCount;
985 
986  // Create the `vector.insert_strided_slice` ops.
987  SmallVector<int64_t> strides(targetShape->size(), 1);
988  SmallVector<Type> newTypes;
989  for (SmallVector<int64_t> offsets :
990  StaticTileOffsetRange(originalShape, *targetShape)) {
991  result = rewriter.create<vector::InsertStridedSliceOp>(
992  loc, dummy, result, offsets, strides);
993  newTypes.push_back(unrolledType);
994  unrolledInputNums.push_back(newInputNo);
995  ++newInputNo;
996  ++newOpCount;
997  }
998  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
999  oneToNTypeMapping.addInputs(origInputNo, newTypes);
1000  }
1001 
1002  // Change the function signature.
1003  auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
1004  auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1005  rewriter.modifyOpInPlace(newFuncOp,
1006  [&] { newFuncOp.setFunctionType(newFnType); });
1007 
1008  // Update the arguments in the entry block.
1009  entryBlock.eraseArguments(0, fnType.getNumInputs());
1010  SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
1011  entryBlock.addArguments(convertedTypes, locs);
1012 
1013  // Replace the placeholder values with the new arguments. We assume there is
1014  // only one block for now.
1015  size_t unrolledInputIdx = 0;
1016  for (auto [count, op] : enumerate(entryBlock.getOperations())) {
1017  // We first look for operands that are placeholders for initially legal
1018  // arguments.
1019  Operation &curOp = op;
1020  for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
1021  Operation *operandOp = operandVal.getDefiningOp();
1022  if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
1023  size_t idx = operandIdx;
1024  rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
1025  curOp.setOperand(idx, newFuncOp.getArgument(it->second));
1026  });
1027  }
1028  }
1029  // Since all newly created operations are in the beginning, reaching the
1030  // end of them means that any later `vector.insert_strided_slice` should
1031  // not be touched.
1032  if (count >= newOpCount)
1033  continue;
1034  if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1035  size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1036  rewriter.modifyOpInPlace(&curOp, [&] {
1037  curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1038  });
1039  ++unrolledInputIdx;
1040  }
1041  }
1042 
1043  // Erase the original funcOp. The `tmpOps` do not need to be erased since
1044  // they have no uses and will be handled by dead-code elimination.
1045  rewriter.eraseOp(funcOp);
1046  return success();
1047  }
1048 };
1049 
1050 //===----------------------------------------------------------------------===//
1051 // func::ReturnOp Conversion Patterns
1052 //===----------------------------------------------------------------------===//
1053 
1054 /// A pattern for rewriting function signature and the return op to convert
1055 /// vectors to be of valid types.
1056 struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
1058 
1059  LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1060  PatternRewriter &rewriter) const override {
1061  // Check whether the parent funcOp is valid.
1062  auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
1063  if (!funcOp)
1064  return failure();
1065 
1066  FunctionType fnType = funcOp.getFunctionType();
1067  OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
1068  Location loc = returnOp.getLoc();
1069 
1070  // For the new return op.
1071  SmallVector<Value> newOperands;
1072 
1073  // Enumerate through the results.
1074  for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
1075  // Check whether the argument is of vector type.
1076  auto origVecType = dyn_cast<VectorType>(origType);
1077  if (!origVecType) {
1078  oneToNTypeMapping.addInputs(origResultNo, origType);
1079  newOperands.push_back(returnOp.getOperand(origResultNo));
1080  continue;
1081  }
1082  // Check whether the vector needs unrolling.
1083  auto targetShape = getTargetShape(origVecType);
1084  if (!targetShape) {
1085  // The original argument can be used.
1086  oneToNTypeMapping.addInputs(origResultNo, origType);
1087  newOperands.push_back(returnOp.getOperand(origResultNo));
1088  continue;
1089  }
1090  VectorType unrolledType =
1091  VectorType::get(*targetShape, origVecType.getElementType());
1092 
1093  // Create `vector.extract_strided_slice` ops to form legal vectors from
1094  // the original operand of illegal type.
1095  auto originalShape =
1096  llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1097  SmallVector<int64_t> strides(originalShape.size(), 1);
1098  SmallVector<int64_t> extractShape(originalShape.size(), 1);
1099  extractShape.back() = targetShape->back();
1100  SmallVector<Type> newTypes;
1101  Value returnValue = returnOp.getOperand(origResultNo);
1102  for (SmallVector<int64_t> offsets :
1103  StaticTileOffsetRange(originalShape, *targetShape)) {
1104  Value result = rewriter.create<vector::ExtractStridedSliceOp>(
1105  loc, returnValue, offsets, extractShape, strides);
1106  if (originalShape.size() > 1) {
1107  SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
1108  result =
1109  rewriter.create<vector::ExtractOp>(loc, result, extractIndices);
1110  }
1111  newOperands.push_back(result);
1112  newTypes.push_back(unrolledType);
1113  }
1114  oneToNTypeMapping.addInputs(origResultNo, newTypes);
1115  }
1116 
1117  // Change the function signature.
1118  auto newFnType =
1119  FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
1120  TypeRange(oneToNTypeMapping.getConvertedTypes()));
1121  rewriter.modifyOpInPlace(funcOp,
1122  [&] { funcOp.setFunctionType(newFnType); });
1123 
1124  // Replace the return op using the new operands. This will automatically
1125  // update the entry block as well.
1126  rewriter.replaceOp(returnOp,
1127  rewriter.create<func::ReturnOp>(loc, newOperands));
1128 
1129  return success();
1130  }
1131 };
1132 
1133 } // namespace
1134 
1135 //===----------------------------------------------------------------------===//
1136 // Public function for builtin variables
1137 //===----------------------------------------------------------------------===//
1138 
1140  spirv::BuiltIn builtin,
1141  Type integerType, OpBuilder &builder,
1142  StringRef prefix, StringRef suffix) {
1144  if (!parent) {
1145  op->emitError("expected operation to be within a module-like op");
1146  return nullptr;
1147  }
1148 
1149  spirv::GlobalVariableOp varOp =
1150  getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
1151  builtin, integerType, builder, prefix, suffix);
1152  Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
1153  return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
1154 }
1155 
1156 //===----------------------------------------------------------------------===//
1157 // Public function for pushing constant storage
1158 //===----------------------------------------------------------------------===//
1159 
1160 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
1161  unsigned offset, Type integerType,
1162  OpBuilder &builder) {
1163  Location loc = op->getLoc();
1165  if (!parent) {
1166  op->emitError("expected operation to be within a module-like op");
1167  return nullptr;
1168  }
1169 
1170  spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1171  loc, parent->getRegion(0).front(), elementCount, builder, integerType);
1172 
1173  Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
1174  Value offsetOp = builder.create<spirv::ConstantOp>(
1175  loc, integerType, builder.getI32IntegerAttr(offset));
1176  auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
1177  auto acOp = builder.create<spirv::AccessChainOp>(
1178  loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp}));
1179  return builder.create<spirv::LoadOp>(loc, acOp);
1180 }
1181 
1182 //===----------------------------------------------------------------------===//
1183 // Public functions for index calculation
1184 //===----------------------------------------------------------------------===//
1185 
1187  int64_t offset, Type integerType,
1188  Location loc, OpBuilder &builder) {
1189  assert(indices.size() == strides.size() &&
1190  "must provide indices for all dimensions");
1191 
1192  // TODO: Consider moving to use affine.apply and patterns converting
1193  // affine.apply to standard ops. This needs converting to SPIR-V passes to be
1194  // broken down into progressive small steps so we can have intermediate steps
1195  // using other dialects. At the moment SPIR-V is the final sink.
1196 
1197  Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
1198  loc, integerType, IntegerAttr::get(integerType, offset));
1199  for (const auto &index : llvm::enumerate(indices)) {
1200  Value strideVal = builder.createOrFold<spirv::ConstantOp>(
1201  loc, integerType,
1202  IntegerAttr::get(integerType, strides[index.index()]));
1203  Value update =
1204  builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1205  linearizedIndex =
1206  builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1207  }
1208  return linearizedIndex;
1209 }
1210 
1212  MemRefType baseType, Value basePtr,
1213  ValueRange indices, Location loc,
1214  OpBuilder &builder) {
1215  // Get base and offset of the MemRefType and verify they are static.
1216 
1217  int64_t offset;
1218  SmallVector<int64_t, 4> strides;
1219  if (failed(getStridesAndOffset(baseType, strides, offset)) ||
1220  llvm::is_contained(strides, ShapedType::kDynamic) ||
1221  ShapedType::isDynamic(offset)) {
1222  return nullptr;
1223  }
1224 
1225  auto indexType = typeConverter.getIndexType();
1226 
1227  SmallVector<Value, 2> linearizedIndices;
1228  auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
1229 
1230  // Add a '0' at the start to index into the struct.
1231  linearizedIndices.push_back(zero);
1232 
1233  if (baseType.getRank() == 0) {
1234  linearizedIndices.push_back(zero);
1235  } else {
1236  linearizedIndices.push_back(
1237  linearizeIndex(indices, strides, offset, indexType, loc, builder));
1238  }
1239  return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
1240 }
1241 
1243  MemRefType baseType, Value basePtr,
1244  ValueRange indices, Location loc,
1245  OpBuilder &builder) {
1246  // Get base and offset of the MemRefType and verify they are static.
1247 
1248  int64_t offset;
1249  SmallVector<int64_t, 4> strides;
1250  if (failed(getStridesAndOffset(baseType, strides, offset)) ||
1251  llvm::is_contained(strides, ShapedType::kDynamic) ||
1252  ShapedType::isDynamic(offset)) {
1253  return nullptr;
1254  }
1255 
1256  auto indexType = typeConverter.getIndexType();
1257 
1258  SmallVector<Value, 2> linearizedIndices;
1259  Value linearIndex;
1260  if (baseType.getRank() == 0) {
1261  linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
1262  } else {
1263  linearIndex =
1264  linearizeIndex(indices, strides, offset, indexType, loc, builder);
1265  }
1266  Type pointeeType =
1267  cast<spirv::PointerType>(basePtr.getType()).getPointeeType();
1268  if (isa<spirv::ArrayType>(pointeeType)) {
1269  linearizedIndices.push_back(linearIndex);
1270  return builder.create<spirv::AccessChainOp>(loc, basePtr,
1271  linearizedIndices);
1272  }
1273  return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
1274  linearizedIndices);
1275 }
1276 
1278  MemRefType baseType, Value basePtr,
1279  ValueRange indices, Location loc,
1280  OpBuilder &builder) {
1281 
1282  if (typeConverter.allows(spirv::Capability::Kernel)) {
1283  return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
1284  builder);
1285  }
1286 
1287  return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
1288  builder);
1289 }
1290 
1291 //===----------------------------------------------------------------------===//
1292 // Public functions for vector unrolling
1293 //===----------------------------------------------------------------------===//
1294 
1296  for (int i : {4, 3, 2}) {
1297  if (size % i == 0)
1298  return i;
1299  }
1300  return 1;
1301 }
1302 
1304 mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
1305  VectorType srcVectorType = op.getSourceVectorType();
1306  assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
1307  int64_t vectorSize =
1308  mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
1309  return {vectorSize};
1310 }
1311 
1313 mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) {
1314  VectorType vectorType = op.getResultVectorType();
1315  SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
1316  nativeSize.back() =
1317  mlir::spirv::getComputeVectorSize(vectorType.getShape().back());
1318  return nativeSize;
1319 }
1320 
1321 std::optional<SmallVector<int64_t>>
1323  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
1324  if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
1325  SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
1326  nativeSize.back() =
1327  mlir::spirv::getComputeVectorSize(vecType.getShape().back());
1328  return nativeSize;
1329  }
1330  }
1331 
1333  .Case<vector::ReductionOp, vector::TransposeOp>(
1334  [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
1335  .Default([](Operation *) { return std::nullopt; });
1336 }
1337 
1339  MLIRContext *context = op->getContext();
1340  RewritePatternSet patterns(context);
1343  // We only want to apply signature conversion once to the existing func ops.
1344  // Without specifying strictMode, the greedy pattern rewriter will keep
1345  // looking for newly created func ops.
1346  GreedyRewriteConfig config;
1348  return applyPatternsAndFoldGreedily(op, std::move(patterns), config);
1349 }
1350 
1352  MLIRContext *context = op->getContext();
1353 
1354  // Unroll vectors in function bodies to native vector size.
1355  {
1356  RewritePatternSet patterns(context);
1358  [](auto op) { return mlir::spirv::getNativeVectorShape(op); });
1360  if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
1361  return failure();
1362  }
1363 
1364  // Convert transpose ops into extract and insert pairs, in preparation of
1365  // further transformations to canonicalize/cancel.
1366  {
1367  RewritePatternSet patterns(context);
1369  vector::VectorTransposeLowering::EltWise);
1372  if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
1373  return failure();
1374  }
1375 
1376  // Run canonicalization to cast away leading size-1 dimensions.
1377  {
1378  RewritePatternSet patterns(context);
1379 
1380  // We need to pull in casting way leading one dims.
1382  vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
1383  vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
1384 
1385  // Decompose different rank insert_strided_slice and n-D
1386  // extract_slided_slice.
1388  patterns);
1389  vector::InsertOp::getCanonicalizationPatterns(patterns, context);
1390  vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
1391 
1392  // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
1393  // them up.
1394  vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
1395  vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
1396 
1397  if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
1398  return failure();
1399  }
1400  return success();
1401 }
1402 
1403 //===----------------------------------------------------------------------===//
1404 // SPIR-V TypeConverter
1405 //===----------------------------------------------------------------------===//
1406 
1409  : targetEnv(targetAttr), options(options) {
1410  // Add conversions. The order matters here: later ones will be tried earlier.
1411 
1412  // Allow all SPIR-V dialect specific types. This assumes all builtin types
1413  // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
1414  // were tried before.
1415  //
1416  // TODO: This assumes that the SPIR-V types are valid to use in the given
1417  // target environment, which should be the case if the whole pipeline is
1418  // driven by the same target environment. Still, we probably still want to
1419  // validate and convert to be safe.
1420  addConversion([](spirv::SPIRVType type) { return type; });
1421 
1422  addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
1423 
1424  addConversion([this](IntegerType intType) -> std::optional<Type> {
1425  if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
1426  return convertScalarType(this->targetEnv, this->options, scalarType);
1427  if (intType.getWidth() < 8)
1428  return convertSubByteIntegerType(this->options, intType);
1429  return Type();
1430  });
1431 
1432  addConversion([this](FloatType floatType) -> std::optional<Type> {
1433  if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
1434  return convertScalarType(this->targetEnv, this->options, scalarType);
1435  return Type();
1436  });
1437 
1438  addConversion([this](ComplexType complexType) {
1439  return convertComplexType(this->targetEnv, this->options, complexType);
1440  });
1441 
1442  addConversion([this](VectorType vectorType) {
1443  return convertVectorType(this->targetEnv, this->options, vectorType);
1444  });
1445 
1446  addConversion([this](TensorType tensorType) {
1447  return convertTensorType(this->targetEnv, this->options, tensorType);
1448  });
1449 
1450  addConversion([this](MemRefType memRefType) {
1451  return convertMemrefType(this->targetEnv, this->options, memRefType);
1452  });
1453 
1454  // Register some last line of defense casting logic.
1456  [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1457  return castToSourceType(this->targetEnv, builder, type, inputs, loc);
1458  });
1459  addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
1460  Location loc) {
1461  auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
1462  return cast.getResult(0);
1463  });
1464 }
1465 
1467  return ::getIndexType(getContext(), options);
1468 }
1469 
1470 MLIRContext *SPIRVTypeConverter::getContext() const {
1471  return targetEnv.getAttr().getContext();
1472 }
1473 
1474 bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
1475  return targetEnv.allows(capability);
1476 }
1477 
1478 //===----------------------------------------------------------------------===//
1479 // SPIR-V ConversionTarget
1480 //===----------------------------------------------------------------------===//
1481 
1482 std::unique_ptr<SPIRVConversionTarget>
1484  std::unique_ptr<SPIRVConversionTarget> target(
1485  // std::make_unique does not work here because the constructor is private.
1486  new SPIRVConversionTarget(targetAttr));
1487  SPIRVConversionTarget *targetPtr = target.get();
1488  target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1489  // We need to capture the raw pointer here because it is stable:
1490  // target will be destroyed once this function is returned.
1491  [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
1492  return target;
1493 }
1494 
1495 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
1496  : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
1497 
1498 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
1499  // Make sure this op is available at the given version. Ops not implementing
1500  // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
1501  // SPIR-V versions.
1502  if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1503  std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1504  if (minVersion && *minVersion > this->targetEnv.getVersion()) {
1505  LLVM_DEBUG(llvm::dbgs()
1506  << op->getName() << " illegal: requiring min version "
1507  << spirv::stringifyVersion(*minVersion) << "\n");
1508  return false;
1509  }
1510  }
1511  if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1512  std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1513  if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
1514  LLVM_DEBUG(llvm::dbgs()
1515  << op->getName() << " illegal: requiring max version "
1516  << spirv::stringifyVersion(*maxVersion) << "\n");
1517  return false;
1518  }
1519  }
1520 
1521  // Make sure this op's required extensions are allowed to use. Ops not
1522  // implementing QueryExtensionInterface do not require extensions to be
1523  // available.
1524  if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1525  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1526  extensions.getExtensions())))
1527  return false;
1528 
1529  // Make sure this op's required extensions are allowed to use. Ops not
1530  // implementing QueryCapabilityInterface do not require capabilities to be
1531  // available.
1532  if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1533  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1534  capabilities.getCapabilities())))
1535  return false;
1536 
1537  SmallVector<Type, 4> valueTypes;
1538  valueTypes.append(op->operand_type_begin(), op->operand_type_end());
1539  valueTypes.append(op->result_type_begin(), op->result_type_end());
1540 
1541  // Ensure that all types have been converted to SPIRV types.
1542  if (llvm::any_of(valueTypes,
1543  [](Type t) { return !isa<spirv::SPIRVType>(t); }))
1544  return false;
1545 
1546  // Special treatment for global variables, whose type requirements are
1547  // conveyed by type attributes.
1548  if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1549  valueTypes.push_back(globalVar.getType());
1550 
1551  // Make sure the op's operands/results use types that are allowed by the
1552  // target environment.
1553  SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
1554  SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
1555  for (Type valueType : valueTypes) {
1556  typeExtensions.clear();
1557  cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1558  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1559  typeExtensions)))
1560  return false;
1561 
1562  typeCapabilities.clear();
1563  cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1564  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1565  typeCapabilities)))
1566  return false;
1567  }
1568 
1569  return true;
1570 }
1571 
1572 //===----------------------------------------------------------------------===//
1573 // Public functions for populating patterns
1574 //===----------------------------------------------------------------------===//
1575 
1577  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1578  patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
1579 }
1580 
1582  patterns.add<FuncOpVectorUnroll>(patterns.getContext());
1583 }
1584 
1586  patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
1587 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
Block represents an ordered list of Operations.
Definition: Block.h:33
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:162
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:203
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:193
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:240
FloatType getF32Type()
Definition: Builders.cpp:87
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:120
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
MLIRContext * getContext() const
Definition: Builders.h:55
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteStrictness strictMode
Strict mode can restrict the ops that are added to the worklist during the rewrite.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Stores a 1:N mapping of types and provides several useful accessors.
TypeRange getConvertedTypes(unsigned originalTypeNo) const
Returns the list of types that corresponds to the original type at the given index.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:248
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:328
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:346
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
operand_type_iterator operand_type_end()
Definition: Operation.h:391
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_iterator result_type_end()
Definition: Operation.h:422
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
result_type_iterator result_type_begin()
Definition: Operation.h:421
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:577
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:423
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
operand_type_iterator operand_type_begin()
Definition: Operation.h:390
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:102
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal replacement value...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting an illegal (source) value...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:87
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:102
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:406
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:463
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:961
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
Definition: TargetAndABI.h:62
MLIRContext * getContext() const
Returns the MLIRContext.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
OpFoldResult linearizeIndex(ArrayRef< OpFoldResult > multiIndex, ArrayRef< OpFoldResult > basis, ImplicitLocOpBuilder &builder)
Definition: Utils.cpp:2006
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
int getComputeVectorSize(int64_t size)
LogicalResult unrollVectorsInSignatures(Operation *op)
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of leading one dimension removal patterns.
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
@ ExistingOps
Only pre-existing ops are processed.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
Structure to control the behavior of vector transform patterns.
VectorTransformsOptions & setVectorTransposeLowering(VectorTransposeLowering opt)