MLIR 22.0.0git
SPIRVConversion.cpp
Go to the documentation of this file.
1//===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities used to lower to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
26#include "mlir/IR/Operation.h"
28#include "mlir/Support/LLVM.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/MathExtras.h"
36
37#include <optional>
38
39#define DEBUG_TYPE "mlir-spirv-conversion"
40
41using namespace mlir;
42
43namespace {
44
45//===----------------------------------------------------------------------===//
46// Utility functions
47//===----------------------------------------------------------------------===//
48
49static std::optional<SmallVector<int64_t>> getTargetShape(VectorType vecType) {
50 LLVM_DEBUG(llvm::dbgs() << "Get target shape\n");
51 if (vecType.isScalable()) {
52 LLVM_DEBUG(llvm::dbgs()
53 << "--scalable vectors are not supported -> BAIL\n");
54 return std::nullopt;
55 }
56 SmallVector<int64_t> unrollShape = llvm::to_vector<4>(vecType.getShape());
57 std::optional<SmallVector<int64_t>> targetShape = SmallVector<int64_t>(
58 1, mlir::spirv::getComputeVectorSize(vecType.getShape().back()));
59 if (!targetShape) {
60 LLVM_DEBUG(llvm::dbgs() << "--no unrolling target shape defined\n");
61 return std::nullopt;
62 }
63 auto maybeShapeRatio = computeShapeRatio(unrollShape, *targetShape);
64 if (!maybeShapeRatio) {
65 LLVM_DEBUG(llvm::dbgs()
66 << "--could not compute integral shape ratio -> BAIL\n");
67 return std::nullopt;
68 }
69 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) {
70 LLVM_DEBUG(llvm::dbgs() << "--no unrolling needed -> SKIP\n");
71 return std::nullopt;
72 }
73 LLVM_DEBUG(llvm::dbgs()
74 << "--found an integral shape ratio to unroll to -> SUCCESS\n");
75 return targetShape;
76}
77
78/// Checks that `candidates` extension requirements are possible to be satisfied
79/// with the given `targetEnv`.
80///
81/// `candidates` is a vector of vector for extension requirements following
82/// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
83/// convention.
84template <typename LabelT>
85static LogicalResult checkExtensionRequirements(
86 LabelT label, const spirv::TargetEnv &targetEnv,
88 for (const auto &ors : candidates) {
89 if (targetEnv.allows(ors))
90 continue;
91
92 LLVM_DEBUG({
93 SmallVector<StringRef> extStrings;
94 for (spirv::Extension ext : ors)
95 extStrings.push_back(spirv::stringifyExtension(ext));
96
97 llvm::dbgs() << label << " illegal: requires at least one extension in ["
98 << llvm::join(extStrings, ", ")
99 << "] but none allowed in target environment\n";
100 });
101 return failure();
102 }
103 return success();
104}
105
106/// Checks that `candidates`capability requirements are possible to be satisfied
107/// with the given `isAllowedFn`.
108///
109/// `candidates` is a vector of vector for capability requirements following
110/// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
111/// convention.
112template <typename LabelT>
113static LogicalResult checkCapabilityRequirements(
114 LabelT label, const spirv::TargetEnv &targetEnv,
116 for (const auto &ors : candidates) {
117 if (targetEnv.allows(ors))
118 continue;
119
120 LLVM_DEBUG({
121 SmallVector<StringRef> capStrings;
122 for (spirv::Capability cap : ors)
123 capStrings.push_back(spirv::stringifyCapability(cap));
124
125 llvm::dbgs() << label << " illegal: requires at least one capability in ["
126 << llvm::join(capStrings, ", ")
127 << "] but none allowed in target environment\n";
128 });
129 return failure();
130 }
131 return success();
132}
133
134/// Returns true if the given `storageClass` needs explicit layout when used in
135/// Shader environments.
136static bool needsExplicitLayout(spirv::StorageClass storageClass) {
137 switch (storageClass) {
138 case spirv::StorageClass::PhysicalStorageBuffer:
139 case spirv::StorageClass::PushConstant:
140 case spirv::StorageClass::StorageBuffer:
141 case spirv::StorageClass::Uniform:
142 return true;
143 default:
144 return false;
145 }
146}
147
148/// Wraps the given `elementType` in a struct and gets the pointer to the
149/// struct. This is used to satisfy Vulkan interface requirements.
151wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
152 auto structType = needsExplicitLayout(storageClass)
153 ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
154 : spirv::StructType::get(elementType);
155 return spirv::PointerType::get(structType, storageClass);
156}
157
158//===----------------------------------------------------------------------===//
159// Type Conversion
160//===----------------------------------------------------------------------===//
161
162static spirv::ScalarType getIndexType(MLIRContext *ctx,
164 return cast<spirv::ScalarType>(
165 IntegerType::get(ctx, options.use64bitIndex ? 64 : 32));
166}
167
168// TODO: This is a utility function that should probably be exposed by the
169// SPIR-V dialect. Keeping it local till the use case arises.
170static std::optional<int64_t>
171getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
172 if (isa<spirv::ScalarType>(type)) {
173 auto bitWidth = type.getIntOrFloatBitWidth();
174 // According to the SPIR-V spec:
175 // "There is no physical size or bit pattern defined for values with boolean
176 // type. If they are stored (in conjunction with OpVariable), they can only
177 // be used with logical addressing operations, not physical, and only with
178 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
179 // Private, Function, Input, and Output."
180 if (bitWidth == 1)
181 return std::nullopt;
182 return bitWidth / 8;
183 }
184
185 // Handle 8-bit floats.
186 if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
187 auto bitWidth = type.getIntOrFloatBitWidth();
188 if (bitWidth == 8)
189 return bitWidth / 8;
190 return std::nullopt;
191 }
192
193 if (auto complexType = dyn_cast<ComplexType>(type)) {
194 auto elementSize = getTypeNumBytes(options, complexType.getElementType());
195 if (!elementSize)
196 return std::nullopt;
197 return 2 * *elementSize;
198 }
199
200 if (auto vecType = dyn_cast<VectorType>(type)) {
201 auto elementSize = getTypeNumBytes(options, vecType.getElementType());
202 if (!elementSize)
203 return std::nullopt;
204 return vecType.getNumElements() * *elementSize;
205 }
206
207 if (auto memRefType = dyn_cast<MemRefType>(type)) {
208 // TODO: Layout should also be controlled by the ABI attributes. For now
209 // using the layout from MemRef.
210 int64_t offset;
212 if (!memRefType.hasStaticShape() ||
213 failed(memRefType.getStridesAndOffset(strides, offset)))
214 return std::nullopt;
215
216 // To get the size of the memref object in memory, the total size is the
217 // max(stride * dimension-size) computed for all dimensions times the size
218 // of the element.
219 auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
220 if (!elementSize)
221 return std::nullopt;
222
223 if (memRefType.getRank() == 0)
224 return elementSize;
225
226 auto dims = memRefType.getShape();
227 if (llvm::is_contained(dims, ShapedType::kDynamic) ||
228 ShapedType::isDynamic(offset) ||
229 llvm::is_contained(strides, ShapedType::kDynamic))
230 return std::nullopt;
231
232 int64_t memrefSize = -1;
233 for (const auto &shape : enumerate(dims))
234 memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
235
236 return (offset + memrefSize) * *elementSize;
237 }
238
239 if (auto tensorType = dyn_cast<TensorType>(type)) {
240 if (!tensorType.hasStaticShape())
241 return std::nullopt;
242
243 auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
244 if (!elementSize)
245 return std::nullopt;
246
247 int64_t size = *elementSize;
248 for (auto shape : tensorType.getShape())
249 size *= shape;
250
251 return size;
252 }
253
254 // TODO: Add size computation for other types.
255 return std::nullopt;
256}
257
258/// Converts a scalar `type` to a suitable type under the given `targetEnv`.
259static Type
260convertScalarType(const spirv::TargetEnv &targetEnv,
262 std::optional<spirv::StorageClass> storageClass = {}) {
263 // Get extension and capability requirements for the given type.
266 type.getExtensions(extensions, storageClass);
267 type.getCapabilities(capabilities, storageClass);
268
269 // If all requirements are met, then we can accept this type as-is.
270 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
271 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
272 return type;
273
274 // Otherwise we need to adjust the type, which really means adjusting the
275 // bitwidth given this is a scalar type.
276 if (!options.emulateLT32BitScalarTypes)
277 return nullptr;
278
279 // We only emulate narrower scalar types here and do not truncate results.
280 if (type.getIntOrFloatBitWidth() > 32) {
281 LLVM_DEBUG(llvm::dbgs()
282 << type
283 << " not converted to 32-bit for SPIR-V to avoid truncation\n");
284 return nullptr;
285 }
286
287 if (auto floatType = dyn_cast<FloatType>(type)) {
288 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
289 return Builder(targetEnv.getContext()).getF32Type();
290 }
291
292 auto intType = cast<IntegerType>(type);
293 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
294 return IntegerType::get(targetEnv.getContext(), /*width=*/32,
295 intType.getSignedness());
296}
297
298/// Converts a sub-byte integer `type` to i32 regardless of target environment.
299/// Returns a nullptr for unsupported integer types, including non sub-byte
300/// types.
301///
302/// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
303/// the above given that these sub-byte types are not supported at all in
304/// SPIR-V; there are no compute/storage capability for them like other
305/// supported integer types.
306static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
307 IntegerType type) {
308 if (type.getWidth() > 8) {
309 LLVM_DEBUG(llvm::dbgs() << "not a subbyte type\n");
310 return nullptr;
311 }
312 if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
313 LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
314 return nullptr;
315 }
316
317 if (!llvm::isPowerOf2_32(type.getWidth())) {
318 LLVM_DEBUG(llvm::dbgs()
319 << "unsupported non-power-of-two bitwidth in sub-byte" << type
320 << "\n");
321 return nullptr;
322 }
323
324 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
325 return IntegerType::get(type.getContext(), /*width=*/32,
326 type.getSignedness());
327}
328
329/// Converts 8-bit float types to integer types with the same bit width.
330/// Returns a nullptr for unsupported 8-bit float types.
331static Type convert8BitFloatType(const SPIRVConversionOptions &options,
332 FloatType type) {
333 if (!options.emulateUnsupportedFloatTypes)
334 return nullptr;
335 // F8 types are converted to integer types with the same bit width.
336 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
337 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
338 Float8E8M0FNUType>(type))
339 return IntegerType::get(type.getContext(), type.getWidth());
340 LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n");
341 return nullptr;
342}
343
344/// Returns a type with the same shape but with any 8-bit float element type
345/// converted to the same bit width integer type. This is a noop when the
346/// element type is not the 8-bit float type or emulation flag is set to false.
347static ShapedType
348convertShaped8BitFloatType(ShapedType type,
350 if (!options.emulateUnsupportedFloatTypes)
351 return type;
352 Type srcElementType = type.getElementType();
353 Type convertedElementType = nullptr;
354 // F8 types are converted to integer types with the same bit width.
355 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
356 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
357 Float8E8M0FNUType>(srcElementType))
358 convertedElementType = IntegerType::get(
359 type.getContext(), srcElementType.getIntOrFloatBitWidth());
360
361 if (!convertedElementType)
362 return type;
363
364 return type.clone(convertedElementType);
365}
366
367/// Returns a type with the same shape but with any index element type converted
368/// to the matching integer type. This is a noop when the element type is not
369/// the index type.
370static ShapedType
371convertIndexElementType(ShapedType type,
373 Type indexType = dyn_cast<IndexType>(type.getElementType());
374 if (!indexType)
375 return type;
376
377 return type.clone(getIndexType(type.getContext(), options));
378}
379
380/// Converts a vector `type` to a suitable type under the given `targetEnv`.
381static Type
382convertVectorType(const spirv::TargetEnv &targetEnv,
383 const SPIRVConversionOptions &options, VectorType type,
384 std::optional<spirv::StorageClass> storageClass = {}) {
385 type = cast<VectorType>(convertIndexElementType(type, options));
386 type = cast<VectorType>(convertShaped8BitFloatType(type, options));
387 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
388 if (!scalarType) {
389 // If this is not a spec allowed scalar type, try to handle sub-byte integer
390 // types.
391 auto intType = dyn_cast<IntegerType>(type.getElementType());
392 if (!intType) {
393 LLVM_DEBUG(llvm::dbgs()
394 << type
395 << " illegal: cannot convert non-scalar element type\n");
396 return nullptr;
397 }
398
399 Type elementType = convertSubByteIntegerType(options, intType);
400 if (!elementType)
401 return nullptr;
402
403 if (type.getRank() <= 1 && type.getNumElements() == 1)
404 return elementType;
405
406 if (type.getNumElements() > 4) {
407 LLVM_DEBUG(llvm::dbgs()
408 << type << " illegal: > 4-element unimplemented\n");
409 return nullptr;
410 }
411
412 return VectorType::get(type.getShape(), elementType);
413 }
414
415 if (type.getRank() <= 1 && type.getNumElements() == 1)
416 return convertScalarType(targetEnv, options, scalarType, storageClass);
417
419 LLVM_DEBUG(llvm::dbgs()
420 << type << " illegal: not a valid composite type\n");
421 return nullptr;
422 }
423
424 // Get extension and capability requirements for the given type.
427 cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
428 cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
429
430 // If all requirements are met, then we can accept this type as-is.
431 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
432 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
433 return type;
434
435 auto elementType =
436 convertScalarType(targetEnv, options, scalarType, storageClass);
437 if (elementType)
438 return VectorType::get(type.getShape(), elementType);
439 return nullptr;
440}
441
442static Type
443convertComplexType(const spirv::TargetEnv &targetEnv,
444 const SPIRVConversionOptions &options, ComplexType type,
445 std::optional<spirv::StorageClass> storageClass = {}) {
446 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
447 if (!scalarType) {
448 LLVM_DEBUG(llvm::dbgs()
449 << type << " illegal: cannot convert non-scalar element type\n");
450 return nullptr;
451 }
452
453 auto elementType =
454 convertScalarType(targetEnv, options, scalarType, storageClass);
455 if (!elementType)
456 return nullptr;
457 if (elementType != type.getElementType()) {
458 LLVM_DEBUG(llvm::dbgs()
459 << type << " illegal: complex type emulation unsupported\n");
460 return nullptr;
461 }
462
463 return VectorType::get(2, elementType);
464}
465
466/// Converts a tensor `type` to a suitable type under the given `targetEnv`.
467///
468/// Note that this is mainly for lowering constant tensors. In SPIR-V one can
469/// create composite constants with OpConstantComposite to embed relative large
470/// constant values and use OpCompositeExtract and OpCompositeInsert to
471/// manipulate, like what we do for vectors.
472static Type convertTensorType(const spirv::TargetEnv &targetEnv,
474 TensorType type) {
475 // TODO: Handle dynamic shapes.
476 if (!type.hasStaticShape()) {
477 LLVM_DEBUG(llvm::dbgs()
478 << type << " illegal: dynamic shape unimplemented\n");
479 return nullptr;
480 }
481
482 type = cast<TensorType>(convertIndexElementType(type, options));
483 type = cast<TensorType>(convertShaped8BitFloatType(type, options));
484 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
485 if (!scalarType) {
486 LLVM_DEBUG(llvm::dbgs()
487 << type << " illegal: cannot convert non-scalar element type\n");
488 return nullptr;
489 }
490
491 std::optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
492 std::optional<int64_t> tensorSize = getTypeNumBytes(options, type);
493 if (!scalarSize || !tensorSize) {
494 LLVM_DEBUG(llvm::dbgs()
495 << type << " illegal: cannot deduce element count\n");
496 return nullptr;
497 }
498
499 int64_t arrayElemCount = *tensorSize / *scalarSize;
500 if (arrayElemCount == 0) {
501 LLVM_DEBUG(llvm::dbgs()
502 << type << " illegal: cannot handle zero-element tensors\n");
503 return nullptr;
504 }
505 if (arrayElemCount > std::numeric_limits<unsigned>::max()) {
506 LLVM_DEBUG(llvm::dbgs()
507 << type << " illegal: cannot fit tensor into target type\n");
508 return nullptr;
509 }
510
511 Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
512 if (!arrayElemType)
513 return nullptr;
514 std::optional<int64_t> arrayElemSize =
515 getTypeNumBytes(options, arrayElemType);
516 if (!arrayElemSize) {
517 LLVM_DEBUG(llvm::dbgs()
518 << type << " illegal: cannot deduce converted element size\n");
519 return nullptr;
520 }
521
522 return spirv::ArrayType::get(arrayElemType, arrayElemCount);
523}
524
525static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
527 MemRefType type,
528 spirv::StorageClass storageClass) {
529 unsigned numBoolBits = options.boolNumBits;
530 if (numBoolBits != 8) {
531 LLVM_DEBUG(llvm::dbgs()
532 << "using non-8-bit storage for bool types unimplemented");
533 return nullptr;
534 }
535 auto elementType = dyn_cast<spirv::ScalarType>(
536 IntegerType::get(type.getContext(), numBoolBits));
537 if (!elementType)
538 return nullptr;
539 Type arrayElemType =
540 convertScalarType(targetEnv, options, elementType, storageClass);
541 if (!arrayElemType)
542 return nullptr;
543 std::optional<int64_t> arrayElemSize =
544 getTypeNumBytes(options, arrayElemType);
545 if (!arrayElemSize) {
546 LLVM_DEBUG(llvm::dbgs()
547 << type << " illegal: cannot deduce converted element size\n");
548 return nullptr;
549 }
550
551 if (!type.hasStaticShape()) {
552 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
553 // to the element.
554 if (targetEnv.allows(spirv::Capability::Kernel))
555 return spirv::PointerType::get(arrayElemType, storageClass);
556 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
557 auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
558 // For Vulkan we need extra wrapping struct and array to satisfy interface
559 // needs.
560 return wrapInStructAndGetPointer(arrayType, storageClass);
561 }
562
563 if (type.getNumElements() == 0) {
564 LLVM_DEBUG(llvm::dbgs()
565 << type << " illegal: zero-element memrefs are not supported\n");
566 return nullptr;
567 }
568
569 int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
570 int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
571 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
572 auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
573 if (targetEnv.allows(spirv::Capability::Kernel))
574 return spirv::PointerType::get(arrayType, storageClass);
575 return wrapInStructAndGetPointer(arrayType, storageClass);
576}
577
578static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv,
580 MemRefType type,
581 spirv::StorageClass storageClass) {
582 IntegerType elementType = cast<IntegerType>(type.getElementType());
583 Type arrayElemType = convertSubByteIntegerType(options, elementType);
584 if (!arrayElemType)
585 return nullptr;
586 int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType);
587
588 if (!type.hasStaticShape()) {
589 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
590 // to the element.
591 if (targetEnv.allows(spirv::Capability::Kernel))
592 return spirv::PointerType::get(arrayElemType, storageClass);
593 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
594 auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
595 // For Vulkan we need extra wrapping struct and array to satisfy interface
596 // needs.
597 return wrapInStructAndGetPointer(arrayType, storageClass);
598 }
599
600 if (type.getNumElements() == 0) {
601 LLVM_DEBUG(llvm::dbgs()
602 << type << " illegal: zero-element memrefs are not supported\n");
603 return nullptr;
604 }
605
606 int64_t memrefSize =
607 llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
608 int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
609 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
610 auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
611 if (targetEnv.allows(spirv::Capability::Kernel))
612 return spirv::PointerType::get(arrayType, storageClass);
613 return wrapInStructAndGetPointer(arrayType, storageClass);
614}
615
616static spirv::Dim convertRank(int64_t rank) {
617 switch (rank) {
618 case 1:
619 return spirv::Dim::Dim1D;
620 case 2:
621 return spirv::Dim::Dim2D;
622 case 3:
623 return spirv::Dim::Dim3D;
624 default:
625 llvm_unreachable("Invalid memref rank!");
626 }
627}
628
629static spirv::ImageFormat getImageFormat(Type elementType) {
630 return TypeSwitch<Type, spirv::ImageFormat>(elementType)
631 .Case<Float16Type>([](Float16Type) { return spirv::ImageFormat::R16f; })
632 .Case<Float32Type>([](Float32Type) { return spirv::ImageFormat::R32f; })
633 .Case<IntegerType>([](IntegerType intType) {
634 auto const isSigned = intType.isSigned() || intType.isSignless();
635#define BIT_WIDTH_CASE(BIT_WIDTH) \
636 case BIT_WIDTH: \
637 return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \
638 : spirv::ImageFormat::R##BIT_WIDTH##ui
639
640 switch (intType.getWidth()) {
641 BIT_WIDTH_CASE(16);
642 BIT_WIDTH_CASE(32);
643 default:
644 llvm_unreachable("Unhandled integer type!");
645 }
646 })
647 .DefaultUnreachable("Unhandled element type!");
648#undef BIT_WIDTH_CASE
649}
650
651static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
653 MemRefType type) {
654 auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
655 if (!attr) {
656 LLVM_DEBUG(
657 llvm::dbgs()
658 << type
659 << " illegal: expected memory space to be a SPIR-V storage class "
660 "attribute; please use MemorySpaceToStorageClassConverter to map "
661 "numeric memory spaces beforehand\n");
662 return nullptr;
663 }
664 spirv::StorageClass storageClass = attr.getValue();
665
666 // Images are a special case since they are an opaque type from which elements
667 // may be accessed via image specific ops or directly through a texture
668 // pointer.
669 if (storageClass == spirv::StorageClass::Image) {
670 const int64_t rank = type.getRank();
671 if (rank < 1 || rank > 3) {
672 LLVM_DEBUG(llvm::dbgs()
673 << type << " illegal: cannot lower memref of rank " << rank
674 << " to a SPIR-V Image\n");
675 return nullptr;
676 }
677
678 // Note that we currently only support lowering to single element texels
679 // e.g. R32f.
680 auto elementType = type.getElementType();
681 if (!isa<spirv::ScalarType>(elementType)) {
682 LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot lower memref of "
683 << elementType << " to a SPIR-V Image\n");
684 return nullptr;
685 }
686
687 // Currently every memref in the image storage class is converted to a
688 // sampled image so we can hardcode the NeedSampler field. Future work
689 // will generalize this to support regular non-sampled images.
690 auto spvImageType = spirv::ImageType::get(
691 elementType, convertRank(rank), spirv::ImageDepthInfo::DepthUnknown,
692 spirv::ImageArrayedInfo::NonArrayed,
693 spirv::ImageSamplingInfo::SingleSampled,
694 spirv::ImageSamplerUseInfo::NeedSampler, getImageFormat(elementType));
695 auto spvSampledImageType = spirv::SampledImageType::get(spvImageType);
696 auto imagePtrType = spirv::PointerType::get(
697 spvSampledImageType, spirv::StorageClass::UniformConstant);
698 return imagePtrType;
699 }
700
701 if (isa<IntegerType>(type.getElementType())) {
702 if (type.getElementTypeBitWidth() == 1)
703 return convertBoolMemrefType(targetEnv, options, type, storageClass);
704 if (type.getElementTypeBitWidth() < 8)
705 return convertSubByteMemrefType(targetEnv, options, type, storageClass);
706 }
707
708 Type arrayElemType;
709 Type elementType = type.getElementType();
710 if (auto vecType = dyn_cast<VectorType>(elementType)) {
711 arrayElemType =
712 convertVectorType(targetEnv, options, vecType, storageClass);
713 } else if (auto complexType = dyn_cast<ComplexType>(elementType)) {
714 arrayElemType =
715 convertComplexType(targetEnv, options, complexType, storageClass);
716 } else if (auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
717 arrayElemType =
718 convertScalarType(targetEnv, options, scalarType, storageClass);
719 } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
720 type = cast<MemRefType>(convertIndexElementType(type, options));
721 arrayElemType = type.getElementType();
722 } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
723 // Hnadle 8 bit float types.
724 type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
725 arrayElemType = type.getElementType();
726 } else {
727 LLVM_DEBUG(
728 llvm::dbgs()
729 << type
730 << " unhandled: can only convert scalar or vector element type\n");
731 return nullptr;
732 }
733 if (!arrayElemType)
734 return nullptr;
735
736 std::optional<int64_t> arrayElemSize =
737 getTypeNumBytes(options, arrayElemType);
738 if (!arrayElemSize) {
739 LLVM_DEBUG(llvm::dbgs()
740 << type << " illegal: cannot deduce converted element size\n");
741 return nullptr;
742 }
743
744 if (!type.hasStaticShape()) {
745 // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing
746 // to the element.
747 if (targetEnv.allows(spirv::Capability::Kernel))
748 return spirv::PointerType::get(arrayElemType, storageClass);
749 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
750 auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
751 // For Vulkan we need extra wrapping struct and array to satisfy interface
752 // needs.
753 return wrapInStructAndGetPointer(arrayType, storageClass);
754 }
755
756 std::optional<int64_t> memrefSize = getTypeNumBytes(options, type);
757 if (!memrefSize) {
758 LLVM_DEBUG(llvm::dbgs()
759 << type << " illegal: cannot deduce element count\n");
760 return nullptr;
761 }
762
763 if (*memrefSize == 0) {
764 LLVM_DEBUG(llvm::dbgs()
765 << type << " illegal: zero-element memrefs are not supported\n");
766 return nullptr;
767 }
768
769 int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
770 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
771 auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
772 if (targetEnv.allows(spirv::Capability::Kernel))
773 return spirv::PointerType::get(arrayType, storageClass);
774 return wrapInStructAndGetPointer(arrayType, storageClass);
775}
776
777//===----------------------------------------------------------------------===//
778// Type casting materialization
779//===----------------------------------------------------------------------===//
780
781/// Converts the given `inputs` to the original source `type` considering the
782/// `targetEnv`'s capabilities.
783///
784/// This function is meant to be used for source materialization in type
785/// converters. When the type converter needs to materialize a cast op back
786/// to some original source type, we need to check whether the original source
787/// type is supported in the target environment. If so, we can insert legal
788/// SPIR-V cast ops accordingly.
789///
790/// Note that in SPIR-V the capabilities for storage and compute are separate.
791/// This function is meant to handle the **compute** side; so it does not
792/// involve storage classes in its logic. The storage side is expected to be
793/// handled by MemRef conversion logic.
794static Value castToSourceType(const spirv::TargetEnv &targetEnv,
795 OpBuilder &builder, Type type, ValueRange inputs,
796 Location loc) {
797 // We can only cast one value in SPIR-V.
798 if (inputs.size() != 1) {
799 auto castOp =
800 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
801 return castOp.getResult(0);
802 }
803 Value input = inputs.front();
804
805 // Only support integer types for now. Floating point types to be implemented.
806 if (!isa<IntegerType>(type)) {
807 auto castOp =
808 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
809 return castOp.getResult(0);
810 }
811 auto inputType = cast<IntegerType>(input.getType());
812
813 auto scalarType = dyn_cast<spirv::ScalarType>(type);
814 if (!scalarType) {
815 auto castOp =
816 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
817 return castOp.getResult(0);
818 }
819
820 // Only support source type with a smaller bitwidth. This would mean we are
821 // truncating to go back so we don't need to worry about the signedness.
822 // For extension, we cannot have enough signal here to decide which op to use.
823 if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
824 auto castOp =
825 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
826 return castOp.getResult(0);
827 }
828
829 // Boolean values would need to use different ops than normal integer values.
830 if (type.isInteger(1)) {
831 Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
832 return spirv::IEqualOp::create(builder, loc, input, one);
833 }
834
835 // Check that the source integer type is supported by the environment.
838 scalarType.getExtensions(exts);
839 scalarType.getCapabilities(caps);
840 if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
841 failed(checkExtensionRequirements(type, targetEnv, exts))) {
842 auto castOp =
843 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
844 return castOp.getResult(0);
845 }
846
847 // We've already made sure this is truncating previously, so we don't need to
848 // care about signedness here. Still try to use a corresponding op for better
849 // consistency though.
850 if (type.isSignedInteger()) {
851 return spirv::SConvertOp::create(builder, loc, type, input);
852 }
853 return spirv::UConvertOp::create(builder, loc, type, input);
854}
855
856//===----------------------------------------------------------------------===//
857// Builtin Variables
858//===----------------------------------------------------------------------===//
859
860static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
861 spirv::BuiltIn builtin) {
862 // Look through all global variables in the given `body` block and check if
863 // there is a spirv.GlobalVariable that has the same `builtin` attribute.
864 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
865 if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
866 spirv::SPIRVDialect::getAttributeName(
867 spirv::Decoration::BuiltIn))) {
868 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
869 if (varBuiltIn == builtin) {
870 return varOp;
871 }
872 }
873 }
874 return nullptr;
875}
876
877/// Gets name of global variable for a builtin.
878std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
879 StringRef suffix) {
880 return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
881}
882
883/// Gets or inserts a global variable for a builtin within `body` block.
884static spirv::GlobalVariableOp
885getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
886 Type integerType, OpBuilder &builder,
887 StringRef prefix, StringRef suffix) {
888 if (auto varOp = getBuiltinVariable(body, builtin))
889 return varOp;
890
891 OpBuilder::InsertionGuard guard(builder);
892 builder.setInsertionPointToStart(&body);
893
894 spirv::GlobalVariableOp newVarOp;
895 switch (builtin) {
896 case spirv::BuiltIn::NumWorkgroups:
897 case spirv::BuiltIn::WorkgroupSize:
898 case spirv::BuiltIn::WorkgroupId:
899 case spirv::BuiltIn::LocalInvocationId:
900 case spirv::BuiltIn::GlobalInvocationId: {
901 auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
902 spirv::StorageClass::Input);
903 std::string name = getBuiltinVarName(builtin, prefix, suffix);
904 newVarOp =
905 spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
906 break;
907 }
908 case spirv::BuiltIn::SubgroupId:
909 case spirv::BuiltIn::NumSubgroups:
910 case spirv::BuiltIn::SubgroupSize:
911 case spirv::BuiltIn::SubgroupLocalInvocationId: {
912 auto ptrType =
913 spirv::PointerType::get(integerType, spirv::StorageClass::Input);
914 std::string name = getBuiltinVarName(builtin, prefix, suffix);
915 newVarOp =
916 spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
917 break;
918 }
919 default:
920 emitError(loc, "unimplemented builtin variable generation for ")
921 << stringifyBuiltIn(builtin);
922 }
923 return newVarOp;
924}
925
926//===----------------------------------------------------------------------===//
927// Push constant storage
928//===----------------------------------------------------------------------===//
929
930/// Returns the pointer type for the push constant storage containing
931/// `elementCount` 32-bit integer values.
932static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
933 Builder &builder,
934 Type indexType) {
935 auto arrayType = spirv::ArrayType::get(indexType, elementCount,
936 /*stride=*/4);
937 auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
938 return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
939}
940
941/// Returns the push constant varible containing `elementCount` 32-bit integer
942/// values in `body`. Returns null op if such an op does not exit.
943static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
944 unsigned elementCount) {
945 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
946 auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
947 if (!ptrType)
948 continue;
949
950 // Note that Vulkan requires "There must be no more than one push constant
951 // block statically used per shader entry point." So we should always reuse
952 // the existing one.
953 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
954 auto numElements = cast<spirv::ArrayType>(
955 cast<spirv::StructType>(ptrType.getPointeeType())
956 .getElementType(0))
957 .getNumElements();
958 if (numElements == elementCount)
959 return varOp;
960 }
961 }
962 return nullptr;
963}
964
965/// Gets or inserts a global variable for push constant storage containing
966/// `elementCount` 32-bit integer values in `block`.
967static spirv::GlobalVariableOp
968getOrInsertPushConstantVariable(Location loc, Block &block,
969 unsigned elementCount, OpBuilder &b,
970 Type indexType) {
971 if (auto varOp = getPushConstantVariable(block, elementCount))
972 return varOp;
973
974 auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
975 auto type = getPushConstantStorageType(elementCount, builder, indexType);
976 const char *name = "__push_constant_var__";
977 return spirv::GlobalVariableOp::create(builder, loc, type, name,
978 /*initializer=*/nullptr);
979}
980
981//===----------------------------------------------------------------------===//
982// func::FuncOp Conversion Patterns
983//===----------------------------------------------------------------------===//
984
985/// A pattern for rewriting function signature to convert arguments of functions
986/// to be of valid SPIR-V types.
987struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
988 using Base::Base;
989
990 LogicalResult
991 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
992 ConversionPatternRewriter &rewriter) const override {
993 FunctionType fnType = funcOp.getFunctionType();
994 if (fnType.getNumResults() > 1)
995 return failure();
996
997 TypeConverter::SignatureConversion signatureConverter(
998 fnType.getNumInputs());
999 for (const auto &argType : enumerate(fnType.getInputs())) {
1000 auto convertedType = getTypeConverter()->convertType(argType.value());
1001 if (!convertedType)
1002 return failure();
1003 signatureConverter.addInputs(argType.index(), convertedType);
1004 }
1005
1006 Type resultType;
1007 if (fnType.getNumResults() == 1) {
1008 resultType = getTypeConverter()->convertType(fnType.getResult(0));
1009 if (!resultType)
1010 return failure();
1011 }
1012
1013 // Create the converted spirv.func op.
1014 auto newFuncOp = spirv::FuncOp::create(
1015 rewriter, funcOp.getLoc(), funcOp.getName(),
1016 rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
1017 resultType ? TypeRange(resultType)
1018 : TypeRange()));
1019
1020 // Copy over all attributes other than the function name and type.
1021 for (const auto &namedAttr : funcOp->getAttrs()) {
1022 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
1023 namedAttr.getName() != SymbolTable::getSymbolAttrName())
1024 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1025 }
1026
1027 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1028 newFuncOp.end());
1029 if (failed(rewriter.convertRegionTypes(
1030 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
1031 return failure();
1032 rewriter.eraseOp(funcOp);
1033 return success();
1034 }
1035};
1036
1037/// A pattern for rewriting function signature to convert vector arguments of
1038/// functions to be of valid types
1039struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
1040 using Base::Base;
1041
1042 LogicalResult matchAndRewrite(func::FuncOp funcOp,
1043 PatternRewriter &rewriter) const override {
1044 FunctionType fnType = funcOp.getFunctionType();
1045
1046 // TODO: Handle declarations.
1047 if (funcOp.isDeclaration()) {
1048 LLVM_DEBUG(llvm::dbgs()
1049 << fnType << " illegal: declarations are unsupported\n");
1050 return failure();
1051 }
1052
1053 // Create a new func op with the original type and copy the function body.
1054 auto newFuncOp = func::FuncOp::create(rewriter, funcOp.getLoc(),
1055 funcOp.getName(), fnType);
1056 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1057 newFuncOp.end());
1058
1059 Location loc = newFuncOp.getBody().getLoc();
1060
1061 Block &entryBlock = newFuncOp.getBlocks().front();
1062 OpBuilder::InsertionGuard guard(rewriter);
1063 rewriter.setInsertionPointToStart(&entryBlock);
1064
1065 TypeConverter::SignatureConversion oneToNTypeMapping(
1066 fnType.getInputs().size());
1067
1068 // For arguments that are of illegal types and require unrolling.
1069 // `unrolledInputNums` stores the indices of arguments that result from
1070 // unrolling in the new function signature. `newInputNo` is a counter.
1071 SmallVector<size_t> unrolledInputNums;
1072 size_t newInputNo = 0;
1073
1074 // For arguments that are of legal types and do not require unrolling.
1075 // `tmpOps` stores a mapping from temporary operations that serve as
1076 // placeholders for new arguments that will be added later. These operations
1077 // will be erased once the entry block's argument list is updated.
1078 llvm::SmallDenseMap<Operation *, size_t> tmpOps;
1079
1080 // This counts the number of new operations created.
1081 size_t newOpCount = 0;
1082
1083 // Enumerate through the arguments.
1084 for (auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
1085 // Check whether the argument is of vector type.
1086 auto origVecType = dyn_cast<VectorType>(origType);
1087 if (!origVecType) {
1088 // We need a placeholder for the old argument that will be erased later.
1089 Value result = arith::ConstantOp::create(
1090 rewriter, loc, origType, rewriter.getZeroAttr(origType));
1091 rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1092 tmpOps.insert({result.getDefiningOp(), newInputNo});
1093 oneToNTypeMapping.addInputs(origInputNo, origType);
1094 ++newInputNo;
1095 ++newOpCount;
1096 continue;
1097 }
1098 // Check whether the vector needs unrolling.
1099 auto targetShape = getTargetShape(origVecType);
1100 if (!targetShape) {
1101 // We need a placeholder for the old argument that will be erased later.
1102 Value result = arith::ConstantOp::create(
1103 rewriter, loc, origType, rewriter.getZeroAttr(origType));
1104 rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1105 tmpOps.insert({result.getDefiningOp(), newInputNo});
1106 oneToNTypeMapping.addInputs(origInputNo, origType);
1107 ++newInputNo;
1108 ++newOpCount;
1109 continue;
1110 }
1111 VectorType unrolledType =
1112 VectorType::get(*targetShape, origVecType.getElementType());
1113 auto originalShape =
1114 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1115
1116 // Prepare the result vector.
1117 Value result = arith::ConstantOp::create(
1118 rewriter, loc, origVecType, rewriter.getZeroAttr(origVecType));
1119 ++newOpCount;
1120 // Prepare the placeholder for the new arguments that will be added later.
1121 Value dummy = arith::ConstantOp::create(
1122 rewriter, loc, unrolledType, rewriter.getZeroAttr(unrolledType));
1123 ++newOpCount;
1124
1125 // Create the `vector.insert_strided_slice` ops.
1126 SmallVector<int64_t> strides(targetShape->size(), 1);
1127 SmallVector<Type> newTypes;
1128 for (SmallVector<int64_t> offsets :
1129 StaticTileOffsetRange(originalShape, *targetShape)) {
1130 result = vector::InsertStridedSliceOp::create(rewriter, loc, dummy,
1131 result, offsets, strides);
1132 newTypes.push_back(unrolledType);
1133 unrolledInputNums.push_back(newInputNo);
1134 ++newInputNo;
1135 ++newOpCount;
1136 }
1137 rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
1138 oneToNTypeMapping.addInputs(origInputNo, newTypes);
1139 }
1140
1141 // Change the function signature.
1142 auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
1143 auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1144 rewriter.modifyOpInPlace(newFuncOp,
1145 [&] { newFuncOp.setFunctionType(newFnType); });
1146
1147 // Update the arguments in the entry block.
1148 entryBlock.eraseArguments(0, fnType.getNumInputs());
1149 SmallVector<Location> locs(convertedTypes.size(), newFuncOp.getLoc());
1150 entryBlock.addArguments(convertedTypes, locs);
1151
1152 // Replace all uses of placeholders for initially legal arguments with their
1153 // original function arguments (that were added to `newFuncOp`).
1154 for (auto &[placeholderOp, argIdx] : tmpOps) {
1155 if (!placeholderOp)
1156 continue;
1157 Value replacement = newFuncOp.getArgument(argIdx);
1158 rewriter.replaceAllUsesWith(placeholderOp->getResult(0), replacement);
1159 }
1160
1161 // Replace dummy operands of new `vector.insert_strided_slice` ops with
1162 // their corresponding new function arguments. The new
1163 // `vector.insert_strided_slice` ops are inserted only into the entry block,
1164 // so iterating over that block is sufficient.
1165 size_t unrolledInputIdx = 0;
1166 for (auto [count, op] : enumerate(entryBlock.getOperations())) {
1167 Operation &curOp = op;
1168 // Since all newly created operations are in the beginning, reaching the
1169 // end of them means that any later `vector.insert_strided_slice` should
1170 // not be touched.
1171 if (count >= newOpCount)
1172 continue;
1173 if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1174 size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1175 rewriter.modifyOpInPlace(&curOp, [&] {
1176 curOp.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1177 });
1178 ++unrolledInputIdx;
1179 }
1180 }
1181
1182 // Erase the original funcOp. The `tmpOps` do not need to be erased since
1183 // they have no uses and will be handled by dead-code elimination.
1184 rewriter.eraseOp(funcOp);
1185 return success();
1186 }
1187};
1188
1189//===----------------------------------------------------------------------===//
1190// func::ReturnOp Conversion Patterns
1191//===----------------------------------------------------------------------===//
1192
1193/// A pattern for rewriting function signature and the return op to convert
1194/// vectors to be of valid types.
1195struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
1196 using Base::Base;
1197
1198 LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1199 PatternRewriter &rewriter) const override {
1200 // Check whether the parent funcOp is valid.
1201 auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
1202 if (!funcOp)
1203 return failure();
1204
1205 FunctionType fnType = funcOp.getFunctionType();
1206 TypeConverter::SignatureConversion oneToNTypeMapping(
1207 fnType.getResults().size());
1208 Location loc = returnOp.getLoc();
1209
1210 // For the new return op.
1211 SmallVector<Value> newOperands;
1212
1213 // Enumerate through the results.
1214 for (auto [origResultNo, origType] : enumerate(fnType.getResults())) {
1215 // Check whether the argument is of vector type.
1216 auto origVecType = dyn_cast<VectorType>(origType);
1217 if (!origVecType) {
1218 oneToNTypeMapping.addInputs(origResultNo, origType);
1219 newOperands.push_back(returnOp.getOperand(origResultNo));
1220 continue;
1221 }
1222 // Check whether the vector needs unrolling.
1223 auto targetShape = getTargetShape(origVecType);
1224 if (!targetShape) {
1225 // The original argument can be used.
1226 oneToNTypeMapping.addInputs(origResultNo, origType);
1227 newOperands.push_back(returnOp.getOperand(origResultNo));
1228 continue;
1229 }
1230 VectorType unrolledType =
1231 VectorType::get(*targetShape, origVecType.getElementType());
1232
1233 // Create `vector.extract_strided_slice` ops to form legal vectors from
1234 // the original operand of illegal type.
1235 auto originalShape =
1236 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1237 SmallVector<int64_t> strides(originalShape.size(), 1);
1238 SmallVector<int64_t> extractShape(originalShape.size(), 1);
1239 extractShape.back() = targetShape->back();
1240 SmallVector<Type> newTypes;
1241 Value returnValue = returnOp.getOperand(origResultNo);
1242 for (SmallVector<int64_t> offsets :
1243 StaticTileOffsetRange(originalShape, *targetShape)) {
1244 Value result = vector::ExtractStridedSliceOp::create(
1245 rewriter, loc, returnValue, offsets, extractShape, strides);
1246 if (originalShape.size() > 1) {
1247 SmallVector<int64_t> extractIndices(originalShape.size() - 1, 0);
1248 result =
1249 vector::ExtractOp::create(rewriter, loc, result, extractIndices);
1250 }
1251 newOperands.push_back(result);
1252 newTypes.push_back(unrolledType);
1253 }
1254 oneToNTypeMapping.addInputs(origResultNo, newTypes);
1255 }
1256
1257 // Change the function signature.
1258 auto newFnType =
1259 FunctionType::get(rewriter.getContext(), TypeRange(fnType.getInputs()),
1260 TypeRange(oneToNTypeMapping.getConvertedTypes()));
1261 rewriter.modifyOpInPlace(funcOp,
1262 [&] { funcOp.setFunctionType(newFnType); });
1263
1264 // Replace the return op using the new operands. This will automatically
1265 // update the entry block as well.
1266 rewriter.replaceOp(returnOp,
1267 func::ReturnOp::create(rewriter, loc, newOperands));
1268
1269 return success();
1270 }
1271};
1272
1273} // namespace
1274
1275//===----------------------------------------------------------------------===//
1276// Public function for builtin variables
1277//===----------------------------------------------------------------------===//
1278
1280 spirv::BuiltIn builtin,
1281 Type integerType, OpBuilder &builder,
1282 StringRef prefix, StringRef suffix) {
1284 if (!parent) {
1285 op->emitError("expected operation to be within a module-like op");
1286 return nullptr;
1287 }
1288
1289 spirv::GlobalVariableOp varOp =
1290 getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
1291 builtin, integerType, builder, prefix, suffix);
1292 Value ptr = spirv::AddressOfOp::create(builder, op->getLoc(), varOp);
1293 return spirv::LoadOp::create(builder, op->getLoc(), ptr);
1294}
1295
1296//===----------------------------------------------------------------------===//
1297// Public function for pushing constant storage
1298//===----------------------------------------------------------------------===//
1299
1301 unsigned offset, Type integerType,
1302 OpBuilder &builder) {
1303 Location loc = op->getLoc();
1305 if (!parent) {
1306 op->emitError("expected operation to be within a module-like op");
1307 return nullptr;
1308 }
1309
1310 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1311 loc, parent->getRegion(0).front(), elementCount, builder, integerType);
1312
1313 Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
1314 Value offsetOp = spirv::ConstantOp::create(builder, loc, integerType,
1315 builder.getI32IntegerAttr(offset));
1316 auto addrOp = spirv::AddressOfOp::create(builder, loc, varOp);
1317 auto acOp = spirv::AccessChainOp::create(builder, loc, addrOp,
1318 llvm::ArrayRef({zeroOp, offsetOp}));
1319 return spirv::LoadOp::create(builder, loc, acOp);
1320}
1321
1322//===----------------------------------------------------------------------===//
1323// Public functions for index calculation
1324//===----------------------------------------------------------------------===//
1325
1327 int64_t offset, Type integerType,
1328 Location loc, OpBuilder &builder) {
1329 assert(indices.size() == strides.size() &&
1330 "must provide indices for all dimensions");
1331
1332 // TODO: Consider moving to use affine.apply and patterns converting
1333 // affine.apply to standard ops. This needs converting to SPIR-V passes to be
1334 // broken down into progressive small steps so we can have intermediate steps
1335 // using other dialects. At the moment SPIR-V is the final sink.
1336
1337 Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
1338 loc, integerType, IntegerAttr::get(integerType, offset));
1339 for (const auto &index : llvm::enumerate(indices)) {
1340 Value strideVal = builder.createOrFold<spirv::ConstantOp>(
1341 loc, integerType,
1342 IntegerAttr::get(integerType, strides[index.index()]));
1343 Value update =
1344 builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1345 linearizedIndex =
1346 builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1347 }
1348 return linearizedIndex;
1349}
1350
1352 MemRefType baseType, Value basePtr,
1354 OpBuilder &builder) {
1355 // Get base and offset of the MemRefType and verify they are static.
1356
1357 int64_t offset;
1359 if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1360 llvm::is_contained(strides, ShapedType::kDynamic) ||
1361 ShapedType::isDynamic(offset)) {
1362 return nullptr;
1363 }
1364
1365 auto indexType = typeConverter.getIndexType();
1366
1367 SmallVector<Value, 2> linearizedIndices;
1368 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
1369
1370 // Add a '0' at the start to index into the struct.
1371 linearizedIndices.push_back(zero);
1372
1373 if (baseType.getRank() == 0) {
1374 linearizedIndices.push_back(zero);
1375 } else {
1376 linearizedIndices.push_back(
1377 linearizeIndex(indices, strides, offset, indexType, loc, builder));
1378 }
1379 return spirv::AccessChainOp::create(builder, loc, basePtr, linearizedIndices);
1380}
1381
1383 MemRefType baseType, Value basePtr,
1385 OpBuilder &builder) {
1386 // Get base and offset of the MemRefType and verify they are static.
1387
1388 int64_t offset;
1390 if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1391 llvm::is_contained(strides, ShapedType::kDynamic) ||
1392 ShapedType::isDynamic(offset)) {
1393 return nullptr;
1394 }
1395
1396 auto indexType = typeConverter.getIndexType();
1397
1398 SmallVector<Value, 2> linearizedIndices;
1399 Value linearIndex;
1400 if (baseType.getRank() == 0) {
1401 linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
1402 } else {
1403 linearIndex =
1404 linearizeIndex(indices, strides, offset, indexType, loc, builder);
1405 }
1406 Type pointeeType =
1407 cast<spirv::PointerType>(basePtr.getType()).getPointeeType();
1408 if (isa<spirv::ArrayType>(pointeeType)) {
1409 linearizedIndices.push_back(linearIndex);
1410 return spirv::AccessChainOp::create(builder, loc, basePtr,
1411 linearizedIndices);
1412 }
1413 return spirv::PtrAccessChainOp::create(builder, loc, basePtr, linearIndex,
1414 linearizedIndices);
1415}
1416
1418 MemRefType baseType, Value basePtr,
1420 OpBuilder &builder) {
1421
1422 if (typeConverter.allows(spirv::Capability::Kernel)) {
1423 return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc,
1424 builder);
1425 }
1426
1427 return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc,
1428 builder);
1429}
1430
1431//===----------------------------------------------------------------------===//
1432// Public functions for vector unrolling
1433//===----------------------------------------------------------------------===//
1434
1436 for (int i : {4, 3, 2}) {
1437 if (size % i == 0)
1438 return i;
1439 }
1440 return 1;
1441}
1442
1445 VectorType srcVectorType = op.getSourceVectorType();
1446 assert(srcVectorType.getRank() == 1); // Guaranteed by semantics
1447 int64_t vectorSize =
1448 mlir::spirv::getComputeVectorSize(srcVectorType.getDimSize(0));
1449 return {vectorSize};
1450}
1451
1454 VectorType vectorType = op.getResultVectorType();
1455 SmallVector<int64_t> nativeSize(vectorType.getRank(), 1);
1456 nativeSize.back() =
1457 mlir::spirv::getComputeVectorSize(vectorType.getShape().back());
1458 return nativeSize;
1459}
1460
1461std::optional<SmallVector<int64_t>>
1464 if (auto vecType = dyn_cast<VectorType>(op->getResultTypes()[0])) {
1465 SmallVector<int64_t> nativeSize(vecType.getRank(), 1);
1466 nativeSize.back() =
1467 mlir::spirv::getComputeVectorSize(vecType.getShape().back());
1468 return nativeSize;
1469 }
1470 }
1471
1473 .Case<vector::ReductionOp, vector::TransposeOp>(
1474 [](auto typedOp) { return getNativeVectorShapeImpl(typedOp); })
1475 .Default(std::nullopt);
1476}
1477
1479 MLIRContext *context = op->getContext();
1480 RewritePatternSet patterns(context);
1483 // We only want to apply signature conversion once to the existing func ops.
1484 // Without specifying strictMode, the greedy pattern rewriter will keep
1485 // looking for newly created func ops.
1486 return applyPatternsGreedily(op, std::move(patterns),
1487 GreedyRewriteConfig().setStrictness(
1489}
1490
1492 MLIRContext *context = op->getContext();
1493
1494 // Unroll vectors in function bodies to native vector size.
1495 {
1496 RewritePatternSet patterns(context);
1498 [](auto op) { return mlir::spirv::getNativeVectorShape(op); });
1499 populateVectorUnrollPatterns(patterns, options);
1500 if (failed(applyPatternsGreedily(op, std::move(patterns))))
1501 return failure();
1502 }
1503
1504 // Convert transpose ops into extract and insert pairs, in preparation of
1505 // further transformations to canonicalize/cancel.
1506 {
1507 RewritePatternSet patterns(context);
1509 patterns, vector::VectorTransposeLowering::EltWise);
1511 if (failed(applyPatternsGreedily(op, std::move(patterns))))
1512 return failure();
1513 }
1514
1515 // Run canonicalization to cast away leading size-1 dimensions.
1516 {
1517 RewritePatternSet patterns(context);
1518
1519 // We need to pull in casting way leading one dims.
1520 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
1521 vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
1522 vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
1523
1524 // Decompose different rank insert_strided_slice and n-D
1525 // extract_slided_slice.
1526 vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
1527 patterns);
1528 vector::InsertOp::getCanonicalizationPatterns(patterns, context);
1529 vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
1530
1531 // Trimming leading unit dims may generate broadcast/shape_cast ops. Clean
1532 // them up.
1533 vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
1534 vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
1535
1536 if (failed(applyPatternsGreedily(op, std::move(patterns))))
1537 return failure();
1538 }
1539 return success();
1540}
1541
1542//===----------------------------------------------------------------------===//
1543// SPIR-V TypeConverter
1544//===----------------------------------------------------------------------===//
1545
1547 const SPIRVConversionOptions &options)
1548 : targetEnv(targetAttr), options(options) {
1549 // Add conversions. The order matters here: later ones will be tried earlier.
1550
1551 // Allow all SPIR-V dialect specific types. This assumes all builtin types
1552 // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
1553 // were tried before.
1554 //
1555 // TODO: This assumes that the SPIR-V types are valid to use in the given
1556 // target environment, which should be the case if the whole pipeline is
1557 // driven by the same target environment. Still, we probably still want to
1558 // validate and convert to be safe.
1559 addConversion([](spirv::SPIRVType type) { return type; });
1560
1561 addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
1562
1563 addConversion([this](IntegerType intType) -> std::optional<Type> {
1564 if (auto scalarType = dyn_cast<spirv::ScalarType>(intType))
1565 return convertScalarType(this->targetEnv, this->options, scalarType);
1566 if (intType.getWidth() < 8)
1567 return convertSubByteIntegerType(this->options, intType);
1568 return Type();
1569 });
1570
1571 addConversion([this](FloatType floatType) -> std::optional<Type> {
1572 if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
1573 return convertScalarType(this->targetEnv, this->options, scalarType);
1574 if (floatType.getWidth() == 8)
1575 return convert8BitFloatType(this->options, floatType);
1576 return Type();
1577 });
1578
1579 addConversion([this](ComplexType complexType) {
1580 return convertComplexType(this->targetEnv, this->options, complexType);
1581 });
1582
1583 addConversion([this](VectorType vectorType) {
1584 return convertVectorType(this->targetEnv, this->options, vectorType);
1585 });
1586
1587 addConversion([this](TensorType tensorType) {
1588 return convertTensorType(this->targetEnv, this->options, tensorType);
1589 });
1590
1591 addConversion([this](MemRefType memRefType) {
1592 return convertMemrefType(this->targetEnv, this->options, memRefType);
1593 });
1594
1595 // Register some last line of defense casting logic.
1596 addSourceMaterialization(
1597 [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) {
1598 return castToSourceType(this->targetEnv, builder, type, inputs, loc);
1599 });
1600 addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs,
1601 Location loc) {
1602 auto cast = UnrealizedConversionCastOp::create(builder, loc, type, inputs);
1603 return cast.getResult(0);
1604 });
1605}
1606
1608 return ::getIndexType(getContext(), options);
1609}
1610
1611MLIRContext *SPIRVTypeConverter::getContext() const {
1612 return targetEnv.getAttr().getContext();
1613}
1614
1615bool SPIRVTypeConverter::allows(spirv::Capability capability) const {
1616 return targetEnv.allows(capability);
1617}
1618
1619//===----------------------------------------------------------------------===//
1620// SPIR-V ConversionTarget
1621//===----------------------------------------------------------------------===//
1622
1623std::unique_ptr<SPIRVConversionTarget>
1625 std::unique_ptr<SPIRVConversionTarget> target(
1626 // std::make_unique does not work here because the constructor is private.
1627 new SPIRVConversionTarget(targetAttr));
1628 SPIRVConversionTarget *targetPtr = target.get();
1629 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1630 // We need to capture the raw pointer here because it is stable:
1631 // target will be destroyed once this function is returned.
1632 [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
1633 return target;
1634}
1635
1636SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
1637 : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
1638
1639bool SPIRVConversionTarget::isLegalOp(Operation *op) {
1640 // Make sure this op is available at the given version. Ops not implementing
1641 // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
1642 // SPIR-V versions.
1643 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1644 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1645 if (minVersion && *minVersion > this->targetEnv.getVersion()) {
1646 LLVM_DEBUG(llvm::dbgs()
1647 << op->getName() << " illegal: requiring min version "
1648 << spirv::stringifyVersion(*minVersion) << "\n");
1649 return false;
1650 }
1651 }
1652 if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1653 std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1654 if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
1655 LLVM_DEBUG(llvm::dbgs()
1656 << op->getName() << " illegal: requiring max version "
1657 << spirv::stringifyVersion(*maxVersion) << "\n");
1658 return false;
1659 }
1660 }
1661
1662 // Make sure this op's required extensions are allowed to use. Ops not
1663 // implementing QueryExtensionInterface do not require extensions to be
1664 // available.
1665 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1666 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1667 extensions.getExtensions())))
1668 return false;
1669
1670 // Make sure this op's required extensions are allowed to use. Ops not
1671 // implementing QueryCapabilityInterface do not require capabilities to be
1672 // available.
1673 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1674 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1675 capabilities.getCapabilities())))
1676 return false;
1677
1678 SmallVector<Type, 4> valueTypes;
1679 valueTypes.append(op->operand_type_begin(), op->operand_type_end());
1680 valueTypes.append(op->result_type_begin(), op->result_type_end());
1681
1682 // Ensure that all types have been converted to SPIRV types.
1683 if (llvm::any_of(valueTypes,
1684 [](Type t) { return !isa<spirv::SPIRVType>(t); }))
1685 return false;
1686
1687 // Special treatment for global variables, whose type requirements are
1688 // conveyed by type attributes.
1689 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1690 valueTypes.push_back(globalVar.getType());
1691
1692 // Make sure the op's operands/results use types that are allowed by the
1693 // target environment.
1694 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
1695 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
1696 for (Type valueType : valueTypes) {
1697 typeExtensions.clear();
1698 cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1699 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
1700 typeExtensions)))
1701 return false;
1702
1703 typeCapabilities.clear();
1704 cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1705 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
1706 typeCapabilities)))
1707 return false;
1708 }
1709
1710 return true;
1711}
1712
1713//===----------------------------------------------------------------------===//
1714// Public functions for populating patterns
1715//===----------------------------------------------------------------------===//
1716
1718 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
1719 patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
1720}
1721
1723 patterns.add<FuncOpVectorUnroll>(patterns.getContext());
1724}
1725
1727 patterns.add<ReturnOpVectorUnroll>(patterns.getContext());
1728}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
#define BIT_WIDTH_CASE(BIT_WIDTH)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
Block represents an ordered list of Operations.
Definition Block.h:33
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
Definition Block.h:193
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:160
OpListType & getOperations()
Definition Block.h:137
Operation & front()
Definition Block.h:153
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
Definition Block.cpp:201
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
FloatType getF32Type()
Definition Builders.cpp:43
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
MLIRContext * getContext() const
Definition Builders.h:56
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
Definition Builders.h:240
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
void setOperand(unsigned idx, Value value)
Definition Operation.h:351
operand_type_iterator operand_type_end()
Definition Operation.h:396
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
result_type_iterator result_type_end()
Definition Operation.h:427
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_type_iterator result_type_begin()
Definition Operation.h:426
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
result_type_range getResultTypes()
Definition Operation.h:428
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
operand_type_iterator operand_type_begin()
Definition Operation.h:395
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Block & front()
Definition Region.h:65
iterator begin()
Definition Region.h:55
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:76
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ArrayType get(Type elementType, unsigned elementCount)
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition SPIRVTypes.h:147
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
Definition SPIRVTypes.h:65
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
Definition SPIRVTypes.h:54
static SampledImageType get(Type imageType)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
MLIRContext * getContext() const
Returns the MLIRContext.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
int getComputeVectorSize(int64_t size)
LogicalResult unrollVectorsInSignatures(Operation *op)
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
const FrozenRewritePatternSet & patterns
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
@ ExistingOps
Only pre-existing ops are processed.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)