31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/SmallVector.h"
33 #include "llvm/ADT/StringExtras.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/MathExtras.h"
39 #define DEBUG_TYPE "mlir-spirv-conversion"
49 static std::optional<SmallVector<int64_t>>
getTargetShape(VectorType vecType) {
50 LLVM_DEBUG(llvm::dbgs() <<
"Get target shape\n");
51 if (vecType.isScalable()) {
52 LLVM_DEBUG(llvm::dbgs()
53 <<
"--scalable vectors are not supported -> BAIL\n");
60 LLVM_DEBUG(llvm::dbgs() <<
"--no unrolling target shape defined\n");
64 if (!maybeShapeRatio) {
65 LLVM_DEBUG(llvm::dbgs()
66 <<
"--could not compute integral shape ratio -> BAIL\n");
69 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) {
return v == 1; })) {
70 LLVM_DEBUG(llvm::dbgs() <<
"--no unrolling needed -> SKIP\n");
73 LLVM_DEBUG(llvm::dbgs()
74 <<
"--found an integral shape ratio to unroll to -> SUCCESS\n");
84 template <
typename LabelT>
85 static LogicalResult checkExtensionRequirements(
88 for (
const auto &ors : candidates) {
94 for (spirv::Extension ext : ors)
95 extStrings.push_back(spirv::stringifyExtension(ext));
97 llvm::dbgs() << label <<
" illegal: requires at least one extension in ["
98 << llvm::join(extStrings,
", ")
99 <<
"] but none allowed in target environment\n";
112 template <
typename LabelT>
113 static LogicalResult checkCapabilityRequirements(
116 for (
const auto &ors : candidates) {
117 if (targetEnv.
allows(ors))
122 for (spirv::Capability cap : ors)
123 capStrings.push_back(spirv::stringifyCapability(cap));
125 llvm::dbgs() << label <<
" illegal: requires at least one capability in ["
126 << llvm::join(capStrings,
", ")
127 <<
"] but none allowed in target environment\n";
136 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
137 switch (storageClass) {
138 case spirv::StorageClass::PhysicalStorageBuffer:
139 case spirv::StorageClass::PushConstant:
140 case spirv::StorageClass::StorageBuffer:
141 case spirv::StorageClass::Uniform:
151 wrapInStructAndGetPointer(
Type elementType, spirv::StorageClass storageClass) {
152 auto structType = needsExplicitLayout(storageClass)
164 return cast<spirv::ScalarType>(
170 static std::optional<int64_t>
172 if (isa<spirv::ScalarType>(type)) {
186 if (
options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
193 if (
auto complexType = dyn_cast<ComplexType>(type)) {
194 auto elementSize = getTypeNumBytes(
options, complexType.getElementType());
197 return 2 * *elementSize;
200 if (
auto vecType = dyn_cast<VectorType>(type)) {
201 auto elementSize = getTypeNumBytes(
options, vecType.getElementType());
204 return vecType.getNumElements() * *elementSize;
207 if (
auto memRefType = dyn_cast<MemRefType>(type)) {
212 if (!memRefType.hasStaticShape() ||
213 failed(memRefType.getStridesAndOffset(strides, offset)))
219 auto elementSize = getTypeNumBytes(
options, memRefType.getElementType());
223 if (memRefType.getRank() == 0)
226 auto dims = memRefType.getShape();
227 if (llvm::is_contained(dims, ShapedType::kDynamic) ||
228 ShapedType::isDynamic(offset) ||
229 llvm::is_contained(strides, ShapedType::kDynamic))
232 int64_t memrefSize = -1;
233 for (
const auto &shape :
enumerate(dims))
234 memrefSize =
std::max(memrefSize, shape.value() * strides[shape.index()]);
236 return (offset + memrefSize) * *elementSize;
239 if (
auto tensorType = dyn_cast<TensorType>(type)) {
240 if (!tensorType.hasStaticShape())
243 auto elementSize = getTypeNumBytes(
options, tensorType.getElementType());
247 int64_t size = *elementSize;
248 for (
auto shape : tensorType.getShape())
262 std::optional<spirv::StorageClass> storageClass = {}) {
266 type.getExtensions(extensions, storageClass);
267 type.getCapabilities(capabilities, storageClass);
270 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
271 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
276 if (!
options.emulateLT32BitScalarTypes)
281 LLVM_DEBUG(llvm::dbgs()
283 <<
" not converted to 32-bit for SPIR-V to avoid truncation\n");
287 if (
auto floatType = dyn_cast<FloatType>(type)) {
288 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
292 auto intType = cast<IntegerType>(type);
293 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
295 intType.getSignedness());
308 if (type.getWidth() > 8) {
309 LLVM_DEBUG(llvm::dbgs() <<
"not a subbyte type\n");
313 LLVM_DEBUG(llvm::dbgs() <<
"unsupported sub-byte storage kind\n");
317 if (!llvm::isPowerOf2_32(type.getWidth())) {
318 LLVM_DEBUG(llvm::dbgs()
319 <<
"unsupported non-power-of-two bitwidth in sub-byte" << type
324 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
326 type.getSignedness());
333 if (!
options.emulateUnsupportedFloatTypes)
336 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
337 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
338 Float8E8M0FNUType>(type))
340 LLVM_DEBUG(llvm::dbgs() <<
"unsupported 8-bit float type: " << type <<
"\n");
348 convertShaped8BitFloatType(ShapedType type,
350 if (!
options.emulateUnsupportedFloatTypes)
352 Type srcElementType = type.getElementType();
353 Type convertedElementType =
nullptr;
355 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
356 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
357 Float8E8M0FNUType>(srcElementType))
361 if (!convertedElementType)
364 return type.clone(convertedElementType);
371 convertIndexElementType(ShapedType type,
373 Type indexType = dyn_cast<IndexType>(type.getElementType());
384 std::optional<spirv::StorageClass> storageClass = {}) {
385 type = cast<VectorType>(convertIndexElementType(type,
options));
386 type = cast<VectorType>(convertShaped8BitFloatType(type,
options));
387 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
391 auto intType = dyn_cast<IntegerType>(type.getElementType());
393 LLVM_DEBUG(llvm::dbgs()
395 <<
" illegal: cannot convert non-scalar element type\n");
399 Type elementType = convertSubByteIntegerType(
options, intType);
403 if (type.getRank() <= 1 && type.getNumElements() == 1)
406 if (type.getNumElements() > 4) {
407 LLVM_DEBUG(llvm::dbgs()
408 << type <<
" illegal: > 4-element unimplemented\n");
415 if (type.getRank() <= 1 && type.getNumElements() == 1)
416 return convertScalarType(targetEnv,
options, scalarType, storageClass);
419 LLVM_DEBUG(llvm::dbgs()
420 << type <<
" illegal: not a valid composite type\n");
427 cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
428 cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
431 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
432 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
436 convertScalarType(targetEnv,
options, scalarType, storageClass);
445 std::optional<spirv::StorageClass> storageClass = {}) {
446 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
448 LLVM_DEBUG(llvm::dbgs()
449 << type <<
" illegal: cannot convert non-scalar element type\n");
454 convertScalarType(targetEnv,
options, scalarType, storageClass);
457 if (elementType != type.getElementType()) {
458 LLVM_DEBUG(llvm::dbgs()
459 << type <<
" illegal: complex type emulation unsupported\n");
476 if (!type.hasStaticShape()) {
477 LLVM_DEBUG(llvm::dbgs()
478 << type <<
" illegal: dynamic shape unimplemented\n");
482 type = cast<TensorType>(convertIndexElementType(type,
options));
483 type = cast<TensorType>(convertShaped8BitFloatType(type,
options));
484 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.
getElementType());
486 LLVM_DEBUG(llvm::dbgs()
487 << type <<
" illegal: cannot convert non-scalar element type\n");
491 std::optional<int64_t> scalarSize = getTypeNumBytes(
options, scalarType);
492 std::optional<int64_t> tensorSize = getTypeNumBytes(
options, type);
493 if (!scalarSize || !tensorSize) {
494 LLVM_DEBUG(llvm::dbgs()
495 << type <<
" illegal: cannot deduce element count\n");
499 int64_t arrayElemCount = *tensorSize / *scalarSize;
500 if (arrayElemCount == 0) {
501 LLVM_DEBUG(llvm::dbgs()
502 << type <<
" illegal: cannot handle zero-element tensors\n");
506 Type arrayElemType = convertScalarType(targetEnv,
options, scalarType);
509 std::optional<int64_t> arrayElemSize =
510 getTypeNumBytes(
options, arrayElemType);
511 if (!arrayElemSize) {
512 LLVM_DEBUG(llvm::dbgs()
513 << type <<
" illegal: cannot deduce converted element size\n");
523 spirv::StorageClass storageClass) {
524 unsigned numBoolBits =
options.boolNumBits;
525 if (numBoolBits != 8) {
526 LLVM_DEBUG(llvm::dbgs()
527 <<
"using non-8-bit storage for bool types unimplemented");
530 auto elementType = dyn_cast<spirv::ScalarType>(
535 convertScalarType(targetEnv,
options, elementType, storageClass);
538 std::optional<int64_t> arrayElemSize =
539 getTypeNumBytes(
options, arrayElemType);
540 if (!arrayElemSize) {
541 LLVM_DEBUG(llvm::dbgs()
542 << type <<
" illegal: cannot deduce converted element size\n");
546 if (!type.hasStaticShape()) {
549 if (targetEnv.
allows(spirv::Capability::Kernel))
551 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
555 return wrapInStructAndGetPointer(arrayType, storageClass);
558 if (type.getNumElements() == 0) {
559 LLVM_DEBUG(llvm::dbgs()
560 << type <<
" illegal: zero-element memrefs are not supported\n");
564 int64_t memrefSize =
llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
566 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
568 if (targetEnv.
allows(spirv::Capability::Kernel))
570 return wrapInStructAndGetPointer(arrayType, storageClass);
576 spirv::StorageClass storageClass) {
577 IntegerType elementType = cast<IntegerType>(type.getElementType());
578 Type arrayElemType = convertSubByteIntegerType(
options, elementType);
581 int64_t arrayElemSize = *getTypeNumBytes(
options, arrayElemType);
583 if (!type.hasStaticShape()) {
586 if (targetEnv.
allows(spirv::Capability::Kernel))
588 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
592 return wrapInStructAndGetPointer(arrayType, storageClass);
595 if (type.getNumElements() == 0) {
596 LLVM_DEBUG(llvm::dbgs()
597 << type <<
" illegal: zero-element memrefs are not supported\n");
604 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
606 if (targetEnv.
allows(spirv::Capability::Kernel))
608 return wrapInStructAndGetPointer(arrayType, storageClass);
611 static spirv::Dim convertRank(int64_t rank) {
614 return spirv::Dim::Dim1D;
616 return spirv::Dim::Dim2D;
618 return spirv::Dim::Dim3D;
620 llvm_unreachable(
"Invalid memref rank!");
624 static spirv::ImageFormat getImageFormat(
Type elementType) {
626 .Case<Float16Type>([](Float16Type) {
return spirv::ImageFormat::R16f; })
627 .Case<Float32Type>([](Float32Type) {
return spirv::ImageFormat::R32f; })
628 .Case<IntegerType>([](IntegerType intType) {
629 auto const isSigned = intType.isSigned() || intType.isSignless();
630 #define BIT_WIDTH_CASE(BIT_WIDTH) \
632 return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \
633 : spirv::ImageFormat::R##BIT_WIDTH##ui
635 switch (intType.getWidth()) {
639 llvm_unreachable(
"Unhandled integer type!");
643 llvm_unreachable(
"Unhandled element type!");
645 return spirv::ImageFormat::R32f;
647 #undef BIT_WIDTH_CASE
653 auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
658 <<
" illegal: expected memory space to be a SPIR-V storage class "
659 "attribute; please use MemorySpaceToStorageClassConverter to map "
660 "numeric memory spaces beforehand\n");
663 spirv::StorageClass storageClass = attr.getValue();
668 if (storageClass == spirv::StorageClass::Image) {
669 const int64_t rank = type.getRank();
670 if (rank < 1 || rank > 3) {
671 LLVM_DEBUG(llvm::dbgs()
672 << type <<
" illegal: cannot lower memref of rank " << rank
673 <<
" to a SPIR-V Image\n");
679 auto elementType = type.getElementType();
680 if (!isa<spirv::ScalarType>(elementType)) {
681 LLVM_DEBUG(llvm::dbgs() << type <<
" illegal: cannot lower memref of "
682 << elementType <<
" to a SPIR-V Image\n");
690 elementType, convertRank(rank), spirv::ImageDepthInfo::DepthUnknown,
691 spirv::ImageArrayedInfo::NonArrayed,
692 spirv::ImageSamplingInfo::SingleSampled,
693 spirv::ImageSamplerUseInfo::NeedSampler, getImageFormat(elementType));
696 spvSampledImageType, spirv::StorageClass::UniformConstant);
700 if (isa<IntegerType>(type.getElementType())) {
701 if (type.getElementTypeBitWidth() == 1)
702 return convertBoolMemrefType(targetEnv,
options, type, storageClass);
703 if (type.getElementTypeBitWidth() < 8)
704 return convertSubByteMemrefType(targetEnv,
options, type, storageClass);
708 Type elementType = type.getElementType();
709 if (
auto vecType = dyn_cast<VectorType>(elementType)) {
711 convertVectorType(targetEnv,
options, vecType, storageClass);
712 }
else if (
auto complexType = dyn_cast<ComplexType>(elementType)) {
714 convertComplexType(targetEnv,
options, complexType, storageClass);
715 }
else if (
auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
717 convertScalarType(targetEnv,
options, scalarType, storageClass);
718 }
else if (
auto indexType = dyn_cast<IndexType>(elementType)) {
719 type = cast<MemRefType>(convertIndexElementType(type,
options));
720 arrayElemType = type.getElementType();
721 }
else if (
auto floatType = dyn_cast<FloatType>(elementType)) {
723 type = cast<MemRefType>(convertShaped8BitFloatType(type,
options));
724 arrayElemType = type.getElementType();
729 <<
" unhandled: can only convert scalar or vector element type\n");
735 std::optional<int64_t> arrayElemSize =
736 getTypeNumBytes(
options, arrayElemType);
737 if (!arrayElemSize) {
738 LLVM_DEBUG(llvm::dbgs()
739 << type <<
" illegal: cannot deduce converted element size\n");
743 if (!type.hasStaticShape()) {
746 if (targetEnv.
allows(spirv::Capability::Kernel))
748 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
752 return wrapInStructAndGetPointer(arrayType, storageClass);
755 std::optional<int64_t> memrefSize = getTypeNumBytes(
options, type);
757 LLVM_DEBUG(llvm::dbgs()
758 << type <<
" illegal: cannot deduce element count\n");
762 if (*memrefSize == 0) {
763 LLVM_DEBUG(llvm::dbgs()
764 << type <<
" illegal: zero-element memrefs are not supported\n");
769 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
771 if (targetEnv.
allows(spirv::Capability::Kernel))
773 return wrapInStructAndGetPointer(arrayType, storageClass);
797 if (inputs.size() != 1) {
799 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
800 return castOp.getResult(0);
802 Value input = inputs.front();
805 if (!isa<IntegerType>(type)) {
807 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
808 return castOp.getResult(0);
810 auto inputType = cast<IntegerType>(input.
getType());
812 auto scalarType = dyn_cast<spirv::ScalarType>(type);
815 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
816 return castOp.getResult(0);
822 if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
824 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
825 return castOp.getResult(0);
830 Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
831 return spirv::IEqualOp::create(builder, loc, input, one);
837 scalarType.getExtensions(exts);
838 scalarType.getCapabilities(caps);
839 if (
failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
840 failed(checkExtensionRequirements(type, targetEnv, exts))) {
842 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
843 return castOp.getResult(0);
850 return spirv::SConvertOp::create(builder, loc, type, input);
852 return spirv::UConvertOp::create(builder, loc, type, input);
859 static spirv::GlobalVariableOp getBuiltinVariable(
Block &body,
860 spirv::BuiltIn builtin) {
863 for (
auto varOp : body.
getOps<spirv::GlobalVariableOp>()) {
864 if (
auto builtinAttr = varOp->getAttrOfType<StringAttr>(
865 spirv::SPIRVDialect::getAttributeName(
866 spirv::Decoration::BuiltIn))) {
867 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
868 if (varBuiltIn == builtin) {
877 std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
879 return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
883 static spirv::GlobalVariableOp
884 getOrInsertBuiltinVariable(
Block &body,
Location loc, spirv::BuiltIn builtin,
886 StringRef prefix, StringRef suffix) {
887 if (
auto varOp = getBuiltinVariable(body, builtin))
893 spirv::GlobalVariableOp newVarOp;
895 case spirv::BuiltIn::NumWorkgroups:
896 case spirv::BuiltIn::WorkgroupSize:
897 case spirv::BuiltIn::WorkgroupId:
898 case spirv::BuiltIn::LocalInvocationId:
899 case spirv::BuiltIn::GlobalInvocationId: {
901 spirv::StorageClass::Input);
902 std::string name = getBuiltinVarName(builtin, prefix, suffix);
904 spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
907 case spirv::BuiltIn::SubgroupId:
908 case spirv::BuiltIn::NumSubgroups:
909 case spirv::BuiltIn::SubgroupSize:
910 case spirv::BuiltIn::SubgroupLocalInvocationId: {
913 std::string name = getBuiltinVarName(builtin, prefix, suffix);
915 spirv::GlobalVariableOp::create(builder, loc, ptrType, name, builtin);
919 emitError(loc,
"unimplemented builtin variable generation for ")
920 << stringifyBuiltIn(builtin);
942 static spirv::GlobalVariableOp getPushConstantVariable(
Block &body,
943 unsigned elementCount) {
944 for (
auto varOp : body.
getOps<spirv::GlobalVariableOp>()) {
945 auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
952 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
953 auto numElements = cast<spirv::ArrayType>(
954 cast<spirv::StructType>(ptrType.getPointeeType())
957 if (numElements == elementCount)
966 static spirv::GlobalVariableOp
970 if (
auto varOp = getPushConstantVariable(block, elementCount))
974 auto type = getPushConstantStorageType(elementCount, builder, indexType);
975 const char *name =
"__push_constant_var__";
976 return spirv::GlobalVariableOp::create(builder, loc, type, name,
990 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
992 FunctionType fnType = funcOp.getFunctionType();
993 if (fnType.getNumResults() > 1)
997 fnType.getNumInputs());
998 for (
const auto &argType :
enumerate(fnType.getInputs())) {
999 auto convertedType = getTypeConverter()->convertType(argType.value());
1002 signatureConverter.
addInputs(argType.index(), convertedType);
1006 if (fnType.getNumResults() == 1) {
1007 resultType = getTypeConverter()->convertType(fnType.getResult(0));
1013 auto newFuncOp = spirv::FuncOp::create(
1014 rewriter, funcOp.getLoc(), funcOp.getName(),
1020 for (
const auto &namedAttr : funcOp->getAttrs()) {
1021 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
1023 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1029 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
1041 LogicalResult matchAndRewrite(func::FuncOp funcOp,
1043 FunctionType fnType = funcOp.getFunctionType();
1046 if (funcOp.isDeclaration()) {
1047 LLVM_DEBUG(llvm::dbgs()
1048 << fnType <<
" illegal: declarations are unsupported\n");
1053 auto newFuncOp = func::FuncOp::create(rewriter, funcOp.getLoc(),
1054 funcOp.getName(), fnType);
1058 Location loc = newFuncOp.getBody().getLoc();
1060 Block &entryBlock = newFuncOp.getBlocks().
front();
1065 fnType.getInputs().size());
1071 size_t newInputNo = 0;
1077 llvm::SmallDenseMap<Operation *, size_t> tmpOps;
1080 size_t newOpCount = 0;
1083 for (
auto [origInputNo, origType] :
enumerate(fnType.getInputs())) {
1085 auto origVecType = dyn_cast<VectorType>(origType);
1088 Value result = arith::ConstantOp::create(
1089 rewriter, loc, origType, rewriter.
getZeroAttr(origType));
1092 oneToNTypeMapping.
addInputs(origInputNo, origType);
1101 Value result = arith::ConstantOp::create(
1102 rewriter, loc, origType, rewriter.
getZeroAttr(origType));
1105 oneToNTypeMapping.
addInputs(origInputNo, origType);
1110 VectorType unrolledType =
1112 auto originalShape =
1113 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1116 Value result = arith::ConstantOp::create(
1117 rewriter, loc, origVecType, rewriter.
getZeroAttr(origVecType));
1120 Value dummy = arith::ConstantOp::create(
1121 rewriter, loc, unrolledType, rewriter.
getZeroAttr(unrolledType));
1129 result = vector::InsertStridedSliceOp::create(rewriter, loc, dummy,
1130 result, offsets, strides);
1131 newTypes.push_back(unrolledType);
1132 unrolledInputNums.push_back(newInputNo);
1137 oneToNTypeMapping.
addInputs(origInputNo, newTypes);
1142 auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1144 [&] { newFuncOp.setFunctionType(newFnType); });
1153 for (
auto &[placeholderOp, argIdx] : tmpOps) {
1156 Value replacement = newFuncOp.getArgument(argIdx);
1164 size_t unrolledInputIdx = 0;
1170 if (count >= newOpCount)
1172 if (
auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1173 size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1175 curOp.
setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1197 LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1200 auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
1204 FunctionType fnType = funcOp.getFunctionType();
1206 fnType.getResults().size());
1213 for (
auto [origResultNo, origType] :
enumerate(fnType.getResults())) {
1215 auto origVecType = dyn_cast<VectorType>(origType);
1217 oneToNTypeMapping.
addInputs(origResultNo, origType);
1218 newOperands.push_back(returnOp.getOperand(origResultNo));
1225 oneToNTypeMapping.
addInputs(origResultNo, origType);
1226 newOperands.push_back(returnOp.getOperand(origResultNo));
1229 VectorType unrolledType =
1234 auto originalShape =
1235 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1238 extractShape.back() = targetShape->back();
1240 Value returnValue = returnOp.getOperand(origResultNo);
1243 Value result = vector::ExtractStridedSliceOp::create(
1244 rewriter, loc, returnValue, offsets, extractShape, strides);
1245 if (originalShape.size() > 1) {
1248 vector::ExtractOp::create(rewriter, loc, result, extractIndices);
1250 newOperands.push_back(result);
1251 newTypes.push_back(unrolledType);
1253 oneToNTypeMapping.
addInputs(origResultNo, newTypes);
1261 [&] { funcOp.setFunctionType(newFnType); });
1266 func::ReturnOp::create(rewriter, loc, newOperands));
1279 spirv::BuiltIn builtin,
1281 StringRef prefix, StringRef suffix) {
1284 op->
emitError(
"expected operation to be within a module-like op");
1288 spirv::GlobalVariableOp varOp =
1290 builtin, integerType, builder, prefix, suffix);
1291 Value ptr = spirv::AddressOfOp::create(builder, op->
getLoc(), varOp);
1292 return spirv::LoadOp::create(builder, op->
getLoc(), ptr);
1300 unsigned offset,
Type integerType,
1305 op->
emitError(
"expected operation to be within a module-like op");
1309 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1310 loc, parent->
getRegion(0).
front(), elementCount, builder, integerType);
1313 Value offsetOp = spirv::ConstantOp::create(builder, loc, integerType,
1315 auto addrOp = spirv::AddressOfOp::create(builder, loc, varOp);
1316 auto acOp = spirv::AccessChainOp::create(builder, loc, addrOp,
1318 return spirv::LoadOp::create(builder, loc, acOp);
1326 int64_t offset,
Type integerType,
1328 assert(indices.size() == strides.size() &&
1329 "must provide indices for all dimensions");
1343 builder.
createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1345 builder.
createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1347 return linearizedIndex;
1351 MemRefType baseType,
Value basePtr,
1358 if (
failed(baseType.getStridesAndOffset(strides, offset)) ||
1359 llvm::is_contained(strides, ShapedType::kDynamic) ||
1360 ShapedType::isDynamic(offset)) {
1370 linearizedIndices.push_back(zero);
1372 if (baseType.getRank() == 0) {
1373 linearizedIndices.push_back(zero);
1375 linearizedIndices.push_back(
1376 linearizeIndex(indices, strides, offset, indexType, loc, builder));
1378 return spirv::AccessChainOp::create(builder, loc, basePtr, linearizedIndices);
1382 MemRefType baseType,
Value basePtr,
1389 if (
failed(baseType.getStridesAndOffset(strides, offset)) ||
1390 llvm::is_contained(strides, ShapedType::kDynamic) ||
1391 ShapedType::isDynamic(offset)) {
1399 if (baseType.getRank() == 0) {
1403 linearizeIndex(indices, strides, offset, indexType, loc, builder);
1406 cast<spirv::PointerType>(basePtr.
getType()).getPointeeType();
1407 if (isa<spirv::ArrayType>(pointeeType)) {
1408 linearizedIndices.push_back(linearIndex);
1409 return spirv::AccessChainOp::create(builder, loc, basePtr,
1412 return spirv::PtrAccessChainOp::create(builder, loc, basePtr, linearIndex,
1417 MemRefType baseType,
Value basePtr,
1421 if (typeConverter.
allows(spirv::Capability::Kernel)) {
1435 for (
int i : {4, 3, 2}) {
1444 VectorType srcVectorType = op.getSourceVectorType();
1445 assert(srcVectorType.getRank() == 1);
1446 int64_t vectorSize =
1448 return {vectorSize};
1453 VectorType vectorType = op.getResultVectorType();
1460 std::optional<SmallVector<int64_t>>
1463 if (
auto vecType = dyn_cast<VectorType>(op->
getResultTypes()[0])) {
1472 .Case<vector::ReductionOp, vector::TransposeOp>(
1474 .Default([](
Operation *) {
return std::nullopt; });
1508 patterns, vector::VectorTransposeLowering::EltWise);
1519 vector::populateCastAwayVectorLeadingOneDimPatterns(
patterns);
1520 vector::ReductionOp::getCanonicalizationPatterns(
patterns, context);
1521 vector::TransposeOp::getCanonicalizationPatterns(
patterns, context);
1525 vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
1527 vector::InsertOp::getCanonicalizationPatterns(
patterns, context);
1528 vector::ExtractOp::getCanonicalizationPatterns(
patterns, context);
1532 vector::BroadcastOp::getCanonicalizationPatterns(
patterns, context);
1533 vector::ShapeCastOp::getCanonicalizationPatterns(
patterns, context);
1562 addConversion([
this](IntegerType intType) -> std::optional<Type> {
1563 if (
auto scalarType = dyn_cast<spirv::ScalarType>(intType))
1564 return convertScalarType(this->targetEnv, this->options, scalarType);
1565 if (intType.getWidth() < 8)
1566 return convertSubByteIntegerType(this->options, intType);
1570 addConversion([
this](FloatType floatType) -> std::optional<Type> {
1571 if (
auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
1572 return convertScalarType(this->targetEnv, this->options, scalarType);
1573 if (floatType.getWidth() == 8)
1574 return convert8BitFloatType(this->options, floatType);
1579 return convertComplexType(this->targetEnv, this->options, complexType);
1583 return convertVectorType(this->targetEnv, this->options, vectorType);
1587 return convertTensorType(this->targetEnv, this->options, tensorType);
1591 return convertMemrefType(this->targetEnv, this->options, memRefType);
1597 return castToSourceType(this->targetEnv, builder, type, inputs, loc);
1601 auto cast = UnrealizedConversionCastOp::create(builder, loc, type, inputs);
1602 return cast.getResult(0);
1607 return ::getIndexType(getContext(), options);
1610 MLIRContext *SPIRVTypeConverter::getContext()
const {
1611 return targetEnv.
getAttr().getContext();
1615 return targetEnv.
allows(capability);
1622 std::unique_ptr<SPIRVConversionTarget>
1624 std::unique_ptr<SPIRVConversionTarget> target(
1628 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1631 [targetPtr](
Operation *op) {
return targetPtr->isLegalOp(op); });
1638 bool SPIRVConversionTarget::isLegalOp(
Operation *op) {
1642 if (
auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1643 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1644 if (minVersion && *minVersion > this->targetEnv.
getVersion()) {
1645 LLVM_DEBUG(llvm::dbgs()
1646 << op->
getName() <<
" illegal: requiring min version "
1647 << spirv::stringifyVersion(*minVersion) <<
"\n");
1651 if (
auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1652 std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1653 if (maxVersion && *maxVersion < this->targetEnv.
getVersion()) {
1654 LLVM_DEBUG(llvm::dbgs()
1655 << op->
getName() <<
" illegal: requiring max version "
1656 << spirv::stringifyVersion(*maxVersion) <<
"\n");
1664 if (
auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1665 if (
failed(checkExtensionRequirements(op->
getName(), this->targetEnv,
1666 extensions.getExtensions())))
1672 if (
auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1673 if (
failed(checkCapabilityRequirements(op->
getName(), this->targetEnv,
1674 capabilities.getCapabilities())))
1682 if (llvm::any_of(valueTypes,
1683 [](
Type t) {
return !isa<spirv::SPIRVType>(t); }))
1688 if (
auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1689 valueTypes.push_back(globalVar.getType());
1695 for (
Type valueType : valueTypes) {
1696 typeExtensions.clear();
1697 cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1698 if (
failed(checkExtensionRequirements(op->
getName(), this->targetEnv,
1702 typeCapabilities.clear();
1703 cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1704 if (
failed(checkCapabilityRequirements(op->
getName(), this->targetEnv,
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
#define BIT_WIDTH_CASE(BIT_WIDTH)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
Block represents an ordered list of Operations.
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
OpListType & getOperations()
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
operand_type_iterator operand_type_end()
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
result_type_iterator result_type_end()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_type_iterator result_type_begin()
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
unsigned getNumResults()
Return the number of results held by this operation.
operand_type_iterator operand_type_begin()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ArrayType get(Type elementType, unsigned elementCount)
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
MLIRContext * getContext() const
Returns the MLIRContext.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
OpFoldResult linearizeIndex(ArrayRef< OpFoldResult > multiIndex, ArrayRef< OpFoldResult > basis, ImplicitLocOpBuilder &builder)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
int getComputeVectorSize(int64_t size)
LogicalResult unrollVectorsInSignatures(Operation *op)
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
@ ExistingOps
Only pre-existing ops are processed.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)