MLIR  20.0.0git
SPIRVConversion.cpp
Go to the documentation of this file.
1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/Pass/Pass.h"
29 #include "mlir/Support/LLVM.h"
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/LogicalResult.h"
38 #include "llvm/Support/MathExtras.h"
39 
40 #include <functional>
41 #include <optional>
42 
43 #define DEBUG_TYPE "mlir-spirv-conversion"
44 
45 using namespace mlir;
46 
47 namespace {
48 
49 //===----------------------------------------------------------------------===//
50 // Utility functions
51 //===----------------------------------------------------------------------===//
52 
53 static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
54  LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
55  if (vecType.isScalable()) {
56  LLVM_DEBUG(llvm::dbgs()
57  << "--scalable vectors are not supported -> BAIL\n");
58  return std::nullopt;
59  }
60  SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
61  std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(
62  1, mlir::spirv::getComputeVectorSize(vecType.getShape().back()));
63  if (!targetShape) {
64  LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
65  return std::nullopt;
66  }
67  auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
68  if (!maybeShapeRatio) {
69  LLVM_DEBUG(llvm::dbgs()
70  << "--could not compute integral shape ratio -> BAIL\n");
71  return std::nullopt;
72  }
73  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
74  LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
75  return std::nullopt;
76  }
77  LLVM_DEBUG(llvm::dbgs()
78  << "--found an integral shape ratio to unroll to -> SUCCESS\n");
79  return targetShape;
80 }
81 
82 /// Checks that `candidates` extension requirements are possible to be satisfied
83 /// with the given `targetEnv`.
84 ///
85 /// `candidates` is a vector of vector for extension requirements following
86 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
87 /// convention.
88 template <typename LabelT>
89 static LogicalResult checkExtensionRequirements(
90  LabelT label, const spirv::TargetEnv &targetEnv,
92  for (const auto &ors : candidates) {
93  if (targetEnv.allows(ors))
94  continue;
95 
96  LLVM_DEBUG({
97  SmallVector<StringRef> extStrings;
98  for (spirv::Extension ext : ors)
99  extStrings.push_back(spirv::stringifyExtension(ext));
100 
101  llvm::dbgs() << label << " illegal: requires at least one extension in ["
102  << llvm::join(extStrings, ", ")
103  << "] but none allowed in target environment\n";
104  });
105  return failure();
106  }
107  return success();
108 }
109 
110 /// Checks that `candidates`capability requirements are possible to be satisfied
111 /// with the given `isAllowedFn`.
112 ///
113 /// `candidates` is a vector of vector for capability requirements following
114 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
115 /// convention.
116 template <typename LabelT>
117 static LogicalResult checkCapabilityRequirements(
118  LabelT label, const spirv::TargetEnv &targetEnv,
119  const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
120  for (const auto &ors : candidates) {
121  if (targetEnv.allows(ors))
122  continue;
123 
124  LLVM_DEBUG({
125  SmallVector<StringRef> capStrings;
126  for (spirv::Capability cap : ors)
127  capStrings.push_back(spirv::stringifyCapability(cap));
128 
129  llvm::dbgs() << label << " illegal: requires at least one capability in ["
130  << llvm::join(capStrings, ", ")
131  << "] but none allowed in target environment\n";
132  });
133  return failure();
134  }
135  return success();
136 }
137 
138 /// Returns true if the given `storageClass` needs explicit layout when used in
139 /// Shader environments.
140 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
141  switch (storageClass) {
142  case spirv::StorageClass::PhysicalStorageBuffer:
143  case spirv::StorageClass::PushConstant:
144  case spirv::StorageClass::StorageBuffer:
145  case spirv::StorageClass::Uniform:
146  return true;
147  default:
148  return false;
149  }
150 }
151 
152 /// Wraps the given `elementType` in a struct and gets the pointer to the
153 /// struct. This is used to satisfy Vulkan interface requirements.
154 static spirv::PointerType
155 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
156  auto structType = needsExplicitLayout(storageClass)
157  ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
158  : spirv::StructType::get(elementType);
159  return spirv::PointerType::get(structType, storageClass);
160 }
161 
162 //===----------------------------------------------------------------------===//
163 // Type Conversion
164 //===----------------------------------------------------------------------===//
165 
166 static spirv::ScalarType getIndexType(MLIRContext *ctx,
168  return cast<spirv::ScalarType>(
169  IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
170 }
171 
172 // TODO: This is a utility function that should probably be exposed by the
173 // SPIR-V dialect. Keeping it local till the use case arises.
174 static std::optional<int64_t>
175 getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
176  if (isa<spirv::ScalarType>(type)) {
177  auto bitWidth = type.getIntOrFloatBitWidth();
178  // According to the SPIR-V spec:
179  // "There is no physical size or bit pattern defined for values with boolean
180  // type. If they are stored (in conjunction with OpVariable), they can only
181  // be used with logical addressing operations, not physical, and only with
182  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
183  // Private, Function, Input, and Output."
184  if (bitWidth == 1)
185  return std::nullopt;
186  return bitWidth / 8;
187  }
188 
189  if (auto complexType = dyn_cast<ComplexType>(type)) {
190  auto elementSize = getTypeNumBytes(options, complexType.getElementType());
191  if (!elementSize)
192  return std::nullopt;
193  return 2 * *elementSize;
194  }
195 
196  if (auto vecType = dyn_cast<VectorType>(type)) {
197  auto elementSize = getTypeNumBytes(options, vecType.getElementType());
198  if (!elementSize)
199  return std::nullopt;
200  return vecType.getNumElements() * *elementSize;
201  }
202 
203  if (auto memRefType = dyn_cast<MemRefType>(type)) {
204  // TODO: Layout should also be controlled by the ABI attributes. For now
205  // using the layout from MemRef.
206  int64_t offset;
207  SmallVector<int64_t, 4> strides;
208  if (!memRefType.hasStaticShape() ||
209  failed(getStridesAndOffset(memRefType, strides, offset)))
210  return std::nullopt;
211 
212  // To get the size of the memref object in memory, the total size is the
213  // max(stride * dimension-size) computed for all dimensions times the size
214  // of the element.
215  auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
216  if (!elementSize)
217  return std::nullopt;
218 
219  if (memRefType.getRank() == 0)
220  return elementSize;
221 
222  auto dims = memRefType.getShape();
223  if (llvm::is_contained(dims, ShapedType::kDynamic) ||
224  ShapedType::isDynamic(offset) ||
225  llvm::is_contained(strides, ShapedType::kDynamic))
226  return std::nullopt;
227 
228  int64_t memrefSize = -1;
229  for (const auto &shape : enumerate(dims))
230  memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
231 
232  return (offset + memrefSize) * *elementSize;
233  }
234 
235  if (auto tensorType = dyn_cast<TensorType>(type)) {
236  if (!tensorType.hasStaticShape())
237  return std::nullopt;
238 
239  auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
240  if (!elementSize)
241  return std::nullopt;
242 
243  int64_t size = *elementSize;
244  for (auto shape : tensorType.getShape())
245  size *= shape;
246 
247  return size;
248  }
249 
250  // TODO: Add size computation for other types.
251  return std::nullopt;
252 }
253 
254 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
255 static Type
256 convertScalarType(const spirv::TargetEnv &targetEnv,
258  std::optional<spirv::StorageClass> storageClass = {}) {
259  // Get extension and capability requirements for the given type.
262  type.getExtensions(extensions, storageClass);
263  type.getCapabilities(capabilities, storageClass);
264 
265  // If all requirements are met, then we can accept this type as-is.
266  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
267  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
268  return type;
269 
270  // Otherwise we need to adjust the type, which really means adjusting the
271  // bitwidth given this is a scalar type.
272  if (!options.emulateLT32BitScalarTypes)
273  return nullptr;
274 
275  // We only emulate narrower scalar types here and do not truncate results.
276  if (type.getIntOrFloatBitWidth() > 32) {
277  LLVM_DEBUG(llvm::dbgs()
278  << type
279  << " not converted to 32-bit for SPIR-V to avoid truncation\n");
280  return nullptr;
281  }
282 
283  if (auto floatType = dyn_cast<FloatType>(type)) {
284  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
285  return Builder(targetEnv.getContext()).getF32Type();
286  }
287 
288  auto intType = cast<IntegerType>(type);
289  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
290  return IntegerType::get(targetEnv.getContext(), /*width=*/32,
291  intType.getSignedness());
292 }
293 
294 /// Converts a sub-byte integer `type` to i32 regardless of target environment.
295 /// Returns a nullptr for unsupported integer types, including non sub-byte
296 /// types.
297 ///
298 /// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
299 /// the above given that these sub-byte types are not supported at all in
300 /// SPIR-V; there are no compute/storage capability for them like other
301 /// supported integer types.
302 static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
303  IntegerType type) {
304  if (type.getWidth() > 8) {
305  LLVM_DEBUG(llvm::dbgs() << "not a subbyte type\n");
306  return nullptr;
307  }
308  if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
309  LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
310  return nullptr;
311  }
312 
313  if (!llvm::isPowerOf2_32(type.getWidth())) {
314  LLVM_DEBUG(llvm::dbgs()
315  << "unsupported non-power-of-two bitwidth in sub-byte" << type
316  << "\n");
317  return nullptr;
318  }
319 
320  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
321  return IntegerType::get(type.getContext(), /*width=*/32,
322  type.getSignedness());
323 }
324 
325 /// Returns a type with the same shape but with any index element type converted
326 /// to the matching integer type. This is a noop when the element type is not
327 /// the index type.
328 static ShapedType
329 convertIndexElementType(ShapedType type,
331  Type indexType = dyn_cast<IndexType>(type.getElementType());
332  if (!indexType)
333  return type;
334 
335  return type.clone(getIndexType(type.getContext(), options));
336 }
337 
338 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
339 static Type
340 convertVectorType(const spirv::TargetEnv &targetEnv,
341  const SPIRVConversionOptions &options, VectorType type,
342  std::optional<spirv::StorageClass> storageClass = {}) {
343  type = cast<VectorType>(convertIndexElementType(type, options));
344  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
345  if (!scalarType) {
346  // If this is not a spec allowed scalar type, try to handle sub-byte integer
347  // types.
348  auto intType = dyn_cast<IntegerType>(type.getElementType());
349  if (!intType) {
350  LLVM_DEBUG(llvm::dbgs()
351  << type
352  << " illegal: cannot convert non-scalar element type\n");
353  return nullptr;
354  }
355 
356  Type elementType = convertSubByteIntegerType(options, intType);
357  if (!elementType)
358  return nullptr;
359 
360  if (type.getRank() <= 1 && type.getNumElements() == 1)
361  return elementType;
362 
363  if (type.getNumElements() > 4) {
364  LLVM_DEBUG(llvm::dbgs()
365  << type << " illegal: > 4-element unimplemented\n");
366  return nullptr;
367  }
368 
369  return VectorType::get(type.getShape(), elementType);
370  }
371 
372  if (type.getRank() <= 1 && type.getNumElements() == 1)
373  return convertScalarType(targetEnv, options, scalarType, storageClass);
374 
375  if (!spirv::CompositeType::isValid(type)) {
376  LLVM_DEBUG(llvm::dbgs()
377  << type << " illegal: not a valid composite type\n");
378  return nullptr;
379  }
380 
381  // Get extension and capability requirements for the given type.
384  cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
385  cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
386 
387  // If all requirements are met, then we can accept this type as-is.
388  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
389  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
390  return type;
391 
392  auto elementType =
393  convertScalarType(targetEnv, options, scalarType, storageClass);
394  if (elementType)
395  return VectorType::get(type.getShape(), elementType);
396  return nullptr;
397 }
398 
399 static Type
400 convertComplexType(const spirv::TargetEnv &targetEnv,
401  const SPIRVConversionOptions &options, ComplexType type,
402  std::optional<spirv::StorageClass> storageClass = {}) {
403  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
404  if (!scalarType) {
405  LLVM_DEBUG(llvm::dbgs()
406  << type << " illegal: cannot convert non-scalar element type\n");
407  return nullptr;
408  }
409 
410  auto elementType =
411  convertScalarType(targetEnv, options, scalarType, storageClass);
412  if (!elementType)
413  return nullptr;
414  if (elementType != type.getElementType()) {
415  LLVM_DEBUG(llvm::dbgs()
416  << type << " illegal: complex type emulation unsupported\n");
417  return nullptr;
418  }
419 
420  return VectorType::get(2, elementType);
421 }
422 
423 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
424 ///
425 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
426 /// create composite constants with OpConstantComposite to embed relative large
427 /// constant values and use OpCompositeExtract and OpCompositeInsert to
428 /// manipulate, like what we do for vectors.
429 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
431  TensorType type) {
432  // TODO: Handle dynamic shapes.
433  if (!type.hasStaticShape()) {
434  LLVM_DEBUG(llvm::dbgs()
435  << type << " illegal: dynamic shape unimplemented\n");
436  return nullptr;
437  }
438 
439  type = cast<TensorType>(convertIndexElementType(type, options));
440  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
441  if (!scalarType) {
442  LLVM_DEBUG(llvm::dbgs()
443  << type << " illegal: cannot convert non-scalar element type\n");
444  return nullptr;
445  }
446 
447  std::optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
448  std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);
449  if (!scalarSize || !tensorSize) {
450  LLVM_DEBUG(llvm::dbgs()
451  << type << " illegal: cannot deduce element count\n");
452  return nullptr;
453  }
454 
455  int64_t arrayElemCount = *tensorSize / *scalarSize;
456  if (arrayElemCount == 0) {
457  LLVM_DEBUG(llvm::dbgs()
458  << type << " illegal: cannot handle zero-element tensors\n");
459  return nullptr;
460  }
461 
462  Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
463  if (!arrayElemType)
464  return nullptr;
465  std::optional<int64_t> arrayElemSize =
466  getTypeNumBytes(options, arrayElemType);
467  if (!arrayElemSize) {
468  LLVM_DEBUG(llvm::dbgs()
469  << type << " illegal: cannot deduce converted element size\n");
470  return nullptr;
471  }
472 
473  return spirv::ArrayType::get(arrayElemType, arrayElemCount);
474 }
475 
476 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
478  MemRefType type,
479  spirv::StorageClass storageClass) {
480  unsigned numBoolBits = options.boolNumBits;
481  if (numBoolBits != 8) {
482  LLVM_DEBUG(llvm::dbgs()
483  << "using non-8-bit storage for bool types unimplemented");
484  return nullptr;
485  }
486  auto elementType = dyn_cast<spirv::ScalarType>(
487  IntegerType::get(type.getContext(), numBoolBits));
488  if (!elementType)
489  return nullptr;
490  Type arrayElemType =
491  convertScalarType(targetEnv, options, elementType, storageClass);
492  if (!arrayElemType)
493  return nullptr;
494  std::optional<int64_t> arrayElemSize =
495  getTypeNumBytes(options, arrayElemType);
496  if (!arrayElemSize) {
497  LLVM_DEBUG(llvm::dbgs()
498  << type << " illegal: cannot deduce converted element size\n");
499  return nullptr;
500  }
501 
502  if (!type.hasStaticShape()) {
503  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
504  // to the element.
505  if (targetEnv.allows(spirv::Capability::Kernel))
506  return spirv::PointerType::get(arrayElemType, storageClass);
507  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
508  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
509  // For Vulkan we need extra wrapping struct and array to satisfy interface
510  // needs.
511  return wrapInStructAndGetPointer(arrayType, storageClass);
512  }
513 
514  if (type.getNumElements() == 0) {
515  LLVM_DEBUG(llvm::dbgs()
516  << type << " illegal: zero-element memrefs are not supported\n");
517  return nullptr;
518  }
519 
520  int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
521  int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
522  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
523  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
524  if (targetEnv.allows(spirv::Capability::Kernel))
525  return spirv::PointerType::get(arrayType, storageClass);
526  return wrapInStructAndGetPointer(arrayType, storageClass);
527 }
528 
529 static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
531  MemRefType type,
532  spirv::StorageClass storageClass) {
533  IntegerType elementType = cast<IntegerType>(type.getElementType());
534  Type arrayElemType = convertSubByteIntegerType(options, elementType);
535  if (!arrayElemType)
536  return nullptr;
537  int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType);
538 
539  if (!type.hasStaticShape()) {
540  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
541  // to the element.
542  if (targetEnv.allows(spirv::Capability::Kernel))
543  return spirv::PointerType::get(arrayElemType, storageClass);
544  int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
545  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
546  // For Vulkan we need extra wrapping struct and array to satisfy interface
547  // needs.
548  return wrapInStructAndGetPointer(arrayType, storageClass);
549  }
550 
551  if (type.getNumElements() == 0) {
552  LLVM_DEBUG(llvm::dbgs()
553  << type << " illegal: zero-element memrefs are not supported\n");
554  return nullptr;
555  }
556 
557  int64_t memrefSize =
558  llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
559  int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
560  int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
561  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
562  if (targetEnv.allows(spirv::Capability::Kernel))
563  return spirv::PointerType::get(arrayType, storageClass);
564  return wrapInStructAndGetPointer(arrayType, storageClass);
565 }
566 
567 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
569  MemRefType type) {
570  auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
571  if (!attr) {
572  LLVM_DEBUG(
573  llvm::dbgs()
574  << type
575  << " illegal: expected memory space to be a SPIR-V storage class "
576  "attribute; please use MemorySpaceToStorageClassConverter to map "
577  "numeric memory spaces beforehand\n");
578  return nullptr;
579  }
580  spirv::StorageClass storageClass = attr.getValue();
581 
582  if (isa<IntegerType>(type.getElementType())) {
583  if (type.getElementTypeBitWidth() == 1)
584  return convertBoolMemrefType(targetEnv, options, type, storageClass);
585  if (type.getElementTypeBitWidth() < 8)
586  return convertSubByteMemrefType(targetEnv, options, type, storageClass);
587  }
588 
589  Type arrayElemType;
590  Type elementType = type.getElementType();
591  if (auto vecType = dyn_cast<VectorType>(elementType)) {
592  arrayElemType =
593  convertVectorType(targetEnv, options, vecType, storageClass);
594  } else if (auto complexType = dyn_cast<ComplexType>(elementType)) {
595  arrayElemType =
596  convertComplexType(targetEnv, options, complexType, storageClass);
597  } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
598  arrayElemType =
599  convertScalarType(targetEnv, options, scalarType, storageClass);
600  } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
601  type = cast<MemRefType>(convertIndexElementType(type, options));
602  arrayElemType = type.getElementType();
603  } else {
604  LLVM_DEBUG(
605  llvm::dbgs()
606  << type
607  << " unhandled: can only convert scalar or vector element type\n");
608  return nullptr;
609  }
610  if (!arrayElemType)
611  return nullptr;
612 
613  std::optional<int64_t> arrayElemSize =
614  getTypeNumBytes(options, arrayElemType);
615  if (!arrayElemSize) {
616  LLVM_DEBUG(llvm::dbgs()
617  << type << " illegal: cannot deduce converted element size\n");
618  return nullptr;
619  }
620 
621  if (!type.hasStaticShape()) {
622  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
623  // to the element.
624  if (targetEnv.allows(spirv::Capability::Kernel))
625  return spirv::PointerType::get(arrayElemType, storageClass);
626  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
627  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
628  // For Vulkan we need extra wrapping struct and array to satisfy interface
629  // needs.
630  return wrapInStructAndGetPointer(arrayType, storageClass);
631  }
632 
633  std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);
634  if (!memrefSize) {
635  LLVM_DEBUG(llvm::dbgs()
636  << type << " illegal: cannot deduce element count\n");
637  return nullptr;
638  }
639 
640  if (*memrefSize == 0) {
641  LLVM_DEBUG(llvm::dbgs()
642  << type << " illegal: zero-element memrefs are not supported\n");
643  return nullptr;
644  }
645 
646  int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
647  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
648  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
649  if (targetEnv.allows(spirv::Capability::Kernel))
650  return spirv::PointerType::get(arrayType, storageClass);
651  return wrapInStructAndGetPointer(arrayType, storageClass);
652 }
653 
654 //===----------------------------------------------------------------------===//
655 // Type casting materialization
656 //===----------------------------------------------------------------------===//
657 
658 /// Converts the given `inputs` to the original source `type` considering the
659 /// `targetEnv`'s capabilities.
660 ///
661 /// This function is meant to be used for source materialization in type
662 /// converters. When the type converter needs to materialize a cast op back
663 /// to some original source type, we need to check whether the original source
664 /// type is supported in the target environment. If so, we can insert legal
665 /// SPIR-V cast ops accordingly.
666 ///
667 /// Note that in SPIR-V the capabilities for storage and compute are separate.
668 /// This function is meant to handle the **compute** side; so it does not
669 /// involve storage classes in its logic. The storage side is expected to be
670 /// handled by MemRef conversion logic.
671 static Value castToSourceType(const spirv::TargetEnv &targetEnv,
672  OpBuilder &builder, Type type, ValueRange inputs,
673  Location loc) {
674  // We can only cast one value in SPIR-V.
675  if (inputs.size() != 1) {
676  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
677  return castOp.getResult(0);
678  }
679  Value input = inputs.front();
680 
681  // Only support integer types for now. Floating point types to be implemented.
682  if (!isa<IntegerType>(type)) {
683  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
684  return castOp.getResult(0);
685  }
686  auto inputType = cast<IntegerType>(input.getType());
687 
688  auto scalarType = dyn_cast<spirv::ScalarType>(type);
689  if (!scalarType) {
690  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
691  return castOp.getResult(0);
692  }
693 
694  // Only support source type with a smaller bitwidth. This would mean we are
695  // truncating to go back so we don't need to worry about the signedness.
696  // For extension, we cannot have enough signal here to decide which op to use.
697  if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
698  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
699  return castOp.getResult(0);
700  }
701 
702  // Boolean values would need to use different ops than normal integer values.
703  if (type.isInteger(1)) {
704  Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
705  return builder.create<spirv::IEqualOp>(loc, input, one);
706  }
707 
708  // Check that the source integer type is supported by the environment.
711  scalarType.getExtensions(exts);
712  scalarType.getCapabilities(caps);
713  if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
714  failed(checkExtensionRequirements(type, targetEnv, exts))) {
715  auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
716  return castOp.getResult(0);
717  }
718 
719  // We've already made sure this is truncating previously, so we don't need to
720  // care about signedness here. Still try to use a corresponding op for better
721  // consistency though.
722  if (type.isSignedInteger()) {
723  return builder.create<spirv::SConvertOp>(loc, type, input);
724  }
725  return builder.create<spirv::UConvertOp>(loc, type, input);
726 }
727 
728 //===----------------------------------------------------------------------===//
729 // Builtin Variables
730 //===----------------------------------------------------------------------===//
731 
732 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
733  spirv::BuiltIn builtin) {
734  // Look through all global variables in the given `body` block and check if
735  // there is a spirv.GlobalVariable that has the same `builtin` attribute.
736  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
737  if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
738  spirv::SPIRVDialect::getAttributeName(
739  spirv::Decoration::BuiltIn))) {
740  auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
741  if (varBuiltIn && *varBuiltIn == builtin) {
742  return varOp;
743  }
744  }
745  }
746  return nullptr;
747 }
748 
749 /// Gets name of global variable for a builtin.
750 std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
751  StringRef suffix) {
752  return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
753 }
754 
755 /// Gets or inserts a global variable for a builtin within `body` block.
756 static spirv::GlobalVariableOp
757 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
758  Type integerType, OpBuilder &builder,
759  StringRef prefix, StringRef suffix) {
760  if (auto varOp = getBuiltinVariable(body, builtin))
761  return varOp;
762 
763  OpBuilder::InsertionGuard guard(builder);
764  builder.setInsertionPointToStart(&body);
765 
766  spirv::GlobalVariableOp newVarOp;
767  switch (builtin) {
768  case spirv::BuiltIn::NumWorkgroups:
769  case spirv::BuiltIn::WorkgroupSize:
770  case spirv::BuiltIn::WorkgroupId:
771  case spirv::BuiltIn::LocalInvocationId:
772  case spirv::BuiltIn::GlobalInvocationId: {
773  auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
774  spirv::StorageClass::Input);
775  std::string name = getBuiltinVarName(builtin, prefix, suffix);
776  newVarOp =
777  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
778  break;
779  }
780  case spirv::BuiltIn::SubgroupId:
781  case spirv::BuiltIn::NumSubgroups:
782  case spirv::BuiltIn::SubgroupSize: {
783  auto ptrType =
784  spirv::PointerType::get(integerType, spirv::StorageClass::Input);
785  std::string name = getBuiltinVarName(builtin, prefix, suffix);
786  newVarOp =
787  builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
788  break;
789  }
790  default:
791  emitError(loc, "unimplemented builtin variable generation for ")
792  << stringifyBuiltIn(builtin);
793  }
794  return newVarOp;
795 }
796 
797 //===----------------------------------------------------------------------===//
798 // Push constant storage
799 //===----------------------------------------------------------------------===//
800 
801 /// Returns the pointer type for the push constant storage containing
802 /// `elementCount` 32-bit integer values.
803 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
804  Builder &builder,
805  Type indexType) {
806  auto arrayType = spirv::ArrayType::get(indexType, elementCount,
807  /*stride=*/4);
808  auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
809  return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
810 }
811 
812 /// Returns the push constant varible containing `elementCount` 32-bit integer
813 /// values in `body`. Returns null op if such an op does not exit.
814 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
815  unsigned elementCount) {
816  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
817  auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
818  if (!ptrType)
819  continue;
820 
821  // Note that Vulkan requires "There must be no more than one push constant
822  // block statically used per shader entry point." So we should always reuse
823  // the existing one.
824  if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
825  auto numElements = cast<spirv::ArrayType>(
826  cast<spirv::StructType>(ptrType.getPointeeType())
827  .getElementType(0))
828  .getNumElements();
829  if (numElements == elementCount)
830  return varOp;
831  }
832  }
833  return nullptr;
834 }
835 
836 /// Gets or inserts a global variable for push constant storage containing
837 /// `elementCount` 32-bit integer values in `block`.
838 static spirv::GlobalVariableOp
839 getOrInsertPushConstantVariable(Location loc, Block &block,
840  unsigned elementCount, OpBuilder &b,
841  Type indexType) {
842  if (auto varOp = getPushConstantVariable(block, elementCount))
843  return varOp;
844 
845  auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
846  auto type = getPushConstantStorageType(elementCount, builder, indexType);
847  const char *name = "__push_constant_var__";
848  return builder.create<spirv::GlobalVariableOp>(loc, type, name,
849  /*initializer=*/nullptr);
850 }
851 
852 //===----------------------------------------------------------------------===//
853 // func::FuncOp Conversion Patterns
854 //===----------------------------------------------------------------------===//
855 
856 /// A pattern for rewriting function signature to convert arguments of functions
857 /// to be of valid SPIR-V types.
858 struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
860 
861  LogicalResult
862  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
863  ConversionPatternRewriter &rewriter) const override {
864  FunctionType fnType = funcOp.getFunctionType();
865  if (fnType.getNumResults() > 1)
866  return failure();
867 
868  TypeConverter::SignatureConversion signatureConverter(
869  fnType.getNumInputs());
870  for (const auto &argType : enumerate(fnType.getInputs())) {
871  auto convertedType = getTypeConverter()->convertType(argType.value());
872  if (!convertedType)
873  return failure();
874  signatureConverter.addInputs(argType.index(), convertedType);
875  }
876 
877  Type resultType;
878  if (fnType.getNumResults() == 1) {
879  resultType = getTypeConverter()->convertType(fnType.getResult(0));
880  if (!resultType)
881  return failure();
882  }
883 
884  // Create the converted spirv.func op.
885  auto newFuncOp = rewriter.create<spirv::FuncOp>(
886  funcOp.getLoc(), funcOp.getName(),
887  rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
888  resultType ? TypeRange(resultType)
889  : TypeRange()));
890 
891  // Copy over all attributes other than the function name and type.
892  for (const auto &namedAttr : funcOp->getAttrs()) {
893  if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
894  namedAttr.getName() != SymbolTable::getSymbolAttrName())
895  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
896  }
897 
898  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
899  newFuncOp.end());
900  if (failed(rewriter.convertRegionTypes(
901  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
902  return failure();
903  rewriter.eraseOp(funcOp);
904  return success();
905  }
906 };
907 
908 /// A pattern for rewriting function signature to convert vector arguments of
909 /// functions to be of valid types
910 struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
912 
913  LogicalResult matchAndRewrite(func::FuncOp funcOp,
914  PatternRewriter &rewriter) const override {
915  FunctionType fnType = funcOp.getFunctionType();
916 
917  // TODO: Handle declarations.
918  if (funcOp.isDeclaration()) {
919  LLVM_DEBUG(llvm::dbgs()
920  << fnType << " illegal: declarations are unsupported\n");
921  return failure();
922  }
923 
924  // Create a new func op with the original type and copy the function body.
925  auto newFuncOp = rewriter.create<func::FuncOp>(funcOp.getLoc(),
926  funcOp.getName(), fnType);
927  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
928  newFuncOp.end());
929 
930  Location loc = newFuncOp.getBody().getLoc();
931 
932  Block &entryBlock = newFuncOp.getBlocks().front();
933  OpBuilder::InsertionGuard guard(rewriter);
934  rewriter.setInsertionPointToStart(&entryBlock);
935 
936  OneToNTypeMapping oneToNTypeMapping(fnType.getInputs());
937 
938  // For arguments that are of illegal types and require unrolling.
939  // `unrolledInputNums` stores the indices of arguments that result from
940  // unrolling in the new function signature. `newInputNo` is a counter.
941  SmallVector<size_t> unrolledInputNums;
942  size_t newInputNo = 0;
943 
944  // For arguments that are of legal types and do not require unrolling.
945  // `tmpOps` stores a mapping from temporary operations that serve as
946  // placeholders for new arguments that will be added later. These operations
947  // will be erased once the entry block's argument list is updated.
948  llvm::SmallDenseMap<Operation *, size_t> tmpOps;
949 
950  // This counts the number of new operations created.
951  size_t newOpCount = 0;
952 
953  // Enumerate through the arguments.
954  for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
955  // Check whether the argument is of vector type.
956  auto origVecType = dyn_cast<VectorType>(origType);
957  if (!origVecType) {
958  // We need a placeholder for the old argument that will be erased later.
959  Value result = rewriter.create<arith::ConstantOp>(
960  loc, origType, rewriter.getZeroAttr(origType));
961  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
962  tmpOps.insert({result.getDefiningOp(), newInputNo});
963  oneToNTypeMapping.addInputs(origInputNo, origType);
964  ++newInputNo;
965  ++newOpCount;
966  continue;
967  }
968  // Check whether the vector needs unrolling.
969  auto targetShape = getTargetShape(origVecType);
970  if (!targetShape) {
971  // We need a placeholder for the old argument that will be erased later.
972  Value result = rewriter.create<arith::ConstantOp>(
973  loc, origType, rewriter.getZeroAttr(origType));
974  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
975  tmpOps.insert({result.getDefiningOp(), newInputNo});
976  oneToNTypeMapping.addInputs(origInputNo, origType);
977  ++newInputNo;
978  ++newOpCount;
979  continue;
980  }
981  VectorType unrolledType =
982  VectorType::get(*targetShape, origVecType.getElementType());
983  auto originalShape =
984  llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
985 
986  // Prepare the result vector.
987  Value result = rewriter.create<arith::ConstantOp>(
988  loc, origVecType, rewriter.getZeroAttr(origVecType));
989  ++newOpCount;
990  // Prepare the placeholder for the new arguments that will be added later.
991  Value dummy = rewriter.create<arith::ConstantOp>(
992  loc, unrolledType, rewriter.getZeroAttr(unrolledType));
993  ++newOpCount;
994 
995  // Create the `vector.insert_strided_slice` ops.
996  SmallVector<int64_t> strides(targetShape->size(), 1);
997  SmallVector<Type> newTypes;
998  for (SmallVector<int64_t> offsets :
999  StaticTileOffsetRange(originalShape, *targetShape)) {
1000  result = rewriter.create<vector::InsertStridedSliceOp>(
1001  loc, dummy, result, offsets, strides);
1002  newTypes.push_back(unrolledType);
1003  unrolledInputNums.push_back(newInputNo);
1004  ++newInputNo;
1005  ++newOpCount;
1006  }
1007  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1008  oneToNTypeMapping.addInputs(origInputNo, newTypes);
1009  }
1010 
1011  // Change the function signature.
1012  auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
1013  auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1014  rewriter.modifyOpInPlace(newFuncOp,
1015  [&] { newFuncOp.setFunctionType(newFnType); });
1016 
1017  // Update the arguments in the entry block.
1018  entryBlock.eraseArguments(0, fnType.getNumInputs());
1019  SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
1020  entryBlock.addArguments(convertedTypes, locs);
1021 
1022  // Replace the placeholder values with the new arguments. We assume there is
1023  // only one block for now.
1024  size_t unrolledInputIdx = 0;
1025  for (auto [count, op] : enumerate(entryBlock.getOperations())) {
1026  // We first look for operands that are placeholders for initially legal
1027  // arguments.
1028  Operation &curOp = op;
1029  for (auto [operandIdx, operandVal] : llvm::enumerate(op.getOperands())) {
1030  Operation *operandOp = operandVal.getDefiningOp();
1031  if (auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
1032  size_t idx = operandIdx;
1033  rewriter.modifyOpInPlace(&curOp, [&curOp, &newFuncOp, it, idx] {
1034  curOp.setOperand(idx, newFuncOp.getArgument(it->second));
1035  });
1036  }
1037  }
1038  // Since all newly created operations are in the beginning, reaching the
1039  // end of them means that any later `vector.insert_strided_slice` should
1040  // not be touched.
1041  if (count >= newOpCount)
1042  continue;
1043  if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1044  size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1045  rewriter.modifyOpInPlace(&curOp, [&] {
1046  curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1047  });
1048  ++unrolledInputIdx;
1049  }
1050  }
1051 
1052  // Erase the original funcOp. The `tmpOps` do not need to be erased since
1053  // they have no uses and will be handled by dead-code elimination.
1054  rewriter.eraseOp(funcOp);
1055  return success();
1056  }
1057 };
1058 
1059 //===----------------------------------------------------------------------===//
1060 // func::ReturnOp Conversion Patterns
1061 //===----------------------------------------------------------------------===//
1062 
1063 /// A pattern for rewriting function signature and the return op to convert
1064 /// vectors to be of valid types.
1065 struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
1067 
1068  LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1069  PatternRewriter &rewriter) const override {
1070  // Check whether the parent funcOp is valid.
1071  auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
1072  if (!funcOp)
1073  return failure();
1074 
1075  FunctionType fnType = funcOp.getFunctionType();
1076  OneToNTypeMapping oneToNTypeMapping(fnType.getResults());
1077  Location loc = returnOp.getLoc();
1078 
1079  // For the new return op.
1080  SmallVector<Value> newOperands;
1081 
1082  // Enumerate through the results.
1083  for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
1084  // Check whether the argument is of vector type.
1085  auto origVecType = dyn_cast<VectorType>(origType);
1086  if (!origVecType) {
1087  oneToNTypeMapping.addInputs(origResultNo, origType);
1088  newOperands.push_back(returnOp.getOperand(origResultNo));
1089  continue;
1090  }
1091  // Check whether the vector needs unrolling.
1092  auto targetShape = getTargetShape(origVecType);
1093  if (!targetShape) {
1094  // The original argument can be used.
1095  oneToNTypeMapping.addInputs(origResultNo, origType);
1096  newOperands.push_back(returnOp.getOperand(origResultNo));
1097  continue;
1098  }
1099  VectorType unrolledType =
1100  VectorType::get(*targetShape, origVecType.getElementType());
1101 
1102  // Create `vector.extract_strided_slice` ops to form legal vectors from
1103  // the original operand of illegal type.
1104  auto originalShape =
1105  llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1106  SmallVector<int64_t> strides(originalShape.size(), 1);
1107  SmallVector<int64_t> extractShape(originalShape.size(), 1);
1108  extractShape.back() = targetShape->back();
1109  SmallVector<Type> newTypes;
1110  Value returnValue = returnOp.getOperand(origResultNo);
1111  for (SmallVector<int64_t> offsets :
1112  StaticTileOffsetRange(originalShape, *targetShape)) {
1113  Value result = rewriter.create<vector::ExtractStridedSliceOp>(
1114  loc, returnValue, offsets, extractShape, strides);
1115  if (originalShape.size() > 1) {
1116  SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
1117  result =
1118  rewriter.create<vector::ExtractOp>(loc, result, extractIndices);
1119  }
1120  newOperands.push_back(result);
1121  newTypes.push_back(unrolledType);
1122  }
1123  oneToNTypeMapping.addInputs(origResultNo, newTypes);
1124  }
1125 
1126  // Change the function signature.
1127  auto newFnType =
1128  FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
1129  TypeRange(oneToNTypeMapping.getConvertedTypes()));
1130  rewriter.modifyOpInPlace(funcOp,
1131  [&] { funcOp.setFunctionType(newFnType); });
1132 
1133  // Replace the return op using the new operands. This will automatically
1134  // update the entry block as well.
1135  rewriter.replaceOp(returnOp,
1136  rewriter.create<func::ReturnOp>(loc, newOperands));
1137 
1138  return success();
1139  }
1140 };
1141 
1142 } // namespace
1143 
1144 //===----------------------------------------------------------------------===//
1145 // Public function for builtin variables
1146 //===----------------------------------------------------------------------===//
1147 
1149  spirv::BuiltIn builtin,
1150  Type integerType, OpBuilder &builder,
1151  StringRef prefix, StringRef suffix) {
1153  if (!parent) {
1154  op->emitError("expected operation to be within a module-like op");
1155  return nullptr;
1156  }
1157 
1158  spirv::GlobalVariableOp varOp =
1159  getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
1160  builtin, integerType, builder, prefix, suffix);
1161  Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
1162  return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
1163 }
1164 
1165 //===----------------------------------------------------------------------===//
1166 // Public function for pushing constant storage
1167 //===----------------------------------------------------------------------===//
1168 
1169 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
1170  unsigned offset, Type integerType,
1171  OpBuilder &builder) {
1172  Location loc = op->getLoc();
1174  if (!parent) {
1175  op->emitError("expected operation to be within a module-like op");
1176  return nullptr;
1177  }
1178 
1179  spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1180  loc, parent->getRegion(0).front(), elementCount, builder, integerType);
1181 
1182  Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
1183  Value offsetOp = builder.create<spirv::ConstantOp>(
1184  loc, integerType, builder.getI32IntegerAttr(offset));
1185  auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
1186  auto acOp = builder.create<spirv::AccessChainOp>(
1187  loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp}));
1188  return builder.create<spirv::LoadOp>(loc, acOp);
1189 }
1190 
1191 //===----------------------------------------------------------------------===//
1192 // Public functions for index calculation
1193 //===----------------------------------------------------------------------===//
1194 
1196  int64_t offset, Type integerType,
1197  Location loc, OpBuilder &builder) {
1198  assert(indices.size() == strides.size() &&
1199  "must provide indices for all dimensions");
1200 
1201  // TODO: Consider moving to use affine.apply and patterns converting
1202  // affine.apply to standard ops. This needs converting to SPIR-V passes to be
1203  // broken down into progressive small steps so we can have intermediate steps
1204  // using other dialects. At the moment SPIR-V is the final sink.
1205 
1206  Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
1207  loc, integerType, IntegerAttr::get(integerType, offset));
1208  for (const auto &index : llvm::enumerate(indices)) {
1209  Value strideVal = builder.createOrFold<spirv::ConstantOp>(
1210  loc, integerType,
1211  IntegerAttr::get(integerType, strides[index.index()]));
1212  Value update =
1213  builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1214  linearizedIndex =
1215  builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1216  }
1217  return linearizedIndex;
1218 }
1219 
1221  MemRefType baseType, Value basePtr,
1222  ValueRange indices, Location loc,
1223  OpBuilder &builder) {
1224  // Get base and offset of the MemRefType and verify they are static.
1225 
1226  int64_t offset;
1227  SmallVector<int64_t, 4> strides;
1228  if (failed(getStridesAndOffset(baseType, strides, offset)) ||
1229  llvm::is_contained(strides, ShapedType::kDynamic) ||
1230  ShapedType::isDynamic(offset)) {
1231  return nullptr;
1232  }
1233 
1234  auto indexType = typeConverter.getIndexType();
1235 
1236  SmallVector<Value, 2> linearizedIndices;
1237  auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
1238 
1239  // Add a '0' at the start to index into the struct.
1240  linearizedIndices.push_back(zero);
1241 
1242  if (baseType.getRank() == 0) {
1243  linearizedIndices.push_back(zero);
1244  } else {
1245  linearizedIndices.push_back(
1246  linearizeIndex(indices, strides, offset, indexType, loc, builder));
1247  }
1248  return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
1249 }
1250 
1252  MemRefType baseType, Value basePtr,
1253  ValueRange indices, Location loc,
1254  OpBuilder &builder) {
1255  // Get base and offset of the MemRefType and verify they are static.
1256 
1257  int64_t offset;
1258  SmallVector<int64_t, 4> strides;
1259  if (failed(getStridesAndOffset(baseType, strides, offset)) ||
1260  llvm::is_contained(strides, ShapedType::kDynamic) ||
1261  ShapedType::isDynamic(offset)) {
1262  return nullptr;
1263  }
1264 
1265  auto indexType = typeConverter.getIndexType();
1266 
1267  SmallVector<Value, 2> linearizedIndices;
1268  Value linearIndex;
1269  if (baseType.getRank() == 0) {
1270  linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
1271  } else {
1272  linearIndex =
1273  linearizeIndex(indices, strides, offset, indexType, loc, builder);
1274  }
1275  Type pointeeType =
1276  cast<spirv::PointerType>(basePtr.getType()).getPointeeType();
1277  if (isa<spirv::ArrayType>(pointeeType)) {
1278  linearizedIndices.push_back(linearIndex);
1279  return builder.create<spirv::AccessChainOp>(loc, basePtr,
1280  linearizedIndices);
1281  }
1282  return builder.create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
1283  linearizedIndices);
1284 }
1285 
1287  MemRefType baseType, Value basePtr,
1288  ValueRange indices, Location loc,
1289  OpBuilder &builder) {
1290 
1291  if (typeConverter.allows(spirv::Capability::Kernel)) {
1292  return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
1293  builder);
1294  }
1295 
1296  return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
1297  builder);
1298 }
1299 
1300 //===----------------------------------------------------------------------===//
1301 // Public functions for vector unrolling
1302 //===----------------------------------------------------------------------===//
1303 
1305  for (int i : {4, 3, 2}) {
1306  if (size % i == 0)
1307  return i;
1308  }
1309  return 1;
1310 }
1311 
1313 mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
1314  VectorType srcVectorType = op.getSourceVectorType();
1315  assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
1316  int64_t vectorSize =
1317  mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
1318  return {vectorSize};
1319 }
1320 
1322 mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) {
1323  VectorType vectorType = op.getResultVectorType();
1324  SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
1325  nativeSize.back() =
1326  mlir::spirv::getComputeVectorSize(vectorType.getShape().back());
1327  return nativeSize;
1328 }
1329 
1330 std::optional<SmallVector<int64_t>>
1332  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
1333  if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
1334  SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
1335  nativeSize.back() =
1336  mlir::spirv::getComputeVectorSize(vecType.getShape().back());
1337  return nativeSize;
1338  }
1339  }
1340 
1342  .Case<vector::ReductionOp, vector::TransposeOp>(
1343  [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
1344  .Default([](Operation *) { return std::nullopt; });
1345 }
1346 
1348  MLIRContext *context = op->getContext();
1349  RewritePatternSet patterns(context);
1352  // We only want to apply signature conversion once to the existing func ops.
1353  // Without specifying strictMode, the greedy pattern rewriter will keep
1354  // looking for newly created func ops.
1357  return applyPatternsGreedily(op, std::move(patterns), config);
1358 }
1359 
1361  MLIRContext *context = op->getContext();
1362 
1363  // Unroll vectors in function bodies to native vector size.
1364  {
1365  RewritePatternSet patterns(context);
1367  [](auto op) { return mlir::spirv::getNativeVectorShape(op); });
1369  if (failed(applyPatternsGreedily(op, std::move(patterns))))
1370  return failure();
1371  }
1372 
1373  // Convert transpose ops into extract and insert pairs, in preparation of
1374  // further transformations to canonicalize/cancel.
1375  {
1376  RewritePatternSet patterns(context);
1378  vector::VectorTransposeLowering::EltWise);
1381  if (failed(applyPatternsGreedily(op, std::move(patterns))))
1382  return failure();
1383  }
1384 
1385  // Run canonicalization to cast away leading size-1 dimensions.
1386  {
1387  RewritePatternSet patterns(context);
1388 
1389  // We need to pull in casting way leading one dims.
1391  vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
1392  vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
1393 
1394  // Decompose different rank insert_strided_slice and n-D
1395  // extract_slided_slice.
1397  patterns);
1398  vector::InsertOp::getCanonicalizationPatterns(patterns, context);
1399  vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
1400 
1401  // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
1402  // them up.
1403  vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
1404  vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
1405 
1406  if (failed(applyPatternsGreedily(op, std::move(patterns))))
1407  return failure();
1408  }
1409  return success();
1410 }
1411 
1412 //===----------------------------------------------------------------------===//
1413 // SPIR-V TypeConverter
1414 //===----------------------------------------------------------------------===//
1415 
1418  : targetEnv(targetAttr), options(options) {
1419  // Add conversions. The order matters here: later ones will be tried earlier.
1420 
1421  // Allow all SPIR-V dialect specific types. This assumes all builtin types
1422  // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
1423  // were tried before.
1424  //
1425  // TODO: This assumes that the SPIR-V types are valid to use in the given
1426  // target environment, which should be the case if the whole pipeline is
1427  // driven by the same target environment. Still, we probably still want to
1428  // validate and convert to be safe.
1429  addConversion([](spirv::SPIRVType type) { return type; });
1430 
1431  addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
1432 
1433  addConversion([this](IntegerType intType) -> std::optional<Type> {
1434  if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
1435  return convertScalarType(this->targetEnv, this->options, scalarType);
1436  if (intType.getWidth() < 8)
1437  return convertSubByteIntegerType(this->options, intType);
1438  return Type();
1439  });
1440 
1441  addConversion([this](FloatType floatType) -> std::optional<Type> {
1442  if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
1443  return convertScalarType(this->targetEnv, this->options, scalarType);
1444  return Type();
1445  });
1446 
1447  addConversion([this](ComplexType complexType) {
1448  return convertComplexType(this->targetEnv, this->options, complexType);
1449  });
1450 
1451  addConversion([this](VectorType vectorType) {
1452  return convertVectorType(this->targetEnv, this->options, vectorType);
1453  });
1454 
1455  addConversion([this](TensorType tensorType) {
1456  return convertTensorType(this->targetEnv, this->options, tensorType);
1457  });
1458 
1459  addConversion([this](MemRefType memRefType) {
1460  return convertMemrefType(this->targetEnv, this->options, memRefType);
1461  });
1462 
1463  // Register some last line of defense casting logic.
1465  [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1466  return castToSourceType(this->targetEnv, builder, type, inputs, loc);
1467  });
1468  addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
1469  Location loc) {
1470  auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
1471  return cast.getResult(0);
1472  });
1473 }
1474 
1476  return ::getIndexType(getContext(), options);
1477 }
1478 
1479 MLIRContext *SPIRVTypeConverter::getContext() const {
1480  return targetEnv.getAttr().getContext();
1481 }
1482 
1483 bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
1484  return targetEnv.allows(capability);
1485 }
1486 
1487 //===----------------------------------------------------------------------===//
1488 // SPIR-V ConversionTarget
1489 //===----------------------------------------------------------------------===//
1490 
1491 std::unique_ptr<SPIRVConversionTarget>
1493  std::unique_ptr<SPIRVConversionTarget> target(
1494  // std::make_unique does not work here because the constructor is private.
1495  new SPIRVConversionTarget(targetAttr));
1496  SPIRVConversionTarget *targetPtr = target.get();
1497  target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1498  // We need to capture the raw pointer here because it is stable:
1499  // target will be destroyed once this function is returned.
1500  [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
1501  return target;
1502 }
1503 
1504 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
1505  : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
1506 
1507 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
1508  // Make sure this op is available at the given version. Ops not implementing
1509  // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
1510  // SPIR-V versions.
1511  if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1512  std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1513  if (minVersion && *minVersion > this->targetEnv.getVersion()) {
1514  LLVM_DEBUG(llvm::dbgs()
1515  << op->getName() << " illegal: requiring min version "
1516  << spirv::stringifyVersion(*minVersion) << "\n");
1517  return false;
1518  }
1519  }
1520  if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1521  std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1522  if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
1523  LLVM_DEBUG(llvm::dbgs()
1524  << op->getName() << " illegal: requiring max version "
1525  << spirv::stringifyVersion(*maxVersion) << "\n");
1526  return false;
1527  }
1528  }
1529 
1530  // Make sure this op's required extensions are allowed to use. Ops not
1531  // implementing QueryExtensionInterface do not require extensions to be
1532  // available.
1533  if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1534  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1535  extensions.getExtensions())))
1536  return false;
1537 
1538  // Make sure this op's required extensions are allowed to use. Ops not
1539  // implementing QueryCapabilityInterface do not require capabilities to be
1540  // available.
1541  if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1542  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1543  capabilities.getCapabilities())))
1544  return false;
1545 
1546  SmallVector<Type, 4> valueTypes;
1547  valueTypes.append(op->operand_type_begin(), op->operand_type_end());
1548  valueTypes.append(op->result_type_begin(), op->result_type_end());
1549 
1550  // Ensure that all types have been converted to SPIRV types.
1551  if (llvm::any_of(valueTypes,
1552  [](Type t) { return !isa<spirv::SPIRVType>(t); }))
1553  return false;
1554 
1555  // Special treatment for global variables, whose type requirements are
1556  // conveyed by type attributes.
1557  if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1558  valueTypes.push_back(globalVar.getType());
1559 
1560  // Make sure the op's operands/results use types that are allowed by the
1561  // target environment.
1562  SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
1563  SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
1564  for (Type valueType : valueTypes) {
1565  typeExtensions.clear();
1566  cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1567  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1568  typeExtensions)))
1569  return false;
1570 
1571  typeCapabilities.clear();
1572  cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1573  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1574  typeCapabilities)))
1575  return false;
1576  }
1577 
1578  return true;
1579 }
1580 
1581 //===----------------------------------------------------------------------===//
1582 // Public functions for populating patterns
1583 //===----------------------------------------------------------------------===//
1584 
1586  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1587  patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
1588 }
1589 
1591  patterns.add<FuncOpVectorUnroll>(patterns.getContext());
1592 }
1593 
1595  patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
1596 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
Block represents an ordered list of Operations.
Definition: Block.h:33
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:162
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:203
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:193
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:240
FloatType getF32Type()
Definition: Builders.cpp:87
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:120
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Stores a 1:N mapping of types and provides several useful accessors.
TypeRange getConvertedTypes(unsigned originalTypeNo) const
Returns the list of types that corresponds to the original type at the given index.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:249
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:440
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:329
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:529
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
operand_type_iterator operand_type_end()
Definition: Operation.h:396
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_iterator result_type_end()
Definition: Operation.h:427
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
result_type_iterator result_type_begin()
Definition: Operation.h:426
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
operand_type_iterator operand_type_begin()
Definition: Operation.h:395
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:644
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:636
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:102
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a replacement value back ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:87
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:52
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:102
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:406
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:463
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
Definition: SPIRVTypes.cpp:961
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
Definition: TargetAndABI.h:62
MLIRContext * getContext() const
Returns the MLIRContext.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1393
OpFoldResult linearizeIndex(ArrayRef< OpFoldResult > multiIndex, ArrayRef< OpFoldResult > basis, ImplicitLocOpBuilder &builder)
Definition: Utils.cpp:2006
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
int getComputeVectorSize(int64_t size)
LogicalResult unrollVectorsInSignatures(Operation *op)
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of leading one dimension removal patterns.
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
@ ExistingOps
Only pre-existing ops are processed.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)
Structure to control the behavior of vector transform patterns.
VectorTransformsOptions & setVectorTransposeLowering(VectorTransposeLowering opt)