MLIR 23.0.0git
SPIRVConversion.cpp
Go to the documentation of this file.
1//===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities used to lower to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
26#include "mlir/IR/Operation.h"
28#include "mlir/Support/LLVM.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/MathExtras.h"
36
37#include <optional>
38
39#define DEBUG_TYPE "mlir-spirv-conversion"
40
41using namespace mlir;
42
43namespace {
44
45//===----------------------------------------------------------------------===//
46// Utility functions
47//===----------------------------------------------------------------------===//
48
49static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
50 LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
51 if (vecType.isScalable()) {
52 LLVM_DEBUG(llvm::dbgs()
53 << "--scalable vectors are not supported -> BAIL\n");
54 return std::nullopt;
55 }
56 SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
57 std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(
58 1, mlir::spirv::getComputeVectorSize(vecType.getShape().back()));
59 if (!targetShape) {
60 LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
61 return std::nullopt;
62 }
63 auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
64 if (!maybeShapeRatio) {
65 LLVM_DEBUG(llvm::dbgs()
66 << "--could not compute integral shape ratio -> BAIL\n");
67 return std::nullopt;
68 }
69 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
70 LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
71 return std::nullopt;
72 }
73 LLVM_DEBUG(llvm::dbgs()
74 << "--found an integral shape ratio to unroll to -> SUCCESS\n");
75 return targetShape;
76}
77
78/// Checks that `candidates` extension requirements are possible to be satisfied
79/// with the given `targetEnv`.
80///
81/// `candidates` is a vector of vector for extension requirements following
82/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
83/// convention.
84template <typename LabelT>
85static LogicalResult checkExtensionRequirements(
86 LabelT label, const spirv::TargetEnv &targetEnv,
88 for (const auto &ors : candidates) {
89 if (targetEnv.allows(ors))
90 continue;
91
92 LLVM_DEBUG({
93 SmallVector<StringRef> extStrings;
94 for (spirv::Extension ext : ors)
95 extStrings.push_back(spirv::stringifyExtension(ext));
96
97 llvm::dbgs() << label << " illegal: requires at least one extension in ["
98 << llvm::join(extStrings, ", ")
99 << "] but none allowed in target environment\n";
100 });
101 return failure();
102 }
103 return success();
104}
105
106/// Checks that `candidates`capability requirements are possible to be satisfied
107/// with the given `isAllowedFn`.
108///
109/// `candidates` is a vector of vector for capability requirements following
110/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
111/// convention.
112template <typename LabelT>
113static LogicalResult checkCapabilityRequirements(
114 LabelT label, const spirv::TargetEnv &targetEnv,
116 for (const auto &ors : candidates) {
117 if (targetEnv.allows(ors))
118 continue;
119
120 LLVM_DEBUG({
121 SmallVector<StringRef> capStrings;
122 for (spirv::Capability cap : ors)
123 capStrings.push_back(spirv::stringifyCapability(cap));
124
125 llvm::dbgs() << label << " illegal: requires at least one capability in ["
126 << llvm::join(capStrings, ", ")
127 << "] but none allowed in target environment\n";
128 });
129 return failure();
130 }
131 return success();
132}
133
134/// Returns true if the given `storageClass` needs explicit layout when used in
135/// Shader environments.
136static bool needsExplicitLayout(spirv::StorageClass storageClass) {
137 switch (storageClass) {
138 case spirv::StorageClass::PhysicalStorageBuffer:
139 case spirv::StorageClass::PushConstant:
140 case spirv::StorageClass::StorageBuffer:
141 case spirv::StorageClass::Uniform:
142 return true;
143 default:
144 return false;
145 }
146}
147
148/// Wraps the given `elementType` in a struct and gets the pointer to the
149/// struct. This is used to satisfy Vulkan interface requirements.
151wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
152 auto structType = needsExplicitLayout(storageClass)
153 ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
154 : spirv::StructType::get(elementType);
155 return spirv::PointerType::get(structType, storageClass);
156}
157
158//===----------------------------------------------------------------------===//
159// Type Conversion
160//===----------------------------------------------------------------------===//
161
162static spirv::ScalarType getIndexType(MLIRContext *ctx,
164 return cast<spirv::ScalarType>(
165 IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
166}
167
168// TODO: This is a utility function that should probably be exposed by the
169// SPIR-V dialect. Keeping it local till the use case arises.
170static std::optional<int64_t>
171getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
172 if (isa<spirv::ScalarType>(type)) {
173 auto bitWidth = type.getIntOrFloatBitWidth();
174 // According to the SPIR-V spec:
175 // "There is no physical size or bit pattern defined for values with boolean
176 // type. If they are stored (in conjunction with OpVariable), they can only
177 // be used with logical addressing operations, not physical, and only with
178 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
179 // Private, Function, Input, and Output."
180 if (bitWidth == 1)
181 return std::nullopt;
182 return bitWidth / 8;
183 }
184
185 // Handle 8-bit floats.
186 if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
187 auto bitWidth = type.getIntOrFloatBitWidth();
188 if (bitWidth == 8)
189 return bitWidth / 8;
190 return std::nullopt;
191 }
192
193 if (auto complexType = dyn_cast<ComplexType>(type)) {
194 auto elementSize = getTypeNumBytes(options, complexType.getElementType());
195 if (!elementSize)
196 return std::nullopt;
197 return 2 * *elementSize;
198 }
199
200 if (auto vecType = dyn_cast<VectorType>(type)) {
201 auto elementSize = getTypeNumBytes(options, vecType.getElementType());
202 if (!elementSize)
203 return std::nullopt;
204 return vecType.getNumElements() * *elementSize;
205 }
206
207 if (auto memRefType = dyn_cast<MemRefType>(type)) {
208 // TODO: Layout should also be controlled by the ABI attributes. For now
209 // using the layout from MemRef.
210 int64_t offset;
212 if (!memRefType.hasStaticShape() ||
213 failed(memRefType.getStridesAndOffset(strides, offset)))
214 return std::nullopt;
215
216 // To get the size of the memref object in memory, the total size is the
217 // max(stride * dimension-size) computed for all dimensions times the size
218 // of the element.
219 auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
220 if (!elementSize)
221 return std::nullopt;
222
223 if (memRefType.getRank() == 0)
224 return elementSize;
225
226 auto dims = memRefType.getShape();
227 if (llvm::is_contained(dims, ShapedType::kDynamic) ||
228 ShapedType::isDynamic(offset) ||
229 llvm::is_contained(strides, ShapedType::kDynamic))
230 return std::nullopt;
231
232 int64_t memrefSize = -1;
233 for (const auto &shape : enumerate(dims))
234 memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
235
236 return (offset + memrefSize) * *elementSize;
237 }
238
239 if (auto tensorType = dyn_cast<TensorType>(type)) {
240 if (!tensorType.hasStaticShape())
241 return std::nullopt;
242
243 auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
244 if (!elementSize)
245 return std::nullopt;
246
247 int64_t size = *elementSize;
248 for (auto shape : tensorType.getShape())
249 size *= shape;
250
251 return size;
252 }
253
254 // TODO: Add size computation for other types.
255 return std::nullopt;
256}
257
258/// Converts a scalar `type` to a suitable type under the given `targetEnv`.
259static Type
260convertScalarType(const spirv::TargetEnv &targetEnv,
262 std::optional<spirv::StorageClass> storageClass = {}) {
263 // Get extension and capability requirements for the given type.
266 type.getExtensions(extensions, storageClass);
267 type.getCapabilities(capabilities, storageClass);
268
269 // If all requirements are met, then we can accept this type as-is.
270 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
271 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
272 return type;
273
274 // Otherwise we need to adjust the type, which really means adjusting the
275 // bitwidth given this is a scalar type.
276 if (!options.emulateLT32BitScalarTypes)
277 return nullptr;
278
279 // We only emulate narrower scalar types here and do not truncate results.
280 if (type.getIntOrFloatBitWidth() > 32) {
281 LLVM_DEBUG(llvm::dbgs()
282 << type
283 << " not converted to 32-bit for SPIR-V to avoid truncation\n");
284 return nullptr;
285 }
286
287 if (auto floatType = dyn_cast<FloatType>(type)) {
288 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
289 return Builder(targetEnv.getContext()).getF32Type();
290 }
291
292 auto intType = cast<IntegerType>(type);
293 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
294 return IntegerType::get(targetEnv.getContext(), /*width=*/32,
295 intType.getSignedness());
296}
297
298/// Converts a sub-byte integer `type` to i32 regardless of target environment.
299/// Returns a nullptr for unsupported integer types, including non sub-byte
300/// types.
301///
302/// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
303/// the above given that these sub-byte types are not supported at all in
304/// SPIR-V; there are no compute/storage capability for them like other
305/// supported integer types.
306static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
307 IntegerType type) {
308 if (type.getWidth() > 8) {
309 LLVM_DEBUG(llvm::dbgs() << "not a subbyte type\n");
310 return nullptr;
311 }
312 if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
313 LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
314 return nullptr;
315 }
316
317 if (!llvm::isPowerOf2_32(type.getWidth())) {
318 LLVM_DEBUG(llvm::dbgs()
319 << "unsupported non-power-of-two bitwidth in sub-byte" << type
320 << "\n");
321 return nullptr;
322 }
323
324 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
325 return IntegerType::get(type.getContext(), /*width=*/32,
326 type.getSignedness());
327}
328
329/// Converts 8-bit float types to integer types with the same bit width.
330/// Returns a nullptr for unsupported 8-bit float types.
331static Type convert8BitFloatType(const SPIRVConversionOptions &options,
332 FloatType type) {
333 if (!options.emulateUnsupportedFloatTypes)
334 return nullptr;
335 // F8 types are converted to integer types with the same bit width.
336 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
337 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
338 Float8E8M0FNUType>(type))
339 return IntegerType::get(type.getContext(), type.getWidth());
340 LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n");
341 return nullptr;
342}
343
344/// Returns a type with the same shape but with any 8-bit float element type
345/// converted to the same bit width integer type. This is a noop when the
346/// element type is not the 8-bit float type or emulation flag is set to false.
347static ShapedType
348convertShaped8BitFloatType(ShapedType type,
350 if (!options.emulateUnsupportedFloatTypes)
351 return type;
352 Type srcElementType = type.getElementType();
353 Type convertedElementType = nullptr;
354 // F8 types are converted to integer types with the same bit width.
355 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
356 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
357 Float8E8M0FNUType>(srcElementType))
358 convertedElementType = IntegerType::get(
359 type.getContext(), srcElementType.getIntOrFloatBitWidth());
360
361 if (!convertedElementType)
362 return type;
363
364 return type.clone(convertedElementType);
365}
366
367/// Returns a type with the same shape but with any index element type converted
368/// to the matching integer type. This is a noop when the element type is not
369/// the index type.
370static ShapedType
371convertIndexElementType(ShapedType type,
373 Type indexType = dyn_cast<IndexType>(type.getElementType());
374 if (!indexType)
375 return type;
376
377 return type.clone(getIndexType(type.getContext(), options));
378}
379
380/// Converts a vector `type` to a suitable type under the given `targetEnv`.
381static Type
382convertVectorType(const spirv::TargetEnv &targetEnv,
383 const SPIRVConversionOptions &options, VectorType type,
384 std::optional<spirv::StorageClass> storageClass = {}) {
385 type = cast<VectorType>(convertIndexElementType(type, options));
386 type = cast<VectorType>(convertShaped8BitFloatType(type, options));
387 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
388 if (!scalarType) {
389 // If this is not a spec allowed scalar type, try to handle sub-byte integer
390 // types.
391 auto intType = dyn_cast<IntegerType>(type.getElementType());
392 if (!intType) {
393 LLVM_DEBUG(llvm::dbgs()
394 << type
395 << " illegal: cannot convert non-scalar element type\n");
396 return nullptr;
397 }
398
399 Type elementType = convertSubByteIntegerType(options, intType);
400 if (!elementType)
401 return nullptr;
402
403 if (type.getRank() <= 1 && type.getNumElements() == 1)
404 return elementType;
405
406 if (type.getNumElements() > 4) {
407 LLVM_DEBUG(llvm::dbgs()
408 << type << " illegal: > 4-element unimplemented\n");
409 return nullptr;
410 }
411
412 return VectorType::get(type.getShape(), elementType);
413 }
414
415 if (type.getRank() <= 1 && type.getNumElements() == 1)
416 return convertScalarType(targetEnv, options, scalarType, storageClass);
417
419 LLVM_DEBUG(llvm::dbgs()
420 << type << " illegal: not a valid composite type\n");
421 return nullptr;
422 }
423
424 // Get extension and capability requirements for the given type.
427 cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
428 cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
429
430 // If all requirements are met, then we can accept this type as-is.
431 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
432 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
433 return type;
434
435 auto elementType =
436 convertScalarType(targetEnv, options, scalarType, storageClass);
437 if (elementType)
438 return VectorType::get(type.getShape(), elementType);
439 return nullptr;
440}
441
442static Type
443convertComplexType(const spirv::TargetEnv &targetEnv,
444 const SPIRVConversionOptions &options, ComplexType type,
445 std::optional<spirv::StorageClass> storageClass = {}) {
446 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
447 if (!scalarType) {
448 LLVM_DEBUG(llvm::dbgs()
449 << type << " illegal: cannot convert non-scalar element type\n");
450 return nullptr;
451 }
452
453 auto elementType =
454 convertScalarType(targetEnv, options, scalarType, storageClass);
455 if (!elementType)
456 return nullptr;
457 if (elementType != type.getElementType()) {
458 LLVM_DEBUG(llvm::dbgs()
459 << type << " illegal: complex type emulation unsupported\n");
460 return nullptr;
461 }
462
463 return VectorType::get(2, elementType);
464}
465
466/// Converts a tensor `type` to a suitable type under the given `targetEnv`.
467///
468/// Note that this is mainly for lowering constant tensors. In SPIR-V one can
469/// create composite constants with OpConstantComposite to embed relative large
470/// constant values and use OpCompositeExtract and OpCompositeInsert to
471/// manipulate, like what we do for vectors.
472static Type convertTensorType(const spirv::TargetEnv &targetEnv,
474 TensorType type) {
475 // TODO: Handle dynamic shapes.
476 if (!type.hasStaticShape()) {
477 LLVM_DEBUG(llvm::dbgs()
478 << type << " illegal: dynamic shape unimplemented\n");
479 return nullptr;
480 }
481
482 type = cast<TensorType>(convertIndexElementType(type, options));
483 type = cast<TensorType>(convertShaped8BitFloatType(type, options));
484 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
485 if (!scalarType) {
486 LLVM_DEBUG(llvm::dbgs()
487 << type << " illegal: cannot convert non-scalar element type\n");
488 return nullptr;
489 }
490
491 std::optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
492 std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);
493 if (!scalarSize || !tensorSize) {
494 LLVM_DEBUG(llvm::dbgs()
495 << type << " illegal: cannot deduce element count\n");
496 return nullptr;
497 }
498
499 int64_t arrayElemCount = *tensorSize / *scalarSize;
500 if (arrayElemCount == 0) {
501 LLVM_DEBUG(llvm::dbgs()
502 << type << " illegal: cannot handle zero-element tensors\n");
503 return nullptr;
504 }
505 if (arrayElemCount > std::numeric_limits<unsigned>::max()) {
506 LLVM_DEBUG(llvm::dbgs()
507 << type << " illegal: cannot fit tensor into target type\n");
508 return nullptr;
509 }
510
511 Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
512 if (!arrayElemType)
513 return nullptr;
514 std::optional<int64_t> arrayElemSize =
515 getTypeNumBytes(options, arrayElemType);
516 if (!arrayElemSize) {
517 LLVM_DEBUG(llvm::dbgs()
518 << type << " illegal: cannot deduce converted element size\n");
519 return nullptr;
520 }
521
522 return spirv::ArrayType::get(arrayElemType, arrayElemCount);
523}
524
525static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
527 MemRefType type,
528 spirv::StorageClass storageClass) {
529 unsigned numBoolBits = options.boolNumBits;
530 if (numBoolBits != 8) {
531 LLVM_DEBUG(llvm::dbgs()
532 << "using non-8-bit storage for bool types unimplemented");
533 return nullptr;
534 }
535 auto elementType = dyn_cast<spirv::ScalarType>(
536 IntegerType::get(type.getContext(), numBoolBits));
537 if (!elementType)
538 return nullptr;
539 Type arrayElemType =
540 convertScalarType(targetEnv, options, elementType, storageClass);
541 if (!arrayElemType)
542 return nullptr;
543 std::optional<int64_t> arrayElemSize =
544 getTypeNumBytes(options, arrayElemType);
545 if (!arrayElemSize) {
546 LLVM_DEBUG(llvm::dbgs()
547 << type << " illegal: cannot deduce converted element size\n");
548 return nullptr;
549 }
550
551 if (!type.hasStaticShape()) {
552 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
553 // to the element.
554 if (targetEnv.allows(spirv::Capability::Kernel))
555 return spirv::PointerType::get(arrayElemType, storageClass);
556 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
557 auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
558 // For Vulkan we need extra wrapping struct and array to satisfy interface
559 // needs.
560 return wrapInStructAndGetPointer(arrayType, storageClass);
561 }
562
563 if (type.getNumElements() == 0) {
564 LLVM_DEBUG(llvm::dbgs()
565 << type << " illegal: zero-element memrefs are not supported\n");
566 return nullptr;
567 }
568
569 int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
570 int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
571 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
572 auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
573 if (targetEnv.allows(spirv::Capability::Kernel))
574 return spirv::PointerType::get(arrayType, storageClass);
575 return wrapInStructAndGetPointer(arrayType, storageClass);
576}
577
578static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
580 MemRefType type,
581 spirv::StorageClass storageClass) {
582 IntegerType elementType = cast<IntegerType>(type.getElementType());
583 Type arrayElemType = convertSubByteIntegerType(options, elementType);
584 if (!arrayElemType)
585 return nullptr;
586 int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType);
587
588 if (!type.hasStaticShape()) {
589 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
590 // to the element.
591 if (targetEnv.allows(spirv::Capability::Kernel))
592 return spirv::PointerType::get(arrayElemType, storageClass);
593 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
594 auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
595 // For Vulkan we need extra wrapping struct and array to satisfy interface
596 // needs.
597 return wrapInStructAndGetPointer(arrayType, storageClass);
598 }
599
600 if (type.getNumElements() == 0) {
601 LLVM_DEBUG(llvm::dbgs()
602 << type << " illegal: zero-element memrefs are not supported\n");
603 return nullptr;
604 }
605
606 int64_t memrefSize =
607 llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
608 int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
609 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
610 auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
611 if (targetEnv.allows(spirv::Capability::Kernel))
612 return spirv::PointerType::get(arrayType, storageClass);
613 return wrapInStructAndGetPointer(arrayType, storageClass);
614}
615
616static spirv::Dim convertRank(int64_t rank) {
617 switch (rank) {
618 case 1:
619 return spirv::Dim::Dim1D;
620 case 2:
621 return spirv::Dim::Dim2D;
622 case 3:
623 return spirv::Dim::Dim3D;
624 default:
625 llvm_unreachable("Invalid memref rank!");
626 }
627}
628
629static spirv::ImageFormat getImageFormat(Type elementType) {
630 return TypeSwitch<Type, spirv::ImageFormat>(elementType)
631 .Case([](Float16Type) { return spirv::ImageFormat::R16f; })
632 .Case([](Float32Type) { return spirv::ImageFormat::R32f; })
633 .Case([](IntegerType intType) {
634 auto const isSigned = intType.isSigned() || intType.isSignless();
635#define BIT_WIDTH_CASE(BIT_WIDTH) \
636 case BIT_WIDTH: \
637 return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \
638 : spirv::ImageFormat::R##BIT_WIDTH##ui
639
640 switch (intType.getWidth()) {
641 BIT_WIDTH_CASE(16);
642 BIT_WIDTH_CASE(32);
643 default:
644 llvm_unreachable("Unhandled integer type!");
645 }
646 })
647 .DefaultUnreachable("Unhandled element type!");
648#undef BIT_WIDTH_CASE
649}
650
651static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
653 MemRefType type) {
654 auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
655 if (!attr) {
656 LLVM_DEBUG(
657 llvm::dbgs()
658 << type
659 << " illegal: expected memory space to be a SPIR-V storage class "
660 "attribute; please use MemorySpaceToStorageClassConverter to map "
661 "numeric memory spaces beforehand\n");
662 return nullptr;
663 }
664 spirv::StorageClass storageClass = attr.getValue();
665
666 // Images are a special case since they are an opaque type from which elements
667 // may be accessed via image specific ops or directly through a texture
668 // pointer.
669 if (storageClass == spirv::StorageClass::Image) {
670 const int64_t rank = type.getRank();
671 if (rank < 1 || rank > 3) {
672 LLVM_DEBUG(llvm::dbgs()
673 << type << " illegal: cannot lower memref of rank " << rank
674 << " to a SPIR-V Image\n");
675 return nullptr;
676 }
677
678 // Note that we currently only support lowering to single element texels
679 // e.g. R32f.
680 auto elementType = type.getElementType();
681 if (!isa<spirv::ScalarType>(elementType)) {
682 LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot lower memref of "
683 << elementType << " to a SPIR-V Image\n");
684 return nullptr;
685 }
686
687 // Currently every memref in the image storage class is converted to a
688 // sampled image so we can hardcode the NeedSampler field. Future work
689 // will generalize this to support regular non-sampled images.
690 auto spvImageType = spirv::ImageType::get(
691 elementType, convertRank(rank), spirv::ImageDepthInfo::DepthUnknown,
692 spirv::ImageArrayedInfo::NonArrayed,
693 spirv::ImageSamplingInfo::SingleSampled,
694 spirv::ImageSamplerUseInfo::NeedSampler, getImageFormat(elementType));
695 auto spvSampledImageType = spirv::SampledImageType::get(spvImageType);
696 auto imagePtrType = spirv::PointerType::get(
697 spvSampledImageType, spirv::StorageClass::UniformConstant);
698 return imagePtrType;
699 }
700
701 if (isa<IntegerType>(type.getElementType())) {
702 if (type.getElementTypeBitWidth() == 1)
703 return convertBoolMemrefType(targetEnv, options, type, storageClass);
704 if (type.getElementTypeBitWidth() < 8)
705 return convertSubByteMemrefType(targetEnv, options, type, storageClass);
706 }
707
708 Type arrayElemType;
709 Type elementType = type.getElementType();
710 if (auto vecType = dyn_cast<VectorType>(elementType)) {
711 arrayElemType =
712 convertVectorType(targetEnv, options, vecType, storageClass);
713 } else if (auto complexType = dyn_cast<ComplexType>(elementType)) {
714 arrayElemType =
715 convertComplexType(targetEnv, options, complexType, storageClass);
716 } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
717 arrayElemType =
718 convertScalarType(targetEnv, options, scalarType, storageClass);
719 } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
720 type = cast<MemRefType>(convertIndexElementType(type, options));
721 arrayElemType = type.getElementType();
722 } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
723 // Hnadle 8 bit float types.
724 type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
725 arrayElemType = type.getElementType();
726 } else {
727 LLVM_DEBUG(
728 llvm::dbgs()
729 << type
730 << " unhandled: can only convert scalar or vector element type\n");
731 return nullptr;
732 }
733 if (!arrayElemType)
734 return nullptr;
735
736 std::optional<int64_t> arrayElemSize =
737 getTypeNumBytes(options, arrayElemType);
738 if (!arrayElemSize) {
739 LLVM_DEBUG(llvm::dbgs()
740 << type << " illegal: cannot deduce converted element size\n");
741 return nullptr;
742 }
743
744 if (!type.hasStaticShape()) {
745 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
746 // to the element.
747 if (targetEnv.allows(spirv::Capability::Kernel))
748 return spirv::PointerType::get(arrayElemType, storageClass);
749 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
750 auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
751 // For Vulkan we need extra wrapping struct and array to satisfy interface
752 // needs.
753 return wrapInStructAndGetPointer(arrayType, storageClass);
754 }
755
756 std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);
757 if (!memrefSize) {
758 LLVM_DEBUG(llvm::dbgs()
759 << type << " illegal: cannot deduce element count\n");
760 return nullptr;
761 }
762
763 if (*memrefSize == 0) {
764 LLVM_DEBUG(llvm::dbgs()
765 << type << " illegal: zero-element memrefs are not supported\n");
766 return nullptr;
767 }
768
769 int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
770 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
771 auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
772 if (targetEnv.allows(spirv::Capability::Kernel))
773 return spirv::PointerType::get(arrayType, storageClass);
774 return wrapInStructAndGetPointer(arrayType, storageClass);
775}
776
777//===----------------------------------------------------------------------===//
778// Type casting materialization
779//===----------------------------------------------------------------------===//
780
781/// Converts the given `inputs` to the original source `type` considering the
782/// `targetEnv`'s capabilities.
783///
784/// This function is meant to be used for source materialization in type
785/// converters. When the type converter needs to materialize a cast op back
786/// to some original source type, we need to check whether the original source
787/// type is supported in the target environment. If so, we can insert legal
788/// SPIR-V cast ops accordingly.
789///
790/// Note that in SPIR-V the capabilities for storage and compute are separate.
791/// This function is meant to handle the **compute** side; so it does not
792/// involve storage classes in its logic. The storage side is expected to be
793/// handled by MemRef conversion logic.
794static Value castToSourceType(const spirv::TargetEnv &targetEnv,
795 OpBuilder &builder, Type type, ValueRange inputs,
796 Location loc) {
797 // We can only cast one value in SPIR-V.
798 if (inputs.size() != 1) {
799 auto castOp =
800 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
801 return castOp.getResult(0);
802 }
803 Value input = inputs.front();
804
805 // Only support integer types for now. Floating point types to be implemented.
806 if (!isa<IntegerType>(type)) {
807 auto castOp =
808 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
809 return castOp.getResult(0);
810 }
811 auto inputType = cast<IntegerType>(input.getType());
812
813 auto scalarType = dyn_cast<spirv::ScalarType>(type);
814 if (!scalarType) {
815 auto castOp =
816 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
817 return castOp.getResult(0);
818 }
819
820 // Only support source type with a smaller bitwidth. This would mean we are
821 // truncating to go back so we don't need to worry about the signedness.
822 // For extension, we cannot have enough signal here to decide which op to use.
823 if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
824 auto castOp =
825 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
826 return castOp.getResult(0);
827 }
828
829 // Boolean values would need to use different ops than normal integer values.
830 if (type.isInteger(1)) {
831 Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
832 return spirv::IEqualOp::create(builder, loc, input, one);
833 }
834
835 // Check that the source integer type is supported by the environment.
838 scalarType.getExtensions(exts);
839 scalarType.getCapabilities(caps);
840 if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
841 failed(checkExtensionRequirements(type, targetEnv, exts))) {
842 auto castOp =
843 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
844 return castOp.getResult(0);
845 }
846
847 // We've already made sure this is truncating previously, so we don't need to
848 // care about signedness here. Still try to use a corresponding op for better
849 // consistency though.
850 if (type.isSignedInteger()) {
851 return spirv::SConvertOp::create(builder, loc, type, input);
852 }
853 return spirv::UConvertOp::create(builder, loc, type, input);
854}
855
856//===----------------------------------------------------------------------===//
857// Builtin Variables
858//===----------------------------------------------------------------------===//
859
860static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
861 spirv::BuiltIn builtin) {
862 // Look through all global variables in the given `body` block and check if
863 // there is a spirv.GlobalVariable that has the same `builtin` attribute.
864 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
865 if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
866 spirv::SPIRVDialect::getAttributeName(
867 spirv::Decoration::BuiltIn))) {
868 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
869 if (varBuiltIn == builtin) {
870 return varOp;
871 }
872 }
873 }
874 return nullptr;
875}
876
877/// Gets name of global variable for a builtin.
878std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
879 StringRef suffix) {
880 return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
881}
882
883/// Gets or inserts a global variable for a builtin within `body` block.
884static spirv::GlobalVariableOp
885getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
886 Type integerType, OpBuilder &builder,
887 StringRef prefix, StringRef suffix) {
888 if (auto varOp = getBuiltinVariable(body, builtin))
889 return varOp;
890
891 OpBuilder::InsertionGuard guard(builder);
892 builder.setInsertionPointToStart(&body);
893
894 spirv::GlobalVariableOp newVarOp;
895 switch (builtin) {
896 case spirv::BuiltIn::NumWorkgroups:
897 case spirv::BuiltIn::WorkgroupSize:
898 case spirv::BuiltIn::WorkgroupId:
899 case spirv::BuiltIn::LocalInvocationId:
900 case spirv::BuiltIn::GlobalInvocationId: {
901 auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
902 spirv::StorageClass::Input);
903 std::string name = getBuiltinVarName(builtin, prefix, suffix);
904 newVarOp =
905 spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
906 break;
907 }
908 case spirv::BuiltIn::SubgroupId:
909 case spirv::BuiltIn::NumSubgroups:
910 case spirv::BuiltIn::SubgroupSize:
911 case spirv::BuiltIn::SubgroupLocalInvocationId: {
912 auto ptrType =
913 spirv::PointerType::get(integerType, spirv::StorageClass::Input);
914 std::string name = getBuiltinVarName(builtin, prefix, suffix);
915 newVarOp =
916 spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
917 break;
918 }
919 default:
920 emitError(loc, "unimplemented builtin variable generation for ")
921 << stringifyBuiltIn(builtin);
922 }
923 return newVarOp;
924}
925
926//===----------------------------------------------------------------------===//
927// Push constant storage
928//===----------------------------------------------------------------------===//
929
930/// Returns the pointer type for the push constant storage containing
931/// `elementCount` 32-bit integer values.
932static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
933 Builder &builder,
934 Type indexType) {
935 auto arrayType = spirv::ArrayType::get(indexType, elementCount,
936 /*stride=*/4);
937 auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
938 return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
939}
940
941/// Returns the push constant varible containing `elementCount` 32-bit integer
942/// values in `body`. Returns null op if such an op does not exit.
943static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
944 unsigned elementCount) {
945 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
946 auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
947 if (!ptrType)
948 continue;
949
950 // Note that Vulkan requires "There must be no more than one push constant
951 // block statically used per shader entry point." So we should always reuse
952 // the existing one.
953 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
954 auto numElements = cast<spirv::ArrayType>(
955 cast<spirv::StructType>(ptrType.getPointeeType())
956 .getElementType(0))
957 .getNumElements();
958 if (numElements == elementCount)
959 return varOp;
960 }
961 }
962 return nullptr;
963}
964
965/// Gets or inserts a global variable for push constant storage containing
966/// `elementCount` 32-bit integer values in `block`.
967static spirv::GlobalVariableOp
968getOrInsertPushConstantVariable(Location loc, Block &block,
969 unsigned elementCount, OpBuilder &b,
970 Type indexType) {
971 if (auto varOp = getPushConstantVariable(block, elementCount))
972 return varOp;
973
974 auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
975 auto type = getPushConstantStorageType(elementCount, builder, indexType);
976 const char *name = "__push_constant_var__";
977 return spirv::GlobalVariableOp::create(builder, loc, type, name,
978 /*initializer=*/nullptr);
979}
980
981//===----------------------------------------------------------------------===//
982// func::FuncOp Conversion Patterns
983//===----------------------------------------------------------------------===//
984
985/// A pattern for rewriting function signature to convert arguments of functions
986/// to be of valid SPIR-V types.
987struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
988 using Base::Base;
989
990 LogicalResult
991 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
992 ConversionPatternRewriter &rewriter) const override {
993 FunctionType fnType = funcOp.getFunctionType();
994 if (fnType.getNumResults() > 1)
995 return failure();
996
997 TypeConverter::SignatureConversion signatureConverter(
998 fnType.getNumInputs());
999 for (const auto &argType : enumerate(fnType.getInputs())) {
1000 auto convertedType = getTypeConverter()->convertType(argType.value());
1001 if (!convertedType)
1002 return failure();
1003 signatureConverter.addInputs(argType.index(), convertedType);
1004 }
1005
1006 Type resultType;
1007 if (fnType.getNumResults() == 1) {
1008 resultType = getTypeConverter()->convertType(fnType.getResult(0));
1009 if (!resultType)
1010 return failure();
1011 }
1012
1013 // Create the converted spirv.func op.
1014 auto newFuncOp = spirv::FuncOp::create(
1015 rewriter, funcOp.getLoc(), funcOp.getName(),
1016 rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
1017 resultType ? TypeRange(resultType)
1018 : TypeRange()));
1019
1020 // Copy over all attributes other than the function name and type.
1021 for (NamedAttribute namedAttr : funcOp->getAttrs()) {
1022 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
1023 namedAttr.getName() != SymbolTable::getSymbolAttrName())
1024 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1025 }
1026
1027 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1028 newFuncOp.end());
1029 if (failed(rewriter.convertRegionTypes(
1030 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
1031 return failure();
1032 rewriter.eraseOp(funcOp);
1033 return success();
1034 }
1035};
1036
1037/// A pattern for rewriting function signature to convert vector arguments of
1038/// functions to be of valid types
1039struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
1040 using Base::Base;
1041
1042 LogicalResult matchAndRewrite(func::FuncOp funcOp,
1043 PatternRewriter &rewriter) const override {
1044 FunctionType fnType = funcOp.getFunctionType();
1045
1046 // TODO: Handle declarations.
1047 if (funcOp.isDeclaration()) {
1048 LLVM_DEBUG(llvm::dbgs()
1049 << fnType << " illegal: declarations are unsupported\n");
1050 return failure();
1051 }
1052
1053 // Bail out early for dynamically-shaped argument types: getZeroAttr
1054 // requires a statically-shaped type. VectorType is always statically
1055 // shaped, so this correctly skips it without a special-case guard.
1056 if (llvm::any_of(fnType.getInputs(), [](Type argType) {
1057 auto shapedType = dyn_cast<ShapedType>(argType);
1058 return shapedType && !shapedType.hasStaticShape();
1059 }))
1060 return failure();
1061
1062 // Create a new func op with the original type and copy the function body.
1063 auto newFuncOp = func::FuncOp::create(rewriter, funcOp.getLoc(),
1064 funcOp.getName(), fnType);
1065 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1066 newFuncOp.end());
1067
1068 Location loc = newFuncOp.getBody().getLoc();
1069
1070 Block &entryBlock = newFuncOp.getBlocks().front();
1071 OpBuilder::InsertionGuard guard(rewriter);
1072 rewriter.setInsertionPointToStart(&entryBlock);
1073
1074 TypeConverter::SignatureConversion oneToNTypeMapping(
1075 fnType.getInputs().size());
1076
1077 // For arguments that are of illegal types and require unrolling.
1078 // `unrolledInputNums` stores the indices of arguments that result from
1079 // unrolling in the new function signature. `newInputNo` is a counter.
1080 SmallVector<size_t> unrolledInputNums;
1081 size_t newInputNo = 0;
1082
1083 // For arguments that are of legal types and do not require unrolling.
1084 // `tmpOps` stores a mapping from temporary operations that serve as
1085 // placeholders for new arguments that will be added later. These operations
1086 // will be erased once the entry block's argument list is updated.
1087 llvm::SmallDenseMap<Operation *, size_t> tmpOps;
1088
1089 // This counts the number of new operations created.
1090 size_t newOpCount = 0;
1091
1092 // Enumerate through the arguments.
1093 for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
1094 // Check whether the argument is of vector type.
1095 auto origVecType = dyn_cast<VectorType>(origType);
1096 if (!origVecType) {
1097 // We need a placeholder for the old argument that will be erased later.
1098 Value result = arith::ConstantOp::create(
1099 rewriter, loc, origType, rewriter.getZeroAttr(origType));
1100 rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1101 tmpOps.insert({result.getDefiningOp(), newInputNo});
1102 oneToNTypeMapping.addInputs(origInputNo, origType);
1103 ++newInputNo;
1104 ++newOpCount;
1105 continue;
1106 }
1107 // Check whether the vector needs unrolling.
1108 auto targetShape = getTargetShape(origVecType);
1109 if (!targetShape) {
1110 // We need a placeholder for the old argument that will be erased later.
1111 Value result = arith::ConstantOp::create(
1112 rewriter, loc, origType, rewriter.getZeroAttr(origType));
1113 rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1114 tmpOps.insert({result.getDefiningOp(), newInputNo});
1115 oneToNTypeMapping.addInputs(origInputNo, origType);
1116 ++newInputNo;
1117 ++newOpCount;
1118 continue;
1119 }
1120 VectorType unrolledType =
1121 VectorType::get(*targetShape, origVecType.getElementType());
1122 auto originalShape =
1123 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1124
1125 // Prepare the result vector.
1126 Value result = arith::ConstantOp::create(
1127 rewriter, loc, origVecType, rewriter.getZeroAttr(origVecType));
1128 ++newOpCount;
1129 // Prepare the placeholder for the new arguments that will be added later.
1130 Value dummy = arith::ConstantOp::create(
1131 rewriter, loc, unrolledType, rewriter.getZeroAttr(unrolledType));
1132 ++newOpCount;
1133
1134 // Create the `vector.insert_strided_slice` ops.
1135 SmallVector<int64_t> strides(targetShape->size(), 1);
1136 SmallVector<Type> newTypes;
1137 for (SmallVector<int64_t> offsets :
1138 StaticTileOffsetRange(originalShape, *targetShape)) {
1139 result = vector::InsertStridedSliceOp::create(rewriter, loc, dummy,
1140 result, offsets, strides);
1141 newTypes.push_back(unrolledType);
1142 unrolledInputNums.push_back(newInputNo);
1143 ++newInputNo;
1144 ++newOpCount;
1145 }
1146 rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1147 oneToNTypeMapping.addInputs(origInputNo, newTypes);
1148 }
1149
1150 // Change the function signature.
1151 auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
1152 auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1153 rewriter.modifyOpInPlace(newFuncOp,
1154 [&] { newFuncOp.setFunctionType(newFnType); });
1155
1156 // Update the arguments in the entry block.
1157 entryBlock.eraseArguments(0, fnType.getNumInputs());
1158 SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
1159 entryBlock.addArguments(convertedTypes, locs);
1160
1161 // Replace all uses of placeholders for initially legal arguments with their
1162 // original function arguments (that were added to `newFuncOp`).
1163 for (auto &[placeholderOp, argIdx] : tmpOps) {
1164 if (!placeholderOp)
1165 continue;
1166 Value replacement = newFuncOp.getArgument(argIdx);
1167 rewriter.replaceAllUsesWith(placeholderOp->getResult(0), replacement);
1168 }
1169
1170 // Replace dummy operands of new `vector.insert_strided_slice` ops with
1171 // their corresponding new function arguments. The new
1172 // `vector.insert_strided_slice` ops are inserted only into the entry block,
1173 // so iterating over that block is sufficient.
1174 size_t unrolledInputIdx = 0;
1175 for (auto [count, op] : enumerate(entryBlock.getOperations())) {
1176 Operation &curOp = op;
1177 // Since all newly created operations are in the beginning, reaching the
1178 // end of them means that any later `vector.insert_strided_slice` should
1179 // not be touched.
1180 if (count >= newOpCount)
1181 continue;
1182 if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1183 size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1184 rewriter.modifyOpInPlace(&curOp, [&] {
1185 curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1186 });
1187 ++unrolledInputIdx;
1188 }
1189 }
1190
1191 // Erase the original funcOp. The `tmpOps` do not need to be erased since
1192 // they have no uses and will be handled by dead-code elimination.
1193 rewriter.eraseOp(funcOp);
1194 return success();
1195 }
1196};
1197
1198//===----------------------------------------------------------------------===//
1199// func::ReturnOp Conversion Patterns
1200//===----------------------------------------------------------------------===//
1201
1202/// A pattern for rewriting function signature and the return op to convert
1203/// vectors to be of valid types.
1204struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
1205 using Base::Base;
1206
1207 LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1208 PatternRewriter &rewriter) const override {
1209 // Check whether the parent funcOp is valid.
1210 auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
1211 if (!funcOp)
1212 return failure();
1213
1214 FunctionType fnType = funcOp.getFunctionType();
1215 TypeConverter::SignatureConversion oneToNTypeMapping(
1216 fnType.getResults().size());
1217 Location loc = returnOp.getLoc();
1218
1219 // For the new return op.
1220 SmallVector<Value> newOperands;
1221
1222 // Enumerate through the results.
1223 for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
1224 // Check whether the argument is of vector type.
1225 auto origVecType = dyn_cast<VectorType>(origType);
1226 if (!origVecType) {
1227 oneToNTypeMapping.addInputs(origResultNo, origType);
1228 newOperands.push_back(returnOp.getOperand(origResultNo));
1229 continue;
1230 }
1231 // Check whether the vector needs unrolling.
1232 auto targetShape = getTargetShape(origVecType);
1233 if (!targetShape) {
1234 // The original argument can be used.
1235 oneToNTypeMapping.addInputs(origResultNo, origType);
1236 newOperands.push_back(returnOp.getOperand(origResultNo));
1237 continue;
1238 }
1239 VectorType unrolledType =
1240 VectorType::get(*targetShape, origVecType.getElementType());
1241
1242 // Create `vector.extract_strided_slice` ops to form legal vectors from
1243 // the original operand of illegal type.
1244 auto originalShape =
1245 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1246 SmallVector<int64_t> strides(originalShape.size(), 1);
1247 SmallVector<int64_t> extractShape(originalShape.size(), 1);
1248 extractShape.back() = targetShape->back();
1249 SmallVector<Type> newTypes;
1250 Value returnValue = returnOp.getOperand(origResultNo);
1251 for (SmallVector<int64_t> offsets :
1252 StaticTileOffsetRange(originalShape, *targetShape)) {
1253 Value result = vector::ExtractStridedSliceOp::create(
1254 rewriter, loc, returnValue, offsets, extractShape, strides);
1255 if (originalShape.size() > 1) {
1256 SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
1257 result =
1258 vector::ExtractOp::create(rewriter, loc, result, extractIndices);
1259 }
1260 newOperands.push_back(result);
1261 newTypes.push_back(unrolledType);
1262 }
1263 oneToNTypeMapping.addInputs(origResultNo, newTypes);
1264 }
1265
1266 // Change the function signature.
1267 auto newFnType =
1268 FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
1269 TypeRange(oneToNTypeMapping.getConvertedTypes()));
1270 rewriter.modifyOpInPlace(funcOp,
1271 [&] { funcOp.setFunctionType(newFnType); });
1272
1273 // Replace the return op using the new operands. This will automatically
1274 // update the entry block as well.
1275 rewriter.replaceOp(returnOp,
1276 func::ReturnOp::create(rewriter, loc, newOperands));
1277
1278 return success();
1279 }
1280};
1281
1282} // namespace
1283
1284//===----------------------------------------------------------------------===//
1285// Public function for builtin variables
1286//===----------------------------------------------------------------------===//
1287
1289 spirv::BuiltIn builtin,
1290 Type integerType, OpBuilder &builder,
1291 StringRef prefix, StringRef suffix) {
1293 if (!parent) {
1294 op->emitError("expected operation to be within a module-like op");
1295 return nullptr;
1296 }
1297
1298 spirv::GlobalVariableOp varOp =
1299 getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
1300 builtin, integerType, builder, prefix, suffix);
1301 Value ptr = spirv::AddressOfOp::create(builder, op->getLoc(), varOp);
1302 return spirv::LoadOp::create(builder, op->getLoc(), ptr);
1303}
1304
1305//===----------------------------------------------------------------------===//
1306// Public function for pushing constant storage
1307//===----------------------------------------------------------------------===//
1308
1310 unsigned offset, Type integerType,
1311 OpBuilder &builder) {
1312 Location loc = op->getLoc();
1314 if (!parent) {
1315 op->emitError("expected operation to be within a module-like op");
1316 return nullptr;
1317 }
1318
1319 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1320 loc, parent->getRegion(0).front(), elementCount, builder, integerType);
1321
1322 Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
1323 Value offsetOp = spirv::ConstantOp::create(builder, loc, integerType,
1324 builder.getI32IntegerAttr(offset));
1325 auto addrOp = spirv::AddressOfOp::create(builder, loc, varOp);
1326 auto acOp = spirv::AccessChainOp::create(builder, loc, addrOp,
1327 llvm::ArrayRef({zeroOp, offsetOp}));
1328 return spirv::LoadOp::create(builder, loc, acOp);
1329}
1330
1331//===----------------------------------------------------------------------===//
1332// Public functions for index calculation
1333//===----------------------------------------------------------------------===//
1334
1336 int64_t offset, Type integerType,
1337 Location loc, OpBuilder &builder) {
1338 assert(indices.size() == strides.size() &&
1339 "must provide indices for all dimensions");
1340
1341 // TODO: Consider moving to use affine.apply and patterns converting
1342 // affine.apply to standard ops. This needs converting to SPIR-V passes to be
1343 // broken down into progressive small steps so we can have intermediate steps
1344 // using other dialects. At the moment SPIR-V is the final sink.
1345
1346 Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
1347 loc, integerType, IntegerAttr::get(integerType, offset));
1348 for (const auto &index : llvm::enumerate(indices)) {
1349 Value strideVal = builder.createOrFold<spirv::ConstantOp>(
1350 loc, integerType,
1351 IntegerAttr::get(integerType, strides[index.index()]));
1352 Value update =
1353 builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1354 linearizedIndex =
1355 builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1356 }
1357 return linearizedIndex;
1358}
1359
1361 MemRefType baseType, Value basePtr,
1363 OpBuilder &builder) {
1364 // Get base and offset of the MemRefType and verify they are static.
1365
1366 int64_t offset;
1368 if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1369 llvm::is_contained(strides, ShapedType::kDynamic) ||
1370 ShapedType::isDynamic(offset)) {
1371 return nullptr;
1372 }
1373
1374 auto indexType = typeConverter.getIndexType();
1375
1376 SmallVector<Value, 2> linearizedIndices;
1377 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
1378
1379 // Add a '0' at the start to index into the struct.
1380 linearizedIndices.push_back(zero);
1381
1382 if (baseType.getRank() == 0) {
1383 linearizedIndices.push_back(zero);
1384 } else {
1385 linearizedIndices.push_back(
1386 linearizeIndex(indices, strides, offset, indexType, loc, builder));
1387 }
1388 return spirv::AccessChainOp::create(builder, loc, basePtr, linearizedIndices);
1389}
1390
1392 MemRefType baseType, Value basePtr,
1394 OpBuilder &builder) {
1395 // Get base and offset of the MemRefType and verify they are static.
1396
1397 int64_t offset;
1399 if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1400 llvm::is_contained(strides, ShapedType::kDynamic) ||
1401 ShapedType::isDynamic(offset)) {
1402 return nullptr;
1403 }
1404
1405 auto indexType = typeConverter.getIndexType();
1406
1407 SmallVector<Value, 2> linearizedIndices;
1408 Value linearIndex;
1409 if (baseType.getRank() == 0) {
1410 linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
1411 } else {
1412 linearIndex =
1413 linearizeIndex(indices, strides, offset, indexType, loc, builder);
1414 }
1415 Type pointeeType =
1416 cast<spirv::PointerType>(basePtr.getType()).getPointeeType();
1417 if (isa<spirv::ArrayType>(pointeeType)) {
1418 linearizedIndices.push_back(linearIndex);
1419 return spirv::AccessChainOp::create(builder, loc, basePtr,
1420 linearizedIndices);
1421 }
1422 return spirv::PtrAccessChainOp::create(builder, loc, basePtr, linearIndex,
1423 linearizedIndices);
1424}
1425
1427 MemRefType baseType, Value basePtr,
1429 OpBuilder &builder) {
1430
1431 if (typeConverter.allows(spirv::Capability::Kernel)) {
1432 return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
1433 builder);
1434 }
1435
1436 return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
1437 builder);
1438}
1439
1440//===----------------------------------------------------------------------===//
1441// Public functions for vector unrolling
1442//===----------------------------------------------------------------------===//
1443
1445 for (int i : {4, 3, 2}) {
1446 if (size % i == 0)
1447 return i;
1448 }
1449 return 1;
1450}
1451
1454 VectorType srcVectorType = op.getSourceVectorType();
1455 assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
1456 int64_t vectorSize =
1457 mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
1458 return {vectorSize};
1459}
1460
1463 VectorType vectorType = op.getResultVectorType();
1464 SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
1465 nativeSize.back() =
1466 mlir::spirv::getComputeVectorSize(vectorType.getShape().back());
1467 return nativeSize;
1468}
1469
1470std::optional<SmallVector<int64_t>>
1473 if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
1474 SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
1475 nativeSize.back() =
1476 mlir::spirv::getComputeVectorSize(vecType.getShape().back());
1477 return nativeSize;
1478 }
1479 }
1480
1482 .Case<vector::ReductionOp, vector::TransposeOp>(
1483 [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
1484 .Default(std::nullopt);
1485}
1486
1488 MLIRContext *context = op->getContext();
1489 RewritePatternSet patterns(context);
1492 // We only want to apply signature conversion once to the existing func ops.
1493 // Without specifying strictMode, the greedy pattern rewriter will keep
1494 // looking for newly created func ops.
1495 return applyPatternsGreedily(op, std::move(patterns),
1496 GreedyRewriteConfig().setStrictness(
1498}
1499
1501 MLIRContext *context = op->getContext();
1502
1503 // Unroll vectors in function bodies to native vector size.
1504 {
1505 RewritePatternSet patterns(context);
1507 [](auto op) { return mlir::spirv::getNativeVectorShape(op); });
1508 populateVectorUnrollPatterns(patterns, options);
1509 if (failed(applyPatternsGreedily(op, std::move(patterns))))
1510 return failure();
1511 }
1512
1513 // Convert transpose ops into extract and insert pairs, in preparation of
1514 // further transformations to canonicalize/cancel.
1515 {
1516 RewritePatternSet patterns(context);
1518 patterns, vector::VectorTransposeLowering::EltWise);
1520 if (failed(applyPatternsGreedily(op, std::move(patterns))))
1521 return failure();
1522 }
1523
1524 // Run canonicalization to cast away leading size-1 dimensions.
1525 {
1526 RewritePatternSet patterns(context);
1527
1528 // We need to pull in casting way leading one dims.
1529 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
1530 vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
1531 vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
1532
1533 // Decompose different rank insert_strided_slice and n-D
1534 // extract_slided_slice.
1535 vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
1536 patterns);
1537 vector::InsertOp::getCanonicalizationPatterns(patterns, context);
1538 vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
1539
1540 // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
1541 // them up.
1542 vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
1543 vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
1544
1545 if (failed(applyPatternsGreedily(op, std::move(patterns))))
1546 return failure();
1547 }
1548 return success();
1549}
1550
1551//===----------------------------------------------------------------------===//
1552// SPIR-V TypeConverter
1553//===----------------------------------------------------------------------===//
1554
1556 const SPIRVConversionOptions &options)
1557 : targetEnv(targetAttr), options(options) {
1558 // Add conversions. The order matters here: later ones will be tried earlier.
1559
1560 // Allow all SPIR-V dialect specific types. This assumes all builtin types
1561 // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
1562 // were tried before.
1563 //
1564 // TODO: This assumes that the SPIR-V types are valid to use in the given
1565 // target environment, which should be the case if the whole pipeline is
1566 // driven by the same target environment. Still, we probably still want to
1567 // validate and convert to be safe.
1568 addConversion([](spirv::SPIRVType type) { return type; });
1569
1570 addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
1571
1572 addConversion([this](IntegerType intType) -> std::optional<Type> {
1573 if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
1574 return convertScalarType(this->targetEnv, this->options, scalarType);
1575 if (intType.getWidth() < 8)
1576 return convertSubByteIntegerType(this->options, intType);
1577 return Type();
1578 });
1579
1580 addConversion([this](FloatType floatType) -> std::optional<Type> {
1581 if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
1582 return convertScalarType(this->targetEnv, this->options, scalarType);
1583 if (floatType.getWidth() == 8)
1584 return convert8BitFloatType(this->options, floatType);
1585 return Type();
1586 });
1587
1588 addConversion([this](ComplexType complexType) {
1589 return convertComplexType(this->targetEnv, this->options, complexType);
1590 });
1591
1592 addConversion([this](VectorType vectorType) {
1593 return convertVectorType(this->targetEnv, this->options, vectorType);
1594 });
1595
1596 addConversion([this](TensorType tensorType) {
1597 return convertTensorType(this->targetEnv, this->options, tensorType);
1598 });
1599
1600 addConversion([this](MemRefType memRefType) {
1601 return convertMemrefType(this->targetEnv, this->options, memRefType);
1602 });
1603
1604 // Register some last line of defense casting logic.
1605 addSourceMaterialization(
1606 [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1607 return castToSourceType(this->targetEnv, builder, type, inputs, loc);
1608 });
1609 addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
1610 Location loc) {
1611 auto cast = UnrealizedConversionCastOp::create(builder, loc, type, inputs);
1612 return cast.getResult(0);
1613 });
1614}
1615
1617 return ::getIndexType(getContext(), options);
1618}
1619
1620MLIRContext *SPIRVTypeConverter::getContext() const {
1621 return targetEnv.getAttr().getContext();
1622}
1623
1624bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
1625 return targetEnv.allows(capability);
1626}
1627
1628//===----------------------------------------------------------------------===//
1629// SPIR-V ConversionTarget
1630//===----------------------------------------------------------------------===//
1631
1632std::unique_ptr<SPIRVConversionTarget>
1634 std::unique_ptr<SPIRVConversionTarget> target(
1635 // std::make_unique does not work here because the constructor is private.
1636 new SPIRVConversionTarget(targetAttr));
1637 SPIRVConversionTarget *targetPtr = target.get();
1638 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1639 // We need to capture the raw pointer here because it is stable:
1640 // target will be destroyed once this function is returned.
1641 [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
1642 return target;
1643}
1644
1645SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
1646 : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
1647
1648bool SPIRVConversionTarget::isLegalOp(Operation *op) {
1649 // Make sure this op is available at the given version. Ops not implementing
1650 // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
1651 // SPIR-V versions.
1652 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1653 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1654 if (minVersion && *minVersion > this->targetEnv.getVersion()) {
1655 LLVM_DEBUG(llvm::dbgs()
1656 << op->getName() << " illegal: requiring min version "
1657 << spirv::stringifyVersion(*minVersion) << "\n");
1658 return false;
1659 }
1660 }
1661 if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1662 std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1663 if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
1664 LLVM_DEBUG(llvm::dbgs()
1665 << op->getName() << " illegal: requiring max version "
1666 << spirv::stringifyVersion(*maxVersion) << "\n");
1667 return false;
1668 }
1669 }
1670
1671 // Make sure this op's required extensions are allowed to use. Ops not
1672 // implementing QueryExtensionInterface do not require extensions to be
1673 // available.
1674 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1675 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1676 extensions.getExtensions())))
1677 return false;
1678
1679 // Make sure this op's required extensions are allowed to use. Ops not
1680 // implementing QueryCapabilityInterface do not require capabilities to be
1681 // available.
1682 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1683 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1684 capabilities.getCapabilities())))
1685 return false;
1686
1687 SmallVector<Type, 4> valueTypes;
1688 valueTypes.append(op->operand_type_begin(), op->operand_type_end());
1689 valueTypes.append(op->result_type_begin(), op->result_type_end());
1690
1691 // Ensure that all types have been converted to SPIRV types.
1692 if (llvm::any_of(valueTypes,
1693 [](Type t) { return !isa<spirv::SPIRVType>(t); }))
1694 return false;
1695
1696 // Special treatment for global variables, whose type requirements are
1697 // conveyed by type attributes.
1698 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1699 valueTypes.push_back(globalVar.getType());
1700
1701 // Make sure the op's operands/results use types that are allowed by the
1702 // target environment.
1703 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
1704 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
1705 for (Type valueType : valueTypes) {
1706 typeExtensions.clear();
1707 cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1708 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1709 typeExtensions)))
1710 return false;
1711
1712 typeCapabilities.clear();
1713 cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1714 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1715 typeCapabilities)))
1716 return false;
1717 }
1718
1719 return true;
1720}
1721
1722//===----------------------------------------------------------------------===//
1723// Public functions for populating patterns
1724//===----------------------------------------------------------------------===//
1725
1727 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1728 patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
1729}
1730
1732 patterns.add<FuncOpVectorUnroll>(patterns.getContext());
1733}
1734
1736 patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
1737}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
#define BIT_WIDTH_CASE(BIT_WIDTH)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
Block represents an ordered list of Operations.
Definition Block.h:33
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition Block.h:203
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:165
OpListType & getOperations()
Definition Block.h:147
Operation & front()
Definition Block.h:163
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition Block.cpp:206
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
FloatType getF32Type()
Definition Builders.cpp:47
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
MLIRContext * getContext() const
Definition Builders.h:56
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition Builders.h:242
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:694
void setOperand(unsigned idx, Value value)
Definition Operation.h:359
operand_type_iterator operand_type_end()
Definition Operation.h:404
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
result_type_iterator result_type_end()
Definition Operation.h:435
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_type_iterator result_type_begin()
Definition Operation.h:434
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_type_range getResultTypes()
Definition Operation.h:436
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:412
operand_type_iterator operand_type_begin()
Definition Operation.h:403
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Block & front()
Definition Region.h:65
iterator begin()
Definition Region.h:55
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:78
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ArrayType get(Type elementType, unsigned elementCount)
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition SPIRVTypes.h:147
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
Definition SPIRVTypes.h:65
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
Definition SPIRVTypes.h:54
static SampledImageType get(Type imageType)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
MLIRContext * getContext() const
Returns the MLIRContext.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
int getComputeVectorSize(int64_t size)
LogicalResult unrollVectorsInSignatures(Operation *op)
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
@ ExistingOps
Only pre-existing ops are processed.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)