31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/MathExtras.h"
39#define DEBUG_TYPE "mlir-spirv-conversion"
49static std::optional<SmallVector<int64_t>>
getTargetShape(VectorType vecType) {
50 LLVM_DEBUG(llvm::dbgs() <<
"Get target shape\n");
51 if (vecType.isScalable()) {
52 LLVM_DEBUG(llvm::dbgs()
53 <<
"--scalable vectors are not supported -> BAIL\n");
60 LLVM_DEBUG(llvm::dbgs() <<
"--no unrolling target shape defined\n");
64 if (!maybeShapeRatio) {
65 LLVM_DEBUG(llvm::dbgs()
66 <<
"--could not compute integral shape ratio -> BAIL\n");
69 if (llvm::all_of(*maybeShapeRatio, [](
int64_t v) {
return v == 1; })) {
70 LLVM_DEBUG(llvm::dbgs() <<
"--no unrolling needed -> SKIP\n");
73 LLVM_DEBUG(llvm::dbgs()
74 <<
"--found an integral shape ratio to unroll to -> SUCCESS\n");
84template <
typename LabelT>
85static LogicalResult checkExtensionRequirements(
88 for (
const auto &ors : candidates) {
94 for (spirv::Extension ext : ors)
95 extStrings.push_back(spirv::stringifyExtension(ext));
97 llvm::dbgs() << label <<
" illegal: requires at least one extension in ["
98 << llvm::join(extStrings,
", ")
99 <<
"] but none allowed in target environment\n";
112template <
typename LabelT>
113static LogicalResult checkCapabilityRequirements(
116 for (
const auto &ors : candidates) {
117 if (targetEnv.
allows(ors))
122 for (spirv::Capability cap : ors)
123 capStrings.push_back(spirv::stringifyCapability(cap));
125 llvm::dbgs() << label <<
" illegal: requires at least one capability in ["
126 << llvm::join(capStrings,
", ")
127 <<
"] but none allowed in target environment\n";
136static bool needsExplicitLayout(spirv::StorageClass storageClass) {
137 switch (storageClass) {
138 case spirv::StorageClass::PhysicalStorageBuffer:
139 case spirv::StorageClass::PushConstant:
140 case spirv::StorageClass::StorageBuffer:
141 case spirv::StorageClass::Uniform:
151wrapInStructAndGetPointer(
Type elementType, spirv::StorageClass storageClass) {
152 auto structType = needsExplicitLayout(storageClass)
164 return cast<spirv::ScalarType>(
165 IntegerType::get(ctx,
options.use64bitIndex ? 64 : 32));
170static std::optional<int64_t>
172 if (isa<spirv::ScalarType>(type)) {
186 if (
options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
193 if (
auto complexType = dyn_cast<ComplexType>(type)) {
194 auto elementSize = getTypeNumBytes(
options, complexType.getElementType());
197 return 2 * *elementSize;
200 if (
auto vecType = dyn_cast<VectorType>(type)) {
201 auto elementSize = getTypeNumBytes(
options, vecType.getElementType());
204 return vecType.getNumElements() * *elementSize;
207 if (
auto memRefType = dyn_cast<MemRefType>(type)) {
212 if (!memRefType.hasStaticShape() ||
213 failed(memRefType.getStridesAndOffset(strides, offset)))
219 auto elementSize = getTypeNumBytes(
options, memRefType.getElementType());
223 if (memRefType.getRank() == 0)
226 auto dims = memRefType.getShape();
227 if (llvm::is_contained(dims, ShapedType::kDynamic) ||
228 ShapedType::isDynamic(offset) ||
229 llvm::is_contained(strides, ShapedType::kDynamic))
233 for (
const auto &
shape : enumerate(dims))
234 memrefSize = std::max(memrefSize,
shape.value() * strides[
shape.index()]);
236 return (offset + memrefSize) * *elementSize;
239 if (
auto tensorType = dyn_cast<TensorType>(type)) {
240 if (!tensorType.hasStaticShape())
243 auto elementSize = getTypeNumBytes(
options, tensorType.getElementType());
248 for (
auto shape : tensorType.getShape())
262 std::optional<spirv::StorageClass> storageClass = {}) {
266 type.getExtensions(extensions, storageClass);
267 type.getCapabilities(capabilities, storageClass);
270 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
271 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
276 if (!
options.emulateLT32BitScalarTypes)
281 LLVM_DEBUG(llvm::dbgs()
283 <<
" not converted to 32-bit for SPIR-V to avoid truncation\n");
287 if (
auto floatType = dyn_cast<FloatType>(type)) {
288 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
292 auto intType = cast<IntegerType>(type);
293 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
294 return IntegerType::get(targetEnv.
getContext(), 32,
295 intType.getSignedness());
308 if (type.getWidth() > 8) {
309 LLVM_DEBUG(llvm::dbgs() <<
"not a subbyte type\n");
313 LLVM_DEBUG(llvm::dbgs() <<
"unsupported sub-byte storage kind\n");
317 if (!llvm::isPowerOf2_32(type.getWidth())) {
318 LLVM_DEBUG(llvm::dbgs()
319 <<
"unsupported non-power-of-two bitwidth in sub-byte" << type
324 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
325 return IntegerType::get(type.getContext(), 32,
326 type.getSignedness());
333 if (!
options.emulateUnsupportedFloatTypes)
336 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
337 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
338 Float8E8M0FNUType>(type))
339 return IntegerType::get(type.getContext(), type.getWidth());
340 LLVM_DEBUG(llvm::dbgs() <<
"unsupported 8-bit float type: " << type <<
"\n");
348convertShaped8BitFloatType(ShapedType type,
350 if (!
options.emulateUnsupportedFloatTypes)
352 Type srcElementType = type.getElementType();
353 Type convertedElementType =
nullptr;
355 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
356 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
357 Float8E8M0FNUType>(srcElementType))
358 convertedElementType = IntegerType::get(
361 if (!convertedElementType)
364 return type.clone(convertedElementType);
371convertIndexElementType(ShapedType type,
373 Type indexType = dyn_cast<IndexType>(type.getElementType());
384 std::optional<spirv::StorageClass> storageClass = {}) {
385 type = cast<VectorType>(convertIndexElementType(type,
options));
386 type = cast<VectorType>(convertShaped8BitFloatType(type,
options));
387 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
391 auto intType = dyn_cast<IntegerType>(type.getElementType());
393 LLVM_DEBUG(llvm::dbgs()
395 <<
" illegal: cannot convert non-scalar element type\n");
399 Type elementType = convertSubByteIntegerType(
options, intType);
403 if (type.getRank() <= 1 && type.getNumElements() == 1)
406 if (type.getNumElements() > 4) {
407 LLVM_DEBUG(llvm::dbgs()
408 << type <<
" illegal: > 4-element unimplemented\n");
412 return VectorType::get(type.getShape(), elementType);
415 if (type.getRank() <= 1 && type.getNumElements() == 1)
416 return convertScalarType(targetEnv,
options, scalarType, storageClass);
419 LLVM_DEBUG(llvm::dbgs()
420 << type <<
" illegal: not a valid composite type\n");
427 cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
428 cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
431 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
432 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
436 convertScalarType(targetEnv,
options, scalarType, storageClass);
438 return VectorType::get(type.getShape(), elementType);
445 std::optional<spirv::StorageClass> storageClass = {}) {
446 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
448 LLVM_DEBUG(llvm::dbgs()
449 << type <<
" illegal: cannot convert non-scalar element type\n");
454 convertScalarType(targetEnv,
options, scalarType, storageClass);
457 if (elementType != type.getElementType()) {
458 LLVM_DEBUG(llvm::dbgs()
459 << type <<
" illegal: complex type emulation unsupported\n");
463 return VectorType::get(2, elementType);
476 if (!type.hasStaticShape()) {
477 LLVM_DEBUG(llvm::dbgs()
478 << type <<
" illegal: dynamic shape unimplemented\n");
482 type = cast<TensorType>(convertIndexElementType(type,
options));
483 type = cast<TensorType>(convertShaped8BitFloatType(type,
options));
484 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.
getElementType());
486 LLVM_DEBUG(llvm::dbgs()
487 << type <<
" illegal: cannot convert non-scalar element type\n");
491 std::optional<int64_t> scalarSize = getTypeNumBytes(
options, scalarType);
492 std::optional<int64_t> tensorSize = getTypeNumBytes(
options, type);
493 if (!scalarSize || !tensorSize) {
494 LLVM_DEBUG(llvm::dbgs()
495 << type <<
" illegal: cannot deduce element count\n");
499 int64_t arrayElemCount = *tensorSize / *scalarSize;
500 if (arrayElemCount == 0) {
501 LLVM_DEBUG(llvm::dbgs()
502 << type <<
" illegal: cannot handle zero-element tensors\n");
505 if (arrayElemCount > std::numeric_limits<unsigned>::max()) {
506 LLVM_DEBUG(llvm::dbgs()
507 << type <<
" illegal: cannot fit tensor into target type\n");
511 Type arrayElemType = convertScalarType(targetEnv,
options, scalarType);
514 std::optional<int64_t> arrayElemSize =
515 getTypeNumBytes(
options, arrayElemType);
516 if (!arrayElemSize) {
517 LLVM_DEBUG(llvm::dbgs()
518 << type <<
" illegal: cannot deduce converted element size\n");
528 spirv::StorageClass storageClass) {
529 unsigned numBoolBits =
options.boolNumBits;
530 if (numBoolBits != 8) {
531 LLVM_DEBUG(llvm::dbgs()
532 <<
"using non-8-bit storage for bool types unimplemented");
535 auto elementType = dyn_cast<spirv::ScalarType>(
536 IntegerType::get(type.
getContext(), numBoolBits));
540 convertScalarType(targetEnv,
options, elementType, storageClass);
543 std::optional<int64_t> arrayElemSize =
544 getTypeNumBytes(
options, arrayElemType);
545 if (!arrayElemSize) {
546 LLVM_DEBUG(llvm::dbgs()
547 << type <<
" illegal: cannot deduce converted element size\n");
551 if (!type.hasStaticShape()) {
554 if (targetEnv.
allows(spirv::Capability::Kernel))
556 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
560 return wrapInStructAndGetPointer(arrayType, storageClass);
563 if (type.getNumElements() == 0) {
564 LLVM_DEBUG(llvm::dbgs()
565 << type <<
" illegal: zero-element memrefs are not supported\n");
569 int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
570 int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
571 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
573 if (targetEnv.
allows(spirv::Capability::Kernel))
575 return wrapInStructAndGetPointer(arrayType, storageClass);
581 spirv::StorageClass storageClass) {
582 IntegerType elementType = cast<IntegerType>(type.getElementType());
583 Type arrayElemType = convertSubByteIntegerType(
options, elementType);
588 if (!type.hasStaticShape()) {
591 if (targetEnv.
allows(spirv::Capability::Kernel))
593 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
597 return wrapInStructAndGetPointer(arrayType, storageClass);
600 if (type.getNumElements() == 0) {
601 LLVM_DEBUG(llvm::dbgs()
602 << type <<
" illegal: zero-element memrefs are not supported\n");
607 llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
608 int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
609 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
611 if (targetEnv.
allows(spirv::Capability::Kernel))
613 return wrapInStructAndGetPointer(arrayType, storageClass);
616static spirv::Dim convertRank(
int64_t rank) {
619 return spirv::Dim::Dim1D;
621 return spirv::Dim::Dim2D;
623 return spirv::Dim::Dim3D;
625 llvm_unreachable(
"Invalid memref rank!");
629static spirv::ImageFormat getImageFormat(
Type elementType) {
631 .Case([](Float16Type) {
return spirv::ImageFormat::R16f; })
632 .Case([](Float32Type) {
return spirv::ImageFormat::R32f; })
633 .Case([](IntegerType intType) {
634 auto const isSigned = intType.isSigned() || intType.isSignless();
635#define BIT_WIDTH_CASE(BIT_WIDTH) \
637 return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \
638 : spirv::ImageFormat::R##BIT_WIDTH##ui
640 switch (intType.getWidth()) {
644 llvm_unreachable(
"Unhandled integer type!");
647 .DefaultUnreachable(
"Unhandled element type!");
654 auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
659 <<
" illegal: expected memory space to be a SPIR-V storage class "
660 "attribute; please use MemorySpaceToStorageClassConverter to map "
661 "numeric memory spaces beforehand\n");
664 spirv::StorageClass storageClass = attr.getValue();
669 if (storageClass == spirv::StorageClass::Image) {
670 const int64_t rank = type.getRank();
671 if (rank < 1 || rank > 3) {
672 LLVM_DEBUG(llvm::dbgs()
673 << type <<
" illegal: cannot lower memref of rank " << rank
674 <<
" to a SPIR-V Image\n");
680 auto elementType = type.getElementType();
681 if (!isa<spirv::ScalarType>(elementType)) {
682 LLVM_DEBUG(llvm::dbgs() << type <<
" illegal: cannot lower memref of "
683 << elementType <<
" to a SPIR-V Image\n");
691 elementType, convertRank(rank), spirv::ImageDepthInfo::DepthUnknown,
692 spirv::ImageArrayedInfo::NonArrayed,
693 spirv::ImageSamplingInfo::SingleSampled,
694 spirv::ImageSamplerUseInfo::NeedSampler, getImageFormat(elementType));
697 spvSampledImageType, spirv::StorageClass::UniformConstant);
701 if (isa<IntegerType>(type.getElementType())) {
702 if (type.getElementTypeBitWidth() == 1)
703 return convertBoolMemrefType(targetEnv,
options, type, storageClass);
704 if (type.getElementTypeBitWidth() < 8)
705 return convertSubByteMemrefType(targetEnv,
options, type, storageClass);
709 Type elementType = type.getElementType();
710 if (
auto vecType = dyn_cast<VectorType>(elementType)) {
712 convertVectorType(targetEnv,
options, vecType, storageClass);
713 }
else if (
auto complexType = dyn_cast<ComplexType>(elementType)) {
715 convertComplexType(targetEnv,
options, complexType, storageClass);
716 }
else if (
auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
718 convertScalarType(targetEnv,
options, scalarType, storageClass);
719 }
else if (
auto indexType = dyn_cast<IndexType>(elementType)) {
720 type = cast<MemRefType>(convertIndexElementType(type,
options));
721 arrayElemType = type.getElementType();
722 }
else if (
auto floatType = dyn_cast<FloatType>(elementType)) {
724 type = cast<MemRefType>(convertShaped8BitFloatType(type,
options));
725 arrayElemType = type.getElementType();
730 <<
" unhandled: can only convert scalar or vector element type\n");
736 std::optional<int64_t> arrayElemSize =
737 getTypeNumBytes(
options, arrayElemType);
738 if (!arrayElemSize) {
739 LLVM_DEBUG(llvm::dbgs()
740 << type <<
" illegal: cannot deduce converted element size\n");
744 if (!type.hasStaticShape()) {
747 if (targetEnv.
allows(spirv::Capability::Kernel))
749 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
753 return wrapInStructAndGetPointer(arrayType, storageClass);
756 std::optional<int64_t> memrefSize = getTypeNumBytes(
options, type);
758 LLVM_DEBUG(llvm::dbgs()
759 << type <<
" illegal: cannot deduce element count\n");
763 if (*memrefSize == 0) {
764 LLVM_DEBUG(llvm::dbgs()
765 << type <<
" illegal: zero-element memrefs are not supported\n");
769 int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
770 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
772 if (targetEnv.
allows(spirv::Capability::Kernel))
774 return wrapInStructAndGetPointer(arrayType, storageClass);
798 if (inputs.size() != 1) {
800 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
801 return castOp.getResult(0);
803 Value input = inputs.front();
806 if (!isa<IntegerType>(type)) {
808 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
809 return castOp.getResult(0);
811 auto inputType = cast<IntegerType>(input.
getType());
813 auto scalarType = dyn_cast<spirv::ScalarType>(type);
816 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
817 return castOp.getResult(0);
823 if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
825 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
826 return castOp.getResult(0);
831 Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
832 return spirv::IEqualOp::create(builder, loc, input, one);
838 scalarType.getExtensions(exts);
839 scalarType.getCapabilities(caps);
840 if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
841 failed(checkExtensionRequirements(type, targetEnv, exts))) {
843 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
844 return castOp.getResult(0);
851 return spirv::SConvertOp::create(builder, loc, type, input);
853 return spirv::UConvertOp::create(builder, loc, type, input);
860static spirv::GlobalVariableOp getBuiltinVariable(
Block &body,
864 for (
auto varOp : body.
getOps<spirv::GlobalVariableOp>()) {
865 if (
auto builtinAttr = varOp->getAttrOfType<StringAttr>(
866 spirv::SPIRVDialect::getAttributeName(
867 spirv::Decoration::BuiltIn))) {
868 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
878std::string getBuiltinVarName(spirv::BuiltIn
builtin, StringRef prefix,
880 return Twine(prefix).concat(stringifyBuiltIn(
builtin)).concat(suffix).str();
884static spirv::GlobalVariableOp
887 StringRef prefix, StringRef suffix) {
888 if (
auto varOp = getBuiltinVariable(body,
builtin))
894 spirv::GlobalVariableOp newVarOp;
896 case spirv::BuiltIn::NumWorkgroups:
897 case spirv::BuiltIn::WorkgroupSize:
898 case spirv::BuiltIn::WorkgroupId:
899 case spirv::BuiltIn::LocalInvocationId:
900 case spirv::BuiltIn::GlobalInvocationId: {
902 spirv::StorageClass::Input);
903 std::string name = getBuiltinVarName(
builtin, prefix, suffix);
905 spirv::GlobalVariableOp::create(builder, loc, ptrType, name,
builtin);
908 case spirv::BuiltIn::SubgroupId:
909 case spirv::BuiltIn::NumSubgroups:
910 case spirv::BuiltIn::SubgroupSize:
911 case spirv::BuiltIn::SubgroupLocalInvocationId: {
914 std::string name = getBuiltinVarName(
builtin, prefix, suffix);
916 spirv::GlobalVariableOp::create(builder, loc, ptrType, name,
builtin);
920 emitError(loc,
"unimplemented builtin variable generation for ")
943static spirv::GlobalVariableOp getPushConstantVariable(
Block &body,
944 unsigned elementCount) {
945 for (
auto varOp : body.
getOps<spirv::GlobalVariableOp>()) {
946 auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
953 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
954 auto numElements = cast<spirv::ArrayType>(
955 cast<spirv::StructType>(ptrType.getPointeeType())
958 if (numElements == elementCount)
967static spirv::GlobalVariableOp
971 if (
auto varOp = getPushConstantVariable(block, elementCount))
975 auto type = getPushConstantStorageType(elementCount, builder, indexType);
976 const char *name =
"__push_constant_var__";
977 return spirv::GlobalVariableOp::create(builder, loc, type, name,
987struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
991 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
992 ConversionPatternRewriter &rewriter)
const override {
993 FunctionType fnType = funcOp.getFunctionType();
994 if (fnType.getNumResults() > 1)
997 TypeConverter::SignatureConversion signatureConverter(
998 fnType.getNumInputs());
999 for (
const auto &argType : enumerate(fnType.getInputs())) {
1000 auto convertedType = getTypeConverter()->convertType(argType.value());
1003 signatureConverter.addInputs(argType.index(), convertedType);
1007 if (fnType.getNumResults() == 1) {
1008 resultType = getTypeConverter()->convertType(fnType.getResult(0));
1014 auto newFuncOp = spirv::FuncOp::create(
1015 rewriter, funcOp.getLoc(), funcOp.getName(),
1016 rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
1022 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
1024 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1027 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1029 if (failed(rewriter.convertRegionTypes(
1030 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
1032 rewriter.eraseOp(funcOp);
1042 LogicalResult matchAndRewrite(func::FuncOp funcOp,
1044 FunctionType fnType = funcOp.getFunctionType();
1047 if (funcOp.isDeclaration()) {
1048 LLVM_DEBUG(llvm::dbgs()
1049 << fnType <<
" illegal: declarations are unsupported\n");
1056 if (llvm::any_of(fnType.getInputs(), [](
Type argType) {
1057 auto shapedType = dyn_cast<ShapedType>(argType);
1058 return shapedType && !shapedType.hasStaticShape();
1063 auto newFuncOp = func::FuncOp::create(rewriter, funcOp.getLoc(),
1064 funcOp.getName(), fnType);
1068 Location loc = newFuncOp.getBody().getLoc();
1070 Block &entryBlock = newFuncOp.getBlocks().
front();
1074 TypeConverter::SignatureConversion oneToNTypeMapping(
1075 fnType.getInputs().size());
1081 size_t newInputNo = 0;
1087 llvm::SmallDenseMap<Operation *, size_t> tmpOps;
1090 size_t newOpCount = 0;
1093 for (
auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
1095 auto origVecType = dyn_cast<VectorType>(origType);
1099 rewriter, loc, origType, rewriter.
getZeroAttr(origType));
1101 tmpOps.insert({
result.getDefiningOp(), newInputNo});
1102 oneToNTypeMapping.addInputs(origInputNo, origType);
1112 rewriter, loc, origType, rewriter.
getZeroAttr(origType));
1114 tmpOps.insert({
result.getDefiningOp(), newInputNo});
1115 oneToNTypeMapping.addInputs(origInputNo, origType);
1120 VectorType unrolledType =
1121 VectorType::get(*targetShape, origVecType.getElementType());
1122 auto originalShape =
1123 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1127 rewriter, loc, origVecType, rewriter.
getZeroAttr(origVecType));
1130 Value dummy = arith::ConstantOp::create(
1131 rewriter, loc, unrolledType, rewriter.
getZeroAttr(unrolledType));
1139 result = vector::InsertStridedSliceOp::create(rewriter, loc, dummy,
1140 result, offsets, strides);
1141 newTypes.push_back(unrolledType);
1142 unrolledInputNums.push_back(newInputNo);
1147 oneToNTypeMapping.addInputs(origInputNo, newTypes);
1151 auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
1152 auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1154 [&] { newFuncOp.setFunctionType(newFnType); });
1163 for (
auto &[placeholderOp, argIdx] : tmpOps) {
1174 size_t unrolledInputIdx = 0;
1175 for (
auto [count, op] : enumerate(entryBlock.
getOperations())) {
1180 if (count >= newOpCount)
1182 if (
auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1183 size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1185 curOp.
setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1207 LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1210 auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
1214 FunctionType fnType = funcOp.getFunctionType();
1215 TypeConverter::SignatureConversion oneToNTypeMapping(
1216 fnType.getResults().size());
1223 for (
auto [origResultNo, origType] : enumerate(fnType.getResults())) {
1225 auto origVecType = dyn_cast<VectorType>(origType);
1227 oneToNTypeMapping.addInputs(origResultNo, origType);
1228 newOperands.push_back(returnOp.getOperand(origResultNo));
1235 oneToNTypeMapping.addInputs(origResultNo, origType);
1236 newOperands.push_back(returnOp.getOperand(origResultNo));
1239 VectorType unrolledType =
1240 VectorType::get(*targetShape, origVecType.getElementType());
1244 auto originalShape =
1245 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1248 extractShape.back() = targetShape->back();
1250 Value returnValue = returnOp.getOperand(origResultNo);
1253 Value result = vector::ExtractStridedSliceOp::create(
1254 rewriter, loc, returnValue, offsets, extractShape, strides);
1255 if (originalShape.size() > 1) {
1258 vector::ExtractOp::create(rewriter, loc,
result, extractIndices);
1260 newOperands.push_back(
result);
1261 newTypes.push_back(unrolledType);
1263 oneToNTypeMapping.addInputs(origResultNo, newTypes);
1269 TypeRange(oneToNTypeMapping.getConvertedTypes()));
1271 [&] { funcOp.setFunctionType(newFnType); });
1276 func::ReturnOp::create(rewriter, loc, newOperands));
1291 StringRef prefix, StringRef suffix) {
1294 op->
emitError(
"expected operation to be within a module-like op");
1298 spirv::GlobalVariableOp varOp =
1300 builtin, integerType, builder, prefix, suffix);
1301 Value ptr = spirv::AddressOfOp::create(builder, op->
getLoc(), varOp);
1302 return spirv::LoadOp::create(builder, op->
getLoc(),
ptr);
1310 unsigned offset,
Type integerType,
1315 op->
emitError(
"expected operation to be within a module-like op");
1319 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1320 loc, parent->
getRegion(0).
front(), elementCount, builder, integerType);
1322 Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
1323 Value offsetOp = spirv::ConstantOp::create(builder, loc, integerType,
1325 auto addrOp = spirv::AddressOfOp::create(builder, loc, varOp);
1326 auto acOp = spirv::AccessChainOp::create(builder, loc, addrOp,
1328 return spirv::LoadOp::create(builder, loc, acOp);
1338 assert(
indices.size() == strides.size() &&
1339 "must provide indices for all dimensions");
1347 loc, integerType, IntegerAttr::get(integerType, offset));
1351 IntegerAttr::get(integerType, strides[
index.index()]));
1355 builder.
createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1357 return linearizedIndex;
1361 MemRefType baseType,
Value basePtr,
1368 if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1369 llvm::is_contained(strides, ShapedType::kDynamic) ||
1370 ShapedType::isDynamic(offset)) {
1377 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
1380 linearizedIndices.push_back(zero);
1382 if (baseType.getRank() == 0) {
1383 linearizedIndices.push_back(zero);
1385 linearizedIndices.push_back(
1388 return spirv::AccessChainOp::create(builder, loc, basePtr, linearizedIndices);
1392 MemRefType baseType,
Value basePtr,
1399 if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1400 llvm::is_contained(strides, ShapedType::kDynamic) ||
1401 ShapedType::isDynamic(offset)) {
1409 if (baseType.getRank() == 0) {
1410 linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
1416 cast<spirv::PointerType>(basePtr.
getType()).getPointeeType();
1417 if (isa<spirv::ArrayType>(pointeeType)) {
1418 linearizedIndices.push_back(linearIndex);
1419 return spirv::AccessChainOp::create(builder, loc, basePtr,
1422 return spirv::PtrAccessChainOp::create(builder, loc, basePtr, linearIndex,
1427 MemRefType baseType,
Value basePtr,
1431 if (typeConverter.
allows(spirv::Capability::Kernel)) {
1445 for (
int i : {4, 3, 2}) {
1454 VectorType srcVectorType = op.getSourceVectorType();
1455 assert(srcVectorType.getRank() == 1);
1458 return {vectorSize};
1463 VectorType vectorType = op.getResultVectorType();
1470std::optional<SmallVector<int64_t>>
1473 if (
auto vecType = dyn_cast<VectorType>(op->
getResultTypes()[0])) {
1482 .Case<vector::ReductionOp, vector::TransposeOp>(
1484 .Default(std::nullopt);
1508 populateVectorUnrollPatterns(patterns,
options);
1518 patterns, vector::VectorTransposeLowering::EltWise);
1529 vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
1530 vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
1531 vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
1535 vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
1537 vector::InsertOp::getCanonicalizationPatterns(patterns, context);
1538 vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
1542 vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
1543 vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
1557 : targetEnv(targetAttr), options(options) {
1570 addConversion([
this](IndexType ) {
return getIndexType(); });
1572 addConversion([
this](IntegerType intType) -> std::optional<Type> {
1573 if (
auto scalarType = dyn_cast<spirv::ScalarType>(intType))
1574 return convertScalarType(this->targetEnv, this->options, scalarType);
1575 if (intType.getWidth() < 8)
1576 return convertSubByteIntegerType(this->options, intType);
1580 addConversion([
this](FloatType floatType) -> std::optional<Type> {
1581 if (
auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
1582 return convertScalarType(this->targetEnv, this->options, scalarType);
1583 if (floatType.getWidth() == 8)
1584 return convert8BitFloatType(this->options, floatType);
1588 addConversion([
this](ComplexType complexType) {
1589 return convertComplexType(this->targetEnv, this->options, complexType);
1592 addConversion([
this](VectorType vectorType) {
1593 return convertVectorType(this->targetEnv, this->options, vectorType);
1596 addConversion([
this](
TensorType tensorType) {
1597 return convertTensorType(this->targetEnv, this->options, tensorType);
1600 addConversion([
this](MemRefType memRefType) {
1601 return convertMemrefType(this->targetEnv, this->options, memRefType);
1605 addSourceMaterialization(
1607 return castToSourceType(this->targetEnv, builder, type, inputs, loc);
1611 auto cast = UnrealizedConversionCastOp::create(builder, loc, type, inputs);
1612 return cast.getResult(0);
1617 return ::getIndexType(
getContext(), options);
1620MLIRContext *SPIRVTypeConverter::getContext()
const {
1621 return targetEnv.
getAttr().getContext();
1625 return targetEnv.allows(capability);
1632std::unique_ptr<SPIRVConversionTarget>
1634 std::unique_ptr<SPIRVConversionTarget>
target(
1636 new SPIRVConversionTarget(targetAttr));
1637 SPIRVConversionTarget *targetPtr =
target.get();
1638 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1641 [targetPtr](
Operation *op) {
return targetPtr->isLegalOp(op); });
1648bool SPIRVConversionTarget::isLegalOp(
Operation *op) {
1652 if (
auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1653 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1654 if (minVersion && *minVersion > this->targetEnv.
getVersion()) {
1655 LLVM_DEBUG(llvm::dbgs()
1656 << op->
getName() <<
" illegal: requiring min version "
1657 << spirv::stringifyVersion(*minVersion) <<
"\n");
1661 if (
auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1662 std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1663 if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
1664 LLVM_DEBUG(llvm::dbgs()
1665 << op->
getName() <<
" illegal: requiring max version "
1666 << spirv::stringifyVersion(*maxVersion) <<
"\n");
1674 if (
auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1675 if (
failed(checkExtensionRequirements(op->
getName(), this->targetEnv,
1676 extensions.getExtensions())))
1682 if (
auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1683 if (
failed(checkCapabilityRequirements(op->
getName(), this->targetEnv,
1684 capabilities.getCapabilities())))
1687 SmallVector<Type, 4> valueTypes;
1692 if (llvm::any_of(valueTypes,
1693 [](Type t) {
return !isa<spirv::SPIRVType>(t); }))
1698 if (
auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1699 valueTypes.push_back(globalVar.getType());
1703 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
1704 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
1705 for (Type valueType : valueTypes) {
1706 typeExtensions.clear();
1707 cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1708 if (
failed(checkExtensionRequirements(op->
getName(), this->targetEnv,
1712 typeCapabilities.clear();
1713 cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1714 if (
failed(checkCapabilityRequirements(op->
getName(), this->targetEnv,
1728 patterns.
add<FuncOpConversion>(typeConverter, patterns.
getContext());
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
#define BIT_WIDTH_CASE(BIT_WIDTH)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
Block represents an ordered list of Operations.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
OpListType & getOperations()
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setOperand(unsigned idx, Value value)
operand_type_iterator operand_type_end()
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
result_type_iterator result_type_end()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_type_iterator result_type_begin()
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumResults()
Return the number of results held by this operation.
operand_type_iterator operand_type_begin()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ArrayType get(Type elementType, unsigned elementCount)
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
static SampledImageType get(Type imageType)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
MLIRContext * getContext() const
Returns the MLIRContext.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
int getComputeVectorSize(int64_t size)
LogicalResult unrollVectorsInSignatures(Operation *op)
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
llvm::TypeSwitch< T, ResultT > TypeSwitch
@ ExistingOps
Only pre-existing ops are processed.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)