MLIR  21.0.0git
SPIRVConversion.cpp
Go to the documentation of this file.
1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Support/LLVM.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/SmallVector.h"
34 #include "llvm/ADT/StringExtras.h"
35 #include "llvm/Support/Debug.h"
36 #include "llvm/Support/LogicalResult.h"
37 #include "llvm/Support/MathExtras.h"
38 
39 #include <functional>
40 #include <optional>
41 
42 #define DEBUG_TYPE "mlir-spirv-conversion"
43 
44 using namespace mlir;
45 
46 namespace {
47 
48 //===----------------------------------------------------------------------===//
49 // Utility functions
50 //===----------------------------------------------------------------------===//
51 
52 static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
53  LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
54  if (vecType.isScalable()) {
55  LLVM_DEBUG(llvm::dbgs()
56  << "--scalable vectors are not supported -> BAIL\n");
57  return std::nullopt;
58  }
59  SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
60  std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(
61  1, mlir::spirv::getComputeVectorSize(vecType.getShape().back()));
62  if (!targetShape) {
63  LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
64  return std::nullopt;
65  }
66  auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
67  if (!maybeShapeRatio) {
68  LLVM_DEBUG(llvm::dbgs()
69  << "--could not compute integral shape ratio -> BAIL\n");
70  return std::nullopt;
71  }
72  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
73  LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
74  return std::nullopt;
75  }
76  LLVM_DEBUG(llvm::dbgs()
77  << "--found an integral shape ratio to unroll to -> SUCCESS\n");
78  return targetShape;
79 }
80 
81 /// Checks that `candidates` extension requirements are possible to be satisfied
82 /// with the given `targetEnv`.
83 ///
84 /// `candidates` is a vector of vector for extension requirements following
85 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
86 /// convention.
87 template <typename LabelT>
88 static LogicalResult checkExtensionRequirements(
89  LabelT label, const spirv::TargetEnv &targetEnv,
91  for (const auto &ors : candidates) {
92  if (targetEnv.allows(ors))
93  continue;
94 
95  LLVM_DEBUG({
96  SmallVector<StringRef> extStrings;
97  for (spirv::Extension ext : ors)
98  extStrings.push_back(spirv::stringifyExtension(ext));
99 
100  llvm::dbgs() << label << " illegal: requires at least one extension in ["
101  << llvm::join(extStrings, ", ")
102  << "] but none allowed in target environment\n";
103  });
104  return failure();
105  }
106  return success();
107 }
108 
109 /// Checks that `candidates`capability requirements are possible to be satisfied
110 /// with the given `isAllowedFn`.
111 ///
112 /// `candidates` is a vector of vector for capability requirements following
113 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
114 /// convention.
115 template <typename LabelT>
116 static LogicalResult checkCapabilityRequirements(
117  LabelT label, const spirv::TargetEnv &targetEnv,
118  const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
119  for (const auto &ors : candidates) {
120  if (targetEnv.allows(ors))
121  continue;
122 
123  LLVM_DEBUG({
124  SmallVector<StringRef> capStrings;
125  for (spirv::Capability cap : ors)
126  capStrings.push_back(spirv::stringifyCapability(cap));
127 
128  llvm::dbgs() << label << " illegal: requires at least one capability in ["
129  << llvm::join(capStrings, ", ")
130  << "] but none allowed in target environment\n";
131  });
132  return failure();
133  }
134  return success();
135 }
136 
137 /// Returns true if the given `storageClass` needs explicit layout when used in
138 /// Shader environments.
139 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
140  switch (storageClass) {
141  case spirv::StorageClass::PhysicalStorageBuffer:
142  case spirv::StorageClass::PushConstant:
143  case spirv::StorageClass::StorageBuffer:
144  case spirv::StorageClass::Uniform:
145  return true;
146  default:
147  return false;
148  }
149 }
150 
151 /// Wraps the given `elementType` in a struct and gets the pointer to the
152 /// struct. This is used to satisfy Vulkan interface requirements.
153 static spirv::PointerType
154 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
155  auto structType = needsExplicitLayout(storageClass)
156  ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
157  : spirv::StructType::get(elementType);
158  return spirv::PointerType::get(structType, storageClass);
159 }
160 
161 //===----------------------------------------------------------------------===//
162 // Type Conversion
163 //===----------------------------------------------------------------------===//
164 
165 static spirv::ScalarType getIndexType(MLIRContext *ctx,
167  return cast<spirv::ScalarType>(
168  IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
169 }
170 
171 // TODO: This is a utility function that should probably be exposed by the
172 // SPIR-V dialect. Keeping it local till the use case arises.
173 static std::optional<int64_t>
174 getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
175  if (isa<spirv::ScalarType>(type)) {
176  auto bitWidth = type.getIntOrFloatBitWidth();
177  // According to the SPIR-V spec:
178  // "There is no physical size or bit pattern defined for values with boolean
179  // type. If they are stored (in conjunction with OpVariable), they can only
180  // be used with logical addressing operations, not physical, and only with
181  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
182  // Private, Function, Input, and Output."
183  if (bitWidth == 1)
184  return std::nullopt;
185  return bitWidth / 8;
186  }
187 
188  if (auto complexType = dyn_cast<ComplexType>(type)) {
189  auto elementSize = getTypeNumBytes(options, complexType.getElementType());
190  if (!elementSize)
191  return std::nullopt;
192  return 2 * *elementSize;
193  }
194 
195  if (auto vecType = dyn_cast<VectorType>(type)) {
196  auto elementSize = getTypeNumBytes(options, vecType.getElementType());
197  if (!elementSize)
198  return std::nullopt;
199  return vecType.getNumElements() * *elementSize;
200  }
201 
202  if (auto memRefType = dyn_cast<MemRefType>(type)) {
203  // TODO: Layout should also be controlled by the ABI attributes. For now
204  // using the layout from MemRef.
205  int64_t offset;
206  SmallVector<int64_t, 4> strides;
207  if (!memRefType.hasStaticShape() ||
208  failed(memRefType.getStridesAndOffset(strides, offset)))
209  return std::nullopt;
210 
211  // To get the size of the memref object in memory, the total size is the
212  // max(stride * dimension-size) computed for all dimensions times the size
213  // of the element.
214  auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
215  if (!elementSize)
216  return std::nullopt;
217 
218  if (memRefType.getRank() == 0)
219  return elementSize;
220 
221  auto dims = memRefType.getShape();
222  if (llvm::is_contained(dims, ShapedType::kDynamic) ||
223  ShapedType::isDynamic(offset) ||
224  llvm::is_contained(strides, ShapedType::kDynamic))
225  return std::nullopt;
226 
227  int64_t memrefSize = -1;
228  for (const auto &shape : enumerate(dims))
229  memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
230 
231  return (offset + memrefSize) * *elementSize;
232  }
233 
234  if (auto tensorType = dyn_cast<TensorType>(type)) {
235  if (!tensorType.hasStaticShape())
236  return std::nullopt;
237 
238  auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
239  if (!elementSize)
240  return std::nullopt;
241 
242  int64_t size = *elementSize;
243  for (auto shape : tensorType.getShape())
244  size *= shape;
245 
246  return size;
247  }
248 
249  // TODO: Add size computation for other types.
250  return std::nullopt;
251 }
252 
253 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
254 static Type
255 convertScalarType(const spirv::TargetEnv &targetEnv,
257  std::optional<spirv::StorageClass> storageClass = {}) {
258  // Get extension and capability requirements for the given type.
261  type.getExtensions(extensions, storageClass);
262  type.getCapabilities(capabilities, storageClass);
263 
264  // If all requirements are met, then we can accept this type as-is.
265  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
266  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
267  return type;
268 
269  // Otherwise we need to adjust the type, which really means adjusting the
270  // bitwidth given this is a scalar type.
271  if (!options.emulateLT32BitScalarTypes)
272  return nullptr;
273 
274  // We only emulate narrower scalar types here and do not truncate results.
275  if (type.getIntOrFloatBitWidth() > 32) {
276  LLVM_DEBUG(llvm::dbgs()
277  << type
278  << " not converted to 32-bit for SPIR-V to avoid truncation\n");
279  return nullptr;
280  }
281 
282  if (auto floatType = dyn_cast<FloatType>(type)) {
283  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
284  return Builder(targetEnv.getContext()).getF32Type();
285  }
286 
287  auto intType = cast<IntegerType>(type);
288  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
289  return IntegerType::get(targetEnv.getContext(), /*width=*/32,
290  intType.getSignedness());
291 }
292 
293 /// Converts a sub-byte integer `type` to i32 regardless of target environment.
294 /// Returns a nullptr for unsupported integer types, including non sub-byte
295 /// types.
296 ///
297 /// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
298 /// the above given that these sub-byte types are not supported at all in
299 /// SPIR-V; there are no compute/storage capability for them like other
300 /// supported integer types.
301 static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
302  IntegerType type) {
303  if (type.getWidth() > 8) {
304  LLVM_DEBUG(llvm::dbgs() << "not a subbyte type\n");
305  return nullptr;
306  }
307  if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
308  LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
309  return nullptr;
310  }
311 
312  if (!llvm::isPowerOf2_32(type.getWidth())) {
313  LLVM_DEBUG(llvm::dbgs()
314  << "unsupported non-power-of-two bitwidth in sub-byte" << type
315  << "\n");
316  return nullptr;
317  }
318 
319  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
320  return IntegerType::get(type.getContext(), /*width=*/32,
321  type.getSignedness());
322 }
323 
324 /// Returns a type with the same shape but with any index element type converted
325 /// to the matching integer type. This is a noop when the element type is not
326 /// the index type.
327 static ShapedType
328 convertIndexElementType(ShapedType type,
330  Type indexType = dyn_cast<IndexType>(type.getElementType());
331  if (!indexType)
332  return type;
333 
334  return type.clone(getIndexType(type.getContext(), options));
335 }
336 
337 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
338 static Type
339 convertVectorType(const spirv::TargetEnv &targetEnv,
340  const SPIRVConversionOptions &options, VectorType type,
341  std::optional<spirv::StorageClass> storageClass = {}) {
342  type = cast<VectorType>(convertIndexElementType(type, options));
343  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
344  if (!scalarType) {
345  // If this is not a spec allowed scalar type, try to handle sub-byte integer
346  // types.
347  auto intType = dyn_cast<IntegerType>(type.getElementType());
348  if (!intType) {
349  LLVM_DEBUG(llvm::dbgs()
350  << type
351  << " illegal: cannot convert non-scalar element type\n");
352  return nullptr;
353  }
354 
355  Type elementType = convertSubByteIntegerType(options, intType);
356  if (!elementType)
357  return nullptr;
358 
359  if (type.getRank() <= 1 && type.getNumElements() == 1)
360  return elementType;
361 
362  if (type.getNumElements() > 4) {
363  LLVM_DEBUG(llvm::dbgs()
364  << type << " illegal: > 4-element unimplemented\n");
365  return nullptr;
366  }
367 
368  return VectorType::get(type.getShape(), elementType);
369  }
370 
371  if (type.getRank() <= 1 && type.getNumElements() == 1)
372  return convertScalarType(targetEnv, options, scalarType, storageClass);
373 
374  if (!spirv::CompositeType::isValid(type)) {
375  LLVM_DEBUG(llvm::dbgs()
376  << type << " illegal: not a valid composite type\n");
377  return nullptr;
378  }
379 
380  // Get extension and capability requirements for the given type.
383  cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
384  cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
385 
386  // If all requirements are met, then we can accept this type as-is.
387  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
388  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
389  return type;
390 
391  auto elementType =
392  convertScalarType(targetEnv, options, scalarType, storageClass);
393  if (elementType)
394  return VectorType::get(type.getShape(), elementType);
395  return nullptr;
396 }
397 
398 static Type
399 convertComplexType(const spirv::TargetEnv &targetEnv,
400  const SPIRVConversionOptions &options, ComplexType type,
401  std::optional<spirv::StorageClass> storageClass = {}) {
402  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
403  if (!scalarType) {
404  LLVM_DEBUG(llvm::dbgs()
405  << type << " illegal: cannot convert non-scalar element type\n");
406  return nullptr;
407  }
408 
409  auto elementType =
410  convertScalarType(targetEnv, options, scalarType, storageClass);
411  if (!elementType)
412  return nullptr;
413  if (elementType != type.getElementType()) {
414  LLVM_DEBUG(llvm::dbgs()
415  << type << " illegal: complex type emulation unsupported\n");
416  return nullptr;
417  }
418 
419  return VectorType::get(2, elementType);
420 }
421 
422 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
423 ///
424 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
425 /// create composite constants with OpConstantComposite to embed relative large
426 /// constant values and use OpCompositeExtract and OpCompositeInsert to
427 /// manipulate, like what we do for vectors.
428 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
430  TensorType type) {
431  // TODO: Handle dynamic shapes.
432  if (!type.hasStaticShape()) {
433  LLVM_DEBUG(llvm::dbgs()
434  << type << " illegal: dynamic shape unimplemented\n");
435  return nullptr;
436  }
437 
438  type = cast<TensorType>(convertIndexElementType(type, options));
439  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
440  if (!scalarType) {
441  LLVM_DEBUG(llvm::dbgs()
442  << type << " illegal: cannot convert non-scalar element type\n");
443  return nullptr;
444  }
445 
446  std::optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
447  std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);
448  if (!scalarSize || !tensorSize) {
449  LLVM_DEBUG(llvm::dbgs()
450  << type << " illegal: cannot deduce element count\n");
451  return nullptr;
452  }
453 
454  int64_t arrayElemCount = *tensorSize / *scalarSize;
455  if (arrayElemCount == 0) {
456  LLVM_DEBUG(llvm::dbgs()
457  << type << " illegal: cannot handle zero-element tensors\n");
458  return nullptr;
459  }
460 
461  Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
462  if (!arrayElemType)
463  return nullptr;
464  std::optional<int64_t> arrayElemSize =
465  getTypeNumBytes(options, arrayElemType);
466  if (!arrayElemSize) {
467  LLVM_DEBUG(llvm::dbgs()
468  << type << " illegal: cannot deduce converted element size\n");
469  return nullptr;
470  }
471 
472  return spirv::ArrayType::get(arrayElemType, arrayElemCount);
473 }
474 
475 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
477  MemRefType type,
478  spirv::StorageClass storageClass) {
479  unsigned numBoolBits = options.boolNumBits;
480  if (numBoolBits != 8) {
481  LLVM_DEBUG(llvm::dbgs()
482  << "using non-8-bit storage for bool types unimplemented");
483  return nullptr;
484  }
485  auto elementType = dyn_cast<spirv::ScalarType>(
486  IntegerType::get(type.getContext(), numBoolBits));
487  if (!elementType)
488  return nullptr;
489  Type arrayElemType =
490  convertScalarType(targetEnv, options, elementType, storageClass);
491  if (!arrayElemType)
492  return nullptr;
493  std::optional<int64_t> arrayElemSize =
494  getTypeNumBytes(options, arrayElemType);
495  if (!arrayElemSize) {
496  LLVM_DEBUG(llvm::dbgs()
497  << type << " illegal: cannot deduce converted element size\n");
498  return nullptr;
499  }
500 
501  if (!type.hasStaticShape()) {
502  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
503  // to the element.
504  if (targetEnv.allows(spirv::Capability::Kernel))
505  return spirv::PointerType::get(arrayElemType, storageClass);
506  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
507  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
508  // For Vulkan we need extra wrapping struct and array to satisfy interface
509  // needs.
510  return wrapInStructAndGetPointer(arrayType, storageClass);
511  }
512 
513  if (type.getNumElements() == 0) {
514  LLVM_DEBUG(llvm::dbgs()
515  << type << " illegal: zero-element memrefs are not supported\n");
516  return nullptr;
517  }
518 
519  int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
520  int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
521  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
522  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
523  if (targetEnv.allows(spirv::Capability::Kernel))
524  return spirv::PointerType::get(arrayType, storageClass);
525  return wrapInStructAndGetPointer(arrayType, storageClass);
526 }
527 
528 static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
530  MemRefType type,
531  spirv::StorageClass storageClass) {
532  IntegerType elementType = cast<IntegerType>(type.getElementType());
533  Type arrayElemType = convertSubByteIntegerType(options, elementType);
534  if (!arrayElemType)
535  return nullptr;
536  int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType);
537 
538  if (!type.hasStaticShape()) {
539  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
540  // to the element.
541  if (targetEnv.allows(spirv::Capability::Kernel))
542  return spirv::PointerType::get(arrayElemType, storageClass);
543  int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
544  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
545  // For Vulkan we need extra wrapping struct and array to satisfy interface
546  // needs.
547  return wrapInStructAndGetPointer(arrayType, storageClass);
548  }
549 
550  if (type.getNumElements() == 0) {
551  LLVM_DEBUG(llvm::dbgs()
552  << type << " illegal: zero-element memrefs are not supported\n");
553  return nullptr;
554  }
555 
556  int64_t memrefSize =
557  llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
558  int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
559  int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
560  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
561  if (targetEnv.allows(spirv::Capability::Kernel))
562  return spirv::PointerType::get(arrayType, storageClass);
563  return wrapInStructAndGetPointer(arrayType, storageClass);
564 }
565 
566 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
568  MemRefType type) {
569  auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
570  if (!attr) {
571  LLVM_DEBUG(
572  llvm::dbgs()
573  << type
574  << " illegal: expected memory space to be a SPIR-V storage class "
575  "attribute; please use MemorySpaceToStorageClassConverter to map "
576  "numeric memory spaces beforehand\n");
577  return nullptr;
578  }
579  spirv::StorageClass storageClass = attr.getValue();
580 
581  if (isa<IntegerType>(type.getElementType())) {
582  if (type.getElementTypeBitWidth() == 1)
583  return convertBoolMemrefType(targetEnv, options, type, storageClass);
584  if (type.getElementTypeBitWidth() < 8)
585  return convertSubByteMemrefType(targetEnv, options, type, storageClass);
586  }
587 
588  Type arrayElemType;
589  Type elementType = type.getElementType();
590  if (auto vecType = dyn_cast<VectorType>(elementType)) {
591  arrayElemType =
592  convertVectorType(targetEnv, options, vecType, storageClass);
593  } else if (auto complexType = dyn_cast<ComplexType>(elementType)) {
594  arrayElemType =
595  convertComplexType(targetEnv, options, complexType, storageClass);
596  } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
597  arrayElemType =
598  convertScalarType(targetEnv, options, scalarType, storageClass);
599  } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
600  type = cast<MemRefType>(convertIndexElementType(type, options));
601  arrayElemType = type.getElementType();
602  } else {
603  LLVM_DEBUG(
604  llvm::dbgs()
605  << type
606  << " unhandled: can only convert scalar or vector element type\n");
607  return nullptr;
608  }
609  if (!arrayElemType)
610  return nullptr;
611 
612  std::optional<int64_t> arrayElemSize =
613  getTypeNumBytes(options, arrayElemType);
614  if (!arrayElemSize) {
615  LLVM_DEBUG(llvm::dbgs()
616  << type << " illegal: cannot deduce converted element size\n");
617  return nullptr;
618  }
619 
620  if (!type.hasStaticShape()) {
621  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
622  // to the element.
623  if (targetEnv.allows(spirv::Capability::Kernel))
624  return spirv::PointerType::get(arrayElemType, storageClass);
625  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
626  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
627  // For Vulkan we need extra wrapping struct and array to satisfy interface
628  // needs.
629  return wrapInStructAndGetPointer(arrayType, storageClass);
630  }
631 
632  std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);
633  if (!memrefSize) {
634  LLVM_DEBUG(llvm::dbgs()
635  << type << " illegal: cannot deduce element count\n");
636  return nullptr;
637  }
638 
639  if (*memrefSize == 0) {
640  LLVM_DEBUG(llvm::dbgs()
641  << type << " illegal: zero-element memrefs are not supported\n");
642  return nullptr;
643  }
644 
645  int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
646  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
647  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
648  if (targetEnv.allows(spirv::Capability::Kernel))
649  return spirv::PointerType::get(arrayType, storageClass);
650  return wrapInStructAndGetPointer(arrayType, storageClass);
651 }
652 
653 //===----------------------------------------------------------------------===//
654 // Type casting materialization
655 //===----------------------------------------------------------------------===//
656 
657 /// Converts the given `inputs` to the original source `type` considering the
658 /// `targetEnv`'s capabilities.
659 ///
660 /// This function is meant to be used for source materialization in type
661 /// converters. When the type converter needs to materialize a cast op back
662 /// to some original source type, we need to check whether the original source
663 /// type is supported in the target environment. If so, we can insert legal
664 /// SPIR-V cast ops accordingly.
665 ///
666 /// Note that in SPIR-V the capabilities for storage and compute are separate.
667 /// This function is meant to handle the **compute** side; so it does not
668 /// involve storage classes in its logic. The storage side is expected to be
669 /// handled by MemRef conversion logic.
670 static Value castToSourceType(const spirv::TargetEnv &targetEnv,
671  OpBuilder &builder, Type type, ValueRange inputs,
672  Location loc) {
673  // We can only cast one value in SPIR-V.
674  if (inputs.size() != 1) {
675  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
676  return castOp.getResult(0);
677  }
678  Value input = inputs.front();
679 
680  // Only support integer types for now. Floating point types to be implemented.
681  if (!isa<IntegerType>(type)) {
682  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
683  return castOp.getResult(0);
684  }
685  auto inputType = cast<IntegerType>(input.getType());
686 
687  auto scalarType = dyn_cast<spirv::ScalarType>(type);
688  if (!scalarType) {
689  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
690  return castOp.getResult(0);
691  }
692 
693  // Only support source type with a smaller bitwidth. This would mean we are
694  // truncating to go back so we don't need to worry about the signedness.
695  // For extension, we cannot have enough signal here to decide which op to use.
696  if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
697  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
698  return castOp.getResult(0);
699  }
700 
701  // Boolean values would need to use different ops than normal integer values.
702  if (type.isInteger(1)) {
703  Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
704  return builder.create<spirv::IEqualOp>(loc, input, one);
705  }
706 
707  // Check that the source integer type is supported by the environment.
710  scalarType.getExtensions(exts);
711  scalarType.getCapabilities(caps);
712  if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
713  failed(checkExtensionRequirements(type, targetEnv, exts))) {
714  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
715  return castOp.getResult(0);
716  }
717 
718  // We've already made sure this is truncating previously, so we don't need to
719  // care about signedness here. Still try to use a corresponding op for better
720  // consistency though.
721  if (type.isSignedInteger()) {
722  return builder.create<spirv::SConvertOp>(loc, type, input);
723  }
724  return builder.create<spirv::UConvertOp>(loc, type, input);
725 }
726 
727 //===----------------------------------------------------------------------===//
728 // Builtin Variables
729 //===----------------------------------------------------------------------===//
730 
731 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
732  spirv::BuiltIn builtin) {
733  // Look through all global variables in the given `body` block and check if
734  // there is a spirv.GlobalVariable that has the same `builtin` attribute.
735  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
736  if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
737  spirv::SPIRVDialect::getAttributeName(
738  spirv::Decoration::BuiltIn))) {
739  auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
740  if (varBuiltIn && *varBuiltIn == builtin) {
741  return varOp;
742  }
743  }
744  }
745  return nullptr;
746 }
747 
748 /// Gets name of global variable for a builtin.
749 std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
750  StringRef suffix) {
751  return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
752 }
753 
754 /// Gets or inserts a global variable for a builtin within `body` block.
755 static spirv::GlobalVariableOp
756 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
757  Type integerType, OpBuilder &builder,
758  StringRef prefix, StringRef suffix) {
759  if (auto varOp = getBuiltinVariable(body, builtin))
760  return varOp;
761 
762  OpBuilder::InsertionGuard guard(builder);
763  builder.setInsertionPointToStart(&body);
764 
765  spirv::GlobalVariableOp newVarOp;
766  switch (builtin) {
767  case spirv::BuiltIn::NumWorkgroups:
768  case spirv::BuiltIn::WorkgroupSize:
769  case spirv::BuiltIn::WorkgroupId:
770  case spirv::BuiltIn::LocalInvocationId:
771  case spirv::BuiltIn::GlobalInvocationId: {
772  auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
773  spirv::StorageClass::Input);
774  std::string name = getBuiltinVarName(builtin, prefix, suffix);
775  newVarOp =
776  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
777  break;
778  }
779  case spirv::BuiltIn::SubgroupId:
780  case spirv::BuiltIn::NumSubgroups:
781  case spirv::BuiltIn::SubgroupSize: {
782  auto ptrType =
783  spirv::PointerType::get(integerType, spirv::StorageClass::Input);
784  std::string name = getBuiltinVarName(builtin, prefix, suffix);
785  newVarOp =
786  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
787  break;
788  }
789  default:
790  emitError(loc, "unimplemented builtin variable generation for ")
791  << stringifyBuiltIn(builtin);
792  }
793  return newVarOp;
794 }
795 
796 //===----------------------------------------------------------------------===//
797 // Push constant storage
798 //===----------------------------------------------------------------------===//
799 
800 /// Returns the pointer type for the push constant storage containing
801 /// `elementCount` 32-bit integer values.
802 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
803  Builder &builder,
804  Type indexType) {
805  auto arrayType = spirv::ArrayType::get(indexType, elementCount,
806  /*stride=*/4);
807  auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
808  return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
809 }
810 
811 /// Returns the push constant varible containing `elementCount` 32-bit integer
812 /// values in `body`. Returns null op if such an op does not exit.
813 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
814  unsigned elementCount) {
815  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
816  auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
817  if (!ptrType)
818  continue;
819 
820  // Note that Vulkan requires "There must be no more than one push constant
821  // block statically used per shader entry point." So we should always reuse
822  // the existing one.
823  if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
824  auto numElements = cast<spirv::ArrayType>(
825  cast<spirv::StructType>(ptrType.getPointeeType())
826  .getElementType(0))
827  .getNumElements();
828  if (numElements == elementCount)
829  return varOp;
830  }
831  }
832  return nullptr;
833 }
834 
835 /// Gets or inserts a global variable for push constant storage containing
836 /// `elementCount` 32-bit integer values in `block`.
837 static spirv::GlobalVariableOp
838 getOrInsertPushConstantVariable(Location loc, Block &block,
839  unsigned elementCount, OpBuilder &b,
840  Type indexType) {
841  if (auto varOp = getPushConstantVariable(block, elementCount))
842  return varOp;
843 
844  auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
845  auto type = getPushConstantStorageType(elementCount, builder, indexType);
846  const char *name = "__push_constant_var__";
847  return builder.create<spirv::GlobalVariableOp>(loc, type, name,
848  /*initializer=*/nullptr);
849 }
850 
851 //===----------------------------------------------------------------------===//
852 // func::FuncOp Conversion Patterns
853 //===----------------------------------------------------------------------===//
854 
855 /// A pattern for rewriting function signature to convert arguments of functions
856 /// to be of valid SPIR-V types.
857 struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
859 
860  LogicalResult
861  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
862  ConversionPatternRewriter &rewriter) const override {
863  FunctionType fnType = funcOp.getFunctionType();
864  if (fnType.getNumResults() > 1)
865  return failure();
866 
867  TypeConverter::SignatureConversion signatureConverter(
868  fnType.getNumInputs());
869  for (const auto &argType : enumerate(fnType.getInputs())) {
870  auto convertedType = getTypeConverter()->convertType(argType.value());
871  if (!convertedType)
872  return failure();
873  signatureConverter.addInputs(argType.index(), convertedType);
874  }
875 
876  Type resultType;
877  if (fnType.getNumResults() == 1) {
878  resultType = getTypeConverter()->convertType(fnType.getResult(0));
879  if (!resultType)
880  return failure();
881  }
882 
883  // Create the converted spirv.func op.
884  auto newFuncOp = rewriter.create<spirv::FuncOp>(
885  funcOp.getLoc(), funcOp.getName(),
886  rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
887  resultType ? TypeRange(resultType)
888  : TypeRange()));
889 
890  // Copy over all attributes other than the function name and type.
891  for (const auto &namedAttr : funcOp->getAttrs()) {
892  if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
893  namedAttr.getName() != SymbolTable::getSymbolAttrName())
894  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
895  }
896 
897  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
898  newFuncOp.end());
899  if (failed(rewriter.convertRegionTypes(
900  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
901  return failure();
902  rewriter.eraseOp(funcOp);
903  return success();
904  }
905 };
906 
907 /// A pattern for rewriting function signature to convert vector arguments of
908 /// functions to be of valid types
909 struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
911 
912  LogicalResult matchAndRewrite(func::FuncOp funcOp,
913  PatternRewriter &rewriter) const override {
914  FunctionType fnType = funcOp.getFunctionType();
915 
916  // TODO: Handle declarations.
917  if (funcOp.isDeclaration()) {
918  LLVM_DEBUG(llvm::dbgs()
919  << fnType << " illegal: declarations are unsupported\n");
920  return failure();
921  }
922 
923  // Create a new func op with the original type and copy the function body.
924  auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(),
925  funcOp.getName(), fnType);
926  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
927  newFuncOp.end());
928 
929  Location loc = newFuncOp.getBody().getLoc();
930 
931  Block &entryBlock = newFuncOp.getBlocks().front();
932  OpBuilder::InsertionGuard guard(rewriter);
933  rewriter.setInsertionPointToStart(&entryBlock);
934 
935  TypeConverter::SignatureConversion oneToNTypeMapping(
936  fnType.getInputs().size());
937 
938  // For arguments that are of illegal types and require unrolling.
939  // `unrolledInputNums` stores the indices of arguments that result from
940  // unrolling in the new function signature. `newInputNo` is a counter.
941  SmallVector<size_t> unrolledInputNums;
942  size_t newInputNo = 0;
943 
944  // For arguments that are of legal types and do not require unrolling.
945  // `tmpOps` stores a mapping from temporary operations that serve as
946  // placeholders for new arguments that will be added later. These operations
947  // will be erased once the entry block's argument list is updated.
948  llvm::SmallDenseMap<Operation *, size_t> tmpOps;
949 
950  // This counts the number of new operations created.
951  size_t newOpCount = 0;
952 
953  // Enumerate through the arguments.
954  for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
955  // Check whether the argument is of vector type.
956  auto origVecType = dyn_cast<VectorType>(origType);
957  if (!origVecType) {
958  // We need a placeholder for the old argument that will be erased later.
959  Value result = rewriter.create<arith::ConstantOp>(
960  loc, origType, rewriter.getZeroAttr(origType));
961  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
962  tmpOps.insert({result.getDefiningOp(), newInputNo});
963  oneToNTypeMapping.addInputs(origInputNo, origType);
964  ++newInputNo;
965  ++newOpCount;
966  continue;
967  }
968  // Check whether the vector needs unrolling.
969  auto targetShape = getTargetShape(origVecType);
970  if (!targetShape) {
971  // We need a placeholder for the old argument that will be erased later.
972  Value result = rewriter.create<arith::ConstantOp>(
973  loc, origType, rewriter.getZeroAttr(origType));
974  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
975  tmpOps.insert({result.getDefiningOp(), newInputNo});
976  oneToNTypeMapping.addInputs(origInputNo, origType);
977  ++newInputNo;
978  ++newOpCount;
979  continue;
980  }
981  VectorType unrolledType =
982  VectorType::get(*targetShape, origVecType.getElementType());
983  auto originalShape =
984  llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
985 
986  // Prepare the result vector.
987  Value result = rewriter.create<arith::ConstantOp>(
988  loc, origVecType, rewriter.getZeroAttr(origVecType));
989  ++newOpCount;
990  // Prepare the placeholder for the new arguments that will be added later.
991  Value dummy = rewriter.create<arith::ConstantOp>(
992  loc, unrolledType, rewriter.getZeroAttr(unrolledType));
993  ++newOpCount;
994 
995  // Create the `vector.insert_strided_slice` ops.
996  SmallVector<int64_t> strides(targetShape->size(), 1);
997  SmallVector<Type> newTypes;
998  for (SmallVector<int64_t> offsets :
999  StaticTileOffsetRange(originalShape, *targetShape)) {
1000  result = rewriter.create<vector::InsertStridedSliceOp>(
1001  loc, dummy, result, offsets, strides);
1002  newTypes.push_back(unrolledType);
1003  unrolledInputNums.push_back(newInputNo);
1004  ++newInputNo;
1005  ++newOpCount;
1006  }
1007  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1008  oneToNTypeMapping.addInputs(origInputNo, newTypes);
1009  }
1010 
1011  // Change the function signature.
1012  auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
1013  auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1014  rewriter.modifyOpInPlace(newFuncOp,
1015  [&] { newFuncOp.setFunctionType(newFnType); });
1016 
1017  // Update the arguments in the entry block.
1018  entryBlock.eraseArguments(0, fnType.getNumInputs());
1019  SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
1020  entryBlock.addArguments(convertedTypes, locs);
1021 
1022  // Replace the placeholder values with the new arguments. We assume there is
1023  // only one block for now.
1024  size_t unrolledInputIdx = 0;
1025  for (auto [count, op] : enumerate(entryBlock.getOperations())) {
1026  // We first look for operands that are placeholders for initially legal
1027  // arguments.
1028  Operation &curOp = op;
1029  for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
1030  Operation *operandOp = operandVal.getDefiningOp();
1031  if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
1032  size_t idx = operandIdx;
1033  rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
1034  curOp.setOperand(idx, newFuncOp.getArgument(it->second));
1035  });
1036  }
1037  }
1038  // Since all newly created operations are in the beginning, reaching the
1039  // end of them means that any later `vector.insert_strided_slice` should
1040  // not be touched.
1041  if (count >= newOpCount)
1042  continue;
1043  if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1044  size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1045  rewriter.modifyOpInPlace(&curOp, [&] {
1046  curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1047  });
1048  ++unrolledInputIdx;
1049  }
1050  }
1051 
1052  // Erase the original funcOp. The `tmpOps` do not need to be erased since
1053  // they have no uses and will be handled by dead-code elimination.
1054  rewriter.eraseOp(funcOp);
1055  return success();
1056  }
1057 };
1058 
1059 //===----------------------------------------------------------------------===//
1060 // func::ReturnOp Conversion Patterns
1061 //===----------------------------------------------------------------------===//
1062 
1063 /// A pattern for rewriting function signature and the return op to convert
1064 /// vectors to be of valid types.
1065 struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
1067 
1068  LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1069  PatternRewriter &rewriter) const override {
1070  // Check whether the parent funcOp is valid.
1071  auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
1072  if (!funcOp)
1073  return failure();
1074 
1075  FunctionType fnType = funcOp.getFunctionType();
1076  TypeConverter::SignatureConversion oneToNTypeMapping(
1077  fnType.getResults().size());
1078  Location loc = returnOp.getLoc();
1079 
1080  // For the new return op.
1081  SmallVector<Value> newOperands;
1082 
1083  // Enumerate through the results.
1084  for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
1085  // Check whether the argument is of vector type.
1086  auto origVecType = dyn_cast<VectorType>(origType);
1087  if (!origVecType) {
1088  oneToNTypeMapping.addInputs(origResultNo, origType);
1089  newOperands.push_back(returnOp.getOperand(origResultNo));
1090  continue;
1091  }
1092  // Check whether the vector needs unrolling.
1093  auto targetShape = getTargetShape(origVecType);
1094  if (!targetShape) {
1095  // The original argument can be used.
1096  oneToNTypeMapping.addInputs(origResultNo, origType);
1097  newOperands.push_back(returnOp.getOperand(origResultNo));
1098  continue;
1099  }
1100  VectorType unrolledType =
1101  VectorType::get(*targetShape, origVecType.getElementType());
1102 
1103  // Create `vector.extract_strided_slice` ops to form legal vectors from
1104  // the original operand of illegal type.
1105  auto originalShape =
1106  llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1107  SmallVector<int64_t> strides(originalShape.size(), 1);
1108  SmallVector<int64_t> extractShape(originalShape.size(), 1);
1109  extractShape.back() = targetShape->back();
1110  SmallVector<Type> newTypes;
1111  Value returnValue = returnOp.getOperand(origResultNo);
1112  for (SmallVector<int64_t> offsets :
1113  StaticTileOffsetRange(originalShape, *targetShape)) {
1114  Value result = rewriter.create<vector::ExtractStridedSliceOp>(
1115  loc, returnValue, offsets, extractShape, strides);
1116  if (originalShape.size() > 1) {
1117  SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
1118  result =
1119  rewriter.create<vector::ExtractOp>(loc, result, extractIndices);
1120  }
1121  newOperands.push_back(result);
1122  newTypes.push_back(unrolledType);
1123  }
1124  oneToNTypeMapping.addInputs(origResultNo, newTypes);
1125  }
1126 
1127  // Change the function signature.
1128  auto newFnType =
1129  FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
1130  TypeRange(oneToNTypeMapping.getConvertedTypes()));
1131  rewriter.modifyOpInPlace(funcOp,
1132  [&] { funcOp.setFunctionType(newFnType); });
1133 
1134  // Replace the return op using the new operands. This will automatically
1135  // update the entry block as well.
1136  rewriter.replaceOp(returnOp,
1137  rewriter.create<func::ReturnOp>(loc, newOperands));
1138 
1139  return success();
1140  }
1141 };
1142 
1143 } // namespace
1144 
1145 //===----------------------------------------------------------------------===//
1146 // Public function for builtin variables
1147 //===----------------------------------------------------------------------===//
1148 
1150  spirv::BuiltIn builtin,
1151  Type integerType, OpBuilder &builder,
1152  StringRef prefix, StringRef suffix) {
1154  if (!parent) {
1155  op->emitError("expected operation to be within a module-like op");
1156  return nullptr;
1157  }
1158 
1159  spirv::GlobalVariableOp varOp =
1160  getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
1161  builtin, integerType, builder, prefix, suffix);
1162  Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
1163  return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
1164 }
1165 
1166 //===----------------------------------------------------------------------===//
1167 // Public function for pushing constant storage
1168 //===----------------------------------------------------------------------===//
1169 
1170 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
1171  unsigned offset, Type integerType,
1172  OpBuilder &builder) {
1173  Location loc = op->getLoc();
1175  if (!parent) {
1176  op->emitError("expected operation to be within a module-like op");
1177  return nullptr;
1178  }
1179 
1180  spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1181  loc, parent->getRegion(0).front(), elementCount, builder, integerType);
1182 
1183  Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
1184  Value offsetOp = builder.create<spirv::ConstantOp>(
1185  loc, integerType, builder.getI32IntegerAttr(offset));
1186  auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
1187  auto acOp = builder.create<spirv::AccessChainOp>(
1188  loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp}));
1189  return builder.create<spirv::LoadOp>(loc, acOp);
1190 }
1191 
1192 //===----------------------------------------------------------------------===//
1193 // Public functions for index calculation
1194 //===----------------------------------------------------------------------===//
1195 
1197  int64_t offset, Type integerType,
1198  Location loc, OpBuilder &builder) {
1199  assert(indices.size() == strides.size() &&
1200  "must provide indices for all dimensions");
1201 
1202  // TODO: Consider moving to use affine.apply and patterns converting
1203  // affine.apply to standard ops. This needs converting to SPIR-V passes to be
1204  // broken down into progressive small steps so we can have intermediate steps
1205  // using other dialects. At the moment SPIR-V is the final sink.
1206 
1207  Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
1208  loc, integerType, IntegerAttr::get(integerType, offset));
1209  for (const auto &index : llvm::enumerate(indices)) {
1210  Value strideVal = builder.createOrFold<spirv::ConstantOp>(
1211  loc, integerType,
1212  IntegerAttr::get(integerType, strides[index.index()]));
1213  Value update =
1214  builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1215  linearizedIndex =
1216  builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1217  }
1218  return linearizedIndex;
1219 }
1220 
1222  MemRefType baseType, Value basePtr,
1223  ValueRange indices, Location loc,
1224  OpBuilder &builder) {
1225  // Get base and offset of the MemRefType and verify they are static.
1226 
1227  int64_t offset;
1228  SmallVector<int64_t, 4> strides;
1229  if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1230  llvm::is_contained(strides, ShapedType::kDynamic) ||
1231  ShapedType::isDynamic(offset)) {
1232  return nullptr;
1233  }
1234 
1235  auto indexType = typeConverter.getIndexType();
1236 
1237  SmallVector<Value, 2> linearizedIndices;
1238  auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
1239 
1240  // Add a '0' at the start to index into the struct.
1241  linearizedIndices.push_back(zero);
1242 
1243  if (baseType.getRank() == 0) {
1244  linearizedIndices.push_back(zero);
1245  } else {
1246  linearizedIndices.push_back(
1247  linearizeIndex(indices, strides, offset, indexType, loc, builder));
1248  }
1249  return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
1250 }
1251 
1253  MemRefType baseType, Value basePtr,
1254  ValueRange indices, Location loc,
1255  OpBuilder &builder) {
1256  // Get base and offset of the MemRefType and verify they are static.
1257 
1258  int64_t offset;
1259  SmallVector<int64_t, 4> strides;
1260  if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1261  llvm::is_contained(strides, ShapedType::kDynamic) ||
1262  ShapedType::isDynamic(offset)) {
1263  return nullptr;
1264  }
1265 
1266  auto indexType = typeConverter.getIndexType();
1267 
1268  SmallVector<Value, 2> linearizedIndices;
1269  Value linearIndex;
1270  if (baseType.getRank() == 0) {
1271  linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
1272  } else {
1273  linearIndex =
1274  linearizeIndex(indices, strides, offset, indexType, loc, builder);
1275  }
1276  Type pointeeType =
1277  cast<spirv::PointerType>(basePtr.getType()).getPointeeType();
1278  if (isa<spirv::ArrayType>(pointeeType)) {
1279  linearizedIndices.push_back(linearIndex);
1280  return builder.create<spirv::AccessChainOp>(loc, basePtr,
1281  linearizedIndices);
1282  }
1283  return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
1284  linearizedIndices);
1285 }
1286 
1288  MemRefType baseType, Value basePtr,
1289  ValueRange indices, Location loc,
1290  OpBuilder &builder) {
1291 
1292  if (typeConverter.allows(spirv::Capability::Kernel)) {
1293  return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
1294  builder);
1295  }
1296 
1297  return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
1298  builder);
1299 }
1300 
1301 //===----------------------------------------------------------------------===//
1302 // Public functions for vector unrolling
1303 //===----------------------------------------------------------------------===//
1304 
1306  for (int i : {4, 3, 2}) {
1307  if (size % i == 0)
1308  return i;
1309  }
1310  return 1;
1311 }
1312 
1314 mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
1315  VectorType srcVectorType = op.getSourceVectorType();
1316  assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
1317  int64_t vectorSize =
1318  mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
1319  return {vectorSize};
1320 }
1321 
1323 mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) {
1324  VectorType vectorType = op.getResultVectorType();
1325  SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
1326  nativeSize.back() =
1327  mlir::spirv::getComputeVectorSize(vectorType.getShape().back());
1328  return nativeSize;
1329 }
1330 
1331 std::optional<SmallVector<int64_t>>
1333  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
1334  if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
1335  SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
1336  nativeSize.back() =
1337  mlir::spirv::getComputeVectorSize(vecType.getShape().back());
1338  return nativeSize;
1339  }
1340  }
1341 
1343  .Case<vector::ReductionOp, vector::TransposeOp>(
1344  [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
1345  .Default([](Operation *) { return std::nullopt; });
1346 }
1347 
1349  MLIRContext *context = op->getContext();
1350  RewritePatternSet patterns(context);
1353  // We only want to apply signature conversion once to the existing func ops.
1354  // Without specifying strictMode, the greedy pattern rewriter will keep
1355  // looking for newly created func ops.
1358  return applyPatternsGreedily(op, std::move(patterns), config);
1359 }
1360 
1362  MLIRContext *context = op->getContext();
1363 
1364  // Unroll vectors in function bodies to native vector size.
1365  {
1366  RewritePatternSet patterns(context);
1368  [](auto op) { return mlir::spirv::getNativeVectorShape(op); });
1370  if (failed(applyPatternsGreedily(op, std::move(patterns))))
1371  return failure();
1372  }
1373 
1374  // Convert transpose ops into extract and insert pairs, in preparation of
1375  // further transformations to canonicalize/cancel.
1376  {
1377  RewritePatternSet patterns(context);
1379  patterns, vector::VectorTransposeLowering::EltWise);
1381  if (failed(applyPatternsGreedily(op, std::move(patterns))))
1382  return failure();
1383  }
1384 
1385  // Run canonicalization to cast away leading size-1 dimensions.
1386  {
1387  RewritePatternSet patterns(context);
1388 
1389  // We need to pull in casting way leading one dims.
1391  vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
1392  vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
1393 
1394  // Decompose different rank insert_strided_slice and n-D
1395  // extract_slided_slice.
1397  patterns);
1398  vector::InsertOp::getCanonicalizationPatterns(patterns, context);
1399  vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
1400 
1401  // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
1402  // them up.
1403  vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
1404  vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
1405 
1406  if (failed(applyPatternsGreedily(op, std::move(patterns))))
1407  return failure();
1408  }
1409  return success();
1410 }
1411 
1412 //===----------------------------------------------------------------------===//
1413 // SPIR-V TypeConverter
1414 //===----------------------------------------------------------------------===//
1415 
1418  : targetEnv(targetAttr), options(options) {
1419  // Add conversions. The order matters here: later ones will be tried earlier.
1420 
1421  // Allow all SPIR-V dialect specific types. This assumes all builtin types
1422  // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
1423  // were tried before.
1424  //
1425  // TODO: This assumes that the SPIR-V types are valid to use in the given
1426  // target environment, which should be the case if the whole pipeline is
1427  // driven by the same target environment. Still, we probably still want to
1428  // validate and convert to be safe.
1429  addConversion([](spirv::SPIRVType type) { return type; });
1430 
1431  addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
1432 
1433  addConversion([this](IntegerType intType) -> std::optional<Type> {
1434  if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
1435  return convertScalarType(this->targetEnv, this->options, scalarType);
1436  if (intType.getWidth() < 8)
1437  return convertSubByteIntegerType(this->options, intType);
1438  return Type();
1439  });
1440 
1441  addConversion([this](FloatType floatType) -> std::optional<Type> {
1442  if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
1443  return convertScalarType(this->targetEnv, this->options, scalarType);
1444  return Type();
1445  });
1446 
1447  addConversion([this](ComplexType complexType) {
1448  return convertComplexType(this->targetEnv, this->options, complexType);
1449  });
1450 
1451  addConversion([this](VectorType vectorType) {
1452  return convertVectorType(this->targetEnv, this->options, vectorType);
1453  });
1454 
1455  addConversion([this](TensorType tensorType) {
1456  return convertTensorType(this->targetEnv, this->options, tensorType);
1457  });
1458 
1459  addConversion([this](MemRefType memRefType) {
1460  return convertMemrefType(this->targetEnv, this->options, memRefType);
1461  });
1462 
1463  // Register some last line of defense casting logic.
1465  [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1466  return castToSourceType(this->targetEnv, builder, type, inputs, loc);
1467  });
1468  addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
1469  Location loc) {
1470  auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
1471  return cast.getResult(0);
1472  });
1473 }
1474 
1476  return ::getIndexType(getContext(), options);
1477 }
1478 
1479 MLIRContext *SPIRVTypeConverter::getContext() const {
1480  return targetEnv.getAttr().getContext();
1481 }
1482 
1483 bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
1484  return targetEnv.allows(capability);
1485 }
1486 
1487 //===----------------------------------------------------------------------===//
1488 // SPIR-V ConversionTarget
1489 //===----------------------------------------------------------------------===//
1490 
1491 std::unique_ptr<SPIRVConversionTarget>
1493  std::unique_ptr<SPIRVConversionTarget> target(
1494  // std::make_unique does not work here because the constructor is private.
1495  new SPIRVConversionTarget(targetAttr));
1496  SPIRVConversionTarget *targetPtr = target.get();
1497  target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1498  // We need to capture the raw pointer here because it is stable:
1499  // target will be destroyed once this function is returned.
1500  [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
1501  return target;
1502 }
1503 
1504 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
1505  : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
1506 
1507 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
1508  // Make sure this op is available at the given version. Ops not implementing
1509  // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
1510  // SPIR-V versions.
1511  if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1512  std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1513  if (minVersion && *minVersion > this->targetEnv.getVersion()) {
1514  LLVM_DEBUG(llvm::dbgs()
1515  << op->getName() << " illegal: requiring min version "
1516  << spirv::stringifyVersion(*minVersion) << "\n");
1517  return false;
1518  }
1519  }
1520  if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1521  std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1522  if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
1523  LLVM_DEBUG(llvm::dbgs()
1524  << op->getName() << " illegal: requiring max version "
1525  << spirv::stringifyVersion(*maxVersion) << "\n");
1526  return false;
1527  }
1528  }
1529 
1530  // Make sure this op's required extensions are allowed to use. Ops not
1531  // implementing QueryExtensionInterface do not require extensions to be
1532  // available.
1533  if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1534  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1535  extensions.getExtensions())))
1536  return false;
1537 
1538  // Make sure this op's required extensions are allowed to use. Ops not
1539  // implementing QueryCapabilityInterface do not require capabilities to be
1540  // available.
1541  if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1542  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1543  capabilities.getCapabilities())))
1544  return false;
1545 
1546  SmallVector<Type, 4> valueTypes;
1547  valueTypes.append(op->operand_type_begin(), op->operand_type_end());
1548  valueTypes.append(op->result_type_begin(), op->result_type_end());
1549 
1550  // Ensure that all types have been converted to SPIRV types.
1551  if (llvm::any_of(valueTypes,
1552  [](Type t) { return !isa<spirv::SPIRVType>(t); }))
1553  return false;
1554 
1555  // Special treatment for global variables, whose type requirements are
1556  // conveyed by type attributes.
1557  if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1558  valueTypes.push_back(globalVar.getType());
1559 
1560  // Make sure the op's operands/results use types that are allowed by the
1561  // target environment.
1562  SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
1563  SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
1564  for (Type valueType : valueTypes) {
1565  typeExtensions.clear();
1566  cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1567  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1568  typeExtensions)))
1569  return false;
1570 
1571  typeCapabilities.clear();
1572  cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1573  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1574  typeCapabilities)))
1575  return false;
1576  }
1577 
1578  return true;
1579 }
1580 
1581 //===----------------------------------------------------------------------===//
1582 // Public functions for populating patterns
1583 //===----------------------------------------------------------------------===//
1584 
1586  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1587  patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
1588 }
1589 
1591  patterns.add<FuncOpVectorUnroll>(patterns.getContext());
1592 }
1593 
1595  patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
1596 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
Block represents an ordered list of Operations.
Definition: Block.h:33
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:162
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:203
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:193
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:196
FloatType getF32Type()
Definition: Builders.cpp:43
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:76
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:238
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
operand_type_iterator operand_type_end()
Definition: Operation.h:396
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_iterator result_type_end()
Definition: Operation.h:427
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
result_type_iterator result_type_begin()
Definition: Operation.h:426
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
operand_type_iterator operand_type_begin()
Definition: Operation.h:395
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:606
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:598
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:102
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:406
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:463
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:961
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
Definition: TargetAndABI.h:62
MLIRContext * getContext() const
Returns the MLIRContext.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
OpFoldResult linearizeIndex(ArrayRef< OpFoldResult > multiIndex, ArrayRef< OpFoldResult > basis, ImplicitLocOpBuilder &builder)
Definition: Utils.cpp:2016
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
int getComputeVectorSize(int64_t size)
LogicalResult unrollVectorsInSignatures(Operation *op)
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of leading one dimension removal patterns.
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
@ ExistingOps
Only pre-existing ops are processed.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:318
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:323
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)