MLIR  22.0.0git
SPIRVConversion.cpp
Go to the documentation of this file.
1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/Support/LLVM.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/StringExtras.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/MathExtras.h"
36 
37 #include <optional>
38 
39 #define DEBUG_TYPE "mlir-spirv-conversion"
40 
41 using namespace mlir;
42 
43 namespace {
44 
45 //===----------------------------------------------------------------------===//
46 // Utility functions
47 //===----------------------------------------------------------------------===//
48 
49 static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
50  LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
51  if (vecType.isScalable()) {
52  LLVM_DEBUG(llvm::dbgs()
53  << "--scalable vectors are not supported -> BAIL\n");
54  return std::nullopt;
55  }
56  SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
57  std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(
58  1, mlir::spirv::getComputeVectorSize(vecType.getShape().back()));
59  if (!targetShape) {
60  LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
61  return std::nullopt;
62  }
63  auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
64  if (!maybeShapeRatio) {
65  LLVM_DEBUG(llvm::dbgs()
66  << "--could not compute integral shape ratio -> BAIL\n");
67  return std::nullopt;
68  }
69  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
70  LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
71  return std::nullopt;
72  }
73  LLVM_DEBUG(llvm::dbgs()
74  << "--found an integral shape ratio to unroll to -> SUCCESS\n");
75  return targetShape;
76 }
77 
78 /// Checks that `candidates` extension requirements are possible to be satisfied
79 /// with the given `targetEnv`.
80 ///
81 /// `candidates` is a vector of vector for extension requirements following
82 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
83 /// convention.
84 template <typename LabelT>
85 static LogicalResult checkExtensionRequirements(
86  LabelT label, const spirv::TargetEnv &targetEnv,
88  for (const auto &ors : candidates) {
89  if (targetEnv.allows(ors))
90  continue;
91 
92  LLVM_DEBUG({
93  SmallVector<StringRef> extStrings;
94  for (spirv::Extension ext : ors)
95  extStrings.push_back(spirv::stringifyExtension(ext));
96 
97  llvm::dbgs() << label << " illegal: requires at least one extension in ["
98  << llvm::join(extStrings, ", ")
99  << "] but none allowed in target environment\n";
100  });
101  return failure();
102  }
103  return success();
104 }
105 
106 /// Checks that `candidates`capability requirements are possible to be satisfied
107 /// with the given `isAllowedFn`.
108 ///
109 /// `candidates` is a vector of vector for capability requirements following
110 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
111 /// convention.
112 template <typename LabelT>
113 static LogicalResult checkCapabilityRequirements(
114  LabelT label, const spirv::TargetEnv &targetEnv,
115  const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
116  for (const auto &ors : candidates) {
117  if (targetEnv.allows(ors))
118  continue;
119 
120  LLVM_DEBUG({
121  SmallVector<StringRef> capStrings;
122  for (spirv::Capability cap : ors)
123  capStrings.push_back(spirv::stringifyCapability(cap));
124 
125  llvm::dbgs() << label << " illegal: requires at least one capability in ["
126  << llvm::join(capStrings, ", ")
127  << "] but none allowed in target environment\n";
128  });
129  return failure();
130  }
131  return success();
132 }
133 
134 /// Returns true if the given `storageClass` needs explicit layout when used in
135 /// Shader environments.
136 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
137  switch (storageClass) {
138  case spirv::StorageClass::PhysicalStorageBuffer:
139  case spirv::StorageClass::PushConstant:
140  case spirv::StorageClass::StorageBuffer:
141  case spirv::StorageClass::Uniform:
142  return true;
143  default:
144  return false;
145  }
146 }
147 
148 /// Wraps the given `elementType` in a struct and gets the pointer to the
149 /// struct. This is used to satisfy Vulkan interface requirements.
150 static spirv::PointerType
151 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
152  auto structType = needsExplicitLayout(storageClass)
153  ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
154  : spirv::StructType::get(elementType);
155  return spirv::PointerType::get(structType, storageClass);
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // Type Conversion
160 //===----------------------------------------------------------------------===//
161 
162 static spirv::ScalarType getIndexType(MLIRContext *ctx,
164  return cast<spirv::ScalarType>(
165  IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
166 }
167 
168 // TODO: This is a utility function that should probably be exposed by the
169 // SPIR-V dialect. Keeping it local till the use case arises.
170 static std::optional<int64_t>
171 getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
172  if (isa<spirv::ScalarType>(type)) {
173  auto bitWidth = type.getIntOrFloatBitWidth();
174  // According to the SPIR-V spec:
175  // "There is no physical size or bit pattern defined for values with boolean
176  // type. If they are stored (in conjunction with OpVariable), they can only
177  // be used with logical addressing operations, not physical, and only with
178  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
179  // Private, Function, Input, and Output."
180  if (bitWidth == 1)
181  return std::nullopt;
182  return bitWidth / 8;
183  }
184 
185  // Handle 8-bit floats.
186  if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
187  auto bitWidth = type.getIntOrFloatBitWidth();
188  if (bitWidth == 8)
189  return bitWidth / 8;
190  return std::nullopt;
191  }
192 
193  if (auto complexType = dyn_cast<ComplexType>(type)) {
194  auto elementSize = getTypeNumBytes(options, complexType.getElementType());
195  if (!elementSize)
196  return std::nullopt;
197  return 2 * *elementSize;
198  }
199 
200  if (auto vecType = dyn_cast<VectorType>(type)) {
201  auto elementSize = getTypeNumBytes(options, vecType.getElementType());
202  if (!elementSize)
203  return std::nullopt;
204  return vecType.getNumElements() * *elementSize;
205  }
206 
207  if (auto memRefType = dyn_cast<MemRefType>(type)) {
208  // TODO: Layout should also be controlled by the ABI attributes. For now
209  // using the layout from MemRef.
210  int64_t offset;
211  SmallVector<int64_t, 4> strides;
212  if (!memRefType.hasStaticShape() ||
213  failed(memRefType.getStridesAndOffset(strides, offset)))
214  return std::nullopt;
215 
216  // To get the size of the memref object in memory, the total size is the
217  // max(stride * dimension-size) computed for all dimensions times the size
218  // of the element.
219  auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
220  if (!elementSize)
221  return std::nullopt;
222 
223  if (memRefType.getRank() == 0)
224  return elementSize;
225 
226  auto dims = memRefType.getShape();
227  if (llvm::is_contained(dims, ShapedType::kDynamic) ||
228  ShapedType::isDynamic(offset) ||
229  llvm::is_contained(strides, ShapedType::kDynamic))
230  return std::nullopt;
231 
232  int64_t memrefSize = -1;
233  for (const auto &shape : enumerate(dims))
234  memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
235 
236  return (offset + memrefSize) * *elementSize;
237  }
238 
239  if (auto tensorType = dyn_cast<TensorType>(type)) {
240  if (!tensorType.hasStaticShape())
241  return std::nullopt;
242 
243  auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
244  if (!elementSize)
245  return std::nullopt;
246 
247  int64_t size = *elementSize;
248  for (auto shape : tensorType.getShape())
249  size *= shape;
250 
251  return size;
252  }
253 
254  // TODO: Add size computation for other types.
255  return std::nullopt;
256 }
257 
258 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
259 static Type
260 convertScalarType(const spirv::TargetEnv &targetEnv,
262  std::optional<spirv::StorageClass> storageClass = {}) {
263  // Get extension and capability requirements for the given type.
266  type.getExtensions(extensions, storageClass);
267  type.getCapabilities(capabilities, storageClass);
268 
269  // If all requirements are met, then we can accept this type as-is.
270  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
271  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
272  return type;
273 
274  // Otherwise we need to adjust the type, which really means adjusting the
275  // bitwidth given this is a scalar type.
276  if (!options.emulateLT32BitScalarTypes)
277  return nullptr;
278 
279  // We only emulate narrower scalar types here and do not truncate results.
280  if (type.getIntOrFloatBitWidth() > 32) {
281  LLVM_DEBUG(llvm::dbgs()
282  << type
283  << " not converted to 32-bit for SPIR-V to avoid truncation\n");
284  return nullptr;
285  }
286 
287  if (auto floatType = dyn_cast<FloatType>(type)) {
288  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
289  return Builder(targetEnv.getContext()).getF32Type();
290  }
291 
292  auto intType = cast<IntegerType>(type);
293  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
294  return IntegerType::get(targetEnv.getContext(), /*width=*/32,
295  intType.getSignedness());
296 }
297 
298 /// Converts a sub-byte integer `type` to i32 regardless of target environment.
299 /// Returns a nullptr for unsupported integer types, including non sub-byte
300 /// types.
301 ///
302 /// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
303 /// the above given that these sub-byte types are not supported at all in
304 /// SPIR-V; there are no compute/storage capability for them like other
305 /// supported integer types.
306 static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
307  IntegerType type) {
308  if (type.getWidth() > 8) {
309  LLVM_DEBUG(llvm::dbgs() << "not a subbyte type\n");
310  return nullptr;
311  }
312  if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
313  LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
314  return nullptr;
315  }
316 
317  if (!llvm::isPowerOf2_32(type.getWidth())) {
318  LLVM_DEBUG(llvm::dbgs()
319  << "unsupported non-power-of-two bitwidth in sub-byte" << type
320  << "\n");
321  return nullptr;
322  }
323 
324  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
325  return IntegerType::get(type.getContext(), /*width=*/32,
326  type.getSignedness());
327 }
328 
329 /// Converts 8-bit float types to integer types with the same bit width.
330 /// Returns a nullptr for unsupported 8-bit float types.
331 static Type convert8BitFloatType(const SPIRVConversionOptions &options,
332  FloatType type) {
333  if (!options.emulateUnsupportedFloatTypes)
334  return nullptr;
335  // F8 types are converted to integer types with the same bit width.
336  if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
337  Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
338  Float8E8M0FNUType>(type))
339  return IntegerType::get(type.getContext(), type.getWidth());
340  LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n");
341  return nullptr;
342 }
343 
344 /// Returns a type with the same shape but with any 8-bit float element type
345 /// converted to the same bit width integer type. This is a noop when the
346 /// element type is not the 8-bit float type or emulation flag is set to false.
347 static ShapedType
348 convertShaped8BitFloatType(ShapedType type,
350  if (!options.emulateUnsupportedFloatTypes)
351  return type;
352  Type srcElementType = type.getElementType();
353  Type convertedElementType = nullptr;
354  // F8 types are converted to integer types with the same bit width.
355  if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
356  Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
357  Float8E8M0FNUType>(srcElementType))
358  convertedElementType = IntegerType::get(
359  type.getContext(), srcElementType.getIntOrFloatBitWidth());
360 
361  if (!convertedElementType)
362  return type;
363 
364  return type.clone(convertedElementType);
365 }
366 
367 /// Returns a type with the same shape but with any index element type converted
368 /// to the matching integer type. This is a noop when the element type is not
369 /// the index type.
370 static ShapedType
371 convertIndexElementType(ShapedType type,
373  Type indexType = dyn_cast<IndexType>(type.getElementType());
374  if (!indexType)
375  return type;
376 
377  return type.clone(getIndexType(type.getContext(), options));
378 }
379 
380 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
381 static Type
382 convertVectorType(const spirv::TargetEnv &targetEnv,
383  const SPIRVConversionOptions &options, VectorType type,
384  std::optional<spirv::StorageClass> storageClass = {}) {
385  type = cast<VectorType>(convertIndexElementType(type, options));
386  type = cast<VectorType>(convertShaped8BitFloatType(type, options));
387  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
388  if (!scalarType) {
389  // If this is not a spec allowed scalar type, try to handle sub-byte integer
390  // types.
391  auto intType = dyn_cast<IntegerType>(type.getElementType());
392  if (!intType) {
393  LLVM_DEBUG(llvm::dbgs()
394  << type
395  << " illegal: cannot convert non-scalar element type\n");
396  return nullptr;
397  }
398 
399  Type elementType = convertSubByteIntegerType(options, intType);
400  if (!elementType)
401  return nullptr;
402 
403  if (type.getRank() <= 1 && type.getNumElements() == 1)
404  return elementType;
405 
406  if (type.getNumElements() > 4) {
407  LLVM_DEBUG(llvm::dbgs()
408  << type << " illegal: > 4-element unimplemented\n");
409  return nullptr;
410  }
411 
412  return VectorType::get(type.getShape(), elementType);
413  }
414 
415  if (type.getRank() <= 1 && type.getNumElements() == 1)
416  return convertScalarType(targetEnv, options, scalarType, storageClass);
417 
418  if (!spirv::CompositeType::isValid(type)) {
419  LLVM_DEBUG(llvm::dbgs()
420  << type << " illegal: not a valid composite type\n");
421  return nullptr;
422  }
423 
424  // Get extension and capability requirements for the given type.
427  cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
428  cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
429 
430  // If all requirements are met, then we can accept this type as-is.
431  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
432  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
433  return type;
434 
435  auto elementType =
436  convertScalarType(targetEnv, options, scalarType, storageClass);
437  if (elementType)
438  return VectorType::get(type.getShape(), elementType);
439  return nullptr;
440 }
441 
442 static Type
443 convertComplexType(const spirv::TargetEnv &targetEnv,
444  const SPIRVConversionOptions &options, ComplexType type,
445  std::optional<spirv::StorageClass> storageClass = {}) {
446  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
447  if (!scalarType) {
448  LLVM_DEBUG(llvm::dbgs()
449  << type << " illegal: cannot convert non-scalar element type\n");
450  return nullptr;
451  }
452 
453  auto elementType =
454  convertScalarType(targetEnv, options, scalarType, storageClass);
455  if (!elementType)
456  return nullptr;
457  if (elementType != type.getElementType()) {
458  LLVM_DEBUG(llvm::dbgs()
459  << type << " illegal: complex type emulation unsupported\n");
460  return nullptr;
461  }
462 
463  return VectorType::get(2, elementType);
464 }
465 
466 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
467 ///
468 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
469 /// create composite constants with OpConstantComposite to embed relative large
470 /// constant values and use OpCompositeExtract and OpCompositeInsert to
471 /// manipulate, like what we do for vectors.
472 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
474  TensorType type) {
475  // TODO: Handle dynamic shapes.
476  if (!type.hasStaticShape()) {
477  LLVM_DEBUG(llvm::dbgs()
478  << type << " illegal: dynamic shape unimplemented\n");
479  return nullptr;
480  }
481 
482  type = cast<TensorType>(convertIndexElementType(type, options));
483  type = cast<TensorType>(convertShaped8BitFloatType(type, options));
484  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
485  if (!scalarType) {
486  LLVM_DEBUG(llvm::dbgs()
487  << type << " illegal: cannot convert non-scalar element type\n");
488  return nullptr;
489  }
490 
491  std::optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
492  std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);
493  if (!scalarSize || !tensorSize) {
494  LLVM_DEBUG(llvm::dbgs()
495  << type << " illegal: cannot deduce element count\n");
496  return nullptr;
497  }
498 
499  int64_t arrayElemCount = *tensorSize / *scalarSize;
500  if (arrayElemCount == 0) {
501  LLVM_DEBUG(llvm::dbgs()
502  << type << " illegal: cannot handle zero-element tensors\n");
503  return nullptr;
504  }
505 
506  Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
507  if (!arrayElemType)
508  return nullptr;
509  std::optional<int64_t> arrayElemSize =
510  getTypeNumBytes(options, arrayElemType);
511  if (!arrayElemSize) {
512  LLVM_DEBUG(llvm::dbgs()
513  << type << " illegal: cannot deduce converted element size\n");
514  return nullptr;
515  }
516 
517  return spirv::ArrayType::get(arrayElemType, arrayElemCount);
518 }
519 
520 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
522  MemRefType type,
523  spirv::StorageClass storageClass) {
524  unsigned numBoolBits = options.boolNumBits;
525  if (numBoolBits != 8) {
526  LLVM_DEBUG(llvm::dbgs()
527  << "using non-8-bit storage for bool types unimplemented");
528  return nullptr;
529  }
530  auto elementType = dyn_cast<spirv::ScalarType>(
531  IntegerType::get(type.getContext(), numBoolBits));
532  if (!elementType)
533  return nullptr;
534  Type arrayElemType =
535  convertScalarType(targetEnv, options, elementType, storageClass);
536  if (!arrayElemType)
537  return nullptr;
538  std::optional<int64_t> arrayElemSize =
539  getTypeNumBytes(options, arrayElemType);
540  if (!arrayElemSize) {
541  LLVM_DEBUG(llvm::dbgs()
542  << type << " illegal: cannot deduce converted element size\n");
543  return nullptr;
544  }
545 
546  if (!type.hasStaticShape()) {
547  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
548  // to the element.
549  if (targetEnv.allows(spirv::Capability::Kernel))
550  return spirv::PointerType::get(arrayElemType, storageClass);
551  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
552  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
553  // For Vulkan we need extra wrapping struct and array to satisfy interface
554  // needs.
555  return wrapInStructAndGetPointer(arrayType, storageClass);
556  }
557 
558  if (type.getNumElements() == 0) {
559  LLVM_DEBUG(llvm::dbgs()
560  << type << " illegal: zero-element memrefs are not supported\n");
561  return nullptr;
562  }
563 
564  int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
565  int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
566  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
567  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
568  if (targetEnv.allows(spirv::Capability::Kernel))
569  return spirv::PointerType::get(arrayType, storageClass);
570  return wrapInStructAndGetPointer(arrayType, storageClass);
571 }
572 
573 static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
575  MemRefType type,
576  spirv::StorageClass storageClass) {
577  IntegerType elementType = cast<IntegerType>(type.getElementType());
578  Type arrayElemType = convertSubByteIntegerType(options, elementType);
579  if (!arrayElemType)
580  return nullptr;
581  int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType);
582 
583  if (!type.hasStaticShape()) {
584  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
585  // to the element.
586  if (targetEnv.allows(spirv::Capability::Kernel))
587  return spirv::PointerType::get(arrayElemType, storageClass);
588  int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
589  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
590  // For Vulkan we need extra wrapping struct and array to satisfy interface
591  // needs.
592  return wrapInStructAndGetPointer(arrayType, storageClass);
593  }
594 
595  if (type.getNumElements() == 0) {
596  LLVM_DEBUG(llvm::dbgs()
597  << type << " illegal: zero-element memrefs are not supported\n");
598  return nullptr;
599  }
600 
601  int64_t memrefSize =
602  llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
603  int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
604  int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
605  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
606  if (targetEnv.allows(spirv::Capability::Kernel))
607  return spirv::PointerType::get(arrayType, storageClass);
608  return wrapInStructAndGetPointer(arrayType, storageClass);
609 }
610 
611 static spirv::Dim convertRank(int64_t rank) {
612  switch (rank) {
613  case 1:
614  return spirv::Dim::Dim1D;
615  case 2:
616  return spirv::Dim::Dim2D;
617  case 3:
618  return spirv::Dim::Dim3D;
619  default:
620  llvm_unreachable("Invalid memref rank!");
621  }
622 }
623 
624 static spirv::ImageFormat getImageFormat(Type elementType) {
625  return TypeSwitch<Type, spirv::ImageFormat>(elementType)
626  .Case<Float16Type>([](Float16Type) { return spirv::ImageFormat::R16f; })
627  .Case<Float32Type>([](Float32Type) { return spirv::ImageFormat::R32f; })
628  .Case<IntegerType>([](IntegerType intType) {
629  auto const isSigned = intType.isSigned() || intType.isSignless();
630 #define BIT_WIDTH_CASE(BIT_WIDTH) \
631  case BIT_WIDTH: \
632  return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \
633  : spirv::ImageFormat::R##BIT_WIDTH##ui
634 
635  switch (intType.getWidth()) {
636  BIT_WIDTH_CASE(16);
637  BIT_WIDTH_CASE(32);
638  default:
639  llvm_unreachable("Unhandled integer type!");
640  }
641  })
642  .DefaultUnreachable("Unhandled element type!");
643 #undef BIT_WIDTH_CASE
644 }
645 
646 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
648  MemRefType type) {
649  auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
650  if (!attr) {
651  LLVM_DEBUG(
652  llvm::dbgs()
653  << type
654  << " illegal: expected memory space to be a SPIR-V storage class "
655  "attribute; please use MemorySpaceToStorageClassConverter to map "
656  "numeric memory spaces beforehand\n");
657  return nullptr;
658  }
659  spirv::StorageClass storageClass = attr.getValue();
660 
661  // Images are a special case since they are an opaque type from which elements
662  // may be accessed via image specific ops or directly through a texture
663  // pointer.
664  if (storageClass == spirv::StorageClass::Image) {
665  const int64_t rank = type.getRank();
666  if (rank < 1 || rank > 3) {
667  LLVM_DEBUG(llvm::dbgs()
668  << type << " illegal: cannot lower memref of rank " << rank
669  << " to a SPIR-V Image\n");
670  return nullptr;
671  }
672 
673  // Note that we currently only support lowering to single element texels
674  // e.g. R32f.
675  auto elementType = type.getElementType();
676  if (!isa<spirv::ScalarType>(elementType)) {
677  LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot lower memref of "
678  << elementType << " to a SPIR-V Image\n");
679  return nullptr;
680  }
681 
682  // Currently every memref in the image storage class is converted to a
683  // sampled image so we can hardcode the NeedSampler field. Future work
684  // will generalize this to support regular non-sampled images.
685  auto spvImageType = spirv::ImageType::get(
686  elementType, convertRank(rank), spirv::ImageDepthInfo::DepthUnknown,
687  spirv::ImageArrayedInfo::NonArrayed,
688  spirv::ImageSamplingInfo::SingleSampled,
689  spirv::ImageSamplerUseInfo::NeedSampler, getImageFormat(elementType));
690  auto spvSampledImageType = spirv::SampledImageType::get(spvImageType);
691  auto imagePtrType = spirv::PointerType::get(
692  spvSampledImageType, spirv::StorageClass::UniformConstant);
693  return imagePtrType;
694  }
695 
696  if (isa<IntegerType>(type.getElementType())) {
697  if (type.getElementTypeBitWidth() == 1)
698  return convertBoolMemrefType(targetEnv, options, type, storageClass);
699  if (type.getElementTypeBitWidth() < 8)
700  return convertSubByteMemrefType(targetEnv, options, type, storageClass);
701  }
702 
703  Type arrayElemType;
704  Type elementType = type.getElementType();
705  if (auto vecType = dyn_cast<VectorType>(elementType)) {
706  arrayElemType =
707  convertVectorType(targetEnv, options, vecType, storageClass);
708  } else if (auto complexType = dyn_cast<ComplexType>(elementType)) {
709  arrayElemType =
710  convertComplexType(targetEnv, options, complexType, storageClass);
711  } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
712  arrayElemType =
713  convertScalarType(targetEnv, options, scalarType, storageClass);
714  } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
715  type = cast<MemRefType>(convertIndexElementType(type, options));
716  arrayElemType = type.getElementType();
717  } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
718  // Hnadle 8 bit float types.
719  type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
720  arrayElemType = type.getElementType();
721  } else {
722  LLVM_DEBUG(
723  llvm::dbgs()
724  << type
725  << " unhandled: can only convert scalar or vector element type\n");
726  return nullptr;
727  }
728  if (!arrayElemType)
729  return nullptr;
730 
731  std::optional<int64_t> arrayElemSize =
732  getTypeNumBytes(options, arrayElemType);
733  if (!arrayElemSize) {
734  LLVM_DEBUG(llvm::dbgs()
735  << type << " illegal: cannot deduce converted element size\n");
736  return nullptr;
737  }
738 
739  if (!type.hasStaticShape()) {
740  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
741  // to the element.
742  if (targetEnv.allows(spirv::Capability::Kernel))
743  return spirv::PointerType::get(arrayElemType, storageClass);
744  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
745  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
746  // For Vulkan we need extra wrapping struct and array to satisfy interface
747  // needs.
748  return wrapInStructAndGetPointer(arrayType, storageClass);
749  }
750 
751  std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);
752  if (!memrefSize) {
753  LLVM_DEBUG(llvm::dbgs()
754  << type << " illegal: cannot deduce element count\n");
755  return nullptr;
756  }
757 
758  if (*memrefSize == 0) {
759  LLVM_DEBUG(llvm::dbgs()
760  << type << " illegal: zero-element memrefs are not supported\n");
761  return nullptr;
762  }
763 
764  int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
765  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
766  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
767  if (targetEnv.allows(spirv::Capability::Kernel))
768  return spirv::PointerType::get(arrayType, storageClass);
769  return wrapInStructAndGetPointer(arrayType, storageClass);
770 }
771 
772 //===----------------------------------------------------------------------===//
773 // Type casting materialization
774 //===----------------------------------------------------------------------===//
775 
776 /// Converts the given `inputs` to the original source `type` considering the
777 /// `targetEnv`'s capabilities.
778 ///
779 /// This function is meant to be used for source materialization in type
780 /// converters. When the type converter needs to materialize a cast op back
781 /// to some original source type, we need to check whether the original source
782 /// type is supported in the target environment. If so, we can insert legal
783 /// SPIR-V cast ops accordingly.
784 ///
785 /// Note that in SPIR-V the capabilities for storage and compute are separate.
786 /// This function is meant to handle the **compute** side; so it does not
787 /// involve storage classes in its logic. The storage side is expected to be
788 /// handled by MemRef conversion logic.
789 static Value castToSourceType(const spirv::TargetEnv &targetEnv,
790  OpBuilder &builder, Type type, ValueRange inputs,
791  Location loc) {
792  // We can only cast one value in SPIR-V.
793  if (inputs.size() != 1) {
794  auto castOp =
795  UnrealizedConversionCastOp::create(builder, loc, type, inputs);
796  return castOp.getResult(0);
797  }
798  Value input = inputs.front();
799 
800  // Only support integer types for now. Floating point types to be implemented.
801  if (!isa<IntegerType>(type)) {
802  auto castOp =
803  UnrealizedConversionCastOp::create(builder, loc, type, inputs);
804  return castOp.getResult(0);
805  }
806  auto inputType = cast<IntegerType>(input.getType());
807 
808  auto scalarType = dyn_cast<spirv::ScalarType>(type);
809  if (!scalarType) {
810  auto castOp =
811  UnrealizedConversionCastOp::create(builder, loc, type, inputs);
812  return castOp.getResult(0);
813  }
814 
815  // Only support source type with a smaller bitwidth. This would mean we are
816  // truncating to go back so we don't need to worry about the signedness.
817  // For extension, we cannot have enough signal here to decide which op to use.
818  if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
819  auto castOp =
820  UnrealizedConversionCastOp::create(builder, loc, type, inputs);
821  return castOp.getResult(0);
822  }
823 
824  // Boolean values would need to use different ops than normal integer values.
825  if (type.isInteger(1)) {
826  Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
827  return spirv::IEqualOp::create(builder, loc, input, one);
828  }
829 
830  // Check that the source integer type is supported by the environment.
833  scalarType.getExtensions(exts);
834  scalarType.getCapabilities(caps);
835  if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
836  failed(checkExtensionRequirements(type, targetEnv, exts))) {
837  auto castOp =
838  UnrealizedConversionCastOp::create(builder, loc, type, inputs);
839  return castOp.getResult(0);
840  }
841 
842  // We've already made sure this is truncating previously, so we don't need to
843  // care about signedness here. Still try to use a corresponding op for better
844  // consistency though.
845  if (type.isSignedInteger()) {
846  return spirv::SConvertOp::create(builder, loc, type, input);
847  }
848  return spirv::UConvertOp::create(builder, loc, type, input);
849 }
850 
851 //===----------------------------------------------------------------------===//
852 // Builtin Variables
853 //===----------------------------------------------------------------------===//
854 
855 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
856  spirv::BuiltIn builtin) {
857  // Look through all global variables in the given `body` block and check if
858  // there is a spirv.GlobalVariable that has the same `builtin` attribute.
859  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
860  if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
861  spirv::SPIRVDialect::getAttributeName(
862  spirv::Decoration::BuiltIn))) {
863  auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
864  if (varBuiltIn == builtin) {
865  return varOp;
866  }
867  }
868  }
869  return nullptr;
870 }
871 
872 /// Gets name of global variable for a builtin.
873 std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
874  StringRef suffix) {
875  return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
876 }
877 
878 /// Gets or inserts a global variable for a builtin within `body` block.
879 static spirv::GlobalVariableOp
880 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
881  Type integerType, OpBuilder &builder,
882  StringRef prefix, StringRef suffix) {
883  if (auto varOp = getBuiltinVariable(body, builtin))
884  return varOp;
885 
886  OpBuilder::InsertionGuard guard(builder);
887  builder.setInsertionPointToStart(&body);
888 
889  spirv::GlobalVariableOp newVarOp;
890  switch (builtin) {
891  case spirv::BuiltIn::NumWorkgroups:
892  case spirv::BuiltIn::WorkgroupSize:
893  case spirv::BuiltIn::WorkgroupId:
894  case spirv::BuiltIn::LocalInvocationId:
895  case spirv::BuiltIn::GlobalInvocationId: {
896  auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
897  spirv::StorageClass::Input);
898  std::string name = getBuiltinVarName(builtin, prefix, suffix);
899  newVarOp =
900  spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
901  break;
902  }
903  case spirv::BuiltIn::SubgroupId:
904  case spirv::BuiltIn::NumSubgroups:
905  case spirv::BuiltIn::SubgroupSize:
906  case spirv::BuiltIn::SubgroupLocalInvocationId: {
907  auto ptrType =
908  spirv::PointerType::get(integerType, spirv::StorageClass::Input);
909  std::string name = getBuiltinVarName(builtin, prefix, suffix);
910  newVarOp =
911  spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
912  break;
913  }
914  default:
915  emitError(loc, "unimplemented builtin variable generation for ")
916  << stringifyBuiltIn(builtin);
917  }
918  return newVarOp;
919 }
920 
921 //===----------------------------------------------------------------------===//
922 // Push constant storage
923 //===----------------------------------------------------------------------===//
924 
925 /// Returns the pointer type for the push constant storage containing
926 /// `elementCount` 32-bit integer values.
927 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
928  Builder &builder,
929  Type indexType) {
930  auto arrayType = spirv::ArrayType::get(indexType, elementCount,
931  /*stride=*/4);
932  auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
933  return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
934 }
935 
936 /// Returns the push constant varible containing `elementCount` 32-bit integer
937 /// values in `body`. Returns null op if such an op does not exit.
938 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
939  unsigned elementCount) {
940  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
941  auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
942  if (!ptrType)
943  continue;
944 
945  // Note that Vulkan requires "There must be no more than one push constant
946  // block statically used per shader entry point." So we should always reuse
947  // the existing one.
948  if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
949  auto numElements = cast<spirv::ArrayType>(
950  cast<spirv::StructType>(ptrType.getPointeeType())
951  .getElementType(0))
952  .getNumElements();
953  if (numElements == elementCount)
954  return varOp;
955  }
956  }
957  return nullptr;
958 }
959 
960 /// Gets or inserts a global variable for push constant storage containing
961 /// `elementCount` 32-bit integer values in `block`.
962 static spirv::GlobalVariableOp
963 getOrInsertPushConstantVariable(Location loc, Block &block,
964  unsigned elementCount, OpBuilder &b,
965  Type indexType) {
966  if (auto varOp = getPushConstantVariable(block, elementCount))
967  return varOp;
968 
969  auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
970  auto type = getPushConstantStorageType(elementCount, builder, indexType);
971  const char *name = "__push_constant_var__";
972  return spirv::GlobalVariableOp::create(builder, loc, type, name,
973  /*initializer=*/nullptr);
974 }
975 
976 //===----------------------------------------------------------------------===//
977 // func::FuncOp Conversion Patterns
978 //===----------------------------------------------------------------------===//
979 
980 /// A pattern for rewriting function signature to convert arguments of functions
981 /// to be of valid SPIR-V types.
982 struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
983  using Base::Base;
984 
985  LogicalResult
986  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
987  ConversionPatternRewriter &rewriter) const override {
988  FunctionType fnType = funcOp.getFunctionType();
989  if (fnType.getNumResults() > 1)
990  return failure();
991 
992  TypeConverter::SignatureConversion signatureConverter(
993  fnType.getNumInputs());
994  for (const auto &argType : enumerate(fnType.getInputs())) {
995  auto convertedType = getTypeConverter()->convertType(argType.value());
996  if (!convertedType)
997  return failure();
998  signatureConverter.addInputs(argType.index(), convertedType);
999  }
1000 
1001  Type resultType;
1002  if (fnType.getNumResults() == 1) {
1003  resultType = getTypeConverter()->convertType(fnType.getResult(0));
1004  if (!resultType)
1005  return failure();
1006  }
1007 
1008  // Create the converted spirv.func op.
1009  auto newFuncOp = spirv::FuncOp::create(
1010  rewriter, funcOp.getLoc(), funcOp.getName(),
1011  rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
1012  resultType ? TypeRange(resultType)
1013  : TypeRange()));
1014 
1015  // Copy over all attributes other than the function name and type.
1016  for (const auto &namedAttr : funcOp->getAttrs()) {
1017  if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
1018  namedAttr.getName() != SymbolTable::getSymbolAttrName())
1019  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1020  }
1021 
1022  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1023  newFuncOp.end());
1024  if (failed(rewriter.convertRegionTypes(
1025  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
1026  return failure();
1027  rewriter.eraseOp(funcOp);
1028  return success();
1029  }
1030 };
1031 
1032 /// A pattern for rewriting function signature to convert vector arguments of
1033 /// functions to be of valid types
1034 struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
1035  using Base::Base;
1036 
1037  LogicalResult matchAndRewrite(func::FuncOp funcOp,
1038  PatternRewriter &rewriter) const override {
1039  FunctionType fnType = funcOp.getFunctionType();
1040 
1041  // TODO: Handle declarations.
1042  if (funcOp.isDeclaration()) {
1043  LLVM_DEBUG(llvm::dbgs()
1044  << fnType << " illegal: declarations are unsupported\n");
1045  return failure();
1046  }
1047 
1048  // Create a new func op with the original type and copy the function body.
1049  auto newFuncOp = func::FuncOp::create(rewriter, funcOp.getLoc(),
1050  funcOp.getName(), fnType);
1051  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1052  newFuncOp.end());
1053 
1054  Location loc = newFuncOp.getBody().getLoc();
1055 
1056  Block &entryBlock = newFuncOp.getBlocks().front();
1057  OpBuilder::InsertionGuard guard(rewriter);
1058  rewriter.setInsertionPointToStart(&entryBlock);
1059 
1060  TypeConverter::SignatureConversion oneToNTypeMapping(
1061  fnType.getInputs().size());
1062 
1063  // For arguments that are of illegal types and require unrolling.
1064  // `unrolledInputNums` stores the indices of arguments that result from
1065  // unrolling in the new function signature. `newInputNo` is a counter.
1066  SmallVector<size_t> unrolledInputNums;
1067  size_t newInputNo = 0;
1068 
1069  // For arguments that are of legal types and do not require unrolling.
1070  // `tmpOps` stores a mapping from temporary operations that serve as
1071  // placeholders for new arguments that will be added later. These operations
1072  // will be erased once the entry block's argument list is updated.
1073  llvm::SmallDenseMap<Operation *, size_t> tmpOps;
1074 
1075  // This counts the number of new operations created.
1076  size_t newOpCount = 0;
1077 
1078  // Enumerate through the arguments.
1079  for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
1080  // Check whether the argument is of vector type.
1081  auto origVecType = dyn_cast<VectorType>(origType);
1082  if (!origVecType) {
1083  // We need a placeholder for the old argument that will be erased later.
1084  Value result = arith::ConstantOp::create(
1085  rewriter, loc, origType, rewriter.getZeroAttr(origType));
1086  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1087  tmpOps.insert({result.getDefiningOp(), newInputNo});
1088  oneToNTypeMapping.addInputs(origInputNo, origType);
1089  ++newInputNo;
1090  ++newOpCount;
1091  continue;
1092  }
1093  // Check whether the vector needs unrolling.
1094  auto targetShape = getTargetShape(origVecType);
1095  if (!targetShape) {
1096  // We need a placeholder for the old argument that will be erased later.
1097  Value result = arith::ConstantOp::create(
1098  rewriter, loc, origType, rewriter.getZeroAttr(origType));
1099  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1100  tmpOps.insert({result.getDefiningOp(), newInputNo});
1101  oneToNTypeMapping.addInputs(origInputNo, origType);
1102  ++newInputNo;
1103  ++newOpCount;
1104  continue;
1105  }
1106  VectorType unrolledType =
1107  VectorType::get(*targetShape, origVecType.getElementType());
1108  auto originalShape =
1109  llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1110 
1111  // Prepare the result vector.
1112  Value result = arith::ConstantOp::create(
1113  rewriter, loc, origVecType, rewriter.getZeroAttr(origVecType));
1114  ++newOpCount;
1115  // Prepare the placeholder for the new arguments that will be added later.
1116  Value dummy = arith::ConstantOp::create(
1117  rewriter, loc, unrolledType, rewriter.getZeroAttr(unrolledType));
1118  ++newOpCount;
1119 
1120  // Create the `vector.insert_strided_slice` ops.
1121  SmallVector<int64_t> strides(targetShape->size(), 1);
1122  SmallVector<Type> newTypes;
1123  for (SmallVector<int64_t> offsets :
1124  StaticTileOffsetRange(originalShape, *targetShape)) {
1125  result = vector::InsertStridedSliceOp::create(rewriter, loc, dummy,
1126  result, offsets, strides);
1127  newTypes.push_back(unrolledType);
1128  unrolledInputNums.push_back(newInputNo);
1129  ++newInputNo;
1130  ++newOpCount;
1131  }
1132  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1133  oneToNTypeMapping.addInputs(origInputNo, newTypes);
1134  }
1135 
1136  // Change the function signature.
1137  auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
1138  auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1139  rewriter.modifyOpInPlace(newFuncOp,
1140  [&] { newFuncOp.setFunctionType(newFnType); });
1141 
1142  // Update the arguments in the entry block.
1143  entryBlock.eraseArguments(0, fnType.getNumInputs());
1144  SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
1145  entryBlock.addArguments(convertedTypes, locs);
1146 
1147  // Replace all uses of placeholders for initially legal arguments with their
1148  // original function arguments (that were added to `newFuncOp`).
1149  for (auto &[placeholderOp, argIdx] : tmpOps) {
1150  if (!placeholderOp)
1151  continue;
1152  Value replacement = newFuncOp.getArgument(argIdx);
1153  rewriter.replaceAllUsesWith(placeholderOp->getResult(0), replacement);
1154  }
1155 
1156  // Replace dummy operands of new `vector.insert_strided_slice` ops with
1157  // their corresponding new function arguments. The new
1158  // `vector.insert_strided_slice` ops are inserted only into the entry block,
1159  // so iterating over that block is sufficient.
1160  size_t unrolledInputIdx = 0;
1161  for (auto [count, op] : enumerate(entryBlock.getOperations())) {
1162  Operation &curOp = op;
1163  // Since all newly created operations are in the beginning, reaching the
1164  // end of them means that any later `vector.insert_strided_slice` should
1165  // not be touched.
1166  if (count >= newOpCount)
1167  continue;
1168  if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1169  size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1170  rewriter.modifyOpInPlace(&curOp, [&] {
1171  curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1172  });
1173  ++unrolledInputIdx;
1174  }
1175  }
1176 
1177  // Erase the original funcOp. The `tmpOps` do not need to be erased since
1178  // they have no uses and will be handled by dead-code elimination.
1179  rewriter.eraseOp(funcOp);
1180  return success();
1181  }
1182 };
1183 
1184 //===----------------------------------------------------------------------===//
1185 // func::ReturnOp Conversion Patterns
1186 //===----------------------------------------------------------------------===//
1187 
1188 /// A pattern for rewriting function signature and the return op to convert
1189 /// vectors to be of valid types.
1190 struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
1191  using Base::Base;
1192 
1193  LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1194  PatternRewriter &rewriter) const override {
1195  // Check whether the parent funcOp is valid.
1196  auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
1197  if (!funcOp)
1198  return failure();
1199 
1200  FunctionType fnType = funcOp.getFunctionType();
1201  TypeConverter::SignatureConversion oneToNTypeMapping(
1202  fnType.getResults().size());
1203  Location loc = returnOp.getLoc();
1204 
1205  // For the new return op.
1206  SmallVector<Value> newOperands;
1207 
1208  // Enumerate through the results.
1209  for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
1210  // Check whether the argument is of vector type.
1211  auto origVecType = dyn_cast<VectorType>(origType);
1212  if (!origVecType) {
1213  oneToNTypeMapping.addInputs(origResultNo, origType);
1214  newOperands.push_back(returnOp.getOperand(origResultNo));
1215  continue;
1216  }
1217  // Check whether the vector needs unrolling.
1218  auto targetShape = getTargetShape(origVecType);
1219  if (!targetShape) {
1220  // The original argument can be used.
1221  oneToNTypeMapping.addInputs(origResultNo, origType);
1222  newOperands.push_back(returnOp.getOperand(origResultNo));
1223  continue;
1224  }
1225  VectorType unrolledType =
1226  VectorType::get(*targetShape, origVecType.getElementType());
1227 
1228  // Create `vector.extract_strided_slice` ops to form legal vectors from
1229  // the original operand of illegal type.
1230  auto originalShape =
1231  llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1232  SmallVector<int64_t> strides(originalShape.size(), 1);
1233  SmallVector<int64_t> extractShape(originalShape.size(), 1);
1234  extractShape.back() = targetShape->back();
1235  SmallVector<Type> newTypes;
1236  Value returnValue = returnOp.getOperand(origResultNo);
1237  for (SmallVector<int64_t> offsets :
1238  StaticTileOffsetRange(originalShape, *targetShape)) {
1239  Value result = vector::ExtractStridedSliceOp::create(
1240  rewriter, loc, returnValue, offsets, extractShape, strides);
1241  if (originalShape.size() > 1) {
1242  SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
1243  result =
1244  vector::ExtractOp::create(rewriter, loc, result, extractIndices);
1245  }
1246  newOperands.push_back(result);
1247  newTypes.push_back(unrolledType);
1248  }
1249  oneToNTypeMapping.addInputs(origResultNo, newTypes);
1250  }
1251 
1252  // Change the function signature.
1253  auto newFnType =
1254  FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
1255  TypeRange(oneToNTypeMapping.getConvertedTypes()));
1256  rewriter.modifyOpInPlace(funcOp,
1257  [&] { funcOp.setFunctionType(newFnType); });
1258 
1259  // Replace the return op using the new operands. This will automatically
1260  // update the entry block as well.
1261  rewriter.replaceOp(returnOp,
1262  func::ReturnOp::create(rewriter, loc, newOperands));
1263 
1264  return success();
1265  }
1266 };
1267 
1268 } // namespace
1269 
1270 //===----------------------------------------------------------------------===//
1271 // Public function for builtin variables
1272 //===----------------------------------------------------------------------===//
1273 
1275  spirv::BuiltIn builtin,
1276  Type integerType, OpBuilder &builder,
1277  StringRef prefix, StringRef suffix) {
1279  if (!parent) {
1280  op->emitError("expected operation to be within a module-like op");
1281  return nullptr;
1282  }
1283 
1284  spirv::GlobalVariableOp varOp =
1285  getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
1286  builtin, integerType, builder, prefix, suffix);
1287  Value ptr = spirv::AddressOfOp::create(builder, op->getLoc(), varOp);
1288  return spirv::LoadOp::create(builder, op->getLoc(), ptr);
1289 }
1290 
1291 //===----------------------------------------------------------------------===//
1292 // Public function for pushing constant storage
1293 //===----------------------------------------------------------------------===//
1294 
1295 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
1296  unsigned offset, Type integerType,
1297  OpBuilder &builder) {
1298  Location loc = op->getLoc();
1300  if (!parent) {
1301  op->emitError("expected operation to be within a module-like op");
1302  return nullptr;
1303  }
1304 
1305  spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1306  loc, parent->getRegion(0).front(), elementCount, builder, integerType);
1307 
1308  Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
1309  Value offsetOp = spirv::ConstantOp::create(builder, loc, integerType,
1310  builder.getI32IntegerAttr(offset));
1311  auto addrOp = spirv::AddressOfOp::create(builder, loc, varOp);
1312  auto acOp = spirv::AccessChainOp::create(builder, loc, addrOp,
1313  llvm::ArrayRef({zeroOp, offsetOp}));
1314  return spirv::LoadOp::create(builder, loc, acOp);
1315 }
1316 
1317 //===----------------------------------------------------------------------===//
1318 // Public functions for index calculation
1319 //===----------------------------------------------------------------------===//
1320 
1322  int64_t offset, Type integerType,
1323  Location loc, OpBuilder &builder) {
1324  assert(indices.size() == strides.size() &&
1325  "must provide indices for all dimensions");
1326 
1327  // TODO: Consider moving to use affine.apply and patterns converting
1328  // affine.apply to standard ops. This needs converting to SPIR-V passes to be
1329  // broken down into progressive small steps so we can have intermediate steps
1330  // using other dialects. At the moment SPIR-V is the final sink.
1331 
1332  Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
1333  loc, integerType, IntegerAttr::get(integerType, offset));
1334  for (const auto &index : llvm::enumerate(indices)) {
1335  Value strideVal = builder.createOrFold<spirv::ConstantOp>(
1336  loc, integerType,
1337  IntegerAttr::get(integerType, strides[index.index()]));
1338  Value update =
1339  builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1340  linearizedIndex =
1341  builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1342  }
1343  return linearizedIndex;
1344 }
1345 
1347  MemRefType baseType, Value basePtr,
1348  ValueRange indices, Location loc,
1349  OpBuilder &builder) {
1350  // Get base and offset of the MemRefType and verify they are static.
1351 
1352  int64_t offset;
1353  SmallVector<int64_t, 4> strides;
1354  if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1355  llvm::is_contained(strides, ShapedType::kDynamic) ||
1356  ShapedType::isDynamic(offset)) {
1357  return nullptr;
1358  }
1359 
1360  auto indexType = typeConverter.getIndexType();
1361 
1362  SmallVector<Value, 2> linearizedIndices;
1363  auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
1364 
1365  // Add a '0' at the start to index into the struct.
1366  linearizedIndices.push_back(zero);
1367 
1368  if (baseType.getRank() == 0) {
1369  linearizedIndices.push_back(zero);
1370  } else {
1371  linearizedIndices.push_back(
1372  linearizeIndex(indices, strides, offset, indexType, loc, builder));
1373  }
1374  return spirv::AccessChainOp::create(builder, loc, basePtr, linearizedIndices);
1375 }
1376 
1378  MemRefType baseType, Value basePtr,
1379  ValueRange indices, Location loc,
1380  OpBuilder &builder) {
1381  // Get base and offset of the MemRefType and verify they are static.
1382 
1383  int64_t offset;
1384  SmallVector<int64_t, 4> strides;
1385  if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1386  llvm::is_contained(strides, ShapedType::kDynamic) ||
1387  ShapedType::isDynamic(offset)) {
1388  return nullptr;
1389  }
1390 
1391  auto indexType = typeConverter.getIndexType();
1392 
1393  SmallVector<Value, 2> linearizedIndices;
1394  Value linearIndex;
1395  if (baseType.getRank() == 0) {
1396  linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
1397  } else {
1398  linearIndex =
1399  linearizeIndex(indices, strides, offset, indexType, loc, builder);
1400  }
1401  Type pointeeType =
1402  cast<spirv::PointerType>(basePtr.getType()).getPointeeType();
1403  if (isa<spirv::ArrayType>(pointeeType)) {
1404  linearizedIndices.push_back(linearIndex);
1405  return spirv::AccessChainOp::create(builder, loc, basePtr,
1406  linearizedIndices);
1407  }
1408  return spirv::PtrAccessChainOp::create(builder, loc, basePtr, linearIndex,
1409  linearizedIndices);
1410 }
1411 
1413  MemRefType baseType, Value basePtr,
1414  ValueRange indices, Location loc,
1415  OpBuilder &builder) {
1416 
1417  if (typeConverter.allows(spirv::Capability::Kernel)) {
1418  return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
1419  builder);
1420  }
1421 
1422  return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
1423  builder);
1424 }
1425 
1426 //===----------------------------------------------------------------------===//
1427 // Public functions for vector unrolling
1428 //===----------------------------------------------------------------------===//
1429 
1431  for (int i : {4, 3, 2}) {
1432  if (size % i == 0)
1433  return i;
1434  }
1435  return 1;
1436 }
1437 
1439 mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
1440  VectorType srcVectorType = op.getSourceVectorType();
1441  assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
1442  int64_t vectorSize =
1443  mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
1444  return {vectorSize};
1445 }
1446 
1448 mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) {
1449  VectorType vectorType = op.getResultVectorType();
1450  SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
1451  nativeSize.back() =
1452  mlir::spirv::getComputeVectorSize(vectorType.getShape().back());
1453  return nativeSize;
1454 }
1455 
1456 std::optional<SmallVector<int64_t>>
1458  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
1459  if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
1460  SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
1461  nativeSize.back() =
1462  mlir::spirv::getComputeVectorSize(vecType.getShape().back());
1463  return nativeSize;
1464  }
1465  }
1466 
1468  .Case<vector::ReductionOp, vector::TransposeOp>(
1469  [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
1470  .Default([](Operation *) { return std::nullopt; });
1471 }
1472 
1474  MLIRContext *context = op->getContext();
1475  RewritePatternSet patterns(context);
1478  // We only want to apply signature conversion once to the existing func ops.
1479  // Without specifying strictMode, the greedy pattern rewriter will keep
1480  // looking for newly created func ops.
1481  return applyPatternsGreedily(op, std::move(patterns),
1482  GreedyRewriteConfig().setStrictness(
1484 }
1485 
1487  MLIRContext *context = op->getContext();
1488 
1489  // Unroll vectors in function bodies to native vector size.
1490  {
1491  RewritePatternSet patterns(context);
1493  [](auto op) { return mlir::spirv::getNativeVectorShape(op); });
1494  populateVectorUnrollPatterns(patterns, options);
1495  if (failed(applyPatternsGreedily(op, std::move(patterns))))
1496  return failure();
1497  }
1498 
1499  // Convert transpose ops into extract and insert pairs, in preparation of
1500  // further transformations to canonicalize/cancel.
1501  {
1502  RewritePatternSet patterns(context);
1504  patterns, vector::VectorTransposeLowering::EltWise);
1506  if (failed(applyPatternsGreedily(op, std::move(patterns))))
1507  return failure();
1508  }
1509 
1510  // Run canonicalization to cast away leading size-1 dimensions.
1511  {
1512  RewritePatternSet patterns(context);
1513 
1514  // We need to pull in casting way leading one dims.
1515  vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
1516  vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
1517  vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
1518 
1519  // Decompose different rank insert_strided_slice and n-D
1520  // extract_slided_slice.
1521  vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
1522  patterns);
1523  vector::InsertOp::getCanonicalizationPatterns(patterns, context);
1524  vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
1525 
1526  // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
1527  // them up.
1528  vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
1529  vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
1530 
1531  if (failed(applyPatternsGreedily(op, std::move(patterns))))
1532  return failure();
1533  }
1534  return success();
1535 }
1536 
1537 //===----------------------------------------------------------------------===//
1538 // SPIR-V TypeConverter
1539 //===----------------------------------------------------------------------===//
1540 
1543  : targetEnv(targetAttr), options(options) {
1544  // Add conversions. The order matters here: later ones will be tried earlier.
1545 
1546  // Allow all SPIR-V dialect specific types. This assumes all builtin types
1547  // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
1548  // were tried before.
1549  //
1550  // TODO: This assumes that the SPIR-V types are valid to use in the given
1551  // target environment, which should be the case if the whole pipeline is
1552  // driven by the same target environment. Still, we probably still want to
1553  // validate and convert to be safe.
1554  addConversion([](spirv::SPIRVType type) { return type; });
1555 
1556  addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
1557 
1558  addConversion([this](IntegerType intType) -> std::optional<Type> {
1559  if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
1560  return convertScalarType(this->targetEnv, this->options, scalarType);
1561  if (intType.getWidth() < 8)
1562  return convertSubByteIntegerType(this->options, intType);
1563  return Type();
1564  });
1565 
1566  addConversion([this](FloatType floatType) -> std::optional<Type> {
1567  if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
1568  return convertScalarType(this->targetEnv, this->options, scalarType);
1569  if (floatType.getWidth() == 8)
1570  return convert8BitFloatType(this->options, floatType);
1571  return Type();
1572  });
1573 
1574  addConversion([this](ComplexType complexType) {
1575  return convertComplexType(this->targetEnv, this->options, complexType);
1576  });
1577 
1578  addConversion([this](VectorType vectorType) {
1579  return convertVectorType(this->targetEnv, this->options, vectorType);
1580  });
1581 
1582  addConversion([this](TensorType tensorType) {
1583  return convertTensorType(this->targetEnv, this->options, tensorType);
1584  });
1585 
1586  addConversion([this](MemRefType memRefType) {
1587  return convertMemrefType(this->targetEnv, this->options, memRefType);
1588  });
1589 
1590  // Register some last line of defense casting logic.
1592  [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1593  return castToSourceType(this->targetEnv, builder, type, inputs, loc);
1594  });
1595  addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
1596  Location loc) {
1597  auto cast = UnrealizedConversionCastOp::create(builder, loc, type, inputs);
1598  return cast.getResult(0);
1599  });
1600 }
1601 
1603  return ::getIndexType(getContext(), options);
1604 }
1605 
1606 MLIRContext *SPIRVTypeConverter::getContext() const {
1607  return targetEnv.getAttr().getContext();
1608 }
1609 
1610 bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
1611  return targetEnv.allows(capability);
1612 }
1613 
1614 //===----------------------------------------------------------------------===//
1615 // SPIR-V ConversionTarget
1616 //===----------------------------------------------------------------------===//
1617 
1618 std::unique_ptr<SPIRVConversionTarget>
1620  std::unique_ptr<SPIRVConversionTarget> target(
1621  // std::make_unique does not work here because the constructor is private.
1622  new SPIRVConversionTarget(targetAttr));
1623  SPIRVConversionTarget *targetPtr = target.get();
1624  target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1625  // We need to capture the raw pointer here because it is stable:
1626  // target will be destroyed once this function is returned.
1627  [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
1628  return target;
1629 }
1630 
1631 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
1632  : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
1633 
1634 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
1635  // Make sure this op is available at the given version. Ops not implementing
1636  // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
1637  // SPIR-V versions.
1638  if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1639  std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1640  if (minVersion && *minVersion > this->targetEnv.getVersion()) {
1641  LLVM_DEBUG(llvm::dbgs()
1642  << op->getName() << " illegal: requiring min version "
1643  << spirv::stringifyVersion(*minVersion) << "\n");
1644  return false;
1645  }
1646  }
1647  if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1648  std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1649  if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
1650  LLVM_DEBUG(llvm::dbgs()
1651  << op->getName() << " illegal: requiring max version "
1652  << spirv::stringifyVersion(*maxVersion) << "\n");
1653  return false;
1654  }
1655  }
1656 
1657  // Make sure this op's required extensions are allowed to use. Ops not
1658  // implementing QueryExtensionInterface do not require extensions to be
1659  // available.
1660  if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1661  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1662  extensions.getExtensions())))
1663  return false;
1664 
1665  // Make sure this op's required extensions are allowed to use. Ops not
1666  // implementing QueryCapabilityInterface do not require capabilities to be
1667  // available.
1668  if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1669  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1670  capabilities.getCapabilities())))
1671  return false;
1672 
1673  SmallVector<Type, 4> valueTypes;
1674  valueTypes.append(op->operand_type_begin(), op->operand_type_end());
1675  valueTypes.append(op->result_type_begin(), op->result_type_end());
1676 
1677  // Ensure that all types have been converted to SPIRV types.
1678  if (llvm::any_of(valueTypes,
1679  [](Type t) { return !isa<spirv::SPIRVType>(t); }))
1680  return false;
1681 
1682  // Special treatment for global variables, whose type requirements are
1683  // conveyed by type attributes.
1684  if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1685  valueTypes.push_back(globalVar.getType());
1686 
1687  // Make sure the op's operands/results use types that are allowed by the
1688  // target environment.
1689  SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
1690  SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
1691  for (Type valueType : valueTypes) {
1692  typeExtensions.clear();
1693  cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1694  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1695  typeExtensions)))
1696  return false;
1697 
1698  typeCapabilities.clear();
1699  cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1700  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1701  typeCapabilities)))
1702  return false;
1703  }
1704 
1705  return true;
1706 }
1707 
1708 //===----------------------------------------------------------------------===//
1709 // Public functions for populating patterns
1710 //===----------------------------------------------------------------------===//
1711 
1713  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1714  patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
1715 }
1716 
1718  patterns.add<FuncOpVectorUnroll>(patterns.getContext());
1719 }
1720 
1722  patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
1723 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define BIT_WIDTH_CASE(BIT_WIDTH)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
Block represents an ordered list of Operations.
Definition: Block.h:33
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:160
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:201
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:193
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:200
FloatType getF32Type()
Definition: Builders.cpp:43
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:76
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:324
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:240
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:320
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
operand_type_iterator operand_type_end()
Definition: Operation.h:396
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_iterator result_type_end()
Definition: Operation.h:427
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
result_type_iterator result_type_begin()
Definition: Operation.h:426
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
operand_type_iterator operand_type_begin()
Definition: Operation.h:395
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:158
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:188
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:147
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:451
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:509
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:756
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
Definition: TargetAndABI.h:62
MLIRContext * getContext() const
Returns the MLIRContext.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
OpFoldResult linearizeIndex(ArrayRef< OpFoldResult > multiIndex, ArrayRef< OpFoldResult > basis, ImplicitLocOpBuilder &builder)
Definition: Utils.cpp:2027
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
int getComputeVectorSize(int64_t size)
LogicalResult unrollVectorsInSignatures(Operation *op)
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
@ ExistingOps
Only pre-existing ops are processed.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)