MLIR 22.0.0git
SPIRVConversion.cpp
Go to the documentation of this file.
1//===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities used to lower to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
26#include "mlir/IR/Operation.h"
28#include "mlir/Support/LLVM.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/MathExtras.h"
36
37#include <optional>
38
39#define DEBUG_TYPE "mlir-spirv-conversion"
40
41using namespace mlir;
42
43namespace {
44
45//===----------------------------------------------------------------------===//
46// Utility functions
47//===----------------------------------------------------------------------===//
48
49static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
50 LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
51 if (vecType.isScalable()) {
52 LLVM_DEBUG(llvm::dbgs()
53 << "--scalable vectors are not supported -> BAIL\n");
54 return std::nullopt;
55 }
56 SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
57 std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(
58 1, mlir::spirv::getComputeVectorSize(vecType.getShape().back()));
59 if (!targetShape) {
60 LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
61 return std::nullopt;
62 }
63 auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
64 if (!maybeShapeRatio) {
65 LLVM_DEBUG(llvm::dbgs()
66 << "--could not compute integral shape ratio -> BAIL\n");
67 return std::nullopt;
68 }
69 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
70 LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
71 return std::nullopt;
72 }
73 LLVM_DEBUG(llvm::dbgs()
74 << "--found an integral shape ratio to unroll to -> SUCCESS\n");
75 return targetShape;
76}
77
78/// Checks that `candidates` extension requirements are possible to be satisfied
79/// with the given `targetEnv`.
80///
81/// `candidates` is a vector of vector for extension requirements following
82/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
83/// convention.
84template <typename LabelT>
85static LogicalResult checkExtensionRequirements(
86 LabelT label, const spirv::TargetEnv &targetEnv,
88 for (const auto &ors : candidates) {
89 if (targetEnv.allows(ors))
90 continue;
91
92 LLVM_DEBUG({
93 SmallVector<StringRef> extStrings;
94 for (spirv::Extension ext : ors)
95 extStrings.push_back(spirv::stringifyExtension(ext));
96
97 llvm::dbgs() << label << " illegal: requires at least one extension in ["
98 << llvm::join(extStrings, ", ")
99 << "] but none allowed in target environment\n";
100 });
101 return failure();
102 }
103 return success();
104}
105
106/// Checks that `candidates`capability requirements are possible to be satisfied
107/// with the given `isAllowedFn`.
108///
109/// `candidates` is a vector of vector for capability requirements following
110/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
111/// convention.
112template <typename LabelT>
113static LogicalResult checkCapabilityRequirements(
114 LabelT label, const spirv::TargetEnv &targetEnv,
116 for (const auto &ors : candidates) {
117 if (targetEnv.allows(ors))
118 continue;
119
120 LLVM_DEBUG({
121 SmallVector<StringRef> capStrings;
122 for (spirv::Capability cap : ors)
123 capStrings.push_back(spirv::stringifyCapability(cap));
124
125 llvm::dbgs() << label << " illegal: requires at least one capability in ["
126 << llvm::join(capStrings, ", ")
127 << "] but none allowed in target environment\n";
128 });
129 return failure();
130 }
131 return success();
132}
133
134/// Returns true if the given `storageClass` needs explicit layout when used in
135/// Shader environments.
136static bool needsExplicitLayout(spirv::StorageClass storageClass) {
137 switch (storageClass) {
138 case spirv::StorageClass::PhysicalStorageBuffer:
139 case spirv::StorageClass::PushConstant:
140 case spirv::StorageClass::StorageBuffer:
141 case spirv::StorageClass::Uniform:
142 return true;
143 default:
144 return false;
145 }
146}
147
148/// Wraps the given `elementType` in a struct and gets the pointer to the
149/// struct. This is used to satisfy Vulkan interface requirements.
151wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
152 auto structType = needsExplicitLayout(storageClass)
153 ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
154 : spirv::StructType::get(elementType);
155 return spirv::PointerType::get(structType, storageClass);
156}
157
158//===----------------------------------------------------------------------===//
159// Type Conversion
160//===----------------------------------------------------------------------===//
161
162static spirv::ScalarType getIndexType(MLIRContext *ctx,
164 return cast<spirv::ScalarType>(
165 IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
166}
167
168// TODO: This is a utility function that should probably be exposed by the
169// SPIR-V dialect. Keeping it local till the use case arises.
170static std::optional<int64_t>
171getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
172 if (isa<spirv::ScalarType>(type)) {
173 auto bitWidth = type.getIntOrFloatBitWidth();
174 // According to the SPIR-V spec:
175 // "There is no physical size or bit pattern defined for values with boolean
176 // type. If they are stored (in conjunction with OpVariable), they can only
177 // be used with logical addressing operations, not physical, and only with
178 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
179 // Private, Function, Input, and Output."
180 if (bitWidth == 1)
181 return std::nullopt;
182 return bitWidth / 8;
183 }
184
185 // Handle 8-bit floats.
186 if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
187 auto bitWidth = type.getIntOrFloatBitWidth();
188 if (bitWidth == 8)
189 return bitWidth / 8;
190 return std::nullopt;
191 }
192
193 if (auto complexType = dyn_cast<ComplexType>(type)) {
194 auto elementSize = getTypeNumBytes(options, complexType.getElementType());
195 if (!elementSize)
196 return std::nullopt;
197 return 2 * *elementSize;
198 }
199
200 if (auto vecType = dyn_cast<VectorType>(type)) {
201 auto elementSize = getTypeNumBytes(options, vecType.getElementType());
202 if (!elementSize)
203 return std::nullopt;
204 return vecType.getNumElements() * *elementSize;
205 }
206
207 if (auto memRefType = dyn_cast<MemRefType>(type)) {
208 // TODO: Layout should also be controlled by the ABI attributes. For now
209 // using the layout from MemRef.
210 int64_t offset;
212 if (!memRefType.hasStaticShape() ||
213 failed(memRefType.getStridesAndOffset(strides, offset)))
214 return std::nullopt;
215
216 // To get the size of the memref object in memory, the total size is the
217 // max(stride * dimension-size) computed for all dimensions times the size
218 // of the element.
219 auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
220 if (!elementSize)
221 return std::nullopt;
222
223 if (memRefType.getRank() == 0)
224 return elementSize;
225
226 auto dims = memRefType.getShape();
227 if (llvm::is_contained(dims, ShapedType::kDynamic) ||
228 ShapedType::isDynamic(offset) ||
229 llvm::is_contained(strides, ShapedType::kDynamic))
230 return std::nullopt;
231
232 int64_t memrefSize = -1;
233 for (const auto &shape : enumerate(dims))
234 memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
235
236 return (offset + memrefSize) * *elementSize;
237 }
238
239 if (auto tensorType = dyn_cast<TensorType>(type)) {
240 if (!tensorType.hasStaticShape())
241 return std::nullopt;
242
243 auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
244 if (!elementSize)
245 return std::nullopt;
246
247 int64_t size = *elementSize;
248 for (auto shape : tensorType.getShape())
249 size *= shape;
250
251 return size;
252 }
253
254 // TODO: Add size computation for other types.
255 return std::nullopt;
256}
257
258/// Converts a scalar `type` to a suitable type under the given `targetEnv`.
259static Type
260convertScalarType(const spirv::TargetEnv &targetEnv,
262 std::optional<spirv::StorageClass> storageClass = {}) {
263 // Get extension and capability requirements for the given type.
266 type.getExtensions(extensions, storageClass);
267 type.getCapabilities(capabilities, storageClass);
268
269 // If all requirements are met, then we can accept this type as-is.
270 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
271 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
272 return type;
273
274 // Otherwise we need to adjust the type, which really means adjusting the
275 // bitwidth given this is a scalar type.
276 if (!options.emulateLT32BitScalarTypes)
277 return nullptr;
278
279 // We only emulate narrower scalar types here and do not truncate results.
280 if (type.getIntOrFloatBitWidth() > 32) {
281 LLVM_DEBUG(llvm::dbgs()
282 << type
283 << " not converted to 32-bit for SPIR-V to avoid truncation\n");
284 return nullptr;
285 }
286
287 if (auto floatType = dyn_cast<FloatType>(type)) {
288 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
289 return Builder(targetEnv.getContext()).getF32Type();
290 }
291
292 auto intType = cast<IntegerType>(type);
293 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
294 return IntegerType::get(targetEnv.getContext(), /*width=*/32,
295 intType.getSignedness());
296}
297
298/// Converts a sub-byte integer `type` to i32 regardless of target environment.
299/// Returns a nullptr for unsupported integer types, including non sub-byte
300/// types.
301///
302/// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
303/// the above given that these sub-byte types are not supported at all in
304/// SPIR-V; there are no compute/storage capability for them like other
305/// supported integer types.
306static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
307 IntegerType type) {
308 if (type.getWidth() > 8) {
309 LLVM_DEBUG(llvm::dbgs() << "not a subbyte type\n");
310 return nullptr;
311 }
312 if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
313 LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
314 return nullptr;
315 }
316
317 if (!llvm::isPowerOf2_32(type.getWidth())) {
318 LLVM_DEBUG(llvm::dbgs()
319 << "unsupported non-power-of-two bitwidth in sub-byte" << type
320 << "\n");
321 return nullptr;
322 }
323
324 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
325 return IntegerType::get(type.getContext(), /*width=*/32,
326 type.getSignedness());
327}
328
329/// Converts 8-bit float types to integer types with the same bit width.
330/// Returns a nullptr for unsupported 8-bit float types.
331static Type convert8BitFloatType(const SPIRVConversionOptions &options,
332 FloatType type) {
333 if (!options.emulateUnsupportedFloatTypes)
334 return nullptr;
335 // F8 types are converted to integer types with the same bit width.
336 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
337 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
338 Float8E8M0FNUType>(type))
339 return IntegerType::get(type.getContext(), type.getWidth());
340 LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n");
341 return nullptr;
342}
343
344/// Returns a type with the same shape but with any 8-bit float element type
345/// converted to the same bit width integer type. This is a noop when the
346/// element type is not the 8-bit float type or emulation flag is set to false.
347static ShapedType
348convertShaped8BitFloatType(ShapedType type,
350 if (!options.emulateUnsupportedFloatTypes)
351 return type;
352 Type srcElementType = type.getElementType();
353 Type convertedElementType = nullptr;
354 // F8 types are converted to integer types with the same bit width.
355 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
356 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
357 Float8E8M0FNUType>(srcElementType))
358 convertedElementType = IntegerType::get(
359 type.getContext(), srcElementType.getIntOrFloatBitWidth());
360
361 if (!convertedElementType)
362 return type;
363
364 return type.clone(convertedElementType);
365}
366
367/// Returns a type with the same shape but with any index element type converted
368/// to the matching integer type. This is a noop when the element type is not
369/// the index type.
370static ShapedType
371convertIndexElementType(ShapedType type,
373 Type indexType = dyn_cast<IndexType>(type.getElementType());
374 if (!indexType)
375 return type;
376
377 return type.clone(getIndexType(type.getContext(), options));
378}
379
380/// Converts a vector `type` to a suitable type under the given `targetEnv`.
381static Type
382convertVectorType(const spirv::TargetEnv &targetEnv,
383 const SPIRVConversionOptions &options, VectorType type,
384 std::optional<spirv::StorageClass> storageClass = {}) {
385 type = cast<VectorType>(convertIndexElementType(type, options));
386 type = cast<VectorType>(convertShaped8BitFloatType(type, options));
387 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
388 if (!scalarType) {
389 // If this is not a spec allowed scalar type, try to handle sub-byte integer
390 // types.
391 auto intType = dyn_cast<IntegerType>(type.getElementType());
392 if (!intType) {
393 LLVM_DEBUG(llvm::dbgs()
394 << type
395 << " illegal: cannot convert non-scalar element type\n");
396 return nullptr;
397 }
398
399 Type elementType = convertSubByteIntegerType(options, intType);
400 if (!elementType)
401 return nullptr;
402
403 if (type.getRank() <= 1 && type.getNumElements() == 1)
404 return elementType;
405
406 if (type.getNumElements() > 4) {
407 LLVM_DEBUG(llvm::dbgs()
408 << type << " illegal: > 4-element unimplemented\n");
409 return nullptr;
410 }
411
412 return VectorType::get(type.getShape(), elementType);
413 }
414
415 if (type.getRank() <= 1 && type.getNumElements() == 1)
416 return convertScalarType(targetEnv, options, scalarType, storageClass);
417
419 LLVM_DEBUG(llvm::dbgs()
420 << type << " illegal: not a valid composite type\n");
421 return nullptr;
422 }
423
424 // Get extension and capability requirements for the given type.
427 cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
428 cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
429
430 // If all requirements are met, then we can accept this type as-is.
431 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
432 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
433 return type;
434
435 auto elementType =
436 convertScalarType(targetEnv, options, scalarType, storageClass);
437 if (elementType)
438 return VectorType::get(type.getShape(), elementType);
439 return nullptr;
440}
441
442static Type
443convertComplexType(const spirv::TargetEnv &targetEnv,
444 const SPIRVConversionOptions &options, ComplexType type,
445 std::optional<spirv::StorageClass> storageClass = {}) {
446 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
447 if (!scalarType) {
448 LLVM_DEBUG(llvm::dbgs()
449 << type << " illegal: cannot convert non-scalar element type\n");
450 return nullptr;
451 }
452
453 auto elementType =
454 convertScalarType(targetEnv, options, scalarType, storageClass);
455 if (!elementType)
456 return nullptr;
457 if (elementType != type.getElementType()) {
458 LLVM_DEBUG(llvm::dbgs()
459 << type << " illegal: complex type emulation unsupported\n");
460 return nullptr;
461 }
462
463 return VectorType::get(2, elementType);
464}
465
466/// Converts a tensor `type` to a suitable type under the given `targetEnv`.
467///
468/// Note that this is mainly for lowering constant tensors. In SPIR-V one can
469/// create composite constants with OpConstantComposite to embed relative large
470/// constant values and use OpCompositeExtract and OpCompositeInsert to
471/// manipulate, like what we do for vectors.
472static Type convertTensorType(const spirv::TargetEnv &targetEnv,
474 TensorType type) {
475 // TODO: Handle dynamic shapes.
476 if (!type.hasStaticShape()) {
477 LLVM_DEBUG(llvm::dbgs()
478 << type << " illegal: dynamic shape unimplemented\n");
479 return nullptr;
480 }
481
482 type = cast<TensorType>(convertIndexElementType(type, options));
483 type = cast<TensorType>(convertShaped8BitFloatType(type, options));
484 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
485 if (!scalarType) {
486 LLVM_DEBUG(llvm::dbgs()
487 << type << " illegal: cannot convert non-scalar element type\n");
488 return nullptr;
489 }
490
491 std::optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
492 std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);
493 if (!scalarSize || !tensorSize) {
494 LLVM_DEBUG(llvm::dbgs()
495 << type << " illegal: cannot deduce element count\n");
496 return nullptr;
497 }
498
499 int64_t arrayElemCount = *tensorSize / *scalarSize;
500 if (arrayElemCount == 0) {
501 LLVM_DEBUG(llvm::dbgs()
502 << type << " illegal: cannot handle zero-element tensors\n");
503 return nullptr;
504 }
505
506 Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
507 if (!arrayElemType)
508 return nullptr;
509 std::optional<int64_t> arrayElemSize =
510 getTypeNumBytes(options, arrayElemType);
511 if (!arrayElemSize) {
512 LLVM_DEBUG(llvm::dbgs()
513 << type << " illegal: cannot deduce converted element size\n");
514 return nullptr;
515 }
516
517 return spirv::ArrayType::get(arrayElemType, arrayElemCount);
518}
519
520static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
522 MemRefType type,
523 spirv::StorageClass storageClass) {
524 unsigned numBoolBits = options.boolNumBits;
525 if (numBoolBits != 8) {
526 LLVM_DEBUG(llvm::dbgs()
527 << "using non-8-bit storage for bool types unimplemented");
528 return nullptr;
529 }
530 auto elementType = dyn_cast<spirv::ScalarType>(
531 IntegerType::get(type.getContext(), numBoolBits));
532 if (!elementType)
533 return nullptr;
534 Type arrayElemType =
535 convertScalarType(targetEnv, options, elementType, storageClass);
536 if (!arrayElemType)
537 return nullptr;
538 std::optional<int64_t> arrayElemSize =
539 getTypeNumBytes(options, arrayElemType);
540 if (!arrayElemSize) {
541 LLVM_DEBUG(llvm::dbgs()
542 << type << " illegal: cannot deduce converted element size\n");
543 return nullptr;
544 }
545
546 if (!type.hasStaticShape()) {
547 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
548 // to the element.
549 if (targetEnv.allows(spirv::Capability::Kernel))
550 return spirv::PointerType::get(arrayElemType, storageClass);
551 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
552 auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
553 // For Vulkan we need extra wrapping struct and array to satisfy interface
554 // needs.
555 return wrapInStructAndGetPointer(arrayType, storageClass);
556 }
557
558 if (type.getNumElements() == 0) {
559 LLVM_DEBUG(llvm::dbgs()
560 << type << " illegal: zero-element memrefs are not supported\n");
561 return nullptr;
562 }
563
564 int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
565 int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
566 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
567 auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
568 if (targetEnv.allows(spirv::Capability::Kernel))
569 return spirv::PointerType::get(arrayType, storageClass);
570 return wrapInStructAndGetPointer(arrayType, storageClass);
571}
572
573static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
575 MemRefType type,
576 spirv::StorageClass storageClass) {
577 IntegerType elementType = cast<IntegerType>(type.getElementType());
578 Type arrayElemType = convertSubByteIntegerType(options, elementType);
579 if (!arrayElemType)
580 return nullptr;
581 int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType);
582
583 if (!type.hasStaticShape()) {
584 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
585 // to the element.
586 if (targetEnv.allows(spirv::Capability::Kernel))
587 return spirv::PointerType::get(arrayElemType, storageClass);
588 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
589 auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
590 // For Vulkan we need extra wrapping struct and array to satisfy interface
591 // needs.
592 return wrapInStructAndGetPointer(arrayType, storageClass);
593 }
594
595 if (type.getNumElements() == 0) {
596 LLVM_DEBUG(llvm::dbgs()
597 << type << " illegal: zero-element memrefs are not supported\n");
598 return nullptr;
599 }
600
601 int64_t memrefSize =
602 llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
603 int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
604 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
605 auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
606 if (targetEnv.allows(spirv::Capability::Kernel))
607 return spirv::PointerType::get(arrayType, storageClass);
608 return wrapInStructAndGetPointer(arrayType, storageClass);
609}
610
611static spirv::Dim convertRank(int64_t rank) {
612 switch (rank) {
613 case 1:
614 return spirv::Dim::Dim1D;
615 case 2:
616 return spirv::Dim::Dim2D;
617 case 3:
618 return spirv::Dim::Dim3D;
619 default:
620 llvm_unreachable("Invalid memref rank!");
621 }
622}
623
624static spirv::ImageFormat getImageFormat(Type elementType) {
625 return TypeSwitch<Type, spirv::ImageFormat>(elementType)
626 .Case<Float16Type>([](Float16Type) { return spirv::ImageFormat::R16f; })
627 .Case<Float32Type>([](Float32Type) { return spirv::ImageFormat::R32f; })
628 .Case<IntegerType>([](IntegerType intType) {
629 auto const isSigned = intType.isSigned() || intType.isSignless();
630#define BIT_WIDTH_CASE(BIT_WIDTH) \
631 case BIT_WIDTH: \
632 return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \
633 : spirv::ImageFormat::R##BIT_WIDTH##ui
634
635 switch (intType.getWidth()) {
636 BIT_WIDTH_CASE(16);
637 BIT_WIDTH_CASE(32);
638 default:
639 llvm_unreachable("Unhandled integer type!");
640 }
641 })
642 .DefaultUnreachable("Unhandled element type!");
643#undef BIT_WIDTH_CASE
644}
645
646static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
648 MemRefType type) {
649 auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
650 if (!attr) {
651 LLVM_DEBUG(
652 llvm::dbgs()
653 << type
654 << " illegal: expected memory space to be a SPIR-V storage class "
655 "attribute; please use MemorySpaceToStorageClassConverter to map "
656 "numeric memory spaces beforehand\n");
657 return nullptr;
658 }
659 spirv::StorageClass storageClass = attr.getValue();
660
661 // Images are a special case since they are an opaque type from which elements
662 // may be accessed via image specific ops or directly through a texture
663 // pointer.
664 if (storageClass == spirv::StorageClass::Image) {
665 const int64_t rank = type.getRank();
666 if (rank < 1 || rank > 3) {
667 LLVM_DEBUG(llvm::dbgs()
668 << type << " illegal: cannot lower memref of rank " << rank
669 << " to a SPIR-V Image\n");
670 return nullptr;
671 }
672
673 // Note that we currently only support lowering to single element texels
674 // e.g. R32f.
675 auto elementType = type.getElementType();
676 if (!isa<spirv::ScalarType>(elementType)) {
677 LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot lower memref of "
678 << elementType << " to a SPIR-V Image\n");
679 return nullptr;
680 }
681
682 // Currently every memref in the image storage class is converted to a
683 // sampled image so we can hardcode the NeedSampler field. Future work
684 // will generalize this to support regular non-sampled images.
685 auto spvImageType = spirv::ImageType::get(
686 elementType, convertRank(rank), spirv::ImageDepthInfo::DepthUnknown,
687 spirv::ImageArrayedInfo::NonArrayed,
688 spirv::ImageSamplingInfo::SingleSampled,
689 spirv::ImageSamplerUseInfo::NeedSampler, getImageFormat(elementType));
690 auto spvSampledImageType = spirv::SampledImageType::get(spvImageType);
691 auto imagePtrType = spirv::PointerType::get(
692 spvSampledImageType, spirv::StorageClass::UniformConstant);
693 return imagePtrType;
694 }
695
696 if (isa<IntegerType>(type.getElementType())) {
697 if (type.getElementTypeBitWidth() == 1)
698 return convertBoolMemrefType(targetEnv, options, type, storageClass);
699 if (type.getElementTypeBitWidth() < 8)
700 return convertSubByteMemrefType(targetEnv, options, type, storageClass);
701 }
702
703 Type arrayElemType;
704 Type elementType = type.getElementType();
705 if (auto vecType = dyn_cast<VectorType>(elementType)) {
706 arrayElemType =
707 convertVectorType(targetEnv, options, vecType, storageClass);
708 } else if (auto complexType = dyn_cast<ComplexType>(elementType)) {
709 arrayElemType =
710 convertComplexType(targetEnv, options, complexType, storageClass);
711 } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
712 arrayElemType =
713 convertScalarType(targetEnv, options, scalarType, storageClass);
714 } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
715 type = cast<MemRefType>(convertIndexElementType(type, options));
716 arrayElemType = type.getElementType();
717 } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
718 // Hnadle 8 bit float types.
719 type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
720 arrayElemType = type.getElementType();
721 } else {
722 LLVM_DEBUG(
723 llvm::dbgs()
724 << type
725 << " unhandled: can only convert scalar or vector element type\n");
726 return nullptr;
727 }
728 if (!arrayElemType)
729 return nullptr;
730
731 std::optional<int64_t> arrayElemSize =
732 getTypeNumBytes(options, arrayElemType);
733 if (!arrayElemSize) {
734 LLVM_DEBUG(llvm::dbgs()
735 << type << " illegal: cannot deduce converted element size\n");
736 return nullptr;
737 }
738
739 if (!type.hasStaticShape()) {
740 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
741 // to the element.
742 if (targetEnv.allows(spirv::Capability::Kernel))
743 return spirv::PointerType::get(arrayElemType, storageClass);
744 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
745 auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
746 // For Vulkan we need extra wrapping struct and array to satisfy interface
747 // needs.
748 return wrapInStructAndGetPointer(arrayType, storageClass);
749 }
750
751 std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);
752 if (!memrefSize) {
753 LLVM_DEBUG(llvm::dbgs()
754 << type << " illegal: cannot deduce element count\n");
755 return nullptr;
756 }
757
758 if (*memrefSize == 0) {
759 LLVM_DEBUG(llvm::dbgs()
760 << type << " illegal: zero-element memrefs are not supported\n");
761 return nullptr;
762 }
763
764 int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
765 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
766 auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
767 if (targetEnv.allows(spirv::Capability::Kernel))
768 return spirv::PointerType::get(arrayType, storageClass);
769 return wrapInStructAndGetPointer(arrayType, storageClass);
770}
771
772//===----------------------------------------------------------------------===//
773// Type casting materialization
774//===----------------------------------------------------------------------===//
775
776/// Converts the given `inputs` to the original source `type` considering the
777/// `targetEnv`'s capabilities.
778///
779/// This function is meant to be used for source materialization in type
780/// converters. When the type converter needs to materialize a cast op back
781/// to some original source type, we need to check whether the original source
782/// type is supported in the target environment. If so, we can insert legal
783/// SPIR-V cast ops accordingly.
784///
785/// Note that in SPIR-V the capabilities for storage and compute are separate.
786/// This function is meant to handle the **compute** side; so it does not
787/// involve storage classes in its logic. The storage side is expected to be
788/// handled by MemRef conversion logic.
789static Value castToSourceType(const spirv::TargetEnv &targetEnv,
790 OpBuilder &builder, Type type, ValueRange inputs,
791 Location loc) {
792 // We can only cast one value in SPIR-V.
793 if (inputs.size() != 1) {
794 auto castOp =
795 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
796 return castOp.getResult(0);
797 }
798 Value input = inputs.front();
799
800 // Only support integer types for now. Floating point types to be implemented.
801 if (!isa<IntegerType>(type)) {
802 auto castOp =
803 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
804 return castOp.getResult(0);
805 }
806 auto inputType = cast<IntegerType>(input.getType());
807
808 auto scalarType = dyn_cast<spirv::ScalarType>(type);
809 if (!scalarType) {
810 auto castOp =
811 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
812 return castOp.getResult(0);
813 }
814
815 // Only support source type with a smaller bitwidth. This would mean we are
816 // truncating to go back so we don't need to worry about the signedness.
817 // For extension, we cannot have enough signal here to decide which op to use.
818 if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
819 auto castOp =
820 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
821 return castOp.getResult(0);
822 }
823
824 // Boolean values would need to use different ops than normal integer values.
825 if (type.isInteger(1)) {
826 Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
827 return spirv::IEqualOp::create(builder, loc, input, one);
828 }
829
830 // Check that the source integer type is supported by the environment.
833 scalarType.getExtensions(exts);
834 scalarType.getCapabilities(caps);
835 if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
836 failed(checkExtensionRequirements(type, targetEnv, exts))) {
837 auto castOp =
838 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
839 return castOp.getResult(0);
840 }
841
842 // We've already made sure this is truncating previously, so we don't need to
843 // care about signedness here. Still try to use a corresponding op for better
844 // consistency though.
845 if (type.isSignedInteger()) {
846 return spirv::SConvertOp::create(builder, loc, type, input);
847 }
848 return spirv::UConvertOp::create(builder, loc, type, input);
849}
850
851//===----------------------------------------------------------------------===//
852// Builtin Variables
853//===----------------------------------------------------------------------===//
854
855static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
856 spirv::BuiltIn builtin) {
857 // Look through all global variables in the given `body` block and check if
858 // there is a spirv.GlobalVariable that has the same `builtin` attribute.
859 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
860 if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
861 spirv::SPIRVDialect::getAttributeName(
862 spirv::Decoration::BuiltIn))) {
863 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
864 if (varBuiltIn == builtin) {
865 return varOp;
866 }
867 }
868 }
869 return nullptr;
870}
871
872/// Gets name of global variable for a builtin.
873std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
874 StringRef suffix) {
875 return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
876}
877
878/// Gets or inserts a global variable for a builtin within `body` block.
879static spirv::GlobalVariableOp
880getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
881 Type integerType, OpBuilder &builder,
882 StringRef prefix, StringRef suffix) {
883 if (auto varOp = getBuiltinVariable(body, builtin))
884 return varOp;
885
886 OpBuilder::InsertionGuard guard(builder);
887 builder.setInsertionPointToStart(&body);
888
889 spirv::GlobalVariableOp newVarOp;
890 switch (builtin) {
891 case spirv::BuiltIn::NumWorkgroups:
892 case spirv::BuiltIn::WorkgroupSize:
893 case spirv::BuiltIn::WorkgroupId:
894 case spirv::BuiltIn::LocalInvocationId:
895 case spirv::BuiltIn::GlobalInvocationId: {
896 auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
897 spirv::StorageClass::Input);
898 std::string name = getBuiltinVarName(builtin, prefix, suffix);
899 newVarOp =
900 spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
901 break;
902 }
903 case spirv::BuiltIn::SubgroupId:
904 case spirv::BuiltIn::NumSubgroups:
905 case spirv::BuiltIn::SubgroupSize:
906 case spirv::BuiltIn::SubgroupLocalInvocationId: {
907 auto ptrType =
908 spirv::PointerType::get(integerType, spirv::StorageClass::Input);
909 std::string name = getBuiltinVarName(builtin, prefix, suffix);
910 newVarOp =
911 spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
912 break;
913 }
914 default:
915 emitError(loc, "unimplemented builtin variable generation for ")
916 << stringifyBuiltIn(builtin);
917 }
918 return newVarOp;
919}
920
921//===----------------------------------------------------------------------===//
922// Push constant storage
923//===----------------------------------------------------------------------===//
924
925/// Returns the pointer type for the push constant storage containing
926/// `elementCount` 32-bit integer values.
927static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
928 Builder &builder,
929 Type indexType) {
930 auto arrayType = spirv::ArrayType::get(indexType, elementCount,
931 /*stride=*/4);
932 auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
933 return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
934}
935
936/// Returns the push constant varible containing `elementCount` 32-bit integer
937/// values in `body`. Returns null op if such an op does not exit.
938static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
939 unsigned elementCount) {
940 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
941 auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
942 if (!ptrType)
943 continue;
944
945 // Note that Vulkan requires "There must be no more than one push constant
946 // block statically used per shader entry point." So we should always reuse
947 // the existing one.
948 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
949 auto numElements = cast<spirv::ArrayType>(
950 cast<spirv::StructType>(ptrType.getPointeeType())
951 .getElementType(0))
952 .getNumElements();
953 if (numElements == elementCount)
954 return varOp;
955 }
956 }
957 return nullptr;
958}
959
960/// Gets or inserts a global variable for push constant storage containing
961/// `elementCount` 32-bit integer values in `block`.
962static spirv::GlobalVariableOp
963getOrInsertPushConstantVariable(Location loc, Block &block,
964 unsigned elementCount, OpBuilder &b,
965 Type indexType) {
966 if (auto varOp = getPushConstantVariable(block, elementCount))
967 return varOp;
968
969 auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
970 auto type = getPushConstantStorageType(elementCount, builder, indexType);
971 const char *name = "__push_constant_var__";
972 return spirv::GlobalVariableOp::create(builder, loc, type, name,
973 /*initializer=*/nullptr);
974}
975
976//===----------------------------------------------------------------------===//
977// func::FuncOp Conversion Patterns
978//===----------------------------------------------------------------------===//
979
980/// A pattern for rewriting function signature to convert arguments of functions
981/// to be of valid SPIR-V types.
982struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
983 using Base::Base;
984
985 LogicalResult
986 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
987 ConversionPatternRewriter &rewriter) const override {
988 FunctionType fnType = funcOp.getFunctionType();
989 if (fnType.getNumResults() > 1)
990 return failure();
991
992 TypeConverter::SignatureConversion signatureConverter(
993 fnType.getNumInputs());
994 for (const auto &argType : enumerate(fnType.getInputs())) {
995 auto convertedType = getTypeConverter()->convertType(argType.value());
996 if (!convertedType)
997 return failure();
998 signatureConverter.addInputs(argType.index(), convertedType);
999 }
1000
1001 Type resultType;
1002 if (fnType.getNumResults() == 1) {
1003 resultType = getTypeConverter()->convertType(fnType.getResult(0));
1004 if (!resultType)
1005 return failure();
1006 }
1007
1008 // Create the converted spirv.func op.
1009 auto newFuncOp = spirv::FuncOp::create(
1010 rewriter, funcOp.getLoc(), funcOp.getName(),
1011 rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
1012 resultType ? TypeRange(resultType)
1013 : TypeRange()));
1014
1015 // Copy over all attributes other than the function name and type.
1016 for (const auto &namedAttr : funcOp->getAttrs()) {
1017 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
1018 namedAttr.getName() != SymbolTable::getSymbolAttrName())
1019 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1020 }
1021
1022 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1023 newFuncOp.end());
1024 if (failed(rewriter.convertRegionTypes(
1025 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
1026 return failure();
1027 rewriter.eraseOp(funcOp);
1028 return success();
1029 }
1030};
1031
1032/// A pattern for rewriting function signature to convert vector arguments of
1033/// functions to be of valid types
1034struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
1035 using Base::Base;
1036
1037 LogicalResult matchAndRewrite(func::FuncOp funcOp,
1038 PatternRewriter &rewriter) const override {
1039 FunctionType fnType = funcOp.getFunctionType();
1040
1041 // TODO: Handle declarations.
1042 if (funcOp.isDeclaration()) {
1043 LLVM_DEBUG(llvm::dbgs()
1044 << fnType << " illegal: declarations are unsupported\n");
1045 return failure();
1046 }
1047
1048 // Create a new func op with the original type and copy the function body.
1049 auto newFuncOp = func::FuncOp::create(rewriter, funcOp.getLoc(),
1050 funcOp.getName(), fnType);
1051 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1052 newFuncOp.end());
1053
1054 Location loc = newFuncOp.getBody().getLoc();
1055
1056 Block &entryBlock = newFuncOp.getBlocks().front();
1057 OpBuilder::InsertionGuard guard(rewriter);
1058 rewriter.setInsertionPointToStart(&entryBlock);
1059
1060 TypeConverter::SignatureConversion oneToNTypeMapping(
1061 fnType.getInputs().size());
1062
1063 // For arguments that are of illegal types and require unrolling.
1064 // `unrolledInputNums` stores the indices of arguments that result from
1065 // unrolling in the new function signature. `newInputNo` is a counter.
1066 SmallVector<size_t> unrolledInputNums;
1067 size_t newInputNo = 0;
1068
1069 // For arguments that are of legal types and do not require unrolling.
1070 // `tmpOps` stores a mapping from temporary operations that serve as
1071 // placeholders for new arguments that will be added later. These operations
1072 // will be erased once the entry block's argument list is updated.
1073 llvm::SmallDenseMap<Operation *, size_t> tmpOps;
1074
1075 // This counts the number of new operations created.
1076 size_t newOpCount = 0;
1077
1078 // Enumerate through the arguments.
1079 for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
1080 // Check whether the argument is of vector type.
1081 auto origVecType = dyn_cast<VectorType>(origType);
1082 if (!origVecType) {
1083 // We need a placeholder for the old argument that will be erased later.
1084 Value result = arith::ConstantOp::create(
1085 rewriter, loc, origType, rewriter.getZeroAttr(origType));
1086 rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1087 tmpOps.insert({result.getDefiningOp(), newInputNo});
1088 oneToNTypeMapping.addInputs(origInputNo, origType);
1089 ++newInputNo;
1090 ++newOpCount;
1091 continue;
1092 }
1093 // Check whether the vector needs unrolling.
1094 auto targetShape = getTargetShape(origVecType);
1095 if (!targetShape) {
1096 // We need a placeholder for the old argument that will be erased later.
1097 Value result = arith::ConstantOp::create(
1098 rewriter, loc, origType, rewriter.getZeroAttr(origType));
1099 rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1100 tmpOps.insert({result.getDefiningOp(), newInputNo});
1101 oneToNTypeMapping.addInputs(origInputNo, origType);
1102 ++newInputNo;
1103 ++newOpCount;
1104 continue;
1105 }
1106 VectorType unrolledType =
1107 VectorType::get(*targetShape, origVecType.getElementType());
1108 auto originalShape =
1109 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1110
1111 // Prepare the result vector.
1112 Value result = arith::ConstantOp::create(
1113 rewriter, loc, origVecType, rewriter.getZeroAttr(origVecType));
1114 ++newOpCount;
1115 // Prepare the placeholder for the new arguments that will be added later.
1116 Value dummy = arith::ConstantOp::create(
1117 rewriter, loc, unrolledType, rewriter.getZeroAttr(unrolledType));
1118 ++newOpCount;
1119
1120 // Create the `vector.insert_strided_slice` ops.
1121 SmallVector<int64_t> strides(targetShape->size(), 1);
1122 SmallVector<Type> newTypes;
1123 for (SmallVector<int64_t> offsets :
1124 StaticTileOffsetRange(originalShape, *targetShape)) {
1125 result = vector::InsertStridedSliceOp::create(rewriter, loc, dummy,
1126 result, offsets, strides);
1127 newTypes.push_back(unrolledType);
1128 unrolledInputNums.push_back(newInputNo);
1129 ++newInputNo;
1130 ++newOpCount;
1131 }
1132 rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1133 oneToNTypeMapping.addInputs(origInputNo, newTypes);
1134 }
1135
1136 // Change the function signature.
1137 auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
1138 auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1139 rewriter.modifyOpInPlace(newFuncOp,
1140 [&] { newFuncOp.setFunctionType(newFnType); });
1141
1142 // Update the arguments in the entry block.
1143 entryBlock.eraseArguments(0, fnType.getNumInputs());
1144 SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
1145 entryBlock.addArguments(convertedTypes, locs);
1146
1147 // Replace all uses of placeholders for initially legal arguments with their
1148 // original function arguments (that were added to `newFuncOp`).
1149 for (auto &[placeholderOp, argIdx] : tmpOps) {
1150 if (!placeholderOp)
1151 continue;
1152 Value replacement = newFuncOp.getArgument(argIdx);
1153 rewriter.replaceAllUsesWith(placeholderOp->getResult(0), replacement);
1154 }
1155
1156 // Replace dummy operands of new `vector.insert_strided_slice` ops with
1157 // their corresponding new function arguments. The new
1158 // `vector.insert_strided_slice` ops are inserted only into the entry block,
1159 // so iterating over that block is sufficient.
1160 size_t unrolledInputIdx = 0;
1161 for (auto [count, op] : enumerate(entryBlock.getOperations())) {
1162 Operation &curOp = op;
1163 // Since all newly created operations are in the beginning, reaching the
1164 // end of them means that any later `vector.insert_strided_slice` should
1165 // not be touched.
1166 if (count >= newOpCount)
1167 continue;
1168 if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1169 size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1170 rewriter.modifyOpInPlace(&curOp, [&] {
1171 curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1172 });
1173 ++unrolledInputIdx;
1174 }
1175 }
1176
1177 // Erase the original funcOp. The `tmpOps` do not need to be erased since
1178 // they have no uses and will be handled by dead-code elimination.
1179 rewriter.eraseOp(funcOp);
1180 return success();
1181 }
1182};
1183
1184//===----------------------------------------------------------------------===//
1185// func::ReturnOp Conversion Patterns
1186//===----------------------------------------------------------------------===//
1187
1188/// A pattern for rewriting function signature and the return op to convert
1189/// vectors to be of valid types.
1190struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
1191 using Base::Base;
1192
1193 LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1194 PatternRewriter &rewriter) const override {
1195 // Check whether the parent funcOp is valid.
1196 auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
1197 if (!funcOp)
1198 return failure();
1199
1200 FunctionType fnType = funcOp.getFunctionType();
1201 TypeConverter::SignatureConversion oneToNTypeMapping(
1202 fnType.getResults().size());
1203 Location loc = returnOp.getLoc();
1204
1205 // For the new return op.
1206 SmallVector<Value> newOperands;
1207
1208 // Enumerate through the results.
1209 for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
1210 // Check whether the argument is of vector type.
1211 auto origVecType = dyn_cast<VectorType>(origType);
1212 if (!origVecType) {
1213 oneToNTypeMapping.addInputs(origResultNo, origType);
1214 newOperands.push_back(returnOp.getOperand(origResultNo));
1215 continue;
1216 }
1217 // Check whether the vector needs unrolling.
1218 auto targetShape = getTargetShape(origVecType);
1219 if (!targetShape) {
1220 // The original argument can be used.
1221 oneToNTypeMapping.addInputs(origResultNo, origType);
1222 newOperands.push_back(returnOp.getOperand(origResultNo));
1223 continue;
1224 }
1225 VectorType unrolledType =
1226 VectorType::get(*targetShape, origVecType.getElementType());
1227
1228 // Create `vector.extract_strided_slice` ops to form legal vectors from
1229 // the original operand of illegal type.
1230 auto originalShape =
1231 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1232 SmallVector<int64_t> strides(originalShape.size(), 1);
1233 SmallVector<int64_t> extractShape(originalShape.size(), 1);
1234 extractShape.back() = targetShape->back();
1235 SmallVector<Type> newTypes;
1236 Value returnValue = returnOp.getOperand(origResultNo);
1237 for (SmallVector<int64_t> offsets :
1238 StaticTileOffsetRange(originalShape, *targetShape)) {
1239 Value result = vector::ExtractStridedSliceOp::create(
1240 rewriter, loc, returnValue, offsets, extractShape, strides);
1241 if (originalShape.size() > 1) {
1242 SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
1243 result =
1244 vector::ExtractOp::create(rewriter, loc, result, extractIndices);
1245 }
1246 newOperands.push_back(result);
1247 newTypes.push_back(unrolledType);
1248 }
1249 oneToNTypeMapping.addInputs(origResultNo, newTypes);
1250 }
1251
1252 // Change the function signature.
1253 auto newFnType =
1254 FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
1255 TypeRange(oneToNTypeMapping.getConvertedTypes()));
1256 rewriter.modifyOpInPlace(funcOp,
1257 [&] { funcOp.setFunctionType(newFnType); });
1258
1259 // Replace the return op using the new operands. This will automatically
1260 // update the entry block as well.
1261 rewriter.replaceOp(returnOp,
1262 func::ReturnOp::create(rewriter, loc, newOperands));
1263
1264 return success();
1265 }
1266};
1267
1268} // namespace
1269
1270//===----------------------------------------------------------------------===//
1271// Public function for builtin variables
1272//===----------------------------------------------------------------------===//
1273
1275 spirv::BuiltIn builtin,
1276 Type integerType, OpBuilder &builder,
1277 StringRef prefix, StringRef suffix) {
1279 if (!parent) {
1280 op->emitError("expected operation to be within a module-like op");
1281 return nullptr;
1282 }
1283
1284 spirv::GlobalVariableOp varOp =
1285 getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
1286 builtin, integerType, builder, prefix, suffix);
1287 Value ptr = spirv::AddressOfOp::create(builder, op->getLoc(), varOp);
1288 return spirv::LoadOp::create(builder, op->getLoc(), ptr);
1289}
1290
1291//===----------------------------------------------------------------------===//
1292// Public function for pushing constant storage
1293//===----------------------------------------------------------------------===//
1294
1296 unsigned offset, Type integerType,
1297 OpBuilder &builder) {
1298 Location loc = op->getLoc();
1300 if (!parent) {
1301 op->emitError("expected operation to be within a module-like op");
1302 return nullptr;
1303 }
1304
1305 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1306 loc, parent->getRegion(0).front(), elementCount, builder, integerType);
1307
1308 Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
1309 Value offsetOp = spirv::ConstantOp::create(builder, loc, integerType,
1310 builder.getI32IntegerAttr(offset));
1311 auto addrOp = spirv::AddressOfOp::create(builder, loc, varOp);
1312 auto acOp = spirv::AccessChainOp::create(builder, loc, addrOp,
1313 llvm::ArrayRef({zeroOp, offsetOp}));
1314 return spirv::LoadOp::create(builder, loc, acOp);
1315}
1316
1317//===----------------------------------------------------------------------===//
1318// Public functions for index calculation
1319//===----------------------------------------------------------------------===//
1320
1322 int64_t offset, Type integerType,
1323 Location loc, OpBuilder &builder) {
1324 assert(indices.size() == strides.size() &&
1325 "must provide indices for all dimensions");
1326
1327 // TODO: Consider moving to use affine.apply and patterns converting
1328 // affine.apply to standard ops. This needs converting to SPIR-V passes to be
1329 // broken down into progressive small steps so we can have intermediate steps
1330 // using other dialects. At the moment SPIR-V is the final sink.
1331
1332 Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
1333 loc, integerType, IntegerAttr::get(integerType, offset));
1334 for (const auto &index : llvm::enumerate(indices)) {
1335 Value strideVal = builder.createOrFold<spirv::ConstantOp>(
1336 loc, integerType,
1337 IntegerAttr::get(integerType, strides[index.index()]));
1338 Value update =
1339 builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1340 linearizedIndex =
1341 builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1342 }
1343 return linearizedIndex;
1344}
1345
1347 MemRefType baseType, Value basePtr,
1349 OpBuilder &builder) {
1350 // Get base and offset of the MemRefType and verify they are static.
1351
1352 int64_t offset;
1354 if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1355 llvm::is_contained(strides, ShapedType::kDynamic) ||
1356 ShapedType::isDynamic(offset)) {
1357 return nullptr;
1358 }
1359
1360 auto indexType = typeConverter.getIndexType();
1361
1362 SmallVector<Value, 2> linearizedIndices;
1363 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
1364
1365 // Add a '0' at the start to index into the struct.
1366 linearizedIndices.push_back(zero);
1367
1368 if (baseType.getRank() == 0) {
1369 linearizedIndices.push_back(zero);
1370 } else {
1371 linearizedIndices.push_back(
1372 linearizeIndex(indices, strides, offset, indexType, loc, builder));
1373 }
1374 return spirv::AccessChainOp::create(builder, loc, basePtr, linearizedIndices);
1375}
1376
1378 MemRefType baseType, Value basePtr,
1380 OpBuilder &builder) {
1381 // Get base and offset of the MemRefType and verify they are static.
1382
1383 int64_t offset;
1385 if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1386 llvm::is_contained(strides, ShapedType::kDynamic) ||
1387 ShapedType::isDynamic(offset)) {
1388 return nullptr;
1389 }
1390
1391 auto indexType = typeConverter.getIndexType();
1392
1393 SmallVector<Value, 2> linearizedIndices;
1394 Value linearIndex;
1395 if (baseType.getRank() == 0) {
1396 linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
1397 } else {
1398 linearIndex =
1399 linearizeIndex(indices, strides, offset, indexType, loc, builder);
1400 }
1401 Type pointeeType =
1402 cast<spirv::PointerType>(basePtr.getType()).getPointeeType();
1403 if (isa<spirv::ArrayType>(pointeeType)) {
1404 linearizedIndices.push_back(linearIndex);
1405 return spirv::AccessChainOp::create(builder, loc, basePtr,
1406 linearizedIndices);
1407 }
1408 return spirv::PtrAccessChainOp::create(builder, loc, basePtr, linearIndex,
1409 linearizedIndices);
1410}
1411
1413 MemRefType baseType, Value basePtr,
1415 OpBuilder &builder) {
1416
1417 if (typeConverter.allows(spirv::Capability::Kernel)) {
1418 return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
1419 builder);
1420 }
1421
1422 return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
1423 builder);
1424}
1425
1426//===----------------------------------------------------------------------===//
1427// Public functions for vector unrolling
1428//===----------------------------------------------------------------------===//
1429
1431 for (int i : {4, 3, 2}) {
1432 if (size % i == 0)
1433 return i;
1434 }
1435 return 1;
1436}
1437
1440 VectorType srcVectorType = op.getSourceVectorType();
1441 assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
1442 int64_t vectorSize =
1443 mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
1444 return {vectorSize};
1445}
1446
1449 VectorType vectorType = op.getResultVectorType();
1450 SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
1451 nativeSize.back() =
1452 mlir::spirv::getComputeVectorSize(vectorType.getShape().back());
1453 return nativeSize;
1454}
1455
1456std::optional<SmallVector<int64_t>>
1459 if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
1460 SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
1461 nativeSize.back() =
1462 mlir::spirv::getComputeVectorSize(vecType.getShape().back());
1463 return nativeSize;
1464 }
1465 }
1466
1468 .Case<vector::ReductionOp, vector::TransposeOp>(
1469 [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
1470 .Default(std::nullopt);
1471}
1472
1474 MLIRContext *context = op->getContext();
1475 RewritePatternSet patterns(context);
1478 // We only want to apply signature conversion once to the existing func ops.
1479 // Without specifying strictMode, the greedy pattern rewriter will keep
1480 // looking for newly created func ops.
1481 return applyPatternsGreedily(op, std::move(patterns),
1482 GreedyRewriteConfig().setStrictness(
1484}
1485
1487 MLIRContext *context = op->getContext();
1488
1489 // Unroll vectors in function bodies to native vector size.
1490 {
1491 RewritePatternSet patterns(context);
1493 [](auto op) { return mlir::spirv::getNativeVectorShape(op); });
1494 populateVectorUnrollPatterns(patterns, options);
1495 if (failed(applyPatternsGreedily(op, std::move(patterns))))
1496 return failure();
1497 }
1498
1499 // Convert transpose ops into extract and insert pairs, in preparation of
1500 // further transformations to canonicalize/cancel.
1501 {
1502 RewritePatternSet patterns(context);
1504 patterns, vector::VectorTransposeLowering::EltWise);
1506 if (failed(applyPatternsGreedily(op, std::move(patterns))))
1507 return failure();
1508 }
1509
1510 // Run canonicalization to cast away leading size-1 dimensions.
1511 {
1512 RewritePatternSet patterns(context);
1513
1514 // We need to pull in casting way leading one dims.
1515 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
1516 vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
1517 vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
1518
1519 // Decompose different rank insert_strided_slice and n-D
1520 // extract_slided_slice.
1521 vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
1522 patterns);
1523 vector::InsertOp::getCanonicalizationPatterns(patterns, context);
1524 vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
1525
1526 // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
1527 // them up.
1528 vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
1529 vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
1530
1531 if (failed(applyPatternsGreedily(op, std::move(patterns))))
1532 return failure();
1533 }
1534 return success();
1535}
1536
1537//===----------------------------------------------------------------------===//
1538// SPIR-V TypeConverter
1539//===----------------------------------------------------------------------===//
1540
1542 const SPIRVConversionOptions &options)
1543 : targetEnv(targetAttr), options(options) {
1544 // Add conversions. The order matters here: later ones will be tried earlier.
1545
1546 // Allow all SPIR-V dialect specific types. This assumes all builtin types
1547 // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
1548 // were tried before.
1549 //
1550 // TODO: This assumes that the SPIR-V types are valid to use in the given
1551 // target environment, which should be the case if the whole pipeline is
1552 // driven by the same target environment. Still, we probably still want to
1553 // validate and convert to be safe.
1554 addConversion([](spirv::SPIRVType type) { return type; });
1555
1556 addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
1557
1558 addConversion([this](IntegerType intType) -> std::optional<Type> {
1559 if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
1560 return convertScalarType(this->targetEnv, this->options, scalarType);
1561 if (intType.getWidth() < 8)
1562 return convertSubByteIntegerType(this->options, intType);
1563 return Type();
1564 });
1565
1566 addConversion([this](FloatType floatType) -> std::optional<Type> {
1567 if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
1568 return convertScalarType(this->targetEnv, this->options, scalarType);
1569 if (floatType.getWidth() == 8)
1570 return convert8BitFloatType(this->options, floatType);
1571 return Type();
1572 });
1573
1574 addConversion([this](ComplexType complexType) {
1575 return convertComplexType(this->targetEnv, this->options, complexType);
1576 });
1577
1578 addConversion([this](VectorType vectorType) {
1579 return convertVectorType(this->targetEnv, this->options, vectorType);
1580 });
1581
1582 addConversion([this](TensorType tensorType) {
1583 return convertTensorType(this->targetEnv, this->options, tensorType);
1584 });
1585
1586 addConversion([this](MemRefType memRefType) {
1587 return convertMemrefType(this->targetEnv, this->options, memRefType);
1588 });
1589
1590 // Register some last line of defense casting logic.
1591 addSourceMaterialization(
1592 [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1593 return castToSourceType(this->targetEnv, builder, type, inputs, loc);
1594 });
1595 addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
1596 Location loc) {
1597 auto cast = UnrealizedConversionCastOp::create(builder, loc, type, inputs);
1598 return cast.getResult(0);
1599 });
1600}
1601
1603 return ::getIndexType(getContext(), options);
1604}
1605
1606MLIRContext *SPIRVTypeConverter::getContext() const {
1607 return targetEnv.getAttr().getContext();
1608}
1609
1610bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
1611 return targetEnv.allows(capability);
1612}
1613
1614//===----------------------------------------------------------------------===//
1615// SPIR-V ConversionTarget
1616//===----------------------------------------------------------------------===//
1617
1618std::unique_ptr<SPIRVConversionTarget>
1620 std::unique_ptr<SPIRVConversionTarget> target(
1621 // std::make_unique does not work here because the constructor is private.
1622 new SPIRVConversionTarget(targetAttr));
1623 SPIRVConversionTarget *targetPtr = target.get();
1624 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1625 // We need to capture the raw pointer here because it is stable:
1626 // target will be destroyed once this function is returned.
1627 [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
1628 return target;
1629}
1630
1631SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
1632 : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
1633
1634bool SPIRVConversionTarget::isLegalOp(Operation *op) {
1635 // Make sure this op is available at the given version. Ops not implementing
1636 // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
1637 // SPIR-V versions.
1638 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1639 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1640 if (minVersion && *minVersion > this->targetEnv.getVersion()) {
1641 LLVM_DEBUG(llvm::dbgs()
1642 << op->getName() << " illegal: requiring min version "
1643 << spirv::stringifyVersion(*minVersion) << "\n");
1644 return false;
1645 }
1646 }
1647 if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1648 std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1649 if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
1650 LLVM_DEBUG(llvm::dbgs()
1651 << op->getName() << " illegal: requiring max version "
1652 << spirv::stringifyVersion(*maxVersion) << "\n");
1653 return false;
1654 }
1655 }
1656
1657 // Make sure this op's required extensions are allowed to use. Ops not
1658 // implementing QueryExtensionInterface do not require extensions to be
1659 // available.
1660 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1661 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1662 extensions.getExtensions())))
1663 return false;
1664
1665 // Make sure this op's required extensions are allowed to use. Ops not
1666 // implementing QueryCapabilityInterface do not require capabilities to be
1667 // available.
1668 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1669 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1670 capabilities.getCapabilities())))
1671 return false;
1672
1673 SmallVector<Type, 4> valueTypes;
1674 valueTypes.append(op->operand_type_begin(), op->operand_type_end());
1675 valueTypes.append(op->result_type_begin(), op->result_type_end());
1676
1677 // Ensure that all types have been converted to SPIRV types.
1678 if (llvm::any_of(valueTypes,
1679 [](Type t) { return !isa<spirv::SPIRVType>(t); }))
1680 return false;
1681
1682 // Special treatment for global variables, whose type requirements are
1683 // conveyed by type attributes.
1684 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1685 valueTypes.push_back(globalVar.getType());
1686
1687 // Make sure the op's operands/results use types that are allowed by the
1688 // target environment.
1689 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
1690 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
1691 for (Type valueType : valueTypes) {
1692 typeExtensions.clear();
1693 cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1694 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1695 typeExtensions)))
1696 return false;
1697
1698 typeCapabilities.clear();
1699 cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1700 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1701 typeCapabilities)))
1702 return false;
1703 }
1704
1705 return true;
1706}
1707
1708//===----------------------------------------------------------------------===//
1709// Public functions for populating patterns
1710//===----------------------------------------------------------------------===//
1711
1713 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1714 patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
1715}
1716
1718 patterns.add<FuncOpVectorUnroll>(patterns.getContext());
1719}
1720
1722 patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
1723}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
#define BIT_WIDTH_CASE(BIT_WIDTH)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
Block represents an ordered list of Operations.
Definition Block.h:33
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition Block.h:193
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:160
OpListType & getOperations()
Definition Block.h:137
Operation & front()
Definition Block.h:153
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition Block.cpp:201
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
FloatType getF32Type()
Definition Builders.cpp:43
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
MLIRContext * getContext() const
Definition Builders.h:56
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition Builders.h:240
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
void setOperand(unsigned idx, Value value)
Definition Operation.h:351
operand_type_iterator operand_type_end()
Definition Operation.h:396
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
result_type_iterator result_type_end()
Definition Operation.h:427
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_type_iterator result_type_begin()
Definition Operation.h:426
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_type_range getResultTypes()
Definition Operation.h:428
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
operand_type_iterator operand_type_begin()
Definition Operation.h:395
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Block & front()
Definition Region.h:65
iterator begin()
Definition Region.h:55
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:76
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ArrayType get(Type elementType, unsigned elementCount)
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition SPIRVTypes.h:147
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
Definition SPIRVTypes.h:65
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
Definition SPIRVTypes.h:54
static SampledImageType get(Type imageType)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
MLIRContext * getContext() const
Returns the MLIRContext.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
int getComputeVectorSize(int64_t size)
LogicalResult unrollVectorsInSignatures(Operation *op)
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
const FrozenRewritePatternSet & patterns
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
@ ExistingOps
Only pre-existing ops are processed.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)