MLIR  22.0.0git
SPIRVConversion.cpp
Go to the documentation of this file.
1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
25 #include "mlir/IR/BuiltinTypes.h"
26 #include "mlir/IR/Operation.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/Support/LLVM.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/StringExtras.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/MathExtras.h"
36 
37 #include <optional>
38 
39 #define DEBUG_TYPE "mlir-spirv-conversion"
40 
41 using namespace mlir;
42 
43 namespace {
44 
45 //===----------------------------------------------------------------------===//
46 // Utility functions
47 //===----------------------------------------------------------------------===//
48 
49 static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
50  LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
51  if (vecType.isScalable()) {
52  LLVM_DEBUG(llvm::dbgs()
53  << "--scalable vectors are not supported -> BAIL\n");
54  return std::nullopt;
55  }
56  SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
57  std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(
58  1, mlir::spirv::getComputeVectorSize(vecType.getShape().back()));
59  if (!targetShape) {
60  LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
61  return std::nullopt;
62  }
63  auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
64  if (!maybeShapeRatio) {
65  LLVM_DEBUG(llvm::dbgs()
66  << "--could not compute integral shape ratio -> BAIL\n");
67  return std::nullopt;
68  }
69  if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
70  LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
71  return std::nullopt;
72  }
73  LLVM_DEBUG(llvm::dbgs()
74  << "--found an integral shape ratio to unroll to -> SUCCESS\n");
75  return targetShape;
76 }
77 
78 /// Checks that `candidates` extension requirements are possible to be satisfied
79 /// with the given `targetEnv`.
80 ///
81 /// `candidates` is a vector of vector for extension requirements following
82 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
83 /// convention.
84 template <typename LabelT>
85 static LogicalResult checkExtensionRequirements(
86  LabelT label, const spirv::TargetEnv &targetEnv,
88  for (const auto &ors : candidates) {
89  if (targetEnv.allows(ors))
90  continue;
91 
92  LLVM_DEBUG({
93  SmallVector<StringRef> extStrings;
94  for (spirv::Extension ext : ors)
95  extStrings.push_back(spirv::stringifyExtension(ext));
96 
97  llvm::dbgs() << label << " illegal: requires at least one extension in ["
98  << llvm::join(extStrings, ", ")
99  << "] but none allowed in target environment\n";
100  });
101  return failure();
102  }
103  return success();
104 }
105 
106 /// Checks that `candidates`capability requirements are possible to be satisfied
107 /// with the given `isAllowedFn`.
108 ///
109 /// `candidates` is a vector of vector for capability requirements following
110 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
111 /// convention.
112 template <typename LabelT>
113 static LogicalResult checkCapabilityRequirements(
114  LabelT label, const spirv::TargetEnv &targetEnv,
115  const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
116  for (const auto &ors : candidates) {
117  if (targetEnv.allows(ors))
118  continue;
119 
120  LLVM_DEBUG({
121  SmallVector<StringRef> capStrings;
122  for (spirv::Capability cap : ors)
123  capStrings.push_back(spirv::stringifyCapability(cap));
124 
125  llvm::dbgs() << label << " illegal: requires at least one capability in ["
126  << llvm::join(capStrings, ", ")
127  << "] but none allowed in target environment\n";
128  });
129  return failure();
130  }
131  return success();
132 }
133 
134 /// Returns true if the given `storageClass` needs explicit layout when used in
135 /// Shader environments.
136 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
137  switch (storageClass) {
138  case spirv::StorageClass::PhysicalStorageBuffer:
139  case spirv::StorageClass::PushConstant:
140  case spirv::StorageClass::StorageBuffer:
141  case spirv::StorageClass::Uniform:
142  return true;
143  default:
144  return false;
145  }
146 }
147 
148 /// Wraps the given `elementType` in a struct and gets the pointer to the
149 /// struct. This is used to satisfy Vulkan interface requirements.
150 static spirv::PointerType
151 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
152  auto structType = needsExplicitLayout(storageClass)
153  ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
154  : spirv::StructType::get(elementType);
155  return spirv::PointerType::get(structType, storageClass);
156 }
157 
158 //===----------------------------------------------------------------------===//
159 // Type Conversion
160 //===----------------------------------------------------------------------===//
161 
162 static spirv::ScalarType getIndexType(MLIRContext *ctx,
164  return cast<spirv::ScalarType>(
165  IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
166 }
167 
168 // TODO: This is a utility function that should probably be exposed by the
169 // SPIR-V dialect. Keeping it local till the use case arises.
170 static std::optional<int64_t>
171 getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
172  if (isa<spirv::ScalarType>(type)) {
173  auto bitWidth = type.getIntOrFloatBitWidth();
174  // According to the SPIR-V spec:
175  // "There is no physical size or bit pattern defined for values with boolean
176  // type. If they are stored (in conjunction with OpVariable), they can only
177  // be used with logical addressing operations, not physical, and only with
178  // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
179  // Private, Function, Input, and Output."
180  if (bitWidth == 1)
181  return std::nullopt;
182  return bitWidth / 8;
183  }
184 
185  // Handle 8-bit floats.
186  if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
187  auto bitWidth = type.getIntOrFloatBitWidth();
188  if (bitWidth == 8)
189  return bitWidth / 8;
190  return std::nullopt;
191  }
192 
193  if (auto complexType = dyn_cast<ComplexType>(type)) {
194  auto elementSize = getTypeNumBytes(options, complexType.getElementType());
195  if (!elementSize)
196  return std::nullopt;
197  return 2 * *elementSize;
198  }
199 
200  if (auto vecType = dyn_cast<VectorType>(type)) {
201  auto elementSize = getTypeNumBytes(options, vecType.getElementType());
202  if (!elementSize)
203  return std::nullopt;
204  return vecType.getNumElements() * *elementSize;
205  }
206 
207  if (auto memRefType = dyn_cast<MemRefType>(type)) {
208  // TODO: Layout should also be controlled by the ABI attributes. For now
209  // using the layout from MemRef.
210  int64_t offset;
211  SmallVector<int64_t, 4> strides;
212  if (!memRefType.hasStaticShape() ||
213  failed(memRefType.getStridesAndOffset(strides, offset)))
214  return std::nullopt;
215 
216  // To get the size of the memref object in memory, the total size is the
217  // max(stride * dimension-size) computed for all dimensions times the size
218  // of the element.
219  auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
220  if (!elementSize)
221  return std::nullopt;
222 
223  if (memRefType.getRank() == 0)
224  return elementSize;
225 
226  auto dims = memRefType.getShape();
227  if (llvm::is_contained(dims, ShapedType::kDynamic) ||
228  ShapedType::isDynamic(offset) ||
229  llvm::is_contained(strides, ShapedType::kDynamic))
230  return std::nullopt;
231 
232  int64_t memrefSize = -1;
233  for (const auto &shape : enumerate(dims))
234  memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
235 
236  return (offset + memrefSize) * *elementSize;
237  }
238 
239  if (auto tensorType = dyn_cast<TensorType>(type)) {
240  if (!tensorType.hasStaticShape())
241  return std::nullopt;
242 
243  auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
244  if (!elementSize)
245  return std::nullopt;
246 
247  int64_t size = *elementSize;
248  for (auto shape : tensorType.getShape())
249  size *= shape;
250 
251  return size;
252  }
253 
254  // TODO: Add size computation for other types.
255  return std::nullopt;
256 }
257 
258 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
259 static Type
260 convertScalarType(const spirv::TargetEnv &targetEnv,
262  std::optional<spirv::StorageClass> storageClass = {}) {
263  // Get extension and capability requirements for the given type.
266  type.getExtensions(extensions, storageClass);
267  type.getCapabilities(capabilities, storageClass);
268 
269  // If all requirements are met, then we can accept this type as-is.
270  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
271  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
272  return type;
273 
274  // Otherwise we need to adjust the type, which really means adjusting the
275  // bitwidth given this is a scalar type.
276  if (!options.emulateLT32BitScalarTypes)
277  return nullptr;
278 
279  // We only emulate narrower scalar types here and do not truncate results.
280  if (type.getIntOrFloatBitWidth() > 32) {
281  LLVM_DEBUG(llvm::dbgs()
282  << type
283  << " not converted to 32-bit for SPIR-V to avoid truncation\n");
284  return nullptr;
285  }
286 
287  if (auto floatType = dyn_cast<FloatType>(type)) {
288  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
289  return Builder(targetEnv.getContext()).getF32Type();
290  }
291 
292  auto intType = cast<IntegerType>(type);
293  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
294  return IntegerType::get(targetEnv.getContext(), /*width=*/32,
295  intType.getSignedness());
296 }
297 
298 /// Converts a sub-byte integer `type` to i32 regardless of target environment.
299 /// Returns a nullptr for unsupported integer types, including non sub-byte
300 /// types.
301 ///
302 /// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
303 /// the above given that these sub-byte types are not supported at all in
304 /// SPIR-V; there are no compute/storage capability for them like other
305 /// supported integer types.
306 static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
307  IntegerType type) {
308  if (type.getWidth() > 8) {
309  LLVM_DEBUG(llvm::dbgs() << "not a subbyte type\n");
310  return nullptr;
311  }
312  if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
313  LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
314  return nullptr;
315  }
316 
317  if (!llvm::isPowerOf2_32(type.getWidth())) {
318  LLVM_DEBUG(llvm::dbgs()
319  << "unsupported non-power-of-two bitwidth in sub-byte" << type
320  << "\n");
321  return nullptr;
322  }
323 
324  LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
325  return IntegerType::get(type.getContext(), /*width=*/32,
326  type.getSignedness());
327 }
328 
329 /// Converts 8-bit float types to integer types with the same bit width.
330 /// Returns a nullptr for unsupported 8-bit float types.
331 static Type convert8BitFloatType(const SPIRVConversionOptions &options,
332  FloatType type) {
333  if (!options.emulateUnsupportedFloatTypes)
334  return nullptr;
335  // F8 types are converted to integer types with the same bit width.
336  if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
337  Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
338  Float8E8M0FNUType>(type))
339  return IntegerType::get(type.getContext(), type.getWidth());
340  LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n");
341  return nullptr;
342 }
343 
344 /// Returns a type with the same shape but with any 8-bit float element type
345 /// converted to the same bit width integer type. This is a noop when the
346 /// element type is not the 8-bit float type or emulation flag is set to false.
347 static ShapedType
348 convertShaped8BitFloatType(ShapedType type,
350  if (!options.emulateUnsupportedFloatTypes)
351  return type;
352  Type srcElementType = type.getElementType();
353  Type convertedElementType = nullptr;
354  // F8 types are converted to integer types with the same bit width.
355  if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
356  Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
357  Float8E8M0FNUType>(srcElementType))
358  convertedElementType = IntegerType::get(
359  type.getContext(), srcElementType.getIntOrFloatBitWidth());
360 
361  if (!convertedElementType)
362  return type;
363 
364  return type.clone(convertedElementType);
365 }
366 
367 /// Returns a type with the same shape but with any index element type converted
368 /// to the matching integer type. This is a noop when the element type is not
369 /// the index type.
370 static ShapedType
371 convertIndexElementType(ShapedType type,
373  Type indexType = dyn_cast<IndexType>(type.getElementType());
374  if (!indexType)
375  return type;
376 
377  return type.clone(getIndexType(type.getContext(), options));
378 }
379 
380 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
381 static Type
382 convertVectorType(const spirv::TargetEnv &targetEnv,
383  const SPIRVConversionOptions &options, VectorType type,
384  std::optional<spirv::StorageClass> storageClass = {}) {
385  type = cast<VectorType>(convertIndexElementType(type, options));
386  type = cast<VectorType>(convertShaped8BitFloatType(type, options));
387  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
388  if (!scalarType) {
389  // If this is not a spec allowed scalar type, try to handle sub-byte integer
390  // types.
391  auto intType = dyn_cast<IntegerType>(type.getElementType());
392  if (!intType) {
393  LLVM_DEBUG(llvm::dbgs()
394  << type
395  << " illegal: cannot convert non-scalar element type\n");
396  return nullptr;
397  }
398 
399  Type elementType = convertSubByteIntegerType(options, intType);
400  if (!elementType)
401  return nullptr;
402 
403  if (type.getRank() <= 1 && type.getNumElements() == 1)
404  return elementType;
405 
406  if (type.getNumElements() > 4) {
407  LLVM_DEBUG(llvm::dbgs()
408  << type << " illegal: > 4-element unimplemented\n");
409  return nullptr;
410  }
411 
412  return VectorType::get(type.getShape(), elementType);
413  }
414 
415  if (type.getRank() <= 1 && type.getNumElements() == 1)
416  return convertScalarType(targetEnv, options, scalarType, storageClass);
417 
418  if (!spirv::CompositeType::isValid(type)) {
419  LLVM_DEBUG(llvm::dbgs()
420  << type << " illegal: not a valid composite type\n");
421  return nullptr;
422  }
423 
424  // Get extension and capability requirements for the given type.
427  cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
428  cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
429 
430  // If all requirements are met, then we can accept this type as-is.
431  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
432  succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
433  return type;
434 
435  auto elementType =
436  convertScalarType(targetEnv, options, scalarType, storageClass);
437  if (elementType)
438  return VectorType::get(type.getShape(), elementType);
439  return nullptr;
440 }
441 
442 static Type
443 convertComplexType(const spirv::TargetEnv &targetEnv,
444  const SPIRVConversionOptions &options, ComplexType type,
445  std::optional<spirv::StorageClass> storageClass = {}) {
446  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
447  if (!scalarType) {
448  LLVM_DEBUG(llvm::dbgs()
449  << type << " illegal: cannot convert non-scalar element type\n");
450  return nullptr;
451  }
452 
453  auto elementType =
454  convertScalarType(targetEnv, options, scalarType, storageClass);
455  if (!elementType)
456  return nullptr;
457  if (elementType != type.getElementType()) {
458  LLVM_DEBUG(llvm::dbgs()
459  << type << " illegal: complex type emulation unsupported\n");
460  return nullptr;
461  }
462 
463  return VectorType::get(2, elementType);
464 }
465 
466 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
467 ///
468 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
469 /// create composite constants with OpConstantComposite to embed relative large
470 /// constant values and use OpCompositeExtract and OpCompositeInsert to
471 /// manipulate, like what we do for vectors.
472 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
474  TensorType type) {
475  // TODO: Handle dynamic shapes.
476  if (!type.hasStaticShape()) {
477  LLVM_DEBUG(llvm::dbgs()
478  << type << " illegal: dynamic shape unimplemented\n");
479  return nullptr;
480  }
481 
482  type = cast<TensorType>(convertIndexElementType(type, options));
483  type = cast<TensorType>(convertShaped8BitFloatType(type, options));
484  auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
485  if (!scalarType) {
486  LLVM_DEBUG(llvm::dbgs()
487  << type << " illegal: cannot convert non-scalar element type\n");
488  return nullptr;
489  }
490 
491  std::optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
492  std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);
493  if (!scalarSize || !tensorSize) {
494  LLVM_DEBUG(llvm::dbgs()
495  << type << " illegal: cannot deduce element count\n");
496  return nullptr;
497  }
498 
499  int64_t arrayElemCount = *tensorSize / *scalarSize;
500  if (arrayElemCount == 0) {
501  LLVM_DEBUG(llvm::dbgs()
502  << type << " illegal: cannot handle zero-element tensors\n");
503  return nullptr;
504  }
505 
506  Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
507  if (!arrayElemType)
508  return nullptr;
509  std::optional<int64_t> arrayElemSize =
510  getTypeNumBytes(options, arrayElemType);
511  if (!arrayElemSize) {
512  LLVM_DEBUG(llvm::dbgs()
513  << type << " illegal: cannot deduce converted element size\n");
514  return nullptr;
515  }
516 
517  return spirv::ArrayType::get(arrayElemType, arrayElemCount);
518 }
519 
520 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
522  MemRefType type,
523  spirv::StorageClass storageClass) {
524  unsigned numBoolBits = options.boolNumBits;
525  if (numBoolBits != 8) {
526  LLVM_DEBUG(llvm::dbgs()
527  << "using non-8-bit storage for bool types unimplemented");
528  return nullptr;
529  }
530  auto elementType = dyn_cast<spirv::ScalarType>(
531  IntegerType::get(type.getContext(), numBoolBits));
532  if (!elementType)
533  return nullptr;
534  Type arrayElemType =
535  convertScalarType(targetEnv, options, elementType, storageClass);
536  if (!arrayElemType)
537  return nullptr;
538  std::optional<int64_t> arrayElemSize =
539  getTypeNumBytes(options, arrayElemType);
540  if (!arrayElemSize) {
541  LLVM_DEBUG(llvm::dbgs()
542  << type << " illegal: cannot deduce converted element size\n");
543  return nullptr;
544  }
545 
546  if (!type.hasStaticShape()) {
547  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
548  // to the element.
549  if (targetEnv.allows(spirv::Capability::Kernel))
550  return spirv::PointerType::get(arrayElemType, storageClass);
551  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
552  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
553  // For Vulkan we need extra wrapping struct and array to satisfy interface
554  // needs.
555  return wrapInStructAndGetPointer(arrayType, storageClass);
556  }
557 
558  if (type.getNumElements() == 0) {
559  LLVM_DEBUG(llvm::dbgs()
560  << type << " illegal: zero-element memrefs are not supported\n");
561  return nullptr;
562  }
563 
564  int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
565  int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
566  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
567  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
568  if (targetEnv.allows(spirv::Capability::Kernel))
569  return spirv::PointerType::get(arrayType, storageClass);
570  return wrapInStructAndGetPointer(arrayType, storageClass);
571 }
572 
573 static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
575  MemRefType type,
576  spirv::StorageClass storageClass) {
577  IntegerType elementType = cast<IntegerType>(type.getElementType());
578  Type arrayElemType = convertSubByteIntegerType(options, elementType);
579  if (!arrayElemType)
580  return nullptr;
581  int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType);
582 
583  if (!type.hasStaticShape()) {
584  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
585  // to the element.
586  if (targetEnv.allows(spirv::Capability::Kernel))
587  return spirv::PointerType::get(arrayElemType, storageClass);
588  int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
589  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
590  // For Vulkan we need extra wrapping struct and array to satisfy interface
591  // needs.
592  return wrapInStructAndGetPointer(arrayType, storageClass);
593  }
594 
595  if (type.getNumElements() == 0) {
596  LLVM_DEBUG(llvm::dbgs()
597  << type << " illegal: zero-element memrefs are not supported\n");
598  return nullptr;
599  }
600 
601  int64_t memrefSize =
602  llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
603  int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
604  int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
605  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
606  if (targetEnv.allows(spirv::Capability::Kernel))
607  return spirv::PointerType::get(arrayType, storageClass);
608  return wrapInStructAndGetPointer(arrayType, storageClass);
609 }
610 
611 static spirv::Dim convertRank(int64_t rank) {
612  switch (rank) {
613  case 1:
614  return spirv::Dim::Dim1D;
615  case 2:
616  return spirv::Dim::Dim2D;
617  case 3:
618  return spirv::Dim::Dim3D;
619  default:
620  llvm_unreachable("Invalid memref rank!");
621  }
622 }
623 
624 static spirv::ImageFormat getImageFormat(Type elementType) {
626  .Case<Float16Type>([](Float16Type) { return spirv::ImageFormat::R16f; })
627  .Case<Float32Type>([](Float32Type) { return spirv::ImageFormat::R32f; })
628  .Case<IntegerType>([](IntegerType intType) {
629  auto const isSigned = intType.isSigned() || intType.isSignless();
630 #define BIT_WIDTH_CASE(BIT_WIDTH) \
631  case BIT_WIDTH: \
632  return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \
633  : spirv::ImageFormat::R##BIT_WIDTH##ui
634 
635  switch (intType.getWidth()) {
636  BIT_WIDTH_CASE(16);
637  BIT_WIDTH_CASE(32);
638  default:
639  llvm_unreachable("Unhandled integer type!");
640  }
641  })
642  .Default([](Type) {
643  llvm_unreachable("Unhandled element type!");
644  // We need to return something here to satisfy the type switch.
645  return spirv::ImageFormat::R32f;
646  });
647 #undef BIT_WIDTH_CASE
648 }
649 
650 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
652  MemRefType type) {
653  auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
654  if (!attr) {
655  LLVM_DEBUG(
656  llvm::dbgs()
657  << type
658  << " illegal: expected memory space to be a SPIR-V storage class "
659  "attribute; please use MemorySpaceToStorageClassConverter to map "
660  "numeric memory spaces beforehand\n");
661  return nullptr;
662  }
663  spirv::StorageClass storageClass = attr.getValue();
664 
665  // Images are a special case since they are an opaque type from which elements
666  // may be accessed via image specific ops or directly through a texture
667  // pointer.
668  if (storageClass == spirv::StorageClass::Image) {
669  const int64_t rank = type.getRank();
670  if (rank < 1 || rank > 3) {
671  LLVM_DEBUG(llvm::dbgs()
672  << type << " illegal: cannot lower memref of rank " << rank
673  << " to a SPIR-V Image\n");
674  return nullptr;
675  }
676 
677  // Note that we currently only support lowering to single element texels
678  // e.g. R32f.
679  auto elementType = type.getElementType();
680  if (!isa<spirv::ScalarType>(elementType)) {
681  LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot lower memref of "
682  << elementType << " to a SPIR-V Image\n");
683  return nullptr;
684  }
685 
686  // Currently every memref in the image storage class is converted to a
687  // sampled image so we can hardcode the NeedSampler field. Future work
688  // will generalize this to support regular non-sampled images.
689  auto spvImageType = spirv::ImageType::get(
690  elementType, convertRank(rank), spirv::ImageDepthInfo::DepthUnknown,
691  spirv::ImageArrayedInfo::NonArrayed,
692  spirv::ImageSamplingInfo::SingleSampled,
693  spirv::ImageSamplerUseInfo::NeedSampler, getImageFormat(elementType));
694  auto spvSampledImageType = spirv::SampledImageType::get(spvImageType);
695  auto imagePtrType = spirv::PointerType::get(
696  spvSampledImageType, spirv::StorageClass::UniformConstant);
697  return imagePtrType;
698  }
699 
700  if (isa<IntegerType>(type.getElementType())) {
701  if (type.getElementTypeBitWidth() == 1)
702  return convertBoolMemrefType(targetEnv, options, type, storageClass);
703  if (type.getElementTypeBitWidth() < 8)
704  return convertSubByteMemrefType(targetEnv, options, type, storageClass);
705  }
706 
707  Type arrayElemType;
708  Type elementType = type.getElementType();
709  if (auto vecType = dyn_cast<VectorType>(elementType)) {
710  arrayElemType =
711  convertVectorType(targetEnv, options, vecType, storageClass);
712  } else if (auto complexType = dyn_cast<ComplexType>(elementType)) {
713  arrayElemType =
714  convertComplexType(targetEnv, options, complexType, storageClass);
715  } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
716  arrayElemType =
717  convertScalarType(targetEnv, options, scalarType, storageClass);
718  } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
719  type = cast<MemRefType>(convertIndexElementType(type, options));
720  arrayElemType = type.getElementType();
721  } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
722  // Hnadle 8 bit float types.
723  type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
724  arrayElemType = type.getElementType();
725  } else {
726  LLVM_DEBUG(
727  llvm::dbgs()
728  << type
729  << " unhandled: can only convert scalar or vector element type\n");
730  return nullptr;
731  }
732  if (!arrayElemType)
733  return nullptr;
734 
735  std::optional<int64_t> arrayElemSize =
736  getTypeNumBytes(options, arrayElemType);
737  if (!arrayElemSize) {
738  LLVM_DEBUG(llvm::dbgs()
739  << type << " illegal: cannot deduce converted element size\n");
740  return nullptr;
741  }
742 
743  if (!type.hasStaticShape()) {
744  // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
745  // to the element.
746  if (targetEnv.allows(spirv::Capability::Kernel))
747  return spirv::PointerType::get(arrayElemType, storageClass);
748  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
749  auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
750  // For Vulkan we need extra wrapping struct and array to satisfy interface
751  // needs.
752  return wrapInStructAndGetPointer(arrayType, storageClass);
753  }
754 
755  std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);
756  if (!memrefSize) {
757  LLVM_DEBUG(llvm::dbgs()
758  << type << " illegal: cannot deduce element count\n");
759  return nullptr;
760  }
761 
762  if (*memrefSize == 0) {
763  LLVM_DEBUG(llvm::dbgs()
764  << type << " illegal: zero-element memrefs are not supported\n");
765  return nullptr;
766  }
767 
768  int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
769  int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
770  auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
771  if (targetEnv.allows(spirv::Capability::Kernel))
772  return spirv::PointerType::get(arrayType, storageClass);
773  return wrapInStructAndGetPointer(arrayType, storageClass);
774 }
775 
776 //===----------------------------------------------------------------------===//
777 // Type casting materialization
778 //===----------------------------------------------------------------------===//
779 
780 /// Converts the given `inputs` to the original source `type` considering the
781 /// `targetEnv`'s capabilities.
782 ///
783 /// This function is meant to be used for source materialization in type
784 /// converters. When the type converter needs to materialize a cast op back
785 /// to some original source type, we need to check whether the original source
786 /// type is supported in the target environment. If so, we can insert legal
787 /// SPIR-V cast ops accordingly.
788 ///
789 /// Note that in SPIR-V the capabilities for storage and compute are separate.
790 /// This function is meant to handle the **compute** side; so it does not
791 /// involve storage classes in its logic. The storage side is expected to be
792 /// handled by MemRef conversion logic.
793 static Value castToSourceType(const spirv::TargetEnv &targetEnv,
794  OpBuilder &builder, Type type, ValueRange inputs,
795  Location loc) {
796  // We can only cast one value in SPIR-V.
797  if (inputs.size() != 1) {
798  auto castOp =
799  UnrealizedConversionCastOp::create(builder, loc, type, inputs);
800  return castOp.getResult(0);
801  }
802  Value input = inputs.front();
803 
804  // Only support integer types for now. Floating point types to be implemented.
805  if (!isa<IntegerType>(type)) {
806  auto castOp =
807  UnrealizedConversionCastOp::create(builder, loc, type, inputs);
808  return castOp.getResult(0);
809  }
810  auto inputType = cast<IntegerType>(input.getType());
811 
812  auto scalarType = dyn_cast<spirv::ScalarType>(type);
813  if (!scalarType) {
814  auto castOp =
815  UnrealizedConversionCastOp::create(builder, loc, type, inputs);
816  return castOp.getResult(0);
817  }
818 
819  // Only support source type with a smaller bitwidth. This would mean we are
820  // truncating to go back so we don't need to worry about the signedness.
821  // For extension, we cannot have enough signal here to decide which op to use.
822  if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
823  auto castOp =
824  UnrealizedConversionCastOp::create(builder, loc, type, inputs);
825  return castOp.getResult(0);
826  }
827 
828  // Boolean values would need to use different ops than normal integer values.
829  if (type.isInteger(1)) {
830  Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
831  return spirv::IEqualOp::create(builder, loc, input, one);
832  }
833 
834  // Check that the source integer type is supported by the environment.
837  scalarType.getExtensions(exts);
838  scalarType.getCapabilities(caps);
839  if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
840  failed(checkExtensionRequirements(type, targetEnv, exts))) {
841  auto castOp =
842  UnrealizedConversionCastOp::create(builder, loc, type, inputs);
843  return castOp.getResult(0);
844  }
845 
846  // We've already made sure this is truncating previously, so we don't need to
847  // care about signedness here. Still try to use a corresponding op for better
848  // consistency though.
849  if (type.isSignedInteger()) {
850  return spirv::SConvertOp::create(builder, loc, type, input);
851  }
852  return spirv::UConvertOp::create(builder, loc, type, input);
853 }
854 
855 //===----------------------------------------------------------------------===//
856 // Builtin Variables
857 //===----------------------------------------------------------------------===//
858 
859 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
860  spirv::BuiltIn builtin) {
861  // Look through all global variables in the given `body` block and check if
862  // there is a spirv.GlobalVariable that has the same `builtin` attribute.
863  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
864  if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
865  spirv::SPIRVDialect::getAttributeName(
866  spirv::Decoration::BuiltIn))) {
867  auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
868  if (varBuiltIn == builtin) {
869  return varOp;
870  }
871  }
872  }
873  return nullptr;
874 }
875 
876 /// Gets name of global variable for a builtin.
877 std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
878  StringRef suffix) {
879  return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
880 }
881 
882 /// Gets or inserts a global variable for a builtin within `body` block.
883 static spirv::GlobalVariableOp
884 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
885  Type integerType, OpBuilder &builder,
886  StringRef prefix, StringRef suffix) {
887  if (auto varOp = getBuiltinVariable(body, builtin))
888  return varOp;
889 
890  OpBuilder::InsertionGuard guard(builder);
891  builder.setInsertionPointToStart(&body);
892 
893  spirv::GlobalVariableOp newVarOp;
894  switch (builtin) {
895  case spirv::BuiltIn::NumWorkgroups:
896  case spirv::BuiltIn::WorkgroupSize:
897  case spirv::BuiltIn::WorkgroupId:
898  case spirv::BuiltIn::LocalInvocationId:
899  case spirv::BuiltIn::GlobalInvocationId: {
900  auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
901  spirv::StorageClass::Input);
902  std::string name = getBuiltinVarName(builtin, prefix, suffix);
903  newVarOp =
904  spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
905  break;
906  }
907  case spirv::BuiltIn::SubgroupId:
908  case spirv::BuiltIn::NumSubgroups:
909  case spirv::BuiltIn::SubgroupSize:
910  case spirv::BuiltIn::SubgroupLocalInvocationId: {
911  auto ptrType =
912  spirv::PointerType::get(integerType, spirv::StorageClass::Input);
913  std::string name = getBuiltinVarName(builtin, prefix, suffix);
914  newVarOp =
915  spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
916  break;
917  }
918  default:
919  emitError(loc, "unimplemented builtin variable generation for ")
920  << stringifyBuiltIn(builtin);
921  }
922  return newVarOp;
923 }
924 
925 //===----------------------------------------------------------------------===//
926 // Push constant storage
927 //===----------------------------------------------------------------------===//
928 
929 /// Returns the pointer type for the push constant storage containing
930 /// `elementCount` 32-bit integer values.
931 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
932  Builder &builder,
933  Type indexType) {
934  auto arrayType = spirv::ArrayType::get(indexType, elementCount,
935  /*stride=*/4);
936  auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
937  return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
938 }
939 
940 /// Returns the push constant varible containing `elementCount` 32-bit integer
941 /// values in `body`. Returns null op if such an op does not exit.
942 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
943  unsigned elementCount) {
944  for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
945  auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
946  if (!ptrType)
947  continue;
948 
949  // Note that Vulkan requires "There must be no more than one push constant
950  // block statically used per shader entry point." So we should always reuse
951  // the existing one.
952  if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
953  auto numElements = cast<spirv::ArrayType>(
954  cast<spirv::StructType>(ptrType.getPointeeType())
955  .getElementType(0))
956  .getNumElements();
957  if (numElements == elementCount)
958  return varOp;
959  }
960  }
961  return nullptr;
962 }
963 
964 /// Gets or inserts a global variable for push constant storage containing
965 /// `elementCount` 32-bit integer values in `block`.
966 static spirv::GlobalVariableOp
967 getOrInsertPushConstantVariable(Location loc, Block &block,
968  unsigned elementCount, OpBuilder &b,
969  Type indexType) {
970  if (auto varOp = getPushConstantVariable(block, elementCount))
971  return varOp;
972 
973  auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
974  auto type = getPushConstantStorageType(elementCount, builder, indexType);
975  const char *name = "__push_constant_var__";
976  return spirv::GlobalVariableOp::create(builder, loc, type, name,
977  /*initializer=*/nullptr);
978 }
979 
980 //===----------------------------------------------------------------------===//
981 // func::FuncOp Conversion Patterns
982 //===----------------------------------------------------------------------===//
983 
984 /// A pattern for rewriting function signature to convert arguments of functions
985 /// to be of valid SPIR-V types.
986 struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
987  using Base::Base;
988 
989  LogicalResult
990  matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
991  ConversionPatternRewriter &rewriter) const override {
992  FunctionType fnType = funcOp.getFunctionType();
993  if (fnType.getNumResults() > 1)
994  return failure();
995 
996  TypeConverter::SignatureConversion signatureConverter(
997  fnType.getNumInputs());
998  for (const auto &argType : enumerate(fnType.getInputs())) {
999  auto convertedType = getTypeConverter()->convertType(argType.value());
1000  if (!convertedType)
1001  return failure();
1002  signatureConverter.addInputs(argType.index(), convertedType);
1003  }
1004 
1005  Type resultType;
1006  if (fnType.getNumResults() == 1) {
1007  resultType = getTypeConverter()->convertType(fnType.getResult(0));
1008  if (!resultType)
1009  return failure();
1010  }
1011 
1012  // Create the converted spirv.func op.
1013  auto newFuncOp = spirv::FuncOp::create(
1014  rewriter, funcOp.getLoc(), funcOp.getName(),
1015  rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
1016  resultType ? TypeRange(resultType)
1017  : TypeRange()));
1018 
1019  // Copy over all attributes other than the function name and type.
1020  for (const auto &namedAttr : funcOp->getAttrs()) {
1021  if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
1022  namedAttr.getName() != SymbolTable::getSymbolAttrName())
1023  newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1024  }
1025 
1026  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1027  newFuncOp.end());
1028  if (failed(rewriter.convertRegionTypes(
1029  &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
1030  return failure();
1031  rewriter.eraseOp(funcOp);
1032  return success();
1033  }
1034 };
1035 
1036 /// A pattern for rewriting function signature to convert vector arguments of
1037 /// functions to be of valid types
1038 struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
1039  using Base::Base;
1040 
1041  LogicalResult matchAndRewrite(func::FuncOp funcOp,
1042  PatternRewriter &rewriter) const override {
1043  FunctionType fnType = funcOp.getFunctionType();
1044 
1045  // TODO: Handle declarations.
1046  if (funcOp.isDeclaration()) {
1047  LLVM_DEBUG(llvm::dbgs()
1048  << fnType << " illegal: declarations are unsupported\n");
1049  return failure();
1050  }
1051 
1052  // Create a new func op with the original type and copy the function body.
1053  auto newFuncOp = func::FuncOp::create(rewriter, funcOp.getLoc(),
1054  funcOp.getName(), fnType);
1055  rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1056  newFuncOp.end());
1057 
1058  Location loc = newFuncOp.getBody().getLoc();
1059 
1060  Block &entryBlock = newFuncOp.getBlocks().front();
1061  OpBuilder::InsertionGuard guard(rewriter);
1062  rewriter.setInsertionPointToStart(&entryBlock);
1063 
1064  TypeConverter::SignatureConversion oneToNTypeMapping(
1065  fnType.getInputs().size());
1066 
1067  // For arguments that are of illegal types and require unrolling.
1068  // `unrolledInputNums` stores the indices of arguments that result from
1069  // unrolling in the new function signature. `newInputNo` is a counter.
1070  SmallVector<size_t> unrolledInputNums;
1071  size_t newInputNo = 0;
1072 
1073  // For arguments that are of legal types and do not require unrolling.
1074  // `tmpOps` stores a mapping from temporary operations that serve as
1075  // placeholders for new arguments that will be added later. These operations
1076  // will be erased once the entry block's argument list is updated.
1077  llvm::SmallDenseMap<Operation *, size_t> tmpOps;
1078 
1079  // This counts the number of new operations created.
1080  size_t newOpCount = 0;
1081 
1082  // Enumerate through the arguments.
1083  for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
1084  // Check whether the argument is of vector type.
1085  auto origVecType = dyn_cast<VectorType>(origType);
1086  if (!origVecType) {
1087  // We need a placeholder for the old argument that will be erased later.
1088  Value result = arith::ConstantOp::create(
1089  rewriter, loc, origType, rewriter.getZeroAttr(origType));
1090  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1091  tmpOps.insert({result.getDefiningOp(), newInputNo});
1092  oneToNTypeMapping.addInputs(origInputNo, origType);
1093  ++newInputNo;
1094  ++newOpCount;
1095  continue;
1096  }
1097  // Check whether the vector needs unrolling.
1098  auto targetShape = getTargetShape(origVecType);
1099  if (!targetShape) {
1100  // We need a placeholder for the old argument that will be erased later.
1101  Value result = arith::ConstantOp::create(
1102  rewriter, loc, origType, rewriter.getZeroAttr(origType));
1103  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1104  tmpOps.insert({result.getDefiningOp(), newInputNo});
1105  oneToNTypeMapping.addInputs(origInputNo, origType);
1106  ++newInputNo;
1107  ++newOpCount;
1108  continue;
1109  }
1110  VectorType unrolledType =
1111  VectorType::get(*targetShape, origVecType.getElementType());
1112  auto originalShape =
1113  llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1114 
1115  // Prepare the result vector.
1116  Value result = arith::ConstantOp::create(
1117  rewriter, loc, origVecType, rewriter.getZeroAttr(origVecType));
1118  ++newOpCount;
1119  // Prepare the placeholder for the new arguments that will be added later.
1120  Value dummy = arith::ConstantOp::create(
1121  rewriter, loc, unrolledType, rewriter.getZeroAttr(unrolledType));
1122  ++newOpCount;
1123 
1124  // Create the `vector.insert_strided_slice` ops.
1125  SmallVector<int64_t> strides(targetShape->size(), 1);
1126  SmallVector<Type> newTypes;
1127  for (SmallVector<int64_t> offsets :
1128  StaticTileOffsetRange(originalShape, *targetShape)) {
1129  result = vector::InsertStridedSliceOp::create(rewriter, loc, dummy,
1130  result, offsets, strides);
1131  newTypes.push_back(unrolledType);
1132  unrolledInputNums.push_back(newInputNo);
1133  ++newInputNo;
1134  ++newOpCount;
1135  }
1136  rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1137  oneToNTypeMapping.addInputs(origInputNo, newTypes);
1138  }
1139 
1140  // Change the function signature.
1141  auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
1142  auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1143  rewriter.modifyOpInPlace(newFuncOp,
1144  [&] { newFuncOp.setFunctionType(newFnType); });
1145 
1146  // Update the arguments in the entry block.
1147  entryBlock.eraseArguments(0, fnType.getNumInputs());
1148  SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
1149  entryBlock.addArguments(convertedTypes, locs);
1150 
1151  // Replace all uses of placeholders for initially legal arguments with their
1152  // original function arguments (that were added to `newFuncOp`).
1153  for (auto &[placeholderOp, argIdx] : tmpOps) {
1154  if (!placeholderOp)
1155  continue;
1156  Value replacement = newFuncOp.getArgument(argIdx);
1157  rewriter.replaceAllUsesWith(placeholderOp->getResult(0), replacement);
1158  }
1159 
1160  // Replace dummy operands of new `vector.insert_strided_slice` ops with
1161  // their corresponding new function arguments. The new
1162  // `vector.insert_strided_slice` ops are inserted only into the entry block,
1163  // so iterating over that block is sufficient.
1164  size_t unrolledInputIdx = 0;
1165  for (auto [count, op] : enumerate(entryBlock.getOperations())) {
1166  Operation &curOp = op;
1167  // Since all newly created operations are in the beginning, reaching the
1168  // end of them means that any later `vector.insert_strided_slice` should
1169  // not be touched.
1170  if (count >= newOpCount)
1171  continue;
1172  if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1173  size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1174  rewriter.modifyOpInPlace(&curOp, [&] {
1175  curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1176  });
1177  ++unrolledInputIdx;
1178  }
1179  }
1180 
1181  // Erase the original funcOp. The `tmpOps` do not need to be erased since
1182  // they have no uses and will be handled by dead-code elimination.
1183  rewriter.eraseOp(funcOp);
1184  return success();
1185  }
1186 };
1187 
1188 //===----------------------------------------------------------------------===//
1189 // func::ReturnOp Conversion Patterns
1190 //===----------------------------------------------------------------------===//
1191 
1192 /// A pattern for rewriting function signature and the return op to convert
1193 /// vectors to be of valid types.
1194 struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
1195  using Base::Base;
1196 
1197  LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1198  PatternRewriter &rewriter) const override {
1199  // Check whether the parent funcOp is valid.
1200  auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
1201  if (!funcOp)
1202  return failure();
1203 
1204  FunctionType fnType = funcOp.getFunctionType();
1205  TypeConverter::SignatureConversion oneToNTypeMapping(
1206  fnType.getResults().size());
1207  Location loc = returnOp.getLoc();
1208 
1209  // For the new return op.
1210  SmallVector<Value> newOperands;
1211 
1212  // Enumerate through the results.
1213  for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
1214  // Check whether the argument is of vector type.
1215  auto origVecType = dyn_cast<VectorType>(origType);
1216  if (!origVecType) {
1217  oneToNTypeMapping.addInputs(origResultNo, origType);
1218  newOperands.push_back(returnOp.getOperand(origResultNo));
1219  continue;
1220  }
1221  // Check whether the vector needs unrolling.
1222  auto targetShape = getTargetShape(origVecType);
1223  if (!targetShape) {
1224  // The original argument can be used.
1225  oneToNTypeMapping.addInputs(origResultNo, origType);
1226  newOperands.push_back(returnOp.getOperand(origResultNo));
1227  continue;
1228  }
1229  VectorType unrolledType =
1230  VectorType::get(*targetShape, origVecType.getElementType());
1231 
1232  // Create `vector.extract_strided_slice` ops to form legal vectors from
1233  // the original operand of illegal type.
1234  auto originalShape =
1235  llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1236  SmallVector<int64_t> strides(originalShape.size(), 1);
1237  SmallVector<int64_t> extractShape(originalShape.size(), 1);
1238  extractShape.back() = targetShape->back();
1239  SmallVector<Type> newTypes;
1240  Value returnValue = returnOp.getOperand(origResultNo);
1241  for (SmallVector<int64_t> offsets :
1242  StaticTileOffsetRange(originalShape, *targetShape)) {
1243  Value result = vector::ExtractStridedSliceOp::create(
1244  rewriter, loc, returnValue, offsets, extractShape, strides);
1245  if (originalShape.size() > 1) {
1246  SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
1247  result =
1248  vector::ExtractOp::create(rewriter, loc, result, extractIndices);
1249  }
1250  newOperands.push_back(result);
1251  newTypes.push_back(unrolledType);
1252  }
1253  oneToNTypeMapping.addInputs(origResultNo, newTypes);
1254  }
1255 
1256  // Change the function signature.
1257  auto newFnType =
1258  FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
1259  TypeRange(oneToNTypeMapping.getConvertedTypes()));
1260  rewriter.modifyOpInPlace(funcOp,
1261  [&] { funcOp.setFunctionType(newFnType); });
1262 
1263  // Replace the return op using the new operands. This will automatically
1264  // update the entry block as well.
1265  rewriter.replaceOp(returnOp,
1266  func::ReturnOp::create(rewriter, loc, newOperands));
1267 
1268  return success();
1269  }
1270 };
1271 
1272 } // namespace
1273 
1274 //===----------------------------------------------------------------------===//
1275 // Public function for builtin variables
1276 //===----------------------------------------------------------------------===//
1277 
1279  spirv::BuiltIn builtin,
1280  Type integerType, OpBuilder &builder,
1281  StringRef prefix, StringRef suffix) {
1283  if (!parent) {
1284  op->emitError("expected operation to be within a module-like op");
1285  return nullptr;
1286  }
1287 
1288  spirv::GlobalVariableOp varOp =
1289  getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
1290  builtin, integerType, builder, prefix, suffix);
1291  Value ptr = spirv::AddressOfOp::create(builder, op->getLoc(), varOp);
1292  return spirv::LoadOp::create(builder, op->getLoc(), ptr);
1293 }
1294 
1295 //===----------------------------------------------------------------------===//
1296 // Public function for pushing constant storage
1297 //===----------------------------------------------------------------------===//
1298 
1299 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
1300  unsigned offset, Type integerType,
1301  OpBuilder &builder) {
1302  Location loc = op->getLoc();
1304  if (!parent) {
1305  op->emitError("expected operation to be within a module-like op");
1306  return nullptr;
1307  }
1308 
1309  spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1310  loc, parent->getRegion(0).front(), elementCount, builder, integerType);
1311 
1312  Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
1313  Value offsetOp = spirv::ConstantOp::create(builder, loc, integerType,
1314  builder.getI32IntegerAttr(offset));
1315  auto addrOp = spirv::AddressOfOp::create(builder, loc, varOp);
1316  auto acOp = spirv::AccessChainOp::create(builder, loc, addrOp,
1317  llvm::ArrayRef({zeroOp, offsetOp}));
1318  return spirv::LoadOp::create(builder, loc, acOp);
1319 }
1320 
1321 //===----------------------------------------------------------------------===//
1322 // Public functions for index calculation
1323 //===----------------------------------------------------------------------===//
1324 
1326  int64_t offset, Type integerType,
1327  Location loc, OpBuilder &builder) {
1328  assert(indices.size() == strides.size() &&
1329  "must provide indices for all dimensions");
1330 
1331  // TODO: Consider moving to use affine.apply and patterns converting
1332  // affine.apply to standard ops. This needs converting to SPIR-V passes to be
1333  // broken down into progressive small steps so we can have intermediate steps
1334  // using other dialects. At the moment SPIR-V is the final sink.
1335 
1336  Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
1337  loc, integerType, IntegerAttr::get(integerType, offset));
1338  for (const auto &index : llvm::enumerate(indices)) {
1339  Value strideVal = builder.createOrFold<spirv::ConstantOp>(
1340  loc, integerType,
1341  IntegerAttr::get(integerType, strides[index.index()]));
1342  Value update =
1343  builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1344  linearizedIndex =
1345  builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1346  }
1347  return linearizedIndex;
1348 }
1349 
1351  MemRefType baseType, Value basePtr,
1352  ValueRange indices, Location loc,
1353  OpBuilder &builder) {
1354  // Get base and offset of the MemRefType and verify they are static.
1355 
1356  int64_t offset;
1357  SmallVector<int64_t, 4> strides;
1358  if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1359  llvm::is_contained(strides, ShapedType::kDynamic) ||
1360  ShapedType::isDynamic(offset)) {
1361  return nullptr;
1362  }
1363 
1364  auto indexType = typeConverter.getIndexType();
1365 
1366  SmallVector<Value, 2> linearizedIndices;
1367  auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
1368 
1369  // Add a '0' at the start to index into the struct.
1370  linearizedIndices.push_back(zero);
1371 
1372  if (baseType.getRank() == 0) {
1373  linearizedIndices.push_back(zero);
1374  } else {
1375  linearizedIndices.push_back(
1376  linearizeIndex(indices, strides, offset, indexType, loc, builder));
1377  }
1378  return spirv::AccessChainOp::create(builder, loc, basePtr, linearizedIndices);
1379 }
1380 
1382  MemRefType baseType, Value basePtr,
1383  ValueRange indices, Location loc,
1384  OpBuilder &builder) {
1385  // Get base and offset of the MemRefType and verify they are static.
1386 
1387  int64_t offset;
1388  SmallVector<int64_t, 4> strides;
1389  if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1390  llvm::is_contained(strides, ShapedType::kDynamic) ||
1391  ShapedType::isDynamic(offset)) {
1392  return nullptr;
1393  }
1394 
1395  auto indexType = typeConverter.getIndexType();
1396 
1397  SmallVector<Value, 2> linearizedIndices;
1398  Value linearIndex;
1399  if (baseType.getRank() == 0) {
1400  linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
1401  } else {
1402  linearIndex =
1403  linearizeIndex(indices, strides, offset, indexType, loc, builder);
1404  }
1405  Type pointeeType =
1406  cast<spirv::PointerType>(basePtr.getType()).getPointeeType();
1407  if (isa<spirv::ArrayType>(pointeeType)) {
1408  linearizedIndices.push_back(linearIndex);
1409  return spirv::AccessChainOp::create(builder, loc, basePtr,
1410  linearizedIndices);
1411  }
1412  return spirv::PtrAccessChainOp::create(builder, loc, basePtr, linearIndex,
1413  linearizedIndices);
1414 }
1415 
1417  MemRefType baseType, Value basePtr,
1418  ValueRange indices, Location loc,
1419  OpBuilder &builder) {
1420 
1421  if (typeConverter.allows(spirv::Capability::Kernel)) {
1422  return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
1423  builder);
1424  }
1425 
1426  return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
1427  builder);
1428 }
1429 
1430 //===----------------------------------------------------------------------===//
1431 // Public functions for vector unrolling
1432 //===----------------------------------------------------------------------===//
1433 
1435  for (int i : {4, 3, 2}) {
1436  if (size % i == 0)
1437  return i;
1438  }
1439  return 1;
1440 }
1441 
1443 mlir::spirv::getNativeVectorShapeImpl(vector::ReductionOp op) {
1444  VectorType srcVectorType = op.getSourceVectorType();
1445  assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
1446  int64_t vectorSize =
1447  mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
1448  return {vectorSize};
1449 }
1450 
1452 mlir::spirv::getNativeVectorShapeImpl(vector::TransposeOp op) {
1453  VectorType vectorType = op.getResultVectorType();
1454  SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
1455  nativeSize.back() =
1456  mlir::spirv::getComputeVectorSize(vectorType.getShape().back());
1457  return nativeSize;
1458 }
1459 
1460 std::optional<SmallVector<int64_t>>
1462  if (OpTrait::hasElementwiseMappableTraits(op) && op->getNumResults() == 1) {
1463  if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
1464  SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
1465  nativeSize.back() =
1466  mlir::spirv::getComputeVectorSize(vecType.getShape().back());
1467  return nativeSize;
1468  }
1469  }
1470 
1472  .Case<vector::ReductionOp, vector::TransposeOp>(
1473  [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
1474  .Default([](Operation *) { return std::nullopt; });
1475 }
1476 
1478  MLIRContext *context = op->getContext();
1479  RewritePatternSet patterns(context);
1482  // We only want to apply signature conversion once to the existing func ops.
1483  // Without specifying strictMode, the greedy pattern rewriter will keep
1484  // looking for newly created func ops.
1485  return applyPatternsGreedily(op, std::move(patterns),
1486  GreedyRewriteConfig().setStrictness(
1488 }
1489 
1491  MLIRContext *context = op->getContext();
1492 
1493  // Unroll vectors in function bodies to native vector size.
1494  {
1495  RewritePatternSet patterns(context);
1497  [](auto op) { return mlir::spirv::getNativeVectorShape(op); });
1498  populateVectorUnrollPatterns(patterns, options);
1499  if (failed(applyPatternsGreedily(op, std::move(patterns))))
1500  return failure();
1501  }
1502 
1503  // Convert transpose ops into extract and insert pairs, in preparation of
1504  // further transformations to canonicalize/cancel.
1505  {
1506  RewritePatternSet patterns(context);
1508  patterns, vector::VectorTransposeLowering::EltWise);
1510  if (failed(applyPatternsGreedily(op, std::move(patterns))))
1511  return failure();
1512  }
1513 
1514  // Run canonicalization to cast away leading size-1 dimensions.
1515  {
1516  RewritePatternSet patterns(context);
1517 
1518  // We need to pull in casting way leading one dims.
1519  vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
1520  vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
1521  vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
1522 
1523  // Decompose different rank insert_strided_slice and n-D
1524  // extract_slided_slice.
1525  vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
1526  patterns);
1527  vector::InsertOp::getCanonicalizationPatterns(patterns, context);
1528  vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
1529 
1530  // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
1531  // them up.
1532  vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
1533  vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
1534 
1535  if (failed(applyPatternsGreedily(op, std::move(patterns))))
1536  return failure();
1537  }
1538  return success();
1539 }
1540 
1541 //===----------------------------------------------------------------------===//
1542 // SPIR-V TypeConverter
1543 //===----------------------------------------------------------------------===//
1544 
1547  : targetEnv(targetAttr), options(options) {
1548  // Add conversions. The order matters here: later ones will be tried earlier.
1549 
1550  // Allow all SPIR-V dialect specific types. This assumes all builtin types
1551  // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
1552  // were tried before.
1553  //
1554  // TODO: This assumes that the SPIR-V types are valid to use in the given
1555  // target environment, which should be the case if the whole pipeline is
1556  // driven by the same target environment. Still, we probably still want to
1557  // validate and convert to be safe.
1558  addConversion([](spirv::SPIRVType type) { return type; });
1559 
1560  addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
1561 
1562  addConversion([this](IntegerType intType) -> std::optional<Type> {
1563  if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
1564  return convertScalarType(this->targetEnv, this->options, scalarType);
1565  if (intType.getWidth() < 8)
1566  return convertSubByteIntegerType(this->options, intType);
1567  return Type();
1568  });
1569 
1570  addConversion([this](FloatType floatType) -> std::optional<Type> {
1571  if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
1572  return convertScalarType(this->targetEnv, this->options, scalarType);
1573  if (floatType.getWidth() == 8)
1574  return convert8BitFloatType(this->options, floatType);
1575  return Type();
1576  });
1577 
1578  addConversion([this](ComplexType complexType) {
1579  return convertComplexType(this->targetEnv, this->options, complexType);
1580  });
1581 
1582  addConversion([this](VectorType vectorType) {
1583  return convertVectorType(this->targetEnv, this->options, vectorType);
1584  });
1585 
1586  addConversion([this](TensorType tensorType) {
1587  return convertTensorType(this->targetEnv, this->options, tensorType);
1588  });
1589 
1590  addConversion([this](MemRefType memRefType) {
1591  return convertMemrefType(this->targetEnv, this->options, memRefType);
1592  });
1593 
1594  // Register some last line of defense casting logic.
1596  [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1597  return castToSourceType(this->targetEnv, builder, type, inputs, loc);
1598  });
1599  addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
1600  Location loc) {
1601  auto cast = UnrealizedConversionCastOp::create(builder, loc, type, inputs);
1602  return cast.getResult(0);
1603  });
1604 }
1605 
1607  return ::getIndexType(getContext(), options);
1608 }
1609 
1610 MLIRContext *SPIRVTypeConverter::getContext() const {
1611  return targetEnv.getAttr().getContext();
1612 }
1613 
1614 bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
1615  return targetEnv.allows(capability);
1616 }
1617 
1618 //===----------------------------------------------------------------------===//
1619 // SPIR-V ConversionTarget
1620 //===----------------------------------------------------------------------===//
1621 
1622 std::unique_ptr<SPIRVConversionTarget>
1624  std::unique_ptr<SPIRVConversionTarget> target(
1625  // std::make_unique does not work here because the constructor is private.
1626  new SPIRVConversionTarget(targetAttr));
1627  SPIRVConversionTarget *targetPtr = target.get();
1628  target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1629  // We need to capture the raw pointer here because it is stable:
1630  // target will be destroyed once this function is returned.
1631  [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
1632  return target;
1633 }
1634 
1635 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
1636  : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
1637 
1638 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
1639  // Make sure this op is available at the given version. Ops not implementing
1640  // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
1641  // SPIR-V versions.
1642  if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1643  std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1644  if (minVersion && *minVersion > this->targetEnv.getVersion()) {
1645  LLVM_DEBUG(llvm::dbgs()
1646  << op->getName() << " illegal: requiring min version "
1647  << spirv::stringifyVersion(*minVersion) << "\n");
1648  return false;
1649  }
1650  }
1651  if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1652  std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1653  if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
1654  LLVM_DEBUG(llvm::dbgs()
1655  << op->getName() << " illegal: requiring max version "
1656  << spirv::stringifyVersion(*maxVersion) << "\n");
1657  return false;
1658  }
1659  }
1660 
1661  // Make sure this op's required extensions are allowed to use. Ops not
1662  // implementing QueryExtensionInterface do not require extensions to be
1663  // available.
1664  if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1665  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1666  extensions.getExtensions())))
1667  return false;
1668 
1669  // Make sure this op's required extensions are allowed to use. Ops not
1670  // implementing QueryCapabilityInterface do not require capabilities to be
1671  // available.
1672  if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1673  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1674  capabilities.getCapabilities())))
1675  return false;
1676 
1677  SmallVector<Type, 4> valueTypes;
1678  valueTypes.append(op->operand_type_begin(), op->operand_type_end());
1679  valueTypes.append(op->result_type_begin(), op->result_type_end());
1680 
1681  // Ensure that all types have been converted to SPIRV types.
1682  if (llvm::any_of(valueTypes,
1683  [](Type t) { return !isa<spirv::SPIRVType>(t); }))
1684  return false;
1685 
1686  // Special treatment for global variables, whose type requirements are
1687  // conveyed by type attributes.
1688  if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1689  valueTypes.push_back(globalVar.getType());
1690 
1691  // Make sure the op's operands/results use types that are allowed by the
1692  // target environment.
1693  SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
1694  SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
1695  for (Type valueType : valueTypes) {
1696  typeExtensions.clear();
1697  cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1698  if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1699  typeExtensions)))
1700  return false;
1701 
1702  typeCapabilities.clear();
1703  cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1704  if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1705  typeCapabilities)))
1706  return false;
1707  }
1708 
1709  return true;
1710 }
1711 
1712 //===----------------------------------------------------------------------===//
1713 // Public functions for populating patterns
1714 //===----------------------------------------------------------------------===//
1715 
1717  const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1718  patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
1719 }
1720 
1722  patterns.add<FuncOpVectorUnroll>(patterns.getContext());
1723 }
1724 
1726  patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
1727 }
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define BIT_WIDTH_CASE(BIT_WIDTH)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
Block represents an ordered list of Operations.
Definition: Block.h:33
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition: Block.cpp:160
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition: Block.cpp:201
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition: Block.h:193
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:199
FloatType getF32Type()
Definition: Builders.cpp:42
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:323
MLIRContext * getContext() const
Definition: Builders.h:56
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition: Builders.h:240
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:320
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
operand_type_iterator operand_type_end()
Definition: Operation.h:396
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_iterator result_type_end()
Definition: Operation.h:427
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
result_type_iterator result_type_begin()
Definition: Operation.h:426
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
result_type_range getResultTypes()
Definition: Operation.h:428
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
operand_type_iterator operand_type_begin()
Definition: Operation.h:395
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
iterator begin()
Definition: Region.h:55
Block & front()
Definition: Region.h:65
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:638
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:55
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:50
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
Definition: SPIRVTypes.cpp:100
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:170
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:447
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:504
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:793
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Definition: TargetAndABI.h:29
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
Definition: TargetAndABI.h:62
MLIRContext * getContext() const
Returns the MLIRContext.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Definition: Operation.cpp:1397
OpFoldResult linearizeIndex(ArrayRef< OpFoldResult > multiIndex, ArrayRef< OpFoldResult > basis, ImplicitLocOpBuilder &builder)
Definition: Utils.cpp:2027
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
int getComputeVectorSize(int64_t size)
LogicalResult unrollVectorsInSignatures(Operation *op)
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
@ ExistingOps
Only pre-existing ops are processed.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)