33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/ADT/SmallVector.h"
35 #include "llvm/ADT/StringExtras.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/LogicalResult.h"
38 #include "llvm/Support/MathExtras.h"
43 #define DEBUG_TYPE "mlir-spirv-conversion"
53 static std::optional<SmallVector<int64_t>>
getTargetShape(VectorType vecType) {
54 LLVM_DEBUG(llvm::dbgs() <<
"Get target shape\n");
55 if (vecType.isScalable()) {
56 LLVM_DEBUG(llvm::dbgs()
57 <<
"--scalable vectors are not supported -> BAIL\n");
64 LLVM_DEBUG(llvm::dbgs() <<
"--no unrolling target shape defined\n");
68 if (!maybeShapeRatio) {
69 LLVM_DEBUG(llvm::dbgs()
70 <<
"--could not compute integral shape ratio -> BAIL\n");
73 if (llvm::all_of(*maybeShapeRatio, [](int64_t v) {
return v == 1; })) {
74 LLVM_DEBUG(llvm::dbgs() <<
"--no unrolling needed -> SKIP\n");
77 LLVM_DEBUG(llvm::dbgs()
78 <<
"--found an integral shape ratio to unroll to -> SUCCESS\n");
88 template <
typename LabelT>
89 static LogicalResult checkExtensionRequirements(
92 for (
const auto &ors : candidates) {
98 for (spirv::Extension ext : ors)
99 extStrings.push_back(spirv::stringifyExtension(ext));
101 llvm::dbgs() << label <<
" illegal: requires at least one extension in ["
102 << llvm::join(extStrings,
", ")
103 <<
"] but none allowed in target environment\n";
116 template <
typename LabelT>
117 static LogicalResult checkCapabilityRequirements(
120 for (
const auto &ors : candidates) {
121 if (targetEnv.
allows(ors))
126 for (spirv::Capability cap : ors)
127 capStrings.push_back(spirv::stringifyCapability(cap));
129 llvm::dbgs() << label <<
" illegal: requires at least one capability in ["
130 << llvm::join(capStrings,
", ")
131 <<
"] but none allowed in target environment\n";
140 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
141 switch (storageClass) {
142 case spirv::StorageClass::PhysicalStorageBuffer:
143 case spirv::StorageClass::PushConstant:
144 case spirv::StorageClass::StorageBuffer:
145 case spirv::StorageClass::Uniform:
155 wrapInStructAndGetPointer(
Type elementType, spirv::StorageClass storageClass) {
156 auto structType = needsExplicitLayout(storageClass)
168 return cast<spirv::ScalarType>(
174 static std::optional<int64_t>
176 if (isa<spirv::ScalarType>(type)) {
189 if (
auto complexType = dyn_cast<ComplexType>(type)) {
190 auto elementSize = getTypeNumBytes(
options, complexType.getElementType());
193 return 2 * *elementSize;
196 if (
auto vecType = dyn_cast<VectorType>(type)) {
197 auto elementSize = getTypeNumBytes(
options, vecType.getElementType());
200 return vecType.getNumElements() * *elementSize;
203 if (
auto memRefType = dyn_cast<MemRefType>(type)) {
208 if (!memRefType.hasStaticShape() ||
215 auto elementSize = getTypeNumBytes(
options, memRefType.getElementType());
219 if (memRefType.getRank() == 0)
222 auto dims = memRefType.getShape();
223 if (llvm::is_contained(dims, ShapedType::kDynamic) ||
224 ShapedType::isDynamic(offset) ||
225 llvm::is_contained(strides, ShapedType::kDynamic))
228 int64_t memrefSize = -1;
229 for (
const auto &shape :
enumerate(dims))
230 memrefSize =
std::max(memrefSize, shape.value() * strides[shape.index()]);
232 return (offset + memrefSize) * *elementSize;
235 if (
auto tensorType = dyn_cast<TensorType>(type)) {
236 if (!tensorType.hasStaticShape())
239 auto elementSize = getTypeNumBytes(
options, tensorType.getElementType());
243 int64_t size = *elementSize;
244 for (
auto shape : tensorType.getShape())
258 std::optional<spirv::StorageClass> storageClass = {}) {
262 type.getExtensions(extensions, storageClass);
263 type.getCapabilities(capabilities, storageClass);
266 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
267 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
272 if (!
options.emulateLT32BitScalarTypes)
277 LLVM_DEBUG(llvm::dbgs()
279 <<
" not converted to 32-bit for SPIR-V to avoid truncation\n");
283 if (
auto floatType = dyn_cast<FloatType>(type)) {
284 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
288 auto intType = cast<IntegerType>(type);
289 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
291 intType.getSignedness());
303 LLVM_DEBUG(llvm::dbgs() <<
"unsupported sub-byte storage kind\n");
307 if (!llvm::isPowerOf2_32(type.getWidth())) {
308 LLVM_DEBUG(llvm::dbgs()
309 <<
"unsupported non-power-of-two bitwidth in sub-byte" << type
314 LLVM_DEBUG(llvm::dbgs() << type <<
" converted to 32-bit for SPIR-V\n");
316 type.getSignedness());
323 convertIndexElementType(ShapedType type,
325 Type indexType = dyn_cast<IndexType>(type.getElementType());
336 std::optional<spirv::StorageClass> storageClass = {}) {
337 type = cast<VectorType>(convertIndexElementType(type,
options));
338 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
342 auto intType = dyn_cast<IntegerType>(type.getElementType());
344 LLVM_DEBUG(llvm::dbgs()
346 <<
" illegal: cannot convert non-scalar element type\n");
350 Type elementType = convertSubByteIntegerType(
options, intType);
351 if (type.getRank() <= 1 && type.getNumElements() == 1)
354 if (type.getNumElements() > 4) {
355 LLVM_DEBUG(llvm::dbgs()
356 << type <<
" illegal: > 4-element unimplemented\n");
363 if (type.getRank() <= 1 && type.getNumElements() == 1)
364 return convertScalarType(targetEnv,
options, scalarType, storageClass);
367 LLVM_DEBUG(llvm::dbgs()
368 << type <<
" illegal: not a valid composite type\n");
375 cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
376 cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
379 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
380 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
384 convertScalarType(targetEnv,
options, scalarType, storageClass);
393 std::optional<spirv::StorageClass> storageClass = {}) {
394 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
396 LLVM_DEBUG(llvm::dbgs()
397 << type <<
" illegal: cannot convert non-scalar element type\n");
402 convertScalarType(targetEnv,
options, scalarType, storageClass);
405 if (elementType != type.getElementType()) {
406 LLVM_DEBUG(llvm::dbgs()
407 << type <<
" illegal: complex type emulation unsupported\n");
424 if (!type.hasStaticShape()) {
425 LLVM_DEBUG(llvm::dbgs()
426 << type <<
" illegal: dynamic shape unimplemented\n");
430 type = cast<TensorType>(convertIndexElementType(type,
options));
431 auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.
getElementType());
433 LLVM_DEBUG(llvm::dbgs()
434 << type <<
" illegal: cannot convert non-scalar element type\n");
438 std::optional<int64_t> scalarSize = getTypeNumBytes(
options, scalarType);
439 std::optional<int64_t> tensorSize = getTypeNumBytes(
options, type);
440 if (!scalarSize || !tensorSize) {
441 LLVM_DEBUG(llvm::dbgs()
442 << type <<
" illegal: cannot deduce element count\n");
446 int64_t arrayElemCount = *tensorSize / *scalarSize;
447 if (arrayElemCount == 0) {
448 LLVM_DEBUG(llvm::dbgs()
449 << type <<
" illegal: cannot handle zero-element tensors\n");
453 Type arrayElemType = convertScalarType(targetEnv,
options, scalarType);
456 std::optional<int64_t> arrayElemSize =
457 getTypeNumBytes(
options, arrayElemType);
458 if (!arrayElemSize) {
459 LLVM_DEBUG(llvm::dbgs()
460 << type <<
" illegal: cannot deduce converted element size\n");
470 spirv::StorageClass storageClass) {
471 unsigned numBoolBits =
options.boolNumBits;
472 if (numBoolBits != 8) {
473 LLVM_DEBUG(llvm::dbgs()
474 <<
"using non-8-bit storage for bool types unimplemented");
477 auto elementType = dyn_cast<spirv::ScalarType>(
482 convertScalarType(targetEnv,
options, elementType, storageClass);
485 std::optional<int64_t> arrayElemSize =
486 getTypeNumBytes(
options, arrayElemType);
487 if (!arrayElemSize) {
488 LLVM_DEBUG(llvm::dbgs()
489 << type <<
" illegal: cannot deduce converted element size\n");
493 if (!type.hasStaticShape()) {
496 if (targetEnv.
allows(spirv::Capability::Kernel))
498 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
502 return wrapInStructAndGetPointer(arrayType, storageClass);
505 if (type.getNumElements() == 0) {
506 LLVM_DEBUG(llvm::dbgs()
507 << type <<
" illegal: zero-element memrefs are not supported\n");
511 int64_t memrefSize =
llvm::divideCeil(type.getNumElements() * numBoolBits, 8);
513 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
515 if (targetEnv.
allows(spirv::Capability::Kernel))
517 return wrapInStructAndGetPointer(arrayType, storageClass);
523 spirv::StorageClass storageClass) {
524 IntegerType elementType = cast<IntegerType>(type.getElementType());
525 Type arrayElemType = convertSubByteIntegerType(
options, elementType);
528 int64_t arrayElemSize = *getTypeNumBytes(
options, arrayElemType);
530 if (!type.hasStaticShape()) {
533 if (targetEnv.
allows(spirv::Capability::Kernel))
535 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
539 return wrapInStructAndGetPointer(arrayType, storageClass);
542 if (type.getNumElements() == 0) {
543 LLVM_DEBUG(llvm::dbgs()
544 << type <<
" illegal: zero-element memrefs are not supported\n");
551 int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0;
553 if (targetEnv.
allows(spirv::Capability::Kernel))
555 return wrapInStructAndGetPointer(arrayType, storageClass);
561 auto attr = dyn_cast_or_null<spirv::StorageClassAttr>(type.getMemorySpace());
566 <<
" illegal: expected memory space to be a SPIR-V storage class "
567 "attribute; please use MemorySpaceToStorageClassConverter to map "
568 "numeric memory spaces beforehand\n");
571 spirv::StorageClass storageClass = attr.getValue();
573 if (isa<IntegerType>(type.getElementType())) {
574 if (type.getElementTypeBitWidth() == 1)
575 return convertBoolMemrefType(targetEnv,
options, type, storageClass);
576 if (type.getElementTypeBitWidth() < 8)
577 return convertSubByteMemrefType(targetEnv,
options, type, storageClass);
581 Type elementType = type.getElementType();
582 if (
auto vecType = dyn_cast<VectorType>(elementType)) {
584 convertVectorType(targetEnv,
options, vecType, storageClass);
585 }
else if (
auto complexType = dyn_cast<ComplexType>(elementType)) {
587 convertComplexType(targetEnv,
options, complexType, storageClass);
588 }
else if (
auto scalarType = dyn_cast<spirv::ScalarType>(elementType)) {
590 convertScalarType(targetEnv,
options, scalarType, storageClass);
591 }
else if (
auto indexType = dyn_cast<IndexType>(elementType)) {
592 type = cast<MemRefType>(convertIndexElementType(type,
options));
593 arrayElemType = type.getElementType();
598 <<
" unhandled: can only convert scalar or vector element type\n");
604 std::optional<int64_t> arrayElemSize =
605 getTypeNumBytes(
options, arrayElemType);
606 if (!arrayElemSize) {
607 LLVM_DEBUG(llvm::dbgs()
608 << type <<
" illegal: cannot deduce converted element size\n");
612 if (!type.hasStaticShape()) {
615 if (targetEnv.
allows(spirv::Capability::Kernel))
617 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
621 return wrapInStructAndGetPointer(arrayType, storageClass);
624 std::optional<int64_t> memrefSize = getTypeNumBytes(
options, type);
626 LLVM_DEBUG(llvm::dbgs()
627 << type <<
" illegal: cannot deduce element count\n");
631 if (*memrefSize == 0) {
632 LLVM_DEBUG(llvm::dbgs()
633 << type <<
" illegal: zero-element memrefs are not supported\n");
638 int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0;
640 if (targetEnv.
allows(spirv::Capability::Kernel))
642 return wrapInStructAndGetPointer(arrayType, storageClass);
662 static std::optional<Value> castToSourceType(
const spirv::TargetEnv &targetEnv,
666 if (inputs.size() != 1) {
667 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
670 Value input = inputs.front();
673 if (!isa<IntegerType>(type)) {
674 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
677 auto inputType = cast<IntegerType>(input.
getType());
679 auto scalarType = dyn_cast<spirv::ScalarType>(type);
681 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
688 if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) {
689 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
695 Value one = spirv::ConstantOp::getOne(inputType, loc, builder);
696 return builder.
create<spirv::IEqualOp>(loc, input, one);
702 scalarType.getExtensions(exts);
703 scalarType.getCapabilities(caps);
704 if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
705 failed(checkExtensionRequirements(type, targetEnv, exts))) {
706 auto castOp = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
714 return builder.
create<spirv::SConvertOp>(loc, type, input);
716 return builder.
create<spirv::UConvertOp>(loc, type, input);
723 static spirv::GlobalVariableOp getBuiltinVariable(
Block &body,
724 spirv::BuiltIn builtin) {
727 for (
auto varOp : body.
getOps<spirv::GlobalVariableOp>()) {
728 if (
auto builtinAttr = varOp->getAttrOfType<StringAttr>(
729 spirv::SPIRVDialect::getAttributeName(
730 spirv::Decoration::BuiltIn))) {
731 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
732 if (varBuiltIn && *varBuiltIn == builtin) {
741 std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix,
743 return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str();
747 static spirv::GlobalVariableOp
748 getOrInsertBuiltinVariable(
Block &body,
Location loc, spirv::BuiltIn builtin,
750 StringRef prefix, StringRef suffix) {
751 if (
auto varOp = getBuiltinVariable(body, builtin))
757 spirv::GlobalVariableOp newVarOp;
759 case spirv::BuiltIn::NumWorkgroups:
760 case spirv::BuiltIn::WorkgroupSize:
761 case spirv::BuiltIn::WorkgroupId:
762 case spirv::BuiltIn::LocalInvocationId:
763 case spirv::BuiltIn::GlobalInvocationId: {
765 spirv::StorageClass::Input);
766 std::string name = getBuiltinVarName(builtin, prefix, suffix);
768 builder.
create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
771 case spirv::BuiltIn::SubgroupId:
772 case spirv::BuiltIn::NumSubgroups:
773 case spirv::BuiltIn::SubgroupSize: {
776 std::string name = getBuiltinVarName(builtin, prefix, suffix);
778 builder.
create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
782 emitError(loc,
"unimplemented builtin variable generation for ")
783 << stringifyBuiltIn(builtin);
805 static spirv::GlobalVariableOp getPushConstantVariable(
Block &body,
806 unsigned elementCount) {
807 for (
auto varOp : body.
getOps<spirv::GlobalVariableOp>()) {
808 auto ptrType = dyn_cast<spirv::PointerType>(varOp.getType());
815 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
816 auto numElements = cast<spirv::ArrayType>(
817 cast<spirv::StructType>(ptrType.getPointeeType())
820 if (numElements == elementCount)
829 static spirv::GlobalVariableOp
833 if (
auto varOp = getPushConstantVariable(block, elementCount))
837 auto type = getPushConstantStorageType(elementCount, builder, indexType);
838 const char *name =
"__push_constant_var__";
839 return builder.
create<spirv::GlobalVariableOp>(loc, type, name,
853 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
855 FunctionType fnType = funcOp.getFunctionType();
856 if (fnType.getNumResults() > 1)
860 fnType.getNumInputs());
861 for (
const auto &argType :
enumerate(fnType.getInputs())) {
862 auto convertedType = getTypeConverter()->convertType(argType.value());
865 signatureConverter.
addInputs(argType.index(), convertedType);
869 if (fnType.getNumResults() == 1) {
870 resultType = getTypeConverter()->convertType(fnType.getResult(0));
876 auto newFuncOp = rewriter.
create<spirv::FuncOp>(
877 funcOp.getLoc(), funcOp.getName(),
883 for (
const auto &namedAttr : funcOp->getAttrs()) {
884 if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
886 newFuncOp->
setAttr(namedAttr.getName(), namedAttr.getValue());
892 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
904 LogicalResult matchAndRewrite(func::FuncOp funcOp,
906 FunctionType fnType = funcOp.getFunctionType();
909 if (funcOp.isDeclaration()) {
910 LLVM_DEBUG(llvm::dbgs()
911 << fnType <<
" illegal: declarations are unsupported\n");
916 auto newFuncOp = rewriter.
create<func::FuncOp>(funcOp.getLoc(),
917 funcOp.getName(), fnType);
921 Location loc = newFuncOp.getBody().getLoc();
923 Block &entryBlock = newFuncOp.getBlocks().
front();
933 size_t newInputNo = 0;
939 llvm::SmallDenseMap<Operation *, size_t> tmpOps;
942 size_t newOpCount = 0;
945 for (
auto [origInputNo, origType] :
enumerate(fnType.getInputs())) {
947 auto origVecType = dyn_cast<VectorType>(origType);
954 oneToNTypeMapping.
addInputs(origInputNo, origType);
967 oneToNTypeMapping.
addInputs(origInputNo, origType);
972 VectorType unrolledType =
975 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
979 loc, origVecType, rewriter.
getZeroAttr(origVecType));
983 loc, unrolledType, rewriter.
getZeroAttr(unrolledType));
991 result = rewriter.
create<vector::InsertStridedSliceOp>(
992 loc, dummy, result, offsets, strides);
993 newTypes.push_back(unrolledType);
994 unrolledInputNums.push_back(newInputNo);
999 oneToNTypeMapping.
addInputs(origInputNo, newTypes);
1004 auto newFnType = fnType.clone(convertedTypes, fnType.getResults());
1006 [&] { newFuncOp.setFunctionType(newFnType); });
1015 size_t unrolledInputIdx = 0;
1021 Operation *operandOp = operandVal.getDefiningOp();
1022 if (
auto it = tmpOps.find(operandOp); it != tmpOps.end()) {
1023 size_t idx = operandIdx;
1025 curOp.
setOperand(idx, newFuncOp.getArgument(it->second));
1032 if (count >= newOpCount)
1034 if (
auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1035 size_t unrolledInputNo = unrolledInputNums[unrolledInputIdx];
1037 curOp.
setOperand(0, newFuncOp.getArgument(unrolledInputNo));
1059 LogicalResult matchAndRewrite(func::ReturnOp returnOp,
1062 auto funcOp = dyn_cast<func::FuncOp>(returnOp->getParentOp());
1066 FunctionType fnType = funcOp.getFunctionType();
1074 for (
auto [origResultNo, origType] :
enumerate(fnType.getResults())) {
1076 auto origVecType = dyn_cast<VectorType>(origType);
1078 oneToNTypeMapping.
addInputs(origResultNo, origType);
1079 newOperands.push_back(returnOp.getOperand(origResultNo));
1086 oneToNTypeMapping.
addInputs(origResultNo, origType);
1087 newOperands.push_back(returnOp.getOperand(origResultNo));
1090 VectorType unrolledType =
1095 auto originalShape =
1096 llvm::to_vector_of<int64_t, 4>(origVecType.getShape());
1099 extractShape.back() = targetShape->back();
1101 Value returnValue = returnOp.getOperand(origResultNo);
1104 Value result = rewriter.
create<vector::ExtractStridedSliceOp>(
1105 loc, returnValue, offsets, extractShape, strides);
1106 if (originalShape.size() > 1) {
1109 rewriter.
create<vector::ExtractOp>(loc, result, extractIndices);
1111 newOperands.push_back(result);
1112 newTypes.push_back(unrolledType);
1114 oneToNTypeMapping.
addInputs(origResultNo, newTypes);
1122 [&] { funcOp.setFunctionType(newFnType); });
1127 rewriter.
create<func::ReturnOp>(loc, newOperands));
1140 spirv::BuiltIn builtin,
1142 StringRef prefix, StringRef suffix) {
1145 op->
emitError(
"expected operation to be within a module-like op");
1149 spirv::GlobalVariableOp varOp =
1151 builtin, integerType, builder, prefix, suffix);
1153 return builder.
create<spirv::LoadOp>(op->
getLoc(), ptr);
1161 unsigned offset,
Type integerType,
1166 op->
emitError(
"expected operation to be within a module-like op");
1170 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
1171 loc, parent->
getRegion(0).
front(), elementCount, builder, integerType);
1174 Value offsetOp = builder.
create<spirv::ConstantOp>(
1176 auto addrOp = builder.
create<spirv::AddressOfOp>(loc, varOp);
1177 auto acOp = builder.
create<spirv::AccessChainOp>(
1179 return builder.
create<spirv::LoadOp>(loc, acOp);
1187 int64_t offset,
Type integerType,
1189 assert(indices.size() == strides.size() &&
1190 "must provide indices for all dimensions");
1204 builder.
createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
1206 builder.
createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
1208 return linearizedIndex;
1212 MemRefType baseType,
Value basePtr,
1220 llvm::is_contained(strides, ShapedType::kDynamic) ||
1221 ShapedType::isDynamic(offset)) {
1231 linearizedIndices.push_back(zero);
1233 if (baseType.getRank() == 0) {
1234 linearizedIndices.push_back(zero);
1236 linearizedIndices.push_back(
1237 linearizeIndex(indices, strides, offset, indexType, loc, builder));
1239 return builder.
create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
1243 MemRefType baseType,
Value basePtr,
1251 llvm::is_contained(strides, ShapedType::kDynamic) ||
1252 ShapedType::isDynamic(offset)) {
1260 if (baseType.getRank() == 0) {
1264 linearizeIndex(indices, strides, offset, indexType, loc, builder);
1267 cast<spirv::PointerType>(basePtr.
getType()).getPointeeType();
1268 if (isa<spirv::ArrayType>(pointeeType)) {
1269 linearizedIndices.push_back(linearIndex);
1270 return builder.
create<spirv::AccessChainOp>(loc, basePtr,
1273 return builder.
create<spirv::PtrAccessChainOp>(loc, basePtr, linearIndex,
1278 MemRefType baseType,
Value basePtr,
1282 if (typeConverter.
allows(spirv::Capability::Kernel)) {
1296 for (
int i : {4, 3, 2}) {
1305 VectorType srcVectorType = op.getSourceVectorType();
1306 assert(srcVectorType.getRank() == 1);
1307 int64_t vectorSize =
1309 return {vectorSize};
1314 VectorType vectorType = op.getResultVectorType();
1321 std::optional<SmallVector<int64_t>>
1324 if (
auto vecType = dyn_cast<VectorType>(op->
getResultTypes()[0])) {
1333 .Case<vector::ReductionOp, vector::TransposeOp>(
1335 .Default([](
Operation *) {
return std::nullopt; });
1369 vector::VectorTransposeLowering::EltWise);
1382 vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
1383 vector::TransposeOp::getCanonicalizationPatterns(patterns, context);
1389 vector::InsertOp::getCanonicalizationPatterns(patterns, context);
1390 vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
1394 vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
1395 vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
1424 addConversion([
this](IntegerType intType) -> std::optional<Type> {
1425 if (
auto scalarType = dyn_cast<spirv::ScalarType>(intType))
1426 return convertScalarType(this->targetEnv, this->options, scalarType);
1427 if (intType.getWidth() < 8)
1428 return convertSubByteIntegerType(this->options, intType);
1433 if (
auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
1434 return convertScalarType(this->targetEnv, this->options, scalarType);
1439 return convertComplexType(this->targetEnv, this->options, complexType);
1443 return convertVectorType(this->targetEnv, this->options, vectorType);
1447 return convertTensorType(this->targetEnv, this->options, tensorType);
1451 return convertMemrefType(this->targetEnv, this->options, memRefType);
1457 return castToSourceType(this->targetEnv, builder, type, inputs, loc);
1461 auto cast = builder.
create<UnrealizedConversionCastOp>(loc, type, inputs);
1462 return std::optional<Value>(cast.getResult(0));
1467 return ::getIndexType(getContext(), options);
1470 MLIRContext *SPIRVTypeConverter::getContext()
const {
1471 return targetEnv.
getAttr().getContext();
1475 return targetEnv.
allows(capability);
1482 std::unique_ptr<SPIRVConversionTarget>
1484 std::unique_ptr<SPIRVConversionTarget> target(
1488 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
1491 [targetPtr](
Operation *op) {
return targetPtr->isLegalOp(op); });
1498 bool SPIRVConversionTarget::isLegalOp(
Operation *op) {
1502 if (
auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
1503 std::optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
1504 if (minVersion && *minVersion > this->targetEnv.
getVersion()) {
1505 LLVM_DEBUG(llvm::dbgs()
1506 << op->
getName() <<
" illegal: requiring min version "
1507 << spirv::stringifyVersion(*minVersion) <<
"\n");
1511 if (
auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
1512 std::optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
1513 if (maxVersion && *maxVersion < this->targetEnv.
getVersion()) {
1514 LLVM_DEBUG(llvm::dbgs()
1515 << op->
getName() <<
" illegal: requiring max version "
1516 << spirv::stringifyVersion(*maxVersion) <<
"\n");
1524 if (
auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
1525 if (failed(checkExtensionRequirements(op->
getName(), this->targetEnv,
1526 extensions.getExtensions())))
1532 if (
auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
1533 if (failed(checkCapabilityRequirements(op->
getName(), this->targetEnv,
1534 capabilities.getCapabilities())))
1542 if (llvm::any_of(valueTypes,
1543 [](
Type t) {
return !isa<spirv::SPIRVType>(t); }))
1548 if (
auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
1549 valueTypes.push_back(globalVar.getType());
1555 for (
Type valueType : valueTypes) {
1556 typeExtensions.clear();
1557 cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
1558 if (failed(checkExtensionRequirements(op->
getName(), this->targetEnv,
1562 typeCapabilities.clear();
1563 cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
1564 if (failed(checkCapabilityRequirements(op->
getName(), this->targetEnv,
1578 patterns.
add<FuncOpConversion>(typeConverter, patterns.
getContext());
static Value getZero(OpBuilder &b, Location loc, Type elementType)
Get zero value for an element type.
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static std::optional< SmallVector< int64_t > > getTargetShape(const vector::UnrollVectorOptions &options, Operation *op)
Return the target shape for unrolling for the given op.
Block represents an ordered list of Operations.
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
void eraseArguments(unsigned start, unsigned num)
Erases 'num' arguments from the index 'start'.
OpListType & getOperations()
iterator_range< op_iterator< OpT > > getOps()
Return an iterator range over the operations within this block that are of 'OpT'.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getI32IntegerAttr(int32_t value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
TypedAttr getZeroAttr(Type type)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
FailureOr< Block * > convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion=nullptr)
Apply a signature conversion to each block in the given region.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
This class allows control over how the GreedyPatternRewriteDriver works.
GreedyRewriteStrictness strictMode
Strict mode can restrict the ops that are added to the worklist during the rewrite.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Stores a 1:N mapping of types and provides several useful accessors.
TypeRange getConvertedTypes(unsigned originalTypeNo) const
Returns the list of types that corresponds to the original type at the given index.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
static OpBuilder atBlockBegin(Block *block, Listener *listener=nullptr)
Create a builder and set the insertion point to before the first operation in the block but still ins...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
operand_type_iterator operand_type_end()
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
result_type_iterator result_type_end()
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_type_iterator result_type_begin()
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
OperationName getName()
The name of an operation is the key identifier for it.
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
operand_type_iterator operand_type_begin()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
static std::unique_ptr< SPIRVConversionTarget > get(spirv::TargetEnvAttr targetAttr)
Creates a SPIR-V conversion target for the given target environment.
Type conversion from builtin types to SPIR-V types for shader interface.
Type getIndexType() const
Gets the SPIR-V correspondence for the standard index type.
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options={})
bool allows(spirv::Capability capability) const
Checks if the SPIR-V capability inquired is supported.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * getNearestSymbolTable(Operation *from)
Returns the nearest symbol table from a given operation from.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides all of the information necessary to convert a type signature.
void addInputs(unsigned origInputNo, ArrayRef< Type > types)
Remap an input of the original signature with a new set of types.
ArrayRef< Type > getConvertedTypes() const
Return the argument types for the new signature.
void addConversion(FnT &&callback)
Register a conversion function.
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal replacement value...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting an illegal (source) value...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isInteger() const
Return true if this is an integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ArrayType get(Type elementType, unsigned elementCount)
static bool isValid(VectorType)
Returns true if the given vector type is valid for the SPIR-V dialect.
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
An attribute that specifies the target version, allowed extensions and capabilities,...
A wrapper class around a spirv::TargetEnvAttr to provide query methods for allowed version/capabiliti...
Version getVersion() const
bool allows(Capability) const
Returns true if the given capability is allowed.
TargetEnvAttr getAttr() const
MLIRContext * getContext() const
Returns the MLIRContext.
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
OpFoldResult linearizeIndex(ArrayRef< OpFoldResult > multiIndex, ArrayRef< OpFoldResult > basis, ImplicitLocOpBuilder &builder)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
llvm::TypeSize divideCeil(llvm::TypeSize numerator, uint64_t denominator)
Divides the known min value of the numerator by the denominator and rounds the result up to the next ...
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix="__builtin__", StringRef suffix="__")
Returns the value for the given builtin variable.
Value getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Performs the index computation to get to the element at indices of the memory pointed to by basePtr,...
Value getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
Value getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder)
Gets the value at the given offset of the push constant storage with a total of elementCount integerT...
std::optional< SmallVector< int64_t > > getNativeVectorShape(Operation *op)
Value linearizeIndex(ValueRange indices, ArrayRef< int64_t > strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder)
Generates IR to perform index linearization with the given indices and their corresponding strides,...
LogicalResult unrollVectorsInFuncBodies(Operation *op)
Value getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder)
SmallVector< int64_t > getNativeVectorShapeImpl(vector::ReductionOp op)
int getComputeVectorSize(int64_t size)
LogicalResult unrollVectorsInSignatures(Operation *op)
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, VectorTransformsOptions options, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorInsertExtractStridedSliceDecompositionPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate patterns with the following patterns.
void populateVectorUnrollPatterns(RewritePatternSet &patterns, const UnrollVectorOptions &options, PatternBenefit benefit=1)
Collect a set of pattern to unroll vector operations to a smaller shapes.
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of leading one dimension removal patterns.
Include the generated interface declarations.
void populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns)
void populateReturnOpVectorRewritePatterns(RewritePatternSet &patterns)
@ Packed
Sub-byte values are tightly packed without any padding, e.g., 4xi2 -> i8.
void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns)
Appends to a pattern list additional patterns for translating the builtin func op to the SPIR-V diale...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
@ ExistingOps
Only pre-existing ops are processed.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Options that control the vector unrolling.
UnrollVectorOptions & setNativeShapeFn(NativeShapeFnType fn)