31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/SmallVector.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/MathExtras.h"
39#define DEBUG_TYPE "mlir-spirv-conversion"
49static std::optional<SmallVector<int64_t>>
getTargetShape(VectorType vecType) {
50 LLVM_DEBUG(llvm::dbgs() <<
"Get target shape\n");
51 if (vecType.isScalable()) {
52 LLVM_DEBUG(llvm::dbgs()
53 <<
"--scalable vectors are not supported -> BAIL\n");
60 LLVM_DEBUG(llvm::dbgs() <<
"--no unrolling target shape defined\n");
64 if (!maybeShapeRatio) {
65 LLVM_DEBUG(llvm::dbgs()
66 <<
"--could not compute integral shape ratio -> BAIL\n");
69 if (llvm::all_of(*maybeShapeRatio, [](
int64_t v) {
return v == 1; })) {
70 LLVM_DEBUG(llvm::dbgs() <<
"--no unrolling needed -> SKIP\n");
73 LLVM_DEBUG(llvm::dbgs()
74 <<
"--found an integral shape ratio to unroll to -> SUCCESS\n");
84template <
typename LabelT>
85static LogicalResult checkExtensionRequirements(
88 for (
const auto &ors : candidates) {
94 for (spirv::Extension ext : ors)
95 extStrings.push_back(spirv::stringifyExtension(ext));
97 llvm::dbgs() << label <<
" illegal: requires at least one extension in ["
98 << llvm::join(extStrings,
", ")
99 <<
"] but none allowed in target environment\n";
112template <
typename LabelT>
113static LogicalResult checkCapabilityRequirements(
116 for (
const auto &ors : candidates) {
117 if (targetEnv.
allows(ors))
122 for (spirv::Capability cap : ors)
123 capStrings.push_back(spirv::stringifyCapability(cap));
125 llvm::dbgs() << label <<
" illegal: requires at least one capability in ["
126 << llvm::join(capStrings,
", ")
127 <<
"] but none allowed in target environment\n";
136static bool needsExplicitLayout(spirv::StorageClass storageClass) {
137 switch (storageClass) {
138 case spirv::StorageClass::PhysicalStorageBuffer:
139 case spirv::StorageClass::PushConstant:
140 case spirv::StorageClass::StorageBuffer:
141 case spirv::StorageClass::Uniform:
151wrapInStructAndGetPointer(
Type elementType, spirv::StorageClass storageClass) {
152 auto structType = needsExplicitLayout(storageClass)
164 return cast<spirv::ScalarType>(
165 IntegerType::get(ctx,
options.use64bitIndex ? 64 : 32));
170static std::optional<int64_t>
172 if (isa<spirv::ScalarType>(type)) {
186 if (
options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
193 if (
auto complexType = dyn_cast<ComplexType>(type)) {
194 auto elementSize = getTypeNumBytes(
options, complexType.getElementType());
197 return 2 * *elementSize;
200 if (
auto vecType = dyn_cast<VectorType>(type)) {
201 auto elementSize = getTypeNumBytes(
options, vecType.getElementType());
204 return vecType.getNumElements() * *elementSize;
207 if (
auto memRefType = dyn_cast<MemRefType>(type)) {
212 if (!memRefType.hasStaticShape() ||
213 failed(memRefType.getStridesAndOffset(strides, offset)))
219 auto elementSize = getTypeNumBytes(
options, memRefType.getElementType());
223 if (memRefType.getRank() == 0)
226 auto dims = memRefType.getShape();
227 if (llvm::is_contained(dims, ShapedType::kDynamic) ||
228 ShapedType::isDynamic(offset) ||
229 llvm::is_contained(strides, ShapedType::kDynamic))
233 for (
const auto &
shape : enumerate(dims))
234 memrefSize = std::max(memrefSize,
shape.value() * strides[
shape.index()]);
236 return (offset + memrefSize) * *elementSize;
239 if (
auto tensorType = dyn_cast<TensorType>(type)) {
240 if (!tensorType.hasStaticShape())
243 auto elementSize = getTypeNumBytes(
options, tensorType.getElementType());
248 for (
auto shape : tensorType.getShape())
262 std::optional<spirv::StorageClass> storageClass = {}) {
266 type.getExtensions(extensions, storageClass);
267 type.getCapabilities(capabilities, storageClass);
270 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
271 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
276 if (!
options.emulateLT32BitScalarTypes)
281 LLVM_DEBUG(llvm::dbgs()
283 <<
" not converted to 32-bit for SPIR-V to avoid truncation\n");
287 if (
auto floatType = dyn_cast<FloatType>(type)) {
288 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
292 auto intType = cast<IntegerType>(type);
293 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
294 return IntegerType::get(targetEnv.
getContext(), 32,
295 intType.getSignedness());
308 if (type.getWidth() > 8) {
309 LLVM_DEBUG(llvm::dbgs() <<
"not a subbyte type\n");
313 LLVM_DEBUG(llvm::dbgs() <<
"unsupported sub-byte storage kind\n");
317 if (!llvm::isPowerOf2_32(type.getWidth())) {
318 LLVM_DEBUG(llvm::dbgs()
319 <<
"unsupported non-power-of-two bitwidth in sub-byte" << type
324 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
325 return IntegerType::get(type.getContext(), 32,
326 type.getSignedness());
333 if (!
options.emulateUnsupportedFloatTypes)
336 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
337 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
338 Float8E8M0FNUType>(type))
339 return IntegerType::get(type.getContext(), type.getWidth());
340 LLVM_DEBUG(llvm::dbgs() <<
"unsupported 8-bit float type: " << type <<
"\n");
348convertShaped8BitFloatType(ShapedType type,
350 if (!
options.emulateUnsupportedFloatTypes)
352 Type srcElementType = type.getElementType();
353 Type convertedElementType =
nullptr;
355 if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
356 Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
357 Float8E8M0FNUType>(srcElementType))
358 convertedElementType = IntegerType::get(
361 if (!convertedElementType)
364 return type.clone(convertedElementType);
371convertIndexElementType(ShapedType type,
373 Type indexType = dyn_cast<IndexType>(type.getElementType());
384 std::optional<spirv::StorageClass> storageClass = {}) {
385 type = cast<VectorType>(convertIndexElementType(type,
options));
386 type = cast<VectorType>(convertShaped8BitFloatType(type,
options));
387 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
391 auto intType = dyn_cast<IntegerType>(type.getElementType());
393 LLVM_DEBUG(llvm::dbgs()
395 <<
" illegal: cannot convert non-scalar element type\n");
399 Type elementType = convertSubByteIntegerType(
options, intType);
403 if (type.getRank() <= 1 && type.getNumElements() == 1)
406 if (type.getNumElements() > 4) {
407 LLVM_DEBUG(llvm::dbgs()
408 << type <<
" illegal: > 4-element unimplemented\n");
412 return VectorType::get(type.getShape(), elementType);
415 if (type.getRank() <= 1 && type.getNumElements() == 1)
416 return convertScalarType(targetEnv,
options, scalarType, storageClass);
419 LLVM_DEBUG(llvm::dbgs()
420 << type <<
" illegal: not a valid composite type\n");
427 cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
428 cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
431 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
432 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
436 convertScalarType(targetEnv,
options, scalarType, storageClass);
438 return VectorType::get(type.getShape(), elementType);
445 std::optional<spirv::StorageClass> storageClass = {}) {
446 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
448 LLVM_DEBUG(llvm::dbgs()
449 << type <<
" illegal: cannot convert non-scalar element type\n");
454 convertScalarType(targetEnv,
options, scalarType, storageClass);
457 if (elementType != type.getElementType()) {
458 LLVM_DEBUG(llvm::dbgs()
459 << type <<
" illegal: complex type emulation unsupported\n");
463 return VectorType::get(2, elementType);
476 if (!type.hasStaticShape()) {
477 LLVM_DEBUG(llvm::dbgs()
478 << type <<
" illegal: dynamic shape unimplemented\n");
482 type = cast<TensorType>(convertIndexElementType(type,
options));
483 type = cast<TensorType>(convertShaped8BitFloatType(type,
options));
484 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.
getElementType());
486 LLVM_DEBUG(llvm::dbgs()
487 << type <<
" illegal: cannot convert non-scalar element type\n");
491 std::optional<int64_t> scalarSize = getTypeNumBytes(
options, scalarType);
492 std::optional<int64_t> tensorSize = getTypeNumBytes(
options, type);
493 if (!scalarSize || !tensorSize) {
494 LLVM_DEBUG(llvm::dbgs()
495 << type <<
" illegal: cannot deduce element count\n");
499 int64_t arrayElemCount = *tensorSize / *scalarSize;
500 if (arrayElemCount == 0) {
501 LLVM_DEBUG(llvm::dbgs()
502 << type <<
" illegal: cannot handle zero-element tensors\n");
506 Type arrayElemType = convertScalarType(targetEnv,
options, scalarType);
509 std::optional<int64_t> arrayElemSize =
510 getTypeNumBytes(
options, arrayElemType);
511 if (!arrayElemSize) {
512 LLVM_DEBUG(llvm::dbgs()
513 << type <<
" illegal: cannot deduce converted element size\n");
523 spirv::StorageClass storageClass) {
524 unsigned numBoolBits =
options.boolNumBits;
525 if (numBoolBits != 8) {
526 LLVM_DEBUG(llvm::dbgs()
527 <<
"using non-8-bit storage for bool types unimplemented");
530 auto elementType = dyn_cast<spirv::ScalarType>(
531 IntegerType::get(type.
getContext(), numBoolBits));
535 convertScalarType(targetEnv,
options, elementType, storageClass);
538 std::optional<int64_t> arrayElemSize =
539 getTypeNumBytes(
options, arrayElemType);
540 if (!arrayElemSize) {
541 LLVM_DEBUG(llvm::dbgs()
542 << type <<
" illegal: cannot deduce converted element size\n");
546 if (!type.hasStaticShape()) {
549 if (targetEnv.
allows(spirv::Capability::Kernel))
551 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
555 return wrapInStructAndGetPointer(arrayType, storageClass);
558 if (type.getNumElements() == 0) {
559 LLVM_DEBUG(llvm::dbgs()
560 << type <<
" illegal: zero-element memrefs are not supported\n");
564 int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
565 int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
566 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
568 if (targetEnv.
allows(spirv::Capability::Kernel))
570 return wrapInStructAndGetPointer(arrayType, storageClass);
576 spirv::StorageClass storageClass) {
577 IntegerType elementType = cast<IntegerType>(type.getElementType());
578 Type arrayElemType = convertSubByteIntegerType(
options, elementType);
583 if (!type.hasStaticShape()) {
586 if (targetEnv.
allows(spirv::Capability::Kernel))
588 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
592 return wrapInStructAndGetPointer(arrayType, storageClass);
595 if (type.getNumElements() == 0) {
596 LLVM_DEBUG(llvm::dbgs()
597 << type <<
" illegal: zero-element memrefs are not supported\n");
602 llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8);
603 int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize);
604 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
606 if (targetEnv.
allows(spirv::Capability::Kernel))
608 return wrapInStructAndGetPointer(arrayType, storageClass);
611static spirv::Dim convertRank(
int64_t rank) {
614 return spirv::Dim::Dim1D;
616 return spirv::Dim::Dim2D;
618 return spirv::Dim::Dim3D;
620 llvm_unreachable(
"Invalid memref rank!");
624static spirv::ImageFormat getImageFormat(
Type elementType) {
626 .Case<Float16Type>([](Float16Type) {
return spirv::ImageFormat::R16f; })
627 .Case<Float32Type>([](Float32Type) {
return spirv::ImageFormat::R32f; })
628 .Case<IntegerType>([](IntegerType intType) {
629 auto const isSigned = intType.isSigned() || intType.isSignless();
630#define BIT_WIDTH_CASE(BIT_WIDTH) \
632 return isSigned ? spirv::ImageFormat::R##BIT_WIDTH##i \
633 : spirv::ImageFormat::R##BIT_WIDTH##ui
635 switch (intType.getWidth()) {
639 llvm_unreachable(
"Unhandled integer type!");
642 .DefaultUnreachable(
"Unhandled element type!");
649 auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
654 <<
" illegal: expected memory space to be a SPIR-V storage class "
655 "attribute; please use MemorySpaceToStorageClassConverter to map "
656 "numeric memory spaces beforehand\n");
659 spirv::StorageClass storageClass = attr.getValue();
664 if (storageClass == spirv::StorageClass::Image) {
665 const int64_t rank = type.getRank();
666 if (rank < 1 || rank > 3) {
667 LLVM_DEBUG(llvm::dbgs()
668 << type <<
" illegal: cannot lower memref of rank " << rank
669 <<
" to a SPIR-V Image\n");
675 auto elementType = type.getElementType();
676 if (!isa<spirv::ScalarType>(elementType)) {
677 LLVM_DEBUG(llvm::dbgs() << type <<
" illegal: cannot lower memref of "
678 << elementType <<
" to a SPIR-V Image\n");
686 elementType, convertRank(rank), spirv::ImageDepthInfo::DepthUnknown,
687 spirv::ImageArrayedInfo::NonArrayed,
688 spirv::ImageSamplingInfo::SingleSampled,
689 spirv::ImageSamplerUseInfo::NeedSampler, getImageFormat(elementType));
692 spvSampledImageType, spirv::StorageClass::UniformConstant);
696 if (isa<IntegerType>(type.getElementType())) {
697 if (type.getElementTypeBitWidth() == 1)
698 return convertBoolMemrefType(targetEnv,
options, type, storageClass);
699 if (type.getElementTypeBitWidth() < 8)
700 return convertSubByteMemrefType(targetEnv,
options, type, storageClass);
704 Type elementType = type.getElementType();
705 if (
auto vecType = dyn_cast<VectorType>(elementType)) {
707 convertVectorType(targetEnv,
options, vecType, storageClass);
708 }
else if (
auto complexType = dyn_cast<ComplexType>(elementType)) {
710 convertComplexType(targetEnv,
options, complexType, storageClass);
711 }
else if (
auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
713 convertScalarType(targetEnv,
options, scalarType, storageClass);
714 }
else if (
auto indexType = dyn_cast<IndexType>(elementType)) {
715 type = cast<MemRefType>(convertIndexElementType(type,
options));
716 arrayElemType = type.getElementType();
717 }
else if (
auto floatType = dyn_cast<FloatType>(elementType)) {
719 type = cast<MemRefType>(convertShaped8BitFloatType(type,
options));
720 arrayElemType = type.getElementType();
725 <<
" unhandled: can only convert scalar or vector element type\n");
731 std::optional<int64_t> arrayElemSize =
732 getTypeNumBytes(
options, arrayElemType);
733 if (!arrayElemSize) {
734 LLVM_DEBUG(llvm::dbgs()
735 << type <<
" illegal: cannot deduce converted element size\n");
739 if (!type.hasStaticShape()) {
742 if (targetEnv.
allows(spirv::Capability::Kernel))
744 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
748 return wrapInStructAndGetPointer(arrayType, storageClass);
751 std::optional<int64_t> memrefSize = getTypeNumBytes(
options, type);
753 LLVM_DEBUG(llvm::dbgs()
754 << type <<
" illegal: cannot deduce element count\n");
758 if (*memrefSize == 0) {
759 LLVM_DEBUG(llvm::dbgs()
760 << type <<
" illegal: zero-element memrefs are not supported\n");
764 int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
765 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
767 if (targetEnv.
allows(spirv::Capability::Kernel))
769 return wrapInStructAndGetPointer(arrayType, storageClass);
793 if (inputs.size() != 1) {
795 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
796 return castOp.getResult(0);
798 Value input = inputs.front();
801 if (!isa<IntegerType>(type)) {
803 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
804 return castOp.getResult(0);
806 auto inputType = cast<IntegerType>(input.
getType());
808 auto scalarType = dyn_cast<spirv::ScalarType>(type);
811 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
812 return castOp.getResult(0);
818 if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
820 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
821 return castOp.getResult(0);
826 Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
827 return spirv::IEqualOp::create(builder, loc, input, one);
833 scalarType.getExtensions(exts);
834 scalarType.getCapabilities(caps);
835 if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
836 failed(checkExtensionRequirements(type, targetEnv, exts))) {
838 UnrealizedConversionCastOp::create(builder, loc, type, inputs);
839 return castOp.getResult(0);
846 return spirv::SConvertOp::create(builder, loc, type, input);
848 return spirv::UConvertOp::create(builder, loc, type, input);
855static spirv::GlobalVariableOp getBuiltinVariable(
Block &body,
859 for (
auto varOp : body.
getOps<spirv::GlobalVariableOp>()) {
860 if (
auto builtinAttr = varOp->getAttrOfType<StringAttr>(
861 spirv::SPIRVDialect::getAttributeName(
862 spirv::Decoration::BuiltIn))) {
863 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
873std::string getBuiltinVarName(spirv::BuiltIn
builtin, StringRef prefix,
875 return Twine(prefix).concat(stringifyBuiltIn(
builtin)).concat(suffix).str();
879static spirv::GlobalVariableOp
882 StringRef prefix, StringRef suffix) {
883 if (
auto varOp = getBuiltinVariable(body,
builtin))
889 spirv::GlobalVariableOp newVarOp;
891 case spirv::BuiltIn::NumWorkgroups:
892 case spirv::BuiltIn::WorkgroupSize:
893 case spirv::BuiltIn::WorkgroupId:
894 case spirv::BuiltIn::LocalInvocationId:
895 case spirv::BuiltIn::GlobalInvocationId: {
897 spirv::StorageClass::Input);
898 std::string name = getBuiltinVarName(
builtin, prefix, suffix);
900 spirv::GlobalVariableOp::create(builder, loc, ptrType, name,
builtin);
903 case spirv::BuiltIn::SubgroupId:
904 case spirv::BuiltIn::NumSubgroups:
905 case spirv::BuiltIn::SubgroupSize:
906 case spirv::BuiltIn::SubgroupLocalInvocationId: {
909 std::string name = getBuiltinVarName(
builtin, prefix, suffix);
911 spirv::GlobalVariableOp::create(builder, loc, ptrType, name,
builtin);
915 emitError(loc,
"unimplemented builtin variable generation for ")
938static spirv::GlobalVariableOp getPushConstantVariable(
Block &body,
939 unsigned elementCount) {
940 for (
auto varOp : body.
getOps<spirv::GlobalVariableOp>()) {
941 auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
948 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
949 auto numElements = cast<spirv::ArrayType>(
950 cast<spirv::StructType>(ptrType.getPointeeType())
953 if (numElements == elementCount)
962static spirv::GlobalVariableOp
966 if (
auto varOp = getPushConstantVariable(block, elementCount))
970 auto type = getPushConstantStorageType(elementCount, builder, indexType);
971 const char *name =
"__push_constant_var__";
972 return spirv::GlobalVariableOp::create(builder, loc, type, name,
982struct FuncOpConversion final : OpConversionPattern<func::FuncOp> {
986 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
987 ConversionPatternRewriter &rewriter)
const override {
988 FunctionType fnType = funcOp.getFunctionType();
989 if (fnType.getNumResults() > 1)
992 TypeConverter::SignatureConversion signatureConverter(
993 fnType.getNumInputs());
994 for (
const auto &argType : enumerate(fnType.getInputs())) {
995 auto convertedType = getTypeConverter()->convertType(argType.value());
998 signatureConverter.addInputs(argType.index(), convertedType);
1002 if (fnType.getNumResults() == 1) {
1003 resultType = getTypeConverter()->convertType(fnType.getResult(0));
1009 auto newFuncOp = spirv::FuncOp::create(
1010 rewriter, funcOp.getLoc(), funcOp.getName(),
1011 rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
1016 for (
const auto &namedAttr : funcOp->getAttrs()) {
1017 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
1019 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
1022 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1024 if (failed(rewriter.convertRegionTypes(
1025 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
1027 rewriter.eraseOp(funcOp);
1037 LogicalResult matchAndRewrite(func::FuncOp funcOp,
1039 FunctionType fnType = funcOp.getFunctionType();
1042 if (funcOp.isDeclaration()) {
1043 LLVM_DEBUG(llvm::dbgs()
1044 << fnType <<
" illegal: declarations are unsupported\n");
1049 auto newFuncOp = func::FuncOp::create(rewriter, funcOp.getLoc(),
1050 funcOp.getName(), fnType);
1054 Location loc = newFuncOp.getBody().getLoc();
1056 Block &entryBlock = newFuncOp.getBlocks().
front();
1060 TypeConverter::SignatureConversion oneToNTypeMapping(
1061 fnType.getInputs().size());
1067 size_t newInputNo = 0;
1073 llvm::SmallDenseMap<Operation *, size_t> tmpOps;
1076 size_t newOpCount = 0;
1079 for (
auto [origInputNo, origType] : enumerate(fnType.getInputs())) {
1081 auto origVecType = dyn_cast<VectorType>(origType);
1085 rewriter, loc, origType, rewriter.
getZeroAttr(origType));
1087 tmpOps.insert({
result.getDefiningOp(), newInputNo});
1088 oneToNTypeMapping.addInputs(origInputNo, origType);
1098 rewriter, loc, origType, rewriter.
getZeroAttr(origType));
1100 tmpOps.insert({
result.getDefiningOp(), newInputNo});
1101 oneToNTypeMapping.addInputs(origInputNo, origType);
1106 VectorType unrolledType =
1107 VectorType::get(*targetShape, origVecType.getElementType());
1108 auto originalShape =
1109 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1113 rewriter, loc, origVecType, rewriter.
getZeroAttr(origVecType));
1116 Value dummy = arith::ConstantOp::create(
1117 rewriter, loc, unrolledType, rewriter.
getZeroAttr(unrolledType));
1125 result = vector::InsertStridedSliceOp::create(rewriter, loc, dummy,
1126 result, offsets, strides);
1127 newTypes.push_back(unrolledType);
1128 unrolledInputNums.push_back(newInputNo);
1133 oneToNTypeMapping.addInputs(origInputNo, newTypes);
1137 auto convertedTypes = oneToNTypeMapping.getConvertedTypes();
1138 auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1140 [&] { newFuncOp.setFunctionType(newFnType); });
1149 for (
auto &[placeholderOp, argIdx] : tmpOps) {
1160 size_t unrolledInputIdx = 0;
1161 for (
auto [count, op] : enumerate(entryBlock.
getOperations())) {
1166 if (count >= newOpCount)
1168 if (
auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1169 size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1171 curOp.
setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1193 LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1196 auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
1200 FunctionType fnType = funcOp.getFunctionType();
1201 TypeConverter::SignatureConversion oneToNTypeMapping(
1202 fnType.getResults().size());
1209 for (
auto [origResultNo, origType] : enumerate(fnType.getResults())) {
1211 auto origVecType = dyn_cast<VectorType>(origType);
1213 oneToNTypeMapping.addInputs(origResultNo, origType);
1214 newOperands.push_back(returnOp.getOperand(origResultNo));
1221 oneToNTypeMapping.addInputs(origResultNo, origType);
1222 newOperands.push_back(returnOp.getOperand(origResultNo));
1225 VectorType unrolledType =
1226 VectorType::get(*targetShape, origVecType.getElementType());
1230 auto originalShape =
1231 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1234 extractShape.back() = targetShape->back();
1236 Value returnValue = returnOp.getOperand(origResultNo);
1239 Value result = vector::ExtractStridedSliceOp::create(
1240 rewriter, loc, returnValue, offsets, extractShape, strides);
1241 if (originalShape.size() > 1) {
1244 vector::ExtractOp::create(rewriter, loc,
result, extractIndices);
1246 newOperands.push_back(
result);
1247 newTypes.push_back(unrolledType);
1249 oneToNTypeMapping.addInputs(origResultNo, newTypes);
1255 TypeRange(oneToNTypeMapping.getConvertedTypes()));
1257 [&] { funcOp.setFunctionType(newFnType); });
1262 func::ReturnOp::create(rewriter, loc, newOperands));
1277 StringRef prefix, StringRef suffix) {
1280 op->
emitError(
"expected operation to be within a module-like op");
1284 spirv::GlobalVariableOp varOp =
1286 builtin, integerType, builder, prefix, suffix);
1287 Value ptr = spirv::AddressOfOp::create(builder, op->
getLoc(), varOp);
1288 return spirv::LoadOp::create(builder, op->
getLoc(),
ptr);
1296 unsigned offset,
Type integerType,
1301 op->
emitError(
"expected operation to be within a module-like op");
1305 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1306 loc, parent->
getRegion(0).
front(), elementCount, builder, integerType);
1308 Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
1309 Value offsetOp = spirv::ConstantOp::create(builder, loc, integerType,
1311 auto addrOp = spirv::AddressOfOp::create(builder, loc, varOp);
1312 auto acOp = spirv::AccessChainOp::create(builder, loc, addrOp,
1314 return spirv::LoadOp::create(builder, loc, acOp);
1324 assert(
indices.size() == strides.size() &&
1325 "must provide indices for all dimensions");
1333 loc, integerType, IntegerAttr::get(integerType, offset));
1337 IntegerAttr::get(integerType, strides[
index.index()]));
1341 builder.
createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1343 return linearizedIndex;
1347 MemRefType baseType,
Value basePtr,
1354 if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1355 llvm::is_contained(strides, ShapedType::kDynamic) ||
1356 ShapedType::isDynamic(offset)) {
1363 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
1366 linearizedIndices.push_back(zero);
1368 if (baseType.getRank() == 0) {
1369 linearizedIndices.push_back(zero);
1371 linearizedIndices.push_back(
1374 return spirv::AccessChainOp::create(builder, loc, basePtr, linearizedIndices);
1378 MemRefType baseType,
Value basePtr,
1385 if (failed(baseType.getStridesAndOffset(strides, offset)) ||
1386 llvm::is_contained(strides, ShapedType::kDynamic) ||
1387 ShapedType::isDynamic(offset)) {
1395 if (baseType.getRank() == 0) {
1396 linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder);
1402 cast<spirv::PointerType>(basePtr.
getType()).getPointeeType();
1403 if (isa<spirv::ArrayType>(pointeeType)) {
1404 linearizedIndices.push_back(linearIndex);
1405 return spirv::AccessChainOp::create(builder, loc, basePtr,
1408 return spirv::PtrAccessChainOp::create(builder, loc, basePtr, linearIndex,
1413 MemRefType baseType,
Value basePtr,
1417 if (typeConverter.
allows(spirv::Capability::Kernel)) {
1431 for (
int i : {4, 3, 2}) {
1440 VectorType srcVectorType = op.getSourceVectorType();
1441 assert(srcVectorType.getRank() == 1);
1444 return {vectorSize};
1449 VectorType vectorType = op.getResultVectorType();
1456std::optional<SmallVector<int64_t>>
1459 if (
auto vecType = dyn_cast<VectorType>(op->
getResultTypes()[0])) {
1468 .Case<vector::ReductionOp, vector::TransposeOp>(
1470 .Default(std::nullopt);
1504 patterns, vector::VectorTransposeLowering::EltWise);
1515 vector::populateCastAwayVectorLeadingOneDimPatterns(
patterns);
1516 vector::ReductionOp::getCanonicalizationPatterns(
patterns, context);
1517 vector::TransposeOp::getCanonicalizationPatterns(
patterns, context);
1521 vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
1523 vector::InsertOp::getCanonicalizationPatterns(
patterns, context);
1524 vector::ExtractOp::getCanonicalizationPatterns(
patterns, context);
1528 vector::BroadcastOp::getCanonicalizationPatterns(
patterns, context);
1529 vector::ShapeCastOp::getCanonicalizationPatterns(
patterns, context);
1543 : targetEnv(targetAttr), options(options) {
1556 addConversion([
this](IndexType ) {
return getIndexType(); });
1558 addConversion([
this](IntegerType intType) -> std::optional<Type> {
1559 if (
auto scalarType = dyn_cast<spirv::ScalarType>(intType))
1560 return convertScalarType(this->targetEnv, this->options, scalarType);
1561 if (intType.getWidth() < 8)
1562 return convertSubByteIntegerType(this->options, intType);
1566 addConversion([
this](FloatType floatType) -> std::optional<Type> {
1567 if (
auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
1568 return convertScalarType(this->targetEnv, this->options, scalarType);
1569 if (floatType.getWidth() == 8)
1570 return convert8BitFloatType(this->options, floatType);
1574 addConversion([
this](ComplexType complexType) {
1575 return convertComplexType(this->targetEnv, this->options, complexType);
1578 addConversion([
this](VectorType vectorType) {
1579 return convertVectorType(this->targetEnv, this->options, vectorType);
1582 addConversion([
this](
TensorType tensorType) {
1583 return convertTensorType(this->targetEnv, this->options, tensorType);
1586 addConversion([
this](MemRefType memRefType) {
1587 return convertMemrefType(this->targetEnv, this->options, memRefType);
1591 addSourceMaterialization(
1593 return castToSourceType(this->targetEnv, builder, type, inputs, loc);
1597 auto cast = UnrealizedConversionCastOp::create(builder, loc, type, inputs);
1598 return cast.getResult(0);
1603 return ::getIndexType(
getContext(), options);
1606MLIRContext *SPIRVTypeConverter::getContext()
const {
1607 return targetEnv.
getAttr().getContext();
1611 return targetEnv.allows(capability);
1618std::unique_ptr<SPIRVConversionTarget>
1620 std::unique_ptr<SPIRVConversionTarget>
target(
1622 new SPIRVConversionTarget(targetAttr));
1623 SPIRVConversionTarget *targetPtr =
target.get();
1624 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1627 [targetPtr](
Operation *op) {
return targetPtr->isLegalOp(op); });
1634bool SPIRVConversionTarget::isLegalOp(
Operation *op) {
1638 if (
auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1639 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1640 if (minVersion && *minVersion > this->targetEnv.
getVersion()) {
1641 LLVM_DEBUG(llvm::dbgs()
1642 << op->
getName() <<
" illegal: requiring min version "
1643 << spirv::stringifyVersion(*minVersion) <<
"\n");
1647 if (
auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1648 std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1649 if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
1650 LLVM_DEBUG(llvm::dbgs()
1651 << op->
getName() <<
" illegal: requiring max version "
1652 << spirv::stringifyVersion(*maxVersion) <<
"\n");
1660 if (
auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1661 if (
failed(checkExtensionRequirements(op->
getName(), this->targetEnv,
1662 extensions.getExtensions())))
1668 if (
auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1669 if (
failed(checkCapabilityRequirements(op->
getName(), this->targetEnv,
1670 capabilities.getCapabilities())))
1673 SmallVector<Type, 4> valueTypes;
1678 if (llvm::any_of(valueTypes,
1679 [](Type t) {
return !isa<spirv::SPIRVType>(t); }))
1684 if (
auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1685 valueTypes.push_back(globalVar.getType());
1689 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
1690 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
1691 for (Type valueType : valueTypes) {
1692 typeExtensions.clear();
1693 cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1694 if (
failed(checkExtensionRequirements(op->
getName(), this->targetEnv,
1698 typeCapabilities.clear();
1699 cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1700 if (
failed(checkCapabilityRequirements(op->
getName(), this->targetEnv,
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
#define BIT_WIDTH_CASE(BIT_WIDTH)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
Block represents an ordered list of Operations.
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
OpListType & getOperations()
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class allows control over how the GreedyPatternRewriteDriver works.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setOperand(unsigned idx, Value value)
operand_type_iterator operand_type_end()
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
result_type_iterator result_type_end()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_type_iterator result_type_begin()
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
MLIRContext * getContext()
Return the context this operation is associated with.
unsigned getNumResults()
Return the number of results held by this operation.
operand_type_iterator operand_type_begin()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ArrayType get(Type elementType, unsigned elementCount)
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
SmallVectorImpl< ArrayRef< Capability > > CapabilityArrayRefVector
The capability requirements for each type are following the ((Capability::A OR Extension::B) AND (Cap...
SmallVectorImpl< ArrayRef< Extension > > ExtensionArrayRefVector
The extension requirements for each type are following the ((Extension::A OR Extension::B) AND (Exten...
static SampledImageType get(Type imageType)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
MLIRContext * getContext() const
Returns the MLIRContext.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
int getComputeVectorSize(int64_t size)
LogicalResult unrollVectorsInSignatures(Operation *op)
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void populateBuiltinFuncToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
const FrozenRewritePatternSet & patterns
llvm::TypeSwitch< T, ResultT > TypeSwitch
@ ExistingOps
Only pre-existing ops are processed.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)